Skip to content

Commit 8c441ad

Browse files
committed
Enable dequant fp8 weights quantized per-channel with compressed-tensor method
Signed-off-by: mandy-li <[email protected]>
1 parent f71c4ca commit 8c441ad

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

vllm_gaudi/ops/hpu_compressed_tensors.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,16 @@ def get_hpu_scheme(self, layer: torch.nn.Module):
8888
raise ValueError(f"{scheme_classname} compressed format is not supported on HPU")
8989
return hpu_scheme
9090

91+
def dequant_fp8_weight(self, layer: torch.nn.Module) -> torch.Tensor:
92+
if layer.scheme.strategy == QuantizationStrategy.CHANNEL.value: # weights were quantized per-channel
93+
dequant_weight = layer.weight.to(layer.weight_scale.dtype) * layer.weight_scale.squeeze()
94+
return dequant_weight.to(torch.bfloat16).t()
95+
else:
96+
raise NotImplementedError("Implemented per-channel dequantization only")
97+
98+
def get_dequant_weights_func(self, ) -> Optional[Callable[[torch.nn.Module], torch.Tensor]]:
99+
return self.dequant_fp8_weight
100+
91101

92102
@CustomOp.register_oot(name='CompressedTensorsW8A16Fp8')
93103
class HPUCompressedTensorsW8A8Fp8(CompressedTensorsScheme):

0 commit comments

Comments
 (0)