Skip to content

Commit 5a6a1f5

Browse files
committed
Enable dequant fp8 weights quantized per-channel with compressed-tensor method
Signed-off-by: mandy-li <[email protected]>
1 parent e38c8e9 commit 5a6a1f5

File tree

4 files changed

+35
-0
lines changed

4 files changed

+35
-0
lines changed

tests/full_tests/ci_gsm8k_tests.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,15 @@ run_compressed_w4a16_moe_gidx_test() {
162162
echo "✅ Test with compressed w4a16 MoE with g_idx passed."
163163
}
164164

165+
# Llama-3.3-70B-Instruct-FP8-dynamic + INC dynamic quant
166+
run_llama3_70b_inc_dynamic_quant_test() {
167+
echo "➡️ Testing Llama-3.3-70B-Instruct-FP8-dynamic + inc dynamic quant in torch.compile mode ..."
168+
QUANT_CONFIG="${VLLM_GAUDI_PREFIX}/tests/models/language/generation/inc_maxabs_dynamic_quant.json" \
169+
HABANA_VISIBLE_DEVICES=all RUNTIME_SCALE_PATCHING=0 VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=0 \
170+
python -u "${VLLM_GAUDI_PREFIX}/tests/full_tests/generate.py" --model RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic --max-model-len 2048
171+
echo "✅ Test with Llama-3.3-70B-Instruct-FP8-dynamic + inc dynamic quant in torch.compile mode passed."
172+
}
173+
165174
# GSM8K on granite-8b
166175
run_gsm8k_granite_test() {
167176
echo "➡️ Testing GSM8K on granite-8b..."
@@ -304,6 +313,7 @@ launch_all_tests() {
304313
run_spec_decode_ngram_test
305314
run_spec_decode_eagle3_test
306315
run_spec_decode_eagle3_num_spec_2_test
316+
run_llama3_70b_inc_dynamic_quant_test
307317
#run_embedding_model_test
308318
echo "🎉 All test suites passed successfully!"
309319
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"mode": "QUANTIZE",
3+
"observer": "maxabs",
4+
"scale_method": "ACT_MAXABS_PCS_POW2_WEIGHT_MAXABS_PTS_POW2_HW",
5+
"dynamic_quantization": "True",
6+
"scale_format": "CONST",
7+
"dump_stats_path": ""
8+
}

vllm_gaudi/extension/ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,12 @@ def wrapper(*args, **kwargs):
767767
return wrapper
768768

769769

770+
def bind_dequant_func(layer):
771+
# For INC path, we attach the dequant func to the layer
772+
layer.get_dequant_weights_func = types.MethodType(get_dequant_weights_func, layer)
773+
return layer
774+
775+
770776
def fp8_block_linear_postprocess_weights(layer, force_channel_fp8=False):
771777
weight, orig_M, orig_N = pad_block_fp8_weight_naive(layer.weight.data, layer.weight_scale_inv.data,
772778
layer.quant_config.weight_block_size)

vllm_gaudi/ops/hpu_compressed_tensors.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ def get_hpu_scheme(self, layer: torch.nn.Module):
8888
raise ValueError(f"{scheme_classname} compressed format is not supported on HPU")
8989
return hpu_scheme
9090

91+
def dequant_fp8_weight(self, layer: torch.nn.Module) -> torch.Tensor:
92+
if layer.scheme.strategy == QuantizationStrategy.CHANNEL: # weights were quantized per-channel
93+
dequant_weight = layer.weight.to(layer.weight_scale.dtype) * layer.weight_scale.squeeze()
94+
return dequant_weight.to(torch.bfloat16).t()
95+
else:
96+
raise NotImplementedError("Implemented per-channel dequantization only")
97+
9198

9299
@CustomOp.register_oot(name='CompressedTensorsW8A16Fp8')
93100
class HPUCompressedTensorsW8A8Fp8(CompressedTensorsScheme):
@@ -115,6 +122,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
115122
# required by torch.compile to be torch.nn.Parameter
116123
layer.input_scale = torch.nn.Parameter(layer.input_scale.data, requires_grad=False)
117124

125+
# bind dequant function to layer for per-channel quantization
126+
if layer.scheme.strategy == QuantizationStrategy.CHANNEL:
127+
hpu_ops.bind_dequant_func(layer)
128+
118129
def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: list[int],
119130
input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs):
120131
"""

0 commit comments

Comments
 (0)