diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index d93a702723e..b100f9ffe0f 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -217,6 +217,7 @@ def _init_xla_lazy_backend(): import torch._functorch.config torch._functorch.config.view_replay_for_aliased_outputs = True +torch._functorch.config.skip_view_reconstruction = True import importlib.metadata import warnings diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index ec80934c320..b3afeb2df49 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -14,11 +14,12 @@ import torch from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.utils.fuser_utils import topo_sort - -import torch._inductor -from torch._inductor.fx_passes.post_grad import ConstructorMoverPass - from torch.utils import _pytree as pytree +from torch._functorch._aot_autograd.schemas import ViewAndMutationMeta, OutputAliasInfo, OutputType +from torch._functorch._aot_autograd.runtime_wrappers import make_output_handler +from torch._functorch._aot_autograd.utils import normalize_as_list +from torch._guards import TracingContext +from torch._inductor.fx_passes.post_grad import ConstructorMoverPass import torch_xla import torch_xla.core.xla_model as xm @@ -30,6 +31,7 @@ dynamo_debug = int(os.environ.get('XLA_DYNAMO_DEBUG', '0')) == 1 ptxla_debug = int(os.environ.get('PT_XLA_DEBUG', '0')) == 1 +enable_skip_handler = int(os.environ.get('XLA_ENABLE_SKIP_HANDLER', '1')) == 1 @contextmanager @@ -710,11 +712,70 @@ def allow_cpu_device(self, node: torch.fx.Node): return (device is not None and device.type == self.target) -def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): +class MaybeReconstructOutputs: + + def __init__(self, graph: torch.fx.GraphModule, metadata: ViewAndMutationMeta) -> None: + self.graph = graph + self.metadata = metadata + + def __call__(self, *args): + outputs = normalize_as_list(self.graph(*args)) + + def maybe_skip_handler(o: torch.Tensor, info: OutputAliasInfo) -> torch.Tensor: + handler = make_output_handler(info, self.metadata, trace_joint=False) + if ( + is_xla_tensor(o) + and info.output_type in ( + OutputType.alias_of_input, + OutputType.alias_of_intermediate, + OutputType.alias_of_intermediate_save_as_output, + OutputType.alias_of_intermediate_base_is_user_output, + ) + and not torch.is_grad_enabled() + and info.functional_tensor is not None + and enable_skip_handler + ): + return torch_xla._XLAC._fresh_functional_tensor_from( + o, + handler.base(args, outputs), + info.functional_tensor.tensor, + ) + else: + return handler(args, outputs, o) + + # AOTAutograd outputs are composed of, in order: + # 1. Tokens + # 2. Mutated inputs + # 3. Actual outputs + # 4. Intermediate bases + # + # However, we only have output aliasing information, e.g. OutputAliasInfo, for (3). + # Therefore, we must run the handlers only for (3), and combine them appropriately. + num_tokens = len(self.metadata.tokens) + num_mutations_to_apply = self.metadata.num_mutated_inp_runtime_indices + num_intermediate_bases = self.metadata.num_intermediate_bases + num_prefix_outputs_without_handlers = num_tokens + num_mutations_to_apply + + # Filter (1) and (2). + outputs_with_handlers = outputs[num_prefix_outputs_without_handlers:] + # Filter (4), and call the handlers on the outputs. + handled_outputs = [maybe_skip_handler(o, info) for o, info in zip(outputs_with_handlers, self.metadata.output_info)] + + # Combine (1), (2) and the possibly modified (3). + combined_outputs = outputs[:num_prefix_outputs_without_handlers] + handled_outputs + # Combine the above with (4), if any. + if num_intermediate_bases > 0: + combined_outputs = combined_outputs + outputs[-num_intermediate_bases:] + + return combined_outputs + + +def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args, metadata: ViewAndMutationMeta): torch_xla._XLAC._xla_increment_counter('DynamoExtractCompiledGraph', 1) with torch_xla.experimental.eager_mode_context(False): - return extract_compiled_graph_helper(xla_model, xla_args) + g = extract_compiled_graph_helper(xla_model, xla_args) + return MaybeReconstructOutputs(g, metadata) def partition_fx_graph_for_cpu_fallback(xla_model, xla_args, all_xla_args, diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 9b11eadf797..32c9496069b 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -2673,6 +2674,22 @@ void InitXlaModuleBindings(py::module m) { "without a data handle or an IR."; }); + m.def("_fresh_functional_tensor_from", + [](const at::Tensor& value, const at::Tensor& storage_source, + const at::Tensor& view_metas_source) { + TORCH_LAZY_COUNTER("fresh_functional_tensor_from", 1); + XLA_CHECK(at::functionalization::impl::isFunctionalTensor(value)); + at::Tensor unwrapped = torch::lazy::maybe_unwrap_functional(value); + at::FunctionalTensorWrapper* storage_source_ = + at::functionalization::impl::unsafeGetFunctionalWrapper( + storage_source); + at::FunctionalTensorWrapper* view_metas_source_ = + at::functionalization::impl::unsafeGetFunctionalWrapper( + view_metas_source); + return at::detail::make_tensor( + unwrapped, storage_source_, view_metas_source_); + }); + // from an XLA tensor to a PyCapsule. // When consuming the PyCapsule, we should synchronize // (waits for all kernels in all streams on a CUDA device to complete) if