-
Notifications
You must be signed in to change notification settings - Fork 462
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distributed training on multi-host v4/v5 TPU Pods is too slow #8020
Comments
We never tried training gemma2, and I vaguely remember it uses sliding windows attention, so I am not sure if that's the bottleneck. Maybe try run with |
Hi @JackCaoG, I'm working with Hanna and I've tried to collect some information that can hopefully help troubleshoot the issue. It's quite a bit so I will start from what you asked i.e. running with For starters, I'm using a TPU v4-32 pod. The environment was setup with a script like this: #!/bin/bash
source ./hf_env/bin/activate
# Step 1: install torch, torch-xla, libtpu
pip install torch~=2.4.0 --index-url https://download.pytorch.org/whl/cpu
pip install torch_xla[tpu]~=2.4.0 -f https://storage.googleapis.com/libtpu-releases/index.html
# Step 2: install HF
pip3 install datasets accelerate evaluate scikit-learn
git clone https://github.com/huggingface/transformers -b v4.44.2 transformers-4.44.2
cd transformers-4.44.2
pip install -e .
So I'm using the diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py
index 15ab7c4c4..18ef59831 100755
--- a/examples/pytorch/language-modeling/run_clm.py
+++ b/examples/pytorch/language-modeling/run_clm.py
@@ -53,6 +53,12 @@ from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
+import torch_xla
+import torch_xla.debug.profiler as xp
+import torch_xla.runtime as xr
+
+xr.use_spmd()
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.44.0")
@@ -295,6 +301,9 @@ def main():
# Set seed before initializing model.
set_seed(training_args.seed)
+ server = xp.start_server(9012)
+ logger.info(f'Profiling server started: {str(server)}')
+
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub). and the diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index 68ba7babf..7cd948e5a 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -34,6 +34,7 @@ import time
import warnings
from collections.abc import Mapping
from pathlib import Path
+from threading import Thread
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
@@ -188,6 +189,7 @@ if is_torch_xla_available():
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
from torch_xla import __version__ as XLA_VERSION
+ import torch_xla.debug.profiler as xp
IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION)
if IS_XLA_FSDPV2_POST_2_2:
@@ -702,6 +704,7 @@ class Trainer:
# Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper.
# Tensor axis is just a placeholder where it will not be used in FSDPv2.
num_devices = xr.global_runtime_device_count()
+ logger.info(f"Global device count: {num_devices}")
xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
def _activate_neftune(self, model):
@@ -2206,6 +2209,7 @@ class Trainer:
total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs):
+ epoch_start_time = time.time()
epoch_iterator = train_dataloader
if hasattr(epoch_iterator, "set_epoch"):
epoch_iterator.set_epoch(epoch)
@@ -2233,6 +2237,11 @@ class Trainer:
rng_to_sync = True
step = -1
+ profile_step = int(os.environ.get('PROFILE_STEP', -1))
+ profile_epoch = int(os.environ.get('PROFILE_EPOCH', -1))
+ profile_duration = int(os.environ.get('PROFILE_DURATION_MS', 20000))
+ profile_logdir = os.environ.get('PROFILE_LOGDIR', None)
+ logger.info("Profiling will start at step {} and epoch {}".format(profile_step, profile_epoch))
for step, inputs in enumerate(epoch_iterator):
total_batched_samples += 1
@@ -2274,9 +2283,14 @@ class Trainer:
if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
+
+ if step == profile_step and epoch == profile_epoch:
+ trace = lambda: xp.trace('127.0.0.1:9012', profile_logdir, profile_duration)
+ Thread(target=trace).start()
with self.accelerator.accumulate(model):
- tr_loss_step = self.training_step(model, inputs)
+ with xp.StepTrace("Training_step", step_num=step):
+ tr_loss_step = self.training_step(model, inputs)
if (
args.logging_nan_inf_filter
@@ -2384,6 +2398,14 @@ class Trainer:
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
"configured. Check your training configuration if this is unexpected."
)
+ epoch_speed_metrics = speed_metrics(
+ "epoch",
+ epoch_start_time,
+ num_samples=num_train_samples / args.num_train_epochs,
+ num_tokens=num_train_tokens / args.num_train_epochs if num_train_tokens is not None else None,
+ )
+ self.log(epoch_speed_metrics)
+
if self.control.should_training_stop:
break Then I used this script to launch the training: #!/bin/bash
. ~/hf_env/bin/activate
export PJRT_DEVICE=TPU
export PROFILE_STEP=20
export PROFILE_EPOCH=0
export PROFILE_DURATION_MS=300000
export PROFILE_LOGDIR=/home/giuliano/xla-exp/profiles-run-clm
export PT_XLA_DEBUG=1
export XLA_FLAGS="--xla_dump_to=/home/giuliano/xla-exp/dumps"
export XLA_IR_DEBUG=1
export XLA_HLO_DEBUG=1
# Run
cd ~/transformers-4.44.2
examples/pytorch/language-modeling/run_clm.py \
--log_level info \
--debug tpu_metrics_debug \
--include_tokens_per_second \
--dataloader_drop_last \
--token MY_TOKEN \
--model_name_or_path google/gemma-2-9b \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--per_device_train_batch_size 64 \
--num_train_epochs 4 \
--block_size 512 \
--optim adamw_torch \
--do_train \
--save_strategy no \
--logging_strategy no \
--fsdp full_shard \
--fsdp_config fsdp_config.json \
--output_dir /home/giuliano/xla-exp \
--overwrite_output_dir where {
"fsdp_transformer_layer_cls_to_wrap": [
"Gemma2DecoderLayer"
],
"min_num_params": 0,
"xla": true,
"xla_fsdp_v2": true,
"xla_fsdp_grad_ckpt": true
} The training was launched on all the workers as typicaly for a tpu pod: gcloud compute tpus tpu-vm ssh $TPU_NAME --worker=all --zone=$TPU_ZONE --command="bash transformers-4.44.2/train_gemma2.sh" Unfortunately, that goes OOM by a few MBs and here is the full output:
Now the OOM is one problem (after replicating the same training loop without all the bell and whistles of HF trainer I am able to make it work, but I'll talk about that later), but the part that I want to highlight is this:
Compilation is extremely slow for this model. You can find the dumps attached. Hope you can give us a hand. I'll follow up with some more info about profiling and performance issue later. Let me know if you have any questions and / or if you need more info to help troubleshoot the problem. |
@JackCaoG one more question regarding the oom, is there a way to somehow get the list of allocated tensors when oom occurs? I could only find this |
The slow compilation is kind of expceted. It will be better after @tengyifei land the layer scan which will make compilation for LLM much faster. In the mean time try persistent cache https://github.com/pytorch/xla/blob/master/API_GUIDE.md#compilation-caching. if you enable this for the same graph you only need to compile once(unless you update the torch_xla version or update the model config). |
ok I see, in fact I was already using compilation caching, but I disabled it to run with ir debug enabled - I'll try that again with caching enabled, thanks in the meantime, do you have any inputs regarding what I asked above about memory profiling? |
cloest thing might be the |
❓ Questions and Help
Hi!
We are trying to train Gemma-2-9B on v4-64 and v5-128 Pod as mentioned in this comment. We use FSDP+SPMD setup on torch XLA 2.4.0 and transformers 4.44.2. We have tested our own code with default HuggingFace Trainer and xmp.spawn launch and also HuggingFace example run_clm.py according to this GCP documentation - everything causes very slow training.
We observe that the batch size per device is equal to total batch size:
tpu
and therefore is marked as False in the output above. Is total train batch size supposed to be equal to batch size per device in this case?train_samples_per_second
being around 0.606. Does this mean that the training is not parallel/distributed?We can't seem to find an issue on what can cause such slow performance on TPUs - we use a default example and it is not working for us. Would be thankful for any help here!
The text was updated successfully, but these errors were encountered: