-
Notifications
You must be signed in to change notification settings - Fork 168
GPTQ quantizer and example of using it to quantize chatglm-6b #1100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
chenbohua3
wants to merge
8
commits into
alibaba:main
Choose a base branch
from
chenbohua3:gptq_quantizer
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 5 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
9313215
gptq quantizer
chenbohua3 05f2b19
gptq quantizer
chenbohua3 8636ef8
refine
chenbohua3 5f6ccbd
refine
chenbohua3 4864fbf
refine
chenbohua3 8ac54cf
full version of gptq and chatglm example
chenbohua3 ded680a
fix
chenbohua3 98f1d2f
fix
chenbohua3 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| # Introduction | ||
|
|
||
| This directory contains some advanced quantization algorithms. These quantization algorithms are | ||
| difficult to implement under the original torch-quant framework (or not necessary, for example, | ||
| if you do weight-only quantization, fx graph may be not necessary). So under the premise of ensuring | ||
| that the interface is consistent with that in torch-quant, we implement the corresponding quantizer | ||
| for each advanced quantization algorithm alone. | ||
|
|
||
|
|
||
| # Supported algorithms | ||
|
|
||
| ### GPTQ | ||
| The official [GPTQ codes](https://github.com/IST-DASLab/gptq) are referenced. | ||
|
|
||
| NOTE: There is one small difference between the official GPTQ implementation and the one here. | ||
| In the official implementation, the inputs used to calculate the H matrix of a specific layer, | ||
| are obtained after the weight of the previous layers are quantized, which is not easy to be achieved | ||
| (The model should be calibrated many times). We relex this condition, that is, the inputs used to | ||
| calculate the H matrix are obtained when the weight of the previous layers are NOT quantized. So | ||
| the calibration data only needs to be executed once on the model, and all H matrices can be calculated. | ||
| This relaxation only results in a slight loss of performance for 4 bit quantization. Of course, we are | ||
| also figuring out how to enable this in a user-friendly way. | ||
|
|
||
| ``` | ||
| @article{frantar-gptq, | ||
| title={{GPTQ}: Accurate Post-training Compression for Generative Pretrained Transformers}, | ||
| author={Elias Frantar and Saleh Ashkboos and Torsten Hoefler and Dan Alistarh}, | ||
| year={2022}, | ||
| journal={arXiv preprint arXiv:2210.17323} | ||
| } | ||
| ``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| # Copyright 2023 The BladeDISC Authors. All rights reserved. | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from .gptq import GPTQuantizer |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,245 @@ | ||
| # Copyright 2023 The BladeDISC Authors. All rights reserved. | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import logging | ||
| import math | ||
| from typing import Callable, List, Optional | ||
|
|
||
| import torch | ||
| from torch import nn | ||
| from torch_quant.module import ModuleFilter | ||
| from torch_quant.observer import Observer | ||
| from torch_quant.quantizer import DEFAULT_W_OB_CTR, Backend, Device, get_default_ctr | ||
|
|
||
| LOGGER = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| try: | ||
| import transformers | ||
| is_transformers_avail = True | ||
| except ModuleNotFoundError: | ||
| LOGGER.warning("transformers is not installed, so that gptq can not be applied to transformers.Conv1D") | ||
| is_transformers_avail = False | ||
|
|
||
|
|
||
| # TODO: For models that can not run the gptq process within single GPU | ||
| # we should support cpu offload. | ||
|
|
||
| QUANT_LAYERS = [nn.Linear, nn.Conv2d] | ||
| if is_transformers_avail: | ||
| QUANT_LAYERS.append(transformers.Conv1D) | ||
|
|
||
|
|
||
| def is_transformer_conv1d(layer): | ||
| return is_transformers_avail and isinstance(layer, transformers.Conv1D) | ||
|
|
||
|
|
||
| class GPTQObserver: | ||
| def __init__(self, observer: Observer): | ||
| self.observer = observer | ||
| self.scales = [] | ||
| self.zero_points = [] | ||
|
|
||
| def find_quant_info(self, x): | ||
| self.observer.set_mode(observe=True, fake_quant=False) | ||
| self.observer(x) | ||
| self.scales.append(self.observer.scale) | ||
| self.zero_points.append(self.observer.zero_point) | ||
| # todo: reset the observer | ||
|
|
||
| def fake_quant(self, x): | ||
| self.observer.set_mode(observe=False, fake_quant=True) | ||
| x = self.observer(x) | ||
| return x | ||
|
|
||
|
|
||
| class GPTQLayerWrapper: | ||
| def __init__(self, layer, observer): | ||
| super().__init__() | ||
| self.layer = layer | ||
| self.gptq_observer = GPTQObserver(observer) | ||
| self.device = layer.weight.device | ||
| columns = layer.weight.shape[1] | ||
| self.columns = columns | ||
| self.H = torch.zeros((columns, columns), device=self.device) | ||
| self.nsamples = 0 | ||
|
|
||
| def record_h(self, x): | ||
| x = x.detach().clone() | ||
| if len(x.shape) == 2: | ||
| x = x.unsqueeze(0) | ||
| batch = x.shape[0] | ||
| if isinstance(self.layer, nn.Linear) or is_transformer_conv1d(self.layer): | ||
| if len(x.shape) == 3: | ||
| x = x.reshape((-1, x.shape[-1])) | ||
| x = x.t() | ||
|
|
||
| if isinstance(self.layer, nn.Conv2d): | ||
| unfold = nn.Unfold( | ||
| self.layer.kernel_size, | ||
| dilation=self.layer.dilation, | ||
| padding=self.layer.padding, | ||
| stride=self.layer.stride | ||
| ) | ||
| x = unfold(x) | ||
| x = x.permute([1, 0, 2]) | ||
| x = x.flatten(1) | ||
|
|
||
| self.H *= self.nsamples / (self.nsamples + batch) | ||
| self.nsamples += batch | ||
| x = math.sqrt(2 / self.nsamples) * x.float() | ||
| self.H += x.matmul(x.t()) | ||
|
|
||
| def quant_weight(self, blocksize=128, percdamp=.01, groupsize=-1): | ||
| weight = self.layer.weight.data.clone() | ||
| if isinstance(self.layer, nn.Conv2d): | ||
| weight = weight.flatten(1) | ||
| if is_transformer_conv1d(self.layer): | ||
| weight = weight.t() | ||
| weight = weight.float() | ||
|
|
||
| if groupsize == -1: | ||
| self.gptq_observer.find_quant_info(weight) | ||
|
|
||
| H = self.H | ||
| # del self.H | ||
| dead = torch.diag(H) == 0 | ||
| H[dead, dead] = 1 | ||
| weight[:, dead] = 0 | ||
|
|
||
| losses = torch.zeros_like(weight) | ||
| Q = torch.zeros_like(weight) | ||
|
|
||
| damp = percdamp * torch.mean(torch.diag(H)) | ||
| diag = torch.arange(self.columns, device=self.device) | ||
| H[diag, diag] += damp | ||
| try: | ||
| H = torch.linalg.cholesky(H) | ||
| H = torch.cholesky_inverse(H) | ||
| H = torch.linalg.cholesky(H, upper=True) | ||
| except Exception: | ||
| # TODO: should handle this situation | ||
| logging.warning("Warning: cannot do compression for inverse error") | ||
|
|
||
| if H.isnan().any(): | ||
| # TODO: should handle this situation | ||
| logging.warning("Warning: cannot do compression for inverse error") | ||
|
|
||
| hinv = H | ||
|
|
||
| for i1 in range(0, self.columns, blocksize): | ||
| i2 = min(i1 + blocksize, self.columns) | ||
| count = i2 - i1 | ||
|
|
||
| w1 = weight[:, i1:i2].clone() | ||
| q1 = torch.zeros_like(w1) | ||
| total_err = torch.zeros_like(w1) | ||
| losses1 = torch.zeros_like(w1) | ||
| hinv1 = hinv[i1:i2, i1:i2] | ||
|
|
||
| for i in range(count): | ||
| w = w1[:, i] | ||
| d = hinv1[i, i] | ||
|
|
||
| if groupsize != -1: | ||
| if (i1 + i) % groupsize == 0: | ||
| self.gptq_observer.find_quant_info(weight[:, (i1 + i):(i1 + i + groupsize)]) | ||
| q = self.gptq_observer.fake_quant(w.unsqueeze(1)).flatten() | ||
|
|
||
| q1[:, i] = q | ||
| losses1[:, i] = (w - q) ** 2 / d ** 2 | ||
| err = (w - q) / d | ||
| w1[:, i:] -= err.unsqueeze(1).matmul(hinv1[i, i:].unsqueeze(0)) | ||
| total_err[:, i] = err | ||
|
|
||
| Q[:, i1:i2] = q1 | ||
| losses[:, i1:i2] = losses1 / 2 | ||
|
|
||
| weight[:, i2:] -= total_err.matmul(hinv[i1:i2, i2:]) | ||
|
|
||
| # torch.cuda.synchronize() | ||
|
|
||
| if is_transformer_conv1d(self.layer): | ||
| Q = Q.t() | ||
| self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) | ||
|
|
||
|
|
||
| class GPTQModuleWrapper: | ||
| def __init__(self, module: nn.Module, w_ob_ctr): | ||
| self.all_layers = {} | ||
| self.all_handles = [] | ||
|
|
||
| def get_hook(layer_name): | ||
| def record_hook(_, x): | ||
| self.all_layers[layer_name].record_h(x[0]) | ||
| return record_hook | ||
|
|
||
| for name, layer in module.named_modules(): | ||
| if isinstance(layer, tuple(QUANT_LAYERS)): | ||
| self.all_layers[name] = GPTQLayerWrapper(layer, w_ob_ctr()) | ||
| handle = layer.register_forward_pre_hook(get_hook(name)) | ||
| self.all_handles.append(handle) | ||
|
|
||
| def quant_module(self): | ||
| for _, wrapper in self.all_layers.items(): | ||
| wrapper.quant_weight() | ||
|
|
||
| for h in self.all_handles: | ||
| h.remove() | ||
|
|
||
|
|
||
| class GPTQuantizer: | ||
| def __init__(self, module_filter: Optional[ModuleFilter] = None, | ||
| backend: Backend = Backend.DISC, | ||
| device: Device = Device.GPU, | ||
| block: Optional[List[nn.Module]] = None | ||
| ) -> None: | ||
| self.module_filter = module_filter | ||
| self.backend = backend | ||
| self.device = device | ||
| self.all_module_wrappers = {} | ||
| self.block = block or QUANT_LAYERS | ||
|
|
||
| def calib(self, model: nn.Module, | ||
| w_ob_ctr: Optional[Callable[..., Observer]] = None): | ||
| default_w_ob_ctr = get_default_ctr(DEFAULT_W_OB_CTR, self.device, self.backend) | ||
| w_ob_ctr = w_ob_ctr or default_w_ob_ctr | ||
| # GPTQ only quantize the weight of LLMs, so there is no need to | ||
| # convert it to fx module | ||
| for name, module in model.named_modules(): | ||
| if isinstance(module, tuple(self.block)): | ||
| self.all_module_wrappers[name] = GPTQModuleWrapper(module, w_ob_ctr) | ||
|
|
||
| return model | ||
|
|
||
| def quantize(self, model: nn.Module): | ||
| for _, module_wrapper in self.all_module_wrappers.items(): | ||
| module_wrapper.quant_module() | ||
|
|
||
| return model | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| class MyModel(nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.linear = nn.Linear(2048, 2048) | ||
|
|
||
| def forward(self, x): | ||
| return self.linear(x) | ||
|
|
||
| model = MyModel() | ||
| ss = dict(model.named_modules()) | ||
| quantizer = GPTQuantizer() | ||
| calib_model = quantizer.calib(model) | ||
| dummy = torch.randn(1, 2048, 2048) | ||
| calib_model(dummy) | ||
| quant_model = quantizer.quantize(model) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
calib_model,相对于model加了一个hook,calib_model is model是True吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,只是为了保持接口和常规量化一致