-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Enhance Distributed Data Parallel wrapper with ZeRO (#207)
## Title - [Feature] Enhance Distributed Data Parallel wrapper with ZeRO ## Description This PR introduces a significant enhancement to the Distributed Data Parallel (DDP) wrapper in the OSLO library. It adds the capability to utilize different sharding strategies during distributed training, allowing better control over memory footprint across multiple GPUs or machines. Here are the key changes: 1. **Sharding Strategy Enum**: A new enum `ShardingStrategy` is introduced, which includes `SHARD_OP`, `SHARD_GRAD_OP`, and `FULL_SHARD` options. These determine how model parameters and optimizer states are partitioned across GPUs. 2. **DistributedDataParallel Changes**: The signature of the `DistributedDataParallel` function is updated to incorporate these new sharding strategies, along with additional configuration parameters for the model and optimizer wrappers. This function now returns a tuple of the wrapped module and optimizer (if applicable). These changes will enable users to utilize various sharding strategies during their distributed training tasks, providing better control and optimization of memory usage across GPUs. **Note**: For `SHARD_OP`, `SHARD_GRAD_OP`, and `FULL_SHARD` strategies, it is mandatory to provide an optimizer. ## Linked Issues - N/A
- Loading branch information
Showing
4 changed files
with
128 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,14 @@ | ||
from oslo.torch.nn.parallel.data_parallel.data_parallel import ( | ||
DistributedDataParallel, | ||
ShardingStrategy, | ||
) | ||
from oslo.torch.nn.parallel.data_parallel.zero import ZeroRedundancyOptimizer | ||
|
||
from oslo.torch.nn.parallel.data_parallel._utils import set_params_to_ignore | ||
|
||
__ALL__ = ["DistributedDataParallel", "ZeroRedundancyOptimizer", "set_params_to_ignore"] | ||
__ALL__ = [ | ||
"DistributedDataParallel", | ||
"ZeroRedundancyOptimizer", | ||
"set_params_to_ignore", | ||
"ShardingStrategy", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters