Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/load model with torch dtype auto #663

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

mreso
Copy link
Contributor

@mreso mreso commented Sep 6, 2024

What does this PR do?

This PR loads a model with torch_dtype=auto instead of bfloat16 when we do not specify train_config.use_fp16.
For llama models this will not make a difference as their default dtype is bfloat16

Fixes # (issue)
#656 (kind of)

Feature/Issue validation/testing

Please describe the tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
Please also list any relevant details for your test configuration.

  • torchrun --nnodes 1 --nproc_per_node 4 ./recipes/quickstart/finetuning/finetuning.py --model_name meta-llama/Meta-Llama-3.1-8B-Instruct --enable_fsdp --max_train_step=2 --batch_size_training 1 --batching_strategy packing --dataset samsum_dataset --save_model False --context_length 4096 --fsdp_config.pure _bf16 True --fsdp_config.optimizer anyprecision --samsum_dataset.trust_remote_code 1
    Logs
W0906 15:44:05.090000 140566387028992 torch/distributed/run.py:779]
W0906 15:44:05.090000 140566387028992 torch/distributed/run.py:779] *****************************************
W0906 15:44:05.090000 140566387028992 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0906 15:44:05.090000 140566387028992 torch/distributed/run.py:779] *****************************************
/home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
/home/mreso/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead
  from torch.distributed._shard.checkpoint import (
Clearing GPU cache for all ranks
--> Running with torch dist debug set to detail
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  4.03it/s]
--> Model meta-llama/Meta-Llama-3.1-8B-Instruct

--> meta-llama/Meta-Llama-3.1-8B-Instruct has 8030.261248 Million params

bFloat16 enabled for mixed precision - using bfSixteen policy
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  4.03it/s]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  4.13it/s$
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  [0/14738]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  4.13it/s]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00,  4.14it/s]
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
Preprocessing dataset:   9%|██████████████████████████▍                                                                                                                                                                                                                                                                            | 1318/14732 [00:00<00:04, 3194.86it/s]--> Training Set Length = 14732
Preprocessing dataset:  51%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                                                                                                               | 7556/14732 [00:02<00:02, 3374.50it/s]--> Validation Set Length = 818
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:04<00:00, 3272.90it/s]
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 3350.07it/s]
--> Num of Validation Set Batches loaded = 8
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:04<00:00, 3324.63it/s]
Preprocessing dataset:  86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                         | 12661/14732 [00:03<00:00, 3174.45it/s]/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 3365.68it/s]
--> Num of Validation Set Batches loaded = 8
Preprocessing dataset:  59%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                                                                                        | 8707/14732 [00:02<00:01, 3237.27it/s]/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:04<00:00, 3149.61it/s]
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 3361.27it/s]
--> Num of Validation Set Batches loaded = 8
Preprocessing dataset:  74%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                           | 10974/14732 [00:03<00:01, 3205.69it/s]/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:04<00:00, 3251.47it/s]
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 3252.59it/s]
--> Num of Validation Set Batches loaded = 8
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                                                                                                                                                                                                                          | 0/159 [00:00<?, ?it/s]NCCL version 2.20.5+cuda12.4
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
Training Epoch: 1/3, step 1/159 completed (loss: 21.161834716796875):   1%|███▏                                                                                                                                                                                                                                                           | 2/159 [00:10<11:53,  4.55s/it]max training steps reached, stopping training, total train steps finished:  2
Training Epoch: 1/3, step 1/159 completed (loss: 21.161834716796875):   1%|███▏                                                                                                                                                                                                                                                           | 2/159 [00:10<13:33,  5.18s/it]
Training Epoch: 1/3, step 1/159 completed (loss: 21.4289608001709):   1%|███▏                                                                                                                                                                                                                                                             | 2/159 [00:10<14:09,  5.41s/it]
Training Epoch: 1/3, step 1/159 completed (loss: 21.154579162597656):   1%|███▏                                                                                                                                                                                                                                                           | 2/159 [00:09<12:36,  4.82s/it]
Training Epoch: 1/3, step 1/159 completed (loss: 21.540555953979492):   1%|███▏                                                                                                                                                                                                                                                           | 2/159 [00:07<10:26,  3.99s/it]
Max CUDA memory allocated was 20 GB
Max CUDA memory reserved was 29 GB
Peak active CUDA memory was 21 GB
CUDA Malloc retries : 0
CPU Total Peak Memory consumed during the train (max): 7 GB
evaluating Epoch: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:03<00:00,  2.22it/s]
evaluating Epoch: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:03<00:00,  2.22it/s]
evaluating Epoch: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:03<00:00,  2.16it/s]
evaluating Epoch: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:03<00:00,  2.20it/s]
 eval_ppl=tensor(88679.3359, device='cuda:0') eval_epoch_loss=tensor(11.3928, device='cuda:0')
best eval loss on epoch 1 is 11.392782211303711
Epoch 1: train_perplexity=1.1542, train_epoch_loss=0.1434, epoch time 8.7118956502527s
training params are saved in /home/mreso/llama-recipes/PATH/to/save/FSDP/model/fine-tuned-meta-llama/Meta-Llama-3.1-8B-Instruct/train_params.yaml
Key: avg_train_prep, Value: 1.1542373895645142
Key: avg_train_loss, Value: 0.14343982934951782
Key: avg_eval_prep, Value: 88679.3359375
Key: avg_eval_loss, Value: 11.392782211303711
Key: avg_epoch_time, Value: 8.7118956502527
Key: avg_checkpoint_time, Value: 1.8780119717121124e-06

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue? Please add a link
    to it if that's the case.

Thanks for contributing 🎉!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants