Skip to content

Commit

Permalink
[Feature] Enhance Distributed Data Parallel wrapper with ZeRO (#207)
Browse files Browse the repository at this point in the history
## Title

- [Feature] Enhance Distributed Data Parallel wrapper with ZeRO

## Description

This PR introduces a significant enhancement to the Distributed Data
Parallel (DDP) wrapper in the OSLO library. It adds the capability to
utilize different sharding strategies during distributed training,
allowing better control over memory footprint across multiple GPUs or
machines.

Here are the key changes:

1. **Sharding Strategy Enum**: A new enum `ShardingStrategy` is
introduced, which includes `SHARD_OP`, `SHARD_GRAD_OP`, and `FULL_SHARD`
options. These determine how model parameters and optimizer states are
partitioned across GPUs.

2. **DistributedDataParallel Changes**: The signature of the
`DistributedDataParallel` function is updated to incorporate these new
sharding strategies, along with additional configuration parameters for
the model and optimizer wrappers. This function now returns a tuple of
the wrapped module and optimizer (if applicable).

These changes will enable users to utilize various sharding strategies
during their distributed training tasks, providing better control and
optimization of memory usage across GPUs.

**Note**: For `SHARD_OP`, `SHARD_GRAD_OP`, and `FULL_SHARD` strategies,
it is mandatory to provide an optimizer.

## Linked Issues

- N/A
  • Loading branch information
yhna940 committed Jul 12, 2023
1 parent 5e14dd6 commit f16c73b
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 26 deletions.
8 changes: 7 additions & 1 deletion oslo/torch/nn/parallel/data_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from oslo.torch.nn.parallel.data_parallel.data_parallel import (
DistributedDataParallel,
ShardingStrategy,
)
from oslo.torch.nn.parallel.data_parallel.zero import ZeroRedundancyOptimizer

from oslo.torch.nn.parallel.data_parallel._utils import set_params_to_ignore

__ALL__ = ["DistributedDataParallel", "ZeroRedundancyOptimizer", "set_params_to_ignore"]
__ALL__ = [
"DistributedDataParallel",
"ZeroRedundancyOptimizer",
"set_params_to_ignore",
"ShardingStrategy",
]
130 changes: 113 additions & 17 deletions oslo/torch/nn/parallel/data_parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,123 @@
DistributedBackwardFunction,
)

from enum import Enum, auto
from typing import Optional, Tuple, Dict, Any


class ShardingStrategy(Enum):
SHARD_OP = auto()
SHARD_GRAD_OP = auto()
FULL_SHARD = auto()


def DistributedDataParallel(
module: nn.Module,
parallel_context: ParallelContext,
bucket_cap_mb: int = 25,
rebuild_bucket: bool = True,
):
ddp = _DistributedDataParallel(
module=module,
parallel_context=parallel_context,
bucket_cap_mb=bucket_cap_mb,
rebuild_bucket=rebuild_bucket,
)

add_wrapper(
module,
mode=ParallelMode.DATA,
wrapper=ddp,
parallel_context=parallel_context,
)
return module
model_wrapper_config: Optional[Dict[str, Any]] = None,
optimizer_wrapper_config: Optional[Dict[str, Any]] = None,
sharding_strategy: Optional[ShardingStrategy] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
) -> Tuple[nn.Module, Optional[torch.optim.Optimizer]]:
"""
This function wraps a PyTorch module with a distributed data parallel wrapper for OSLO.
This wrapper allows the module to be trained across multiple GPUs or machines, with optional sharding strategies
to reduce memory footprint. The function supports different sharding strategies that determines how the model
parameters and optimizer states are partitioned across the GPUs.
Supported sharding strategies are:
- None: No sharding is used. This is the default strategy, where each GPU maintains a full replica of the model.
- SHARD_OP: The optimizer states are sharded across GPUs. Each GPU maintains only a portion of the optimizer state.
- SHARD_GRAD_OP: In addition to sharding the optimizer states, the gradients are also sharded across GPUs.
- FULL_SHARD: The model parameters, optimizer states, and gradients are all sharded across GPUs.
For the SHARD_OP, SHARD_GRAD_OP, and FULL_SHARD strategies, it is mandatory to provide an optimizer.
Args:
module (nn.Module): PyTorch module object to be wrapped.
parallel_context (ParallelContext): Process group object for distributed training.
model_wrapper_config (Optional[Dict[str, Any]]): Additional configuration parameters for the model wrapper.
optimizer_wrapper_config (Optional[Dict[str, Any]]): Additional configuration parameters for the optimizer wrapper.
sharding_strategy (Optional[ShardingStrategy]): The strategy for sharding. Options include None, SHARD_OP, SHARD_GRAD_OP, and FULL_SHARD.
optimizer (Optional[torch.optim.Optimizer]): PyTorch optimizer object to be wrapped if a sharding strategy is specified.
Returns:
Tuple[nn.Module, Optional[torch.optim.Optimizer]]: The wrapped module and optimizer (if applicable).
Raises:
AssertionError: If a sharding strategy other than None is selected, but no optimizer is provided.
"""
if sharding_strategy is not None:
assert (
optimizer is not None
), "optimizer must be provided when sharding_strategy is not None"
from oslo.torch.nn.parallel.data_parallel import zero

