Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance Optimizations for TP-Aware GPTQ #67

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

cyang49
Copy link
Contributor

@cyang49 cyang49 commented Mar 22, 2024

This is a draft. Please do not merge.

Motivation

Current tgis-native provides GPTQ support for llama and starcoder models by utilizing the fast exllamav2 kernel (and also Marlin #66 when PR is merged). It works well in single GPU deployment. However, for multi-GPU TP deployment, the performance is known to be bad when deploying GPTQ checkpoints that requires activation reordering (desc_act=True in quantization config). This includes many publicly available GPTQ checkpoints.

The reason for the bad performance of these models is that the fast exllamav2 (or Marlin) kernels cannot be used in row-parallel layers, as the weight matrix row shuffling requirement would introduce an extra all-gather communication to do global reordering of input activations of row-parallel layers in TP. The all-gather communication can be prohibitively expensive. As a result, in TGIS, the much slower Triton matmul_248 kernel, which doesn't require shuffling, is used. This 50% CUDA and 50% Triton mixed used in QuantLinear layers works but it is too slow to be a practical solution. vLLM uses a similar approach except that it uses an alternative gptq cuda kernel than the Triton kernel. It still suffers from less optimal performance.

In this PR, we implement TP-aware GPTQ model inference optimizations which includes the technique introduced in the arxiv paper we published previously for the MLP layers, and combining newer technique, masked matmul, for the attention layer optimization.

Preliminary results using exllamav2 show that our techniques enable deploying Llama-70b GPTQ on L40Sx2 getting 24.67 tokens/s, a 30% throughput improvement over deploying FP16 model on A100-80GBx2 (19 tokens/s) thus providing a good cost-saving alternatives for deploying llama-70b. We expect to see even better results using Marlin.

Modifications

The code changes include primarily control path adjustments to manipulate the loading of weight tensors and environment variable flags to toggle different modes.

Known issues:

  • The weight shuffling can slow down model loading significantly. Pack/unpack functions should be move to GPUs.
  • The code should be thoroughly tested for non-supported cases, as the control path is heavily modified
  • I welcome suggestions to make the control path modifications more clean
  • Santacoder support is not implemented yet
  • desc_act=False path may not have been sufficiently tested

Result

Prefill Token latency Throughput
FP16: L40Sx4 1.96s 62.33ms 16.04 tokens/s
GPTQ, TP-aware: L40Sx2 2.11s 40.55ms 24.67 tokens/s
GPTQ, original: L40Sx2 3.48s 84.21ms 11.88 tokens/s
  • GPTQ, TP-aware means using communication avoiding techniques and exllamav2 for dequantization+gemm
  • GPTQ, original means using exllamav2 for column-parallel layer and triton kernel for row-parallel layer (also avoids all-gather)

We plan to update the results when Marlin PR is merged.

Related Issues

To merge #66 to enable Marlin support

Signed-off-by: Chih-Chieh-Yang <[email protected]>
Signed-off-by: Chih-Chieh-Yang <[email protected]>
Xaenalt pushed a commit to Xaenalt/text-generation-inference that referenced this pull request Sep 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants