From 58903e5096250dc9c3e97ba54d31b58b0e8debe7 Mon Sep 17 00:00:00 2001 From: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Date: Thu, 21 Mar 2024 17:28:03 +0000 Subject: [PATCH] fix use_exllama logic error Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- server/text_generation_server/utils/weights.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index b4e1590e..4f2444e0 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -219,12 +219,13 @@ def get_multi_weights_row(self, prefix: str, quantize: str, row_perm=None, nosha if quantize == "gptq": bits, groupsize = self._get_gptq_params() - from text_generation_server.utils.layers import HAS_GPTQ_CUDA + from text_generation_server.utils.layers import HAS_GPTQ_CUDA, IS_TP_AWARE_GPTQ is_preshuffle = (row_perm != None) is_masked_matmul = noshard assert (is_preshuffle != is_masked_matmul or not (is_preshuffle or is_masked_matmul)), f"TP-aware optimization can't both be enabled at the same time {is_preshuffle=}, {is_masked_matmul=}" - use_gptq_cuda = (bits == 4) and HAS_GPTQ_CUDA or (is_preshuffle or is_masked_matmul) + + use_exllama = (bits == 4) and HAS_GPTQ_CUDA and (IS_TP_AWARE_GPTQ and (is_preshuffle or is_masked_matmul)) if self.process_group.rank == 0: if use_gptq_cuda: logger.info(f"Using GPTQ cuda kernels for row {prefix}")