Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
Signed-off-by: Ubuntu <[email protected]>
  • Loading branch information
Ubuntu committed Sep 17, 2024
1 parent f32727d commit 9e13938
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 13 deletions.
36 changes: 25 additions & 11 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,14 @@ def do_exec_tasks(
"""
try:
import cupy as cp
import torch
from ray.air._internal import torch_utils

self.exec_stream = cp.cuda.Stream(null=False, non_blocking=True)
device = torch_utils.get_devices()[0]
exec_stream = torch.cuda.Stream()
self.exec_stream = cp.cuda.ExternalStream(
exec_stream.cuda_stream, device_id=device.index
)

for task in tasks:
task.prepare()
Expand Down Expand Up @@ -445,6 +451,7 @@ def set_intermediate_buffer(self, op: _DAGNodeOperation, data: Any):
else:
assert False, f"Invalid operation type: {op.type}"
self._intermediate_buffer[key] = data
# logger.info(f"Setting {key} to {data}", stack_info=True)

def reset_intermediate_buffer(self, op: _DAGNodeOperation) -> Any:
"""
Expand All @@ -454,6 +461,7 @@ def reset_intermediate_buffer(self, op: _DAGNodeOperation) -> Any:
Returns:
The intermediate result of a READ or COMPUTE operation.
"""
logger.info(f"{self._intermediate_buffer=}")
return self._intermediate_buffer.pop(op)

def _read(self, op: _DAGNodeOperation) -> bool:
Expand All @@ -463,7 +471,7 @@ def _read(self, op: _DAGNodeOperation) -> bool:
Returns:
True if system error occurs and exit the loop; otherwise, False.
"""
assert self._intermediate_buffer is None
# assert self._intermediate_buffer is None
exit = False
try:
input_data = self.input_reader.read()
Expand Down Expand Up @@ -501,9 +509,13 @@ def _compute(self, op: _DAGNodeOperation, class_handle) -> bool:
return False

channel_results = []
for channel_result, event in input_data:
if event:
event.synchronize()
for entry in input_data:
if isinstance(entry, tuple):
channel_result, event = entry
if event:
event.synchronize()
else:
channel_result = entry
channel_results.append(channel_result)

resolved_inputs = []
Expand All @@ -513,12 +525,14 @@ def _compute(self, op: _DAGNodeOperation, class_handle) -> bool:
import cupy as cp

exec_event = cp.cuda.Event()
with self.exec_stream:
try:
output_val = method(*resolved_inputs, **self.resolved_kwargs)
except Exception as exc:
output_val = _wrap_exception(exc)
exec_event.record()
# TODO: run on exec_stream
# with self.exec_stream:
logger.info(f"{resolved_inputs=}")
try:
output_val = method(*resolved_inputs, **self.resolved_kwargs)
except Exception as exc:
output_val = _wrap_exception(exc)
exec_event.record()

self.set_intermediate_buffer(op, (output_val, exec_event))
return False
Expand Down
10 changes: 10 additions & 0 deletions python/ray/dag/dag_node_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ def __repr__(self):
def __str__(self):
return f"([{self.exec_task_idx}] {self.func_name} {self.type})"

def __hash__(self):
return hash(str(self))

def __eq__(self, other):
return (
self.exec_task_idx == other.exec_task_idx
and self.type == other.type
and self.func_name == other.func_name
)


@total_ordering
class _DAGOperationGraphNode:
Expand Down
1 change: 1 addition & 0 deletions python/ray/dag/tests/experimental/test_torch_tensor_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def send_with_tuple_args(self, args):
return torch.ones(shape, dtype=dtype, device=self.device) * value

def recv(self, tensor):
print(f"{tensor=}")
# Check that tensor got loaded to the correct device.
assert tensor.device == self.device
return (tensor[0].item(), tensor.shape, tensor.dtype)
Expand Down
12 changes: 10 additions & 2 deletions python/ray/experimental/channel/nccl_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(
assert rank is not None, "NCCL actor has no rank assigned"

import cupy as cp
import torch

from ray.air._internal import torch_utils

Expand All @@ -104,8 +105,15 @@ def __init__(
cuda_stream, device_id=device.index
)

self._send_stream = cp.cuda.Stream(null=False, non_blocking=True)
self._recv_stream = cp.cuda.Stream(null=False, non_blocking=True)
send_stream = torch.cuda.Stream()
recv_stream = torch.cuda.Stream()
self._send_stream = cp.cuda.ExternalStream(
send_stream.cuda_stream, device_id=device.index
)
self._recv_stream = cp.cuda.ExternalStream(
recv_stream.cuda_stream, device_id=device.index
)
print("Inited send and recv streams")

self._closed = False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ def _do_init_nccl_group(
), "Actors participating in NCCL group must have at least one GPU assigned"

ctx = ChannelContext.get_current()
print(f"_do_init_nccl_group() with {custom_nccl_group=}")
if custom_nccl_group is not None:
custom_nccl_group.initialize(rank)
ctx.nccl_groups[group_id] = custom_nccl_group
Expand Down

0 comments on commit 9e13938

Please sign in to comment.