model_wrapper_config = model_wrapper_config or {}
optimizer_wrapper_config = optimizer_wrapper_config or {}

def default_strategy():
ddp = _DistributedDataParallel(
module=module, parallel_context=parallel_context, **model_wrapper_config
)
add_wrapper(
module,
mode=ParallelMode.DATA,
wrapper=ddp,
parallel_context=parallel_context,
)
return module

def SHARD_OP_strategy():
optimizer_wrapper_config.pop("partition_grad", None)
return module, zero.ZeroRedundancyOptimizer(
optimizer,
parallel_context=parallel_context,
partition_grad=False,
**optimizer_wrapper_config,
)

def shard_grad_op_strategy():
optimizer_wrapper_config.pop("partition_grad", None)
return module, zero.ZeroRedundancyOptimizer(
optimizer,
parallel_context=parallel_context,
partition_grad=True,
**optimizer_wrapper_config,
)

def full_shard_strategy():
fsdp = zero._FullyShardedDataParallel(
module=module,
device=torch.device("cuda"),
parallel_context=parallel_context,
force_outputs_fp32=True,
**model_wrapper_config,
)
opt = zero._HeterogeneousZeroOptimizer(
optimizer,
module=fsdp,
**optimizer_wrapper_config,
)
add_wrapper(
module,
mode=ParallelMode.DATA,
wrapper=fsdp,
parallel_context=parallel_context,
)
return module, opt

strategy_map = {
None: default_strategy,
ShardingStrategy.SHARD_OP: SHARD_OP_strategy,
ShardingStrategy.SHARD_GRAD_OP: shard_grad_op_strategy,
ShardingStrategy.FULL_SHARD: full_shard_strategy,
}

strategy = strategy_map.get(sharding_strategy)

return strategy()


class _DistributedDataParallel(OsloParallelWrapper):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
parallel_context: ParallelContext = None,
placement_policy: str = "cuda",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
force_outputs_fp32: bool = True,
search_range_mb: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_mb: float = 32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

import math
import warnings
from enum import Enum
from typing import Any, Dict, Set, Tuple
from typing import Any, Dict, Set, Tuple, TYPE_CHECKING

import torch
import torch.distributed as dist
Expand All @@ -38,9 +37,11 @@
)

from oslo.torch.nn.parallel.data_parallel.zero.chunk import Chunk, ChunkManager
from oslo.torch.nn.parallel.data_parallel.zero.fully_sharded_data_parallel import (
_FullyShardedDataParallel,
)

if TYPE_CHECKING:
from oslo.torch.nn.parallel.data_parallel.zero.fully_sharded_data_parallel import (
_FullyShardedDataParallel,
)

import functools

Expand Down Expand Up @@ -87,7 +88,7 @@ class _HeterogeneousZeroOptimizer(BaseOptimizerWrapper):
def __init__(
self,
optim: Optimizer,
module: _FullyShardedDataParallel,
module: "_FullyShardedDataParallel",
gpu_margin_mem_ratio: float = 0.0,
clipping_norm: float = 0.0,
norm_type: float = 2.0,
Expand All @@ -96,7 +97,6 @@ def __init__(
**kwargs: Any,
):
super().__init__(optim)
assert isinstance(module, _FullyShardedDataParallel)
self.module = module
self.heterogeneous_manager = module.heterogeneous_manager
self.chunk_manager: ChunkManager = self.heterogeneous_manager.chunk_manager
Expand Down

0 comments on commit f16c73b

Please sign in to comment.