Skip to content

Commit 438e327

Browse files
committed
Enable dequant fp8 weights quantized per-channel with compressed-tensor method
Signed-off-by: mandy-li <[email protected]>
1 parent e18a075 commit 438e327

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

vllm_gaudi/ops/hpu_compressed_tensors.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,19 @@ def get_hpu_scheme(self, layer: torch.nn.Module):
8888
raise ValueError(f"{scheme_classname} compressed format is not supported on HPU")
8989
return hpu_scheme
9090

91+
def dequant_fp8_weight(self, layer: torch.nn.Module) -> torch.Tensor:
92+
if layer.scheme.strategy == QuantizationStrategy.CHANNEL.value: # weights were quantized per-channel
93+
dequant_weight = layer.weight.to(
94+
layer.weight_scale.dtype) * layer.weight_scale.squeeze()
95+
return dequant_weight.to(torch.bfloat16).t()
96+
else:
97+
raise NotImplementedError(
98+
"Implemented per-channel dequantization only")
99+
100+
def get_dequant_weights_func(
101+
self, ) -> Optional[Callable[[torch.nn.Module], torch.Tensor]]:
102+
return self.dequant_fp8_weight
103+
91104

92105
@CustomOp.register_oot(name='CompressedTensorsW8A16Fp8')
93106
class HPUCompressedTensorsW8A8Fp8(CompressedTensorsScheme):

0 commit comments

Comments
 (0)