Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support of AOT compilation (refine #6992) #7581

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open

Conversation

zpcore
Copy link
Collaborator

@zpcore zpcore commented Jun 26, 2024

This is a follow up PR to refine #6992.

In this PR, I created the PjRtCompilationClient to serve the ahead of time compilation. In this way, we don't need to create the CompileOnlyPjRtClient, CompileOnlyPjRtDevice etc. This makes it easier during openxla pin update.

Instructions on how to run AOT compilation has been updated. We need to specify two extra flags when run on CPU device: XLA_PERSISTENT_CACHE_PATH as follows:
----------------------- ON CPU--------------------

PJRT_DEVICE=TPU   XLA_PERSISTENT_CACHE_PATH=./  python aot_encode.py

aot_encode.py:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

torch_xla._XLAC._xla_set_virtual_topology("v4:2x2x1")  # <----- define virtual topology here. Must be specified before any device declaration.

a = torch.rand([2,3])
b = torch.rand([2,3])
device = xm.xla_device()
a = a.to(device)
b = b.to(device)
f = torch.hstack([a,b])

torch_xla._XLAC._xla_warm_up_cache([f],[])  # <----- call this to avoid real computation.

This will genereate a hashing file named like 229013763457648799216243727807636414712, which can be deserialized by running the same graph code on a TPU:
-----------------------ON TPU v4-8--------------------

PJRT_DEVICE=TPU XLA_PERSISTENT_CACHE_PATH=./  python aot_decode.py

aot_decode.py:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

a = torch.rand([2,3])
b = torch.rand([2,3])
device = xm.xla_device()
a = a.to(device)
b = b.to(device)
f = torch.hstack([a,b])

print(f)

@zpcore zpcore added the AOT ahead of time cross device compilation label Jun 26, 2024
@zpcore zpcore marked this pull request as ready for review June 26, 2024 20:31
@zpcore zpcore changed the title refining the AOT compilation Support of AOT compilation (refine https://github.com/pytorch/xla/pull/6992) Jun 26, 2024
@zpcore zpcore changed the title Support of AOT compilation (refine https://github.com/pytorch/xla/pull/6992) Support of AOT compilation (refine #6992) Jun 26, 2024
torch_xla/csrc/runtime/pjrt_compile_only.h Outdated Show resolved Hide resolved
@@ -12,6 +13,8 @@ namespace runtime {

std::atomic<bool> g_computation_client_initialized(false);

std::string aot_topology = "";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this state is only ever used by PjRtCompilationClient, would it make sense to make this a static field on PjRtCompilationClient?

I think it's probably even better to just imperatively initialize this client when set_virtual_topology() happens... but that also would introduce some extra edge cases for initialization. For the POC, a global var is okay, but I think a static field makes more sense to make it clear what uses that state.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving aot_topology to pjrt_compilation_client.h may cause circular dependency issue because setting the aot_topology need to check g_computation_client_initialized in the current setup.

Let's use the global var for now since this makes it clear in runtime.cc that we have three kinds of clients all together.


// Builds a map from the device's global ordinal to its index in the `devices`
// array.
std::unordered_map<int, int> build_index_map(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see some duplication of helpers here and in the PJRT/IFRT client implementations. It would be cleaner to factor it out rather than have 3 copies, but both AOT and IFRT are just in a concept stage...

I'll let @JackCaoG have the final word on readability so I don't have to decide.

torch_xla/csrc/runtime/pjrt_compilation_client.cc Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
AOT ahead of time cross device compilation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants