-
Notifications
You must be signed in to change notification settings - Fork 74
Enable dequant fp8 weights quantized per-channel with compressed-tensor method #621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
✅ CI PassedAll checks passed successfully against the following vllm commit: |
|
|
||
| def get_dequant_weights_func(self, ) -> Optional[Callable[[torch.nn.Module], torch.Tensor]]: | ||
| return self.dequant_fp8_weight | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to assign get_dequant_weights_func to the layer to stay consistent with the existing implementation, and no changes are required on the INC side.
vllm-gaudi/vllm_gaudi/extension/ops.py
Lines 787 to 789 in e18a075
| else: | |
| # For INC path, we attach the dequant func to the layer | |
| layer.get_dequant_weights_func = types.MethodType(get_dequant_weights_func, layer) |
✅ CI PassedAll checks passed successfully against the following vllm commit: |
|
@yiliu30 , address your comment by binding dequant function to linear layer after loading weight. Please review |
yiliu30
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
@xuechendi Please be aware this change, thanks! |
73d8ed6 to
b91f94c
Compare
✅ CI PassedAll checks passed successfully against the following vllm commit: |
|
@skavulya @lkk12014402 , please help to cross review since you're working on compressed-tensor |
…or method Signed-off-by: mandy-li <[email protected]>
|
|
||
| # bind dequant function to layer for per-channel quantization | ||
| if layer.scheme.strategy == QuantizationStrategy.CHANNEL: | ||
| hpu_ops.bind_dequant_func(layer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the PR is only for INC dynamic, should not bind the dequant for any per-channel here, right?
What is the scope for this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For inc. I can check if QUANT_CONFIG env var is set or not if you think necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add check, we can't hijack non inc path
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems the change will also impact #552
Please also did a check for dynamic scheme
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this should apply to static quant as well.
If we change bind_dequant_func() to something like fp8_perchannel_linear_postprocess_weights to be consistent with fp8_block_linear_postprocess_weights which is not INC specific, do I still need to check if inc path?
| def dequant_fp8_weight(self, layer: torch.nn.Module) -> torch.Tensor: | ||
| if layer.scheme.strategy == QuantizationStrategy.CHANNEL: # weights were quantized per-channel | ||
| dequant_weight = layer.weight.to(layer.weight_scale.dtype) * layer.weight_scale.squeeze() | ||
| return dequant_weight.to(torch.bfloat16).t() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this works for Gaudi2? Will it gets nan since scale might out of range
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, for g3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I checked CI, seems Gaudi2 is not getting nan, this is quite unexpected.
@yiliu30, is there any recent changes fix the Gaudi2 scale issue? Or it is because "scale_method": "ACT_MAXABS_PCS_POW2_WEIGHT_MAXABS_PTS_POW2_HW", will keep range under 244?

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never mind, I realized this is handled at create_weights
|
@yiliu30 , please help to review, this PR is to enable INC dynamic for compressed_tensor, would like to know if meet your initial design |
✅ CI PassedAll checks passed successfully against the following vllm commit: |
|
LGTM |
| return wrapper | ||
|
|
||
|
|
||
| def bind_dequant_func(layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest to follow the same name pattern to the rest: fp8_perchannel_linear_postprocess_weights
Yes, it’s aligned with what we did for block-wise scaling. |
This PR enables dequant fp8 weights quantized with compressed-tensor method channel-wise