Skip to content

Commit

Permalink
[Fix] Refactor ZeRO Directory Structure (#211)
Browse files Browse the repository at this point in the history
## Title

- [Fix] Refactor ZeRO Directory Structure

## Description

- This PR restructures the zero directory under
`oslo/torch/nn/parallel/data_parallel/zero` to enhance code organization
and readability. The changes align the implementation with the
architecture of our project, providing a more logical separation between
different components and functionalities.

- Organized heterogeneous components (Inspired by PatrickStar) into the
`hetero` subdirectory, centralizing related code and improving
maintainability.

- Update to Zero Optimizer Wrapper Interface:

> In the existing Zero optimizer, we were not sharding the optimizer
state, so the wrapper interface has been updated accordingly. My sincere
apologies for any confusion or inconvenience this change may cause, and
I urge reviewers to assess this modification to ensure alignment with
our project's requirements.

- Renaming FULL_SHARD to PatrickStar Algorithm:

> Please note that the previously termed FULL_SHARD strategy was, in
fact, implementing the PatrickStar algorithm. PatrickStar is a novel
approach to parallel training of pre-trained models via chunk-based
memory management, leveraging CPU-GPU heterogeneous memory space. It has
demonstrated significant advantages in model scaling and execution
speed.
> 
> However, I felt that the name "PatrickStar" did not adequately convey
the specific characteristics of this approach. Therefore, I have taken
the liberty to rename it as "hetero," reflecting the heterogeneous
memory utilization. I genuinely value the reviewers' opinions on this
naming choice and kindly ask for your feedback. If a more suitable name
can be agreed upon, I will happily update it accordingly.

## Linked Issues

- N/A
  • Loading branch information
yhna940 committed Sep 9, 2023
1 parent 21ef4a1 commit d7c4e32
Show file tree
Hide file tree
Showing 41 changed files with 200 additions and 162 deletions.
32 changes: 16 additions & 16 deletions oslo/torch/nn/parallel/data_parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@


class ShardingStrategy(Enum):
SHARD_OP = auto()
SHARD_GRAD_OP = auto()
FULL_SHARD = auto()
SHARD_PARAM = auto()
SHARD_GRAD_PARAM = auto()
HETERO_SHARD = auto()


def DistributedDataParallel(
Expand All @@ -45,18 +45,18 @@ def DistributedDataParallel(
Supported sharding strategies are:
- None: No sharding is used. This is the default strategy, where each GPU maintains a full replica of the model.
- SHARD_OP: The optimizer states are sharded across GPUs. Each GPU maintains only a portion of the optimizer state.
- SHARD_GRAD_OP: In addition to sharding the optimizer states, the gradients are also sharded across GPUs.
- FULL_SHARD: The model parameters, optimizer states, and gradients are all sharded across GPUs.
- SHARD_PARAM: Shards the model parameters across GPUs.
- SHARD_GRAD_PARAM: Shards the gradient as well as the model parameters across GPUs.
- HETERO_SHARD: Use the CPU-GPU heterogeneous memory space to store the model data, inspired from PatrickStar.
For the SHARD_OP, SHARD_GRAD_OP, and FULL_SHARD strategies, it is mandatory to provide an optimizer.
For the SHARD_PARAM, SHARD_GRAD_PARAM, and HETERO_SHARD strategies, it is mandatory to provide an optimizer.
Args:
module (nn.Module): PyTorch module object to be wrapped.
parallel_context (ParallelContext): Process group object for distributed training.
model_wrapper_config (Optional[Dict[str, Any]]): Additional configuration parameters for the model wrapper.
optimizer_wrapper_config (Optional[Dict[str, Any]]): Additional configuration parameters for the optimizer wrapper.
sharding_strategy (Optional[ShardingStrategy]): The strategy for sharding. Options include None, SHARD_OP, SHARD_GRAD_OP, and FULL_SHARD.
sharding_strategy (Optional[ShardingStrategy]): The strategy for sharding. Options include None, SHARD_PARAM, SHARD_GRAD_PARAM, and HETERO_SHARD.
optimizer (Optional[torch.optim.Optimizer]): PyTorch optimizer object to be wrapped if a sharding strategy is specified.
Returns:
Expand Down Expand Up @@ -86,7 +86,7 @@ def default_strategy():
)
return module

def SHARD_OP_strategy():
def shard_param_strategy():
optimizer_wrapper_config.pop("partition_grad", None)
return module, zero.ZeroRedundancyOptimizer(
optimizer,
Expand All @@ -95,7 +95,7 @@ def SHARD_OP_strategy():
**optimizer_wrapper_config,
)

def shard_grad_op_strategy():
def shard_grad_param_strategy():
optimizer_wrapper_config.pop("partition_grad", None)
return module, zero.ZeroRedundancyOptimizer(
optimizer,
Expand All @@ -104,15 +104,15 @@ def shard_grad_op_strategy():
**optimizer_wrapper_config,
)

def full_shard_strategy():
fsdp = zero._FullyShardedDataParallel(
def hetero_shard_strategy():
fsdp = zero._HeteroDataParallel(
module=module,
device=torch.device("cuda"),
parallel_context=parallel_context,
force_outputs_fp32=True,
**model_wrapper_config,
)
opt = zero._HeterogeneousZeroOptimizer(
opt = zero._HeteroOptimizer(
optimizer,
module=fsdp,
**optimizer_wrapper_config,
Expand All @@ -127,9 +127,9 @@ def full_shard_strategy():

strategy_map = {
None: default_strategy,
ShardingStrategy.SHARD_OP: SHARD_OP_strategy,
ShardingStrategy.SHARD_GRAD_OP: shard_grad_op_strategy,
ShardingStrategy.FULL_SHARD: full_shard_strategy,
ShardingStrategy.SHARD_PARAM: shard_param_strategy,
ShardingStrategy.SHARD_GRAD_PARAM: shard_grad_param_strategy,
ShardingStrategy.HETERO_SHARD: hetero_shard_strategy,
}

strategy = strategy_map.get(sharding_strategy)
Expand Down
14 changes: 7 additions & 7 deletions oslo/torch/nn/parallel/data_parallel/zero/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from oslo.torch.nn.parallel.data_parallel.zero.sharded_optim.sharded_optim import (
from oslo.torch.nn.parallel.data_parallel.zero.optim.optim import (
ZeroRedundancyOptimizer,
)
from oslo.torch.nn.parallel.data_parallel.zero.fully_sharded_data_parallel import (
_FullyShardedDataParallel,
from oslo.torch.nn.parallel.data_parallel.zero.hetero.data_parallel import (
_HeteroDataParallel,
)
from oslo.torch.nn.parallel.data_parallel.zero.sharded_optim.heterogeneous_optim import (
_HeterogeneousZeroOptimizer,
from oslo.torch.nn.parallel.data_parallel.zero.hetero.optim import (
_HeteroOptimizer,
)

__ALL__ = [
"ZeroRedundancyOptimizer",
"_FullyShardedDataParallel",
"_HeterogeneousZeroOptimizer",
"_HeteroDataParallel",
"_HeteroOptimizer",
]
9 changes: 0 additions & 9 deletions oslo/torch/nn/parallel/data_parallel/zero/chunk/__init__.py

This file was deleted.

6 changes: 6 additions & 0 deletions oslo/torch/nn/parallel/data_parallel/zero/hetero/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from oslo.torch.nn.parallel.data_parallel.zero.hetero.data_parallel import (
_HeteroDataParallel,
)
from oslo.torch.nn.parallel.data_parallel.zero.hetero.optim import _HeteroOptimizer

__ALL__ = ["_HeteroDataParallel", "_HeteroOptimizer"]
17 changes: 17 additions & 0 deletions oslo/torch/nn/parallel/data_parallel/zero/hetero/chunk/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from oslo.torch.nn.parallel.data_parallel.zero.hetero.chunk.chunk import (
Chunk,
TensorState,
ChunkFullError,
)
from oslo.torch.nn.parallel.data_parallel.zero.hetero.chunk.manager import ChunkManager
from oslo.torch.nn.parallel.data_parallel.zero.hetero.chunk.utils import (
init_chunk_manager,
)

__ALL__ = [
"Chunk",
"TensorState",
"ChunkFullError",
"ChunkManager",
"init_chunk_manager",
]
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from oslo.torch.distributed.parallel_mode import ParallelMode
from oslo.torch.distributed.parallel_context import ParallelContext

from oslo.torch.nn.parallel.data_parallel.zero.utils import get_current_device
from oslo.torch.nn.parallel.data_parallel.zero.hetero.utils import get_current_device


class TensorState(Enum):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch

from .chunk import Chunk, ChunkFullError, TensorState
from oslo.torch.nn.parallel.data_parallel.zero.utils import get_current_device
from oslo.torch.nn.parallel.data_parallel.zero.hetero.utils import get_current_device

from oslo.torch.distributed.parallel_context import ParallelContext

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch.distributed as dist
from torch import nn

from oslo.torch.nn.parallel.data_parallel.zero.chunk.manager import (
from oslo.torch.nn.parallel.data_parallel.zero.hetero.chunk.manager import (
ChunkManager,
)

Expand All @@ -34,7 +34,7 @@
from oslo.torch.nn.parallel.data_parallel._utils import (
is_ddp_ignored,
)
from oslo.torch.nn.parallel.data_parallel.zero.memory_tracer import (
from oslo.torch.nn.parallel.data_parallel.zero.hetero.memory_tracer import (
MemStats,
OrderedParamGenerator,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,29 @@
from oslo.torch.distributed.parallel_mode import ParallelMode
from oslo.torch.nn.parallel.data_parallel._utils import is_ddp_ignored
from oslo.torch.nn.parallel.data_parallel.data_parallel import _DistributedDataParallel
from oslo.torch.nn.parallel.data_parallel.zero.chunk import (
from oslo.torch.nn.parallel.data_parallel.zero.hetero.chunk import (
Chunk,
ChunkManager,
TensorState,
)
from oslo.torch.nn.parallel.data_parallel.zero.heterogeneous_manager import (
HeterogeneousMemoryManager,
from oslo.torch.nn.parallel.data_parallel.zero.hetero.memory_manager import (
HeteroMemoryManager,
)
from oslo.torch.nn.parallel.data_parallel.zero.memory_tracer.param_runtime_order import (
from oslo.torch.nn.parallel.data_parallel.zero.hetero.memory_tracer.param_runtime_order import (
OrderedParamGenerator,
)
from oslo.torch.nn.parallel.data_parallel.zero.utils import (
from oslo.torch.nn.parallel.data_parallel.zero.hetero.utils import (
get_current_device,
get_temp_total_chunk_on_cuda,
)
from oslo.torch.nn.parallel.data_parallel.zero.heterogeneous_hook import (
HeterogeneousZeROHook,
from oslo.torch.nn.parallel.data_parallel.zero.hetero.hook import (
HeteroHook,
)

from oslo.torch.nn.parallel.data_parallel.zero.memory_tracer import (
from oslo.torch.nn.parallel.data_parallel.zero.hetero.memory_tracer import (
MemStats,
)
from oslo.torch.nn.parallel.data_parallel.zero.chunk import (
from oslo.torch.nn.parallel.data_parallel.zero.hetero.chunk import (
init_chunk_manager,
)

Expand All @@ -70,22 +70,38 @@ def _cast_float(args, dtype: torch.dtype):
return args


class _FullyShardedDataParallel(_DistributedDataParallel):
"""Fully sharded data parallel.
Warning: Nested FullyShardedDataParallel is not supported now.
It is designed to be used with ChunkManager and HeterogeneousMemoryManager.
For more details, see the API reference of ``ChunkManager`` and ``HeterogeneousMemoryManager``.
class _HeteroDataParallel(_DistributedDataParallel):
"""Heterogeneous sharded data parallel.
Inspired by the PatrickStar system introduced in "PatrickStar: Parallel
Training of Pre-trained Models via Chunk-based Dynamic Memory Management"
by Jiarui Fang, Zilin Zhu, et al. from Tencent Inc:
- PatrickStar uses a CPU-GPU heterogeneous memory space to store model data,
organized in memory chunks.
- Chunks are dynamically distributed across the heterogeneous memory,
guided by runtime memory statistics from a warm-up iteration.
- This approach reduces CPU-GPU data transmission volume and optimizes
bandwidth utilization.
- In tandem with the Zero Redundancy Optimizer, PatrickStar can efficiently
scale to multiple GPUs across multiple nodes.
Note:
Nested HeteroDataParallel is not supported now. It is designed to be
used with ChunkManager and HeterogeneousMemoryManager. For more details,
see the API reference of ``ChunkManager`` and ``HeteroMemoryManager``.
Args:
module (torch.nn.Module): Module to apply ZeRO-DP.
device (torch.device): Device to place the module.
parallel_context (ParallelContext): process group object.
placement_policy (str): Placement policy for the chunks.
pin_memory (bool): Chunks on CPU Memory use pin-memory.
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16.
Defaults to False.
force_outputs_fp32 (bool): If set to True, outputs will be fp32.
Otherwise, outputs will be fp16. Defaults to False.
search_range_mb (int): Search range for the chunk size. Defaults to 32.
hidden_dim (int): Hidden dimension for the chunk size search. Defaults to None.
hidden_dim (int): Hidden dimension for the chunk size search.
Defaults to None.
min_chunk_size_mb (int): Minimum chunk size in MB. Defaults to 32.
memstats (MemStats): Memory statistics. Defaults to None.
"""
Expand All @@ -111,11 +127,11 @@ def __init__(
search_range_mb=search_range_mb,
min_chunk_size_mb=min_chunk_size_mb,
)
self.heterogeneous_manager = HeterogeneousMemoryManager(
self.hetero_memory_manager = HeteroMemoryManager(
placement_policy, self.chunk_manager, memstats
)
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = HeterogeneousZeROHook(self.heterogeneous_manager)
self.param_op_hook = HeteroHook(self.hetero_memory_manager)
self.fp32_params: List[torch.Tensor] = list()
self.fp16_params: List[torch.Tensor] = list()
self.overflow_counter = 0
Expand All @@ -126,9 +142,9 @@ def __init__(
self._cast_buffers()
self._logger = DistributedLogger.get_instance(__name__)

if self.heterogeneous_manager._premade_memstats_:
if self.hetero_memory_manager._premade_memstats_:
# build chunk in param runtime visited order.
param_order = self.heterogeneous_manager.memstats()._param_runtime_order
param_order = self.hetero_memory_manager.memstats()._param_runtime_order
else:
# build chunk in param initialized order.
# Note: in this way, it can not get filter unused params during runtime.
Expand All @@ -138,7 +154,7 @@ def __init__(

self._init_chunks(
param_order=param_order,
cpu_offload=self.heterogeneous_manager.policy_name != "cuda",
cpu_offload=self.hetero_memory_manager.policy_name != "cuda",
pin_memory=pin_memory,
)

Expand All @@ -163,20 +179,20 @@ def _post_forward(self):
self.chunk_manager.move_chunk(chunk, self.grads_device[first_param])
assert self.chunk_manager.accessed_mem == 0
# reset all recorded attributes
self.heterogeneous_manager.reset_attributes()
self.hetero_memory_manager.reset_attributes()

def forward(self, *args, **kwargs):
# check whether we are in a inference mode
grad_flag = torch.is_grad_enabled()
if not grad_flag:
assert (
not self.heterogeneous_manager.need_warmup
or not self.heterogeneous_manager.is_warmup()
not self.hetero_memory_manager.need_warmup
or not self.hetero_memory_manager.is_warmup()
), "You should run a completed iteration as your warmup iter"

args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)

self.heterogeneous_manager.pre_iter(*args)
self.hetero_memory_manager.pre_iter(*args)
self.param_op_hook.pre_forward(self.fp16_params)
outputs = super().forward(*args, **kwargs)
self.param_op_hook.post_forward(self.fp16_params)
Expand Down Expand Up @@ -225,9 +241,9 @@ def _post_backward(self):
)
self._setup_grads_ptr()
self._logger.debug(
f"comp cuda demand time: {self.heterogeneous_manager._comp_cuda_demand_time}, layout time: {self.heterogeneous_manager._layout_time}, evict time: {self.heterogeneous_manager._evict_time}, CPU->CUDA vol: {self.heterogeneous_manager._h2d_volume}B, CUDA->CPU vol: {self.heterogeneous_manager._d2h_volume}"
f"comp cuda demand time: {self.hetero_memory_manager._comp_cuda_demand_time}, layout time: {self.hetero_memory_manager._layout_time}, evict time: {self.hetero_memory_manager._evict_time}, CPU->CUDA vol: {self.hetero_memory_manager._h2d_volume}B, CUDA->CPU vol: {self.hetero_memory_manager._d2h_volume}"
)
self.heterogeneous_manager.post_iter()
self.hetero_memory_manager.post_iter()

def grad_handle(self, p, grad):
self.param_op_hook.post_backward([p])
Expand Down Expand Up @@ -645,7 +661,7 @@ def _init_chunks(self, param_order, cpu_offload: bool, pin_memory: bool):

self.fp16_params.append(p)
self.fp32_params.append(fp32_p)
self.grads_device[p] = self.heterogeneous_manager.default_device
self.grads_device[p] = self.hetero_memory_manager.default_device

self.chunk_manager.close_all_groups()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@

import torch

from oslo.torch.nn.parallel.data_parallel.zero.chunk import (
from oslo.torch.nn.parallel.data_parallel.zero.hetero.chunk import (
TensorState,
)
from oslo.torch.nn.parallel.data_parallel.zero.heterogeneous_manager import (
HeterogeneousMemoryManager,
from oslo.torch.nn.parallel.data_parallel.zero.hetero.memory_manager import (
HeteroMemoryManager,
)
from oslo.torch.nn.parallel.data_parallel._utils import is_ddp_ignored

Expand All @@ -33,25 +33,25 @@ class TrainingPhase(Enum):
BACKWARD = 1


class HeterogeneousZeROHook:
def __init__(self, heterogeneous_manager: HeterogeneousMemoryManager) -> None:
class HeteroHook:
def __init__(self, hetero_memory_manager: HeteroMemoryManager) -> None:
super().__init__()
self._heterogeneous_manager = heterogeneous_manager
self._chunk_manager = heterogeneous_manager.chunk_manager
self._hetero_memory_manager = hetero_memory_manager
self._chunk_manager = hetero_memory_manager.chunk_manager
self._training_phase = TrainingPhase.FORWARD

def pre_op(self, params):
params = [p for p in params if not is_ddp_ignored(p)]
chunks = self._chunk_manager.get_chunks(params)
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
self._heterogeneous_manager.sample_overall_data()
self._heterogeneous_manager.adjust_layout(chunks)
self._hetero_memory_manager.sample_overall_data()
self._hetero_memory_manager.adjust_layout(chunks)
for chunk in chunks:
self._chunk_manager.access_chunk(chunk)

# record cuda model data of the current OP
self._heterogeneous_manager.record_model_data_volume()
self._hetero_memory_manager.record_model_data_volume()

def post_op(self, params):
params = [p for p in params if not is_ddp_ignored(p)]
Expand Down
Loading

0 comments on commit d7c4e32

Please sign in to comment.