Skip to content

Commit

Permalink
fix use_exllama logic error
Browse files Browse the repository at this point in the history
Signed-off-by: Chih-Chieh-Yang <[email protected]>
  • Loading branch information
cyang49 committed Mar 25, 2024
1 parent c3efc60 commit 58903e5
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions server/text_generation_server/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,13 @@ def get_multi_weights_row(self, prefix: str, quantize: str, row_perm=None, nosha
if quantize == "gptq":
bits, groupsize = self._get_gptq_params()

from text_generation_server.utils.layers import HAS_GPTQ_CUDA
from text_generation_server.utils.layers import HAS_GPTQ_CUDA, IS_TP_AWARE_GPTQ
is_preshuffle = (row_perm != None)
is_masked_matmul = noshard
assert (is_preshuffle != is_masked_matmul
or not (is_preshuffle or is_masked_matmul)), f"TP-aware optimization can't both be enabled at the same time {is_preshuffle=}, {is_masked_matmul=}"
use_gptq_cuda = (bits == 4) and HAS_GPTQ_CUDA or (is_preshuffle or is_masked_matmul)

use_exllama = (bits == 4) and HAS_GPTQ_CUDA and (IS_TP_AWARE_GPTQ and (is_preshuffle or is_masked_matmul))
if self.process_group.rank == 0:
if use_gptq_cuda:
logger.info(f"Using GPTQ cuda kernels for row {prefix}")
Expand Down

0 comments on commit 58903e5

Please sign in to comment.