From 9313215ba794e04d29dbc97b44090edef4fa347b Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Fri, 17 Mar 2023 15:45:49 +0800 Subject: [PATCH 1/8] gptq quantizer --- .../torch_quant/experiment/gptq.py | 219 ++++++++++++++++++ 1 file changed, 219 insertions(+) create mode 100644 tools/torch_quant/torch_quant/experiment/gptq.py diff --git a/tools/torch_quant/torch_quant/experiment/gptq.py b/tools/torch_quant/torch_quant/experiment/gptq.py new file mode 100644 index 00000000000..cd51746f599 --- /dev/null +++ b/tools/torch_quant/torch_quant/experiment/gptq.py @@ -0,0 +1,219 @@ +import logging +import math +from typing import Callable, List, Optional + +import torch +from torch import nn +from torch_quant.module import ModuleFilter +from torch_quant.observer import Observer +from torch_quant.quantizer import DEFAULT_W_OB_CTR, Backend, Device, get_default_ctr + +LOGGER = logging.getLogger(__name__) + + +try: + import transformers + is_transformers_avail = True +except ModuleNotFoundError: + LOGGER.warning("transformers is not installed, so that gptq can not be applied to transformers.Conv1D") + is_transformers_avail = False + + +# TODO: For models that can not run the gptq process within single GPU +# we should support cpu offload. + +QUANT_LAYERS = [nn.Linear, nn.Conv2d] +if is_transformers_avail: + QUANT_LAYERS.append(transformers.Conv1D) + + +class GPTQObserver: + def __init__(self, observer: Observer): + self.observer = observer + self.scales = [] + self.zero_points = [] + + def find_quant_info(self, x): + self.observer.set_mode(observe=True, fake_quant=False) + self.observer(x) + self.scales.append(self.observer.scale) + self.zero_points.append(self.observer.zero_point) + # todo: reset the observer + + +class GPTQLayerWrapper: + def __init__(self, layer, observer): + super().__init__() + self.layer = layer + self.gptq_observer = GPTQObserver(observer) + self.device = layer.weight.device + columns = layer.weight.shape[1] + self.columns = columns + self.H = torch.zeros((columns, columns), device=self.device) + self.nsamples = 0 + + def record(self, x): + x = x.detach().clone() + if len(x.shape) == 2: + x = x.unsqueeze(0) + batch = x.shape[0] + if isinstance(self.layer, nn.Linear) or (is_transformers_avail and isinstance(self.layer, transformers.Conv1D)): + if len(x.shape) == 3: + x = x.reshape((-1, x.shape[-1])) + x = x.t() + + if isinstance(self.layer, nn.Conv2d): + unfold = nn.Unfold( + self.layer.kernel_size, + dilation=self.layer.dilation, + padding=self.layer.padding, + stride=self.layer.stride + ) + x = unfold(x) + x = x.permute([1, 0, 2]) + x = x.flatten(1) + + self.H *= self.nsamples / (self.nsamples + batch) + self.nsamples += batch + x = math.sqrt(2 / self.nsamples) * x.float() + self.H += x.matmul(x.t()) + + def quant(self, blocksize=128, percdamp=.01, groupsize=-1): + weight = self.layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + weight = weight.flatten(1) + if is_transformers_avail and isinstance(self.layer, transformers.Conv1D): + weight = weight.t() + weight = weight.float() + + if groupsize == -1: + self.gptq_observer.find_quant_info(weight) + + H = self.H + # del self.H + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + weight[:, dead] = 0 + + losses = torch.zeros_like(weight) + Q = torch.zeros_like(weight) + mask = torch.ones_like(weight) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.device) + H[diag, diag] += damp + try: + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + except Exception: + logging.warning(f"Warning: cannot do compression for inverse error") + + if H.isnan().any(): + logging.warning(f"Warning: cannot do compression for inverse error") + + hinv = H + hinv_diag = torch.diag(hinv) + + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + w1 = weight[:, i1:i2].clone() + q1 = torch.zeros_like(w1) + err1 = torch.zeros_like(w1) + losses1 = torch.zeros_like(w1) + hinv1 = hinv[i1:i2, i1:i2] + + for i in range(count): + w = w1[:, i] + d = hinv1[i, i] + + if groupsize != -1: + if (i1 + i) % groupsize == 0: + self.gptq_observer.find_quant_info(weight[:, (i1 + i):(i1 + i + groupsize)]) + + q = quantize( + w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq + ).flatten() + + + q1[:, i] = q + losses1[:, i] = (w - q) ** 2 / d ** 2 + err1 = (w - q) / d # 此处量化和稀疏不太一样啊,damo (w - q)**2 / d + w1[:, i:] -= err1.unsqueeze(1).matmul(hinv1[i, i:].unsqueeze(0)) + err1[:, i] = err1 + + Q[:, i1:i2] = q1 + losses[:, i1:i2] = losses1 / 2 + + weight[:, i2:] -= err1.matmul(hinv[i1:i2, i2:]) + + torch.cuda.synchronize() + + if isinstance(self.layer, transformers.Conv1D): + Q = Q.t() + self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) + + +class GPTQModuleWrapper: + def __init__(self, module: nn.Module, w_ob_ctr): + self.all_layers = {} + self.all_handles = [] + def get_hook(name): + def record(_, x): + self.all_layers[name].record(x[0]) + return record + + for name, layer in module.named_modules(): + if isinstance(layer, tuple(QUANT_LAYERS)): + self.all_layers[name] = GPTQLayerWrapper(layer, w_ob_ctr()) + handle = layer.register_forward_pre_hook(get_hook(name)) + self.all_handles.append(handle) + + +class GPTQuantizer: + def __init__(self, module_filter: Optional[ModuleFilter] = None, + backend: Backend = Backend.DISC, + device: Device = Device.GPU, + block: Optional[List[nn.Module]] = None + ) -> None: + self.module_filter = module_filter + self.backend = backend + self.device = device + self.all_module_wrappers = {} + self.block = block or QUANT_LAYERS + + def calib(self, model: nn.Module, + w_ob_ctr: Optional[Callable[..., Observer]] = None): + default_w_ob_ctr = get_default_ctr(DEFAULT_W_OB_CTR, self.device, self.backend) + w_ob_ctr = w_ob_ctr or default_w_ob_ctr + # GPTQ only quantize the weight of LLMs, so there is no need to + # convert it to fx module + for name, module in model.named_modules(): + if isinstance(module, tuple(self.block)): + self.all_module_wrappers[name] = GPTQModuleWrapper(module, w_ob_ctr) + + return model + + def quantize(self, model: nn.Module): + pass + + + +if __name__ == "__main__": + class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(3, 4) + + def forward(self, x): + return self.linear(x) + + model = MyModel() + ss = dict(model.named_modules()) + quantizer = GPTQuantizer() + calib_model = quantizer.calib(model) + dummy = torch.randn(1, 3) + calib_model(dummy) + quant_model = quantizer.quantize(model) From 05f2b191235dd9c8cff3f6ae7c6458eb617ab488 Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Mon, 20 Mar 2023 11:48:38 +0800 Subject: [PATCH 2/8] gptq quantizer --- tools/torch_quant/torch_quant/experiment/gptq.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tools/torch_quant/torch_quant/experiment/gptq.py b/tools/torch_quant/torch_quant/experiment/gptq.py index cd51746f599..efe0340da76 100644 --- a/tools/torch_quant/torch_quant/experiment/gptq.py +++ b/tools/torch_quant/torch_quant/experiment/gptq.py @@ -40,6 +40,11 @@ def find_quant_info(self, x): self.zero_points.append(self.observer.zero_point) # todo: reset the observer + def fake_quant(self, x): + self.observer.set_mode(observe=False, fake_quant=True) + x = self.observer(x) + return x + class GPTQLayerWrapper: def __init__(self, layer, observer): @@ -107,9 +112,11 @@ def quant(self, blocksize=128, percdamp=.01, groupsize=-1): H = torch.cholesky_inverse(H) H = torch.linalg.cholesky(H, upper=True) except Exception: + # TODO: should handle this situation logging.warning(f"Warning: cannot do compression for inverse error") if H.isnan().any(): + # TODO: should handle this situation logging.warning(f"Warning: cannot do compression for inverse error") hinv = H @@ -132,15 +139,11 @@ def quant(self, blocksize=128, percdamp=.01, groupsize=-1): if groupsize != -1: if (i1 + i) % groupsize == 0: self.gptq_observer.find_quant_info(weight[:, (i1 + i):(i1 + i + groupsize)]) - - q = quantize( - w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq - ).flatten() - + q = self.gptq_observer.fake_quant(w.unsqueeze(1)).flatten() q1[:, i] = q losses1[:, i] = (w - q) ** 2 / d ** 2 - err1 = (w - q) / d # 此处量化和稀疏不太一样啊,damo (w - q)**2 / d + err1 = (w - q) / d w1[:, i:] -= err1.unsqueeze(1).matmul(hinv1[i, i:].unsqueeze(0)) err1[:, i] = err1 From 8636ef81d0f925860e746f421594df3fd37ca947 Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Mon, 20 Mar 2023 16:01:37 +0800 Subject: [PATCH 3/8] refine --- tools/torch_quant/torch_quant/experiment/gptq.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tools/torch_quant/torch_quant/experiment/gptq.py b/tools/torch_quant/torch_quant/experiment/gptq.py index efe0340da76..a9a05432ac1 100644 --- a/tools/torch_quant/torch_quant/experiment/gptq.py +++ b/tools/torch_quant/torch_quant/experiment/gptq.py @@ -27,6 +27,10 @@ QUANT_LAYERS.append(transformers.Conv1D) +def is_transformer_conv1d(layer): + return is_transformers_avail and isinstance(layer, transformers.Conv1D) + + class GPTQObserver: def __init__(self, observer: Observer): self.observer = observer @@ -62,7 +66,7 @@ def record(self, x): if len(x.shape) == 2: x = x.unsqueeze(0) batch = x.shape[0] - if isinstance(self.layer, nn.Linear) or (is_transformers_avail and isinstance(self.layer, transformers.Conv1D)): + if isinstance(self.layer, nn.Linear) or is_transformer_conv1d(self.layer): if len(x.shape) == 3: x = x.reshape((-1, x.shape[-1])) x = x.t() @@ -87,7 +91,7 @@ def quant(self, blocksize=128, percdamp=.01, groupsize=-1): weight = self.layer.weight.data.clone() if isinstance(self.layer, nn.Conv2d): weight = weight.flatten(1) - if is_transformers_avail and isinstance(self.layer, transformers.Conv1D): + if is_transformer_conv1d(self.layer): weight = weight.t() weight = weight.float() @@ -154,7 +158,7 @@ def quant(self, blocksize=128, percdamp=.01, groupsize=-1): torch.cuda.synchronize() - if isinstance(self.layer, transformers.Conv1D): + if is_transformer_conv1d(self.layer): Q = Q.t() self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) From 5f6ccbd618ddcdf2c85c97fdbc2e0e58e500f345 Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Tue, 28 Mar 2023 14:25:31 +0800 Subject: [PATCH 4/8] refine --- .../torch_quant/experiment/README.md | 31 +++++++++++++ .../torch_quant/experiment/__init__.py | 0 .../torch_quant/experiment/gptq.py | 46 +++++++++++-------- 3 files changed, 58 insertions(+), 19 deletions(-) create mode 100644 tools/torch_quant/torch_quant/experiment/README.md create mode 100644 tools/torch_quant/torch_quant/experiment/__init__.py diff --git a/tools/torch_quant/torch_quant/experiment/README.md b/tools/torch_quant/torch_quant/experiment/README.md new file mode 100644 index 00000000000..d0bc5c75dec --- /dev/null +++ b/tools/torch_quant/torch_quant/experiment/README.md @@ -0,0 +1,31 @@ +# Introduction + +This directory contains some advanced quantization algorithms. These quantization algorithms are +difficult to implement under the original torch-quant framework (or not necessary, for example, +if you do weight-only quantization, fx graph may be not necessary). So under the premise of ensuring +that the interface is consistent with that in torch-quant, we implement the corresponding quantizer +for each advanced quantization algorithm alone. + + +# Supported algorithms + +### GPTQ +The official [GPTQ codes](https://github.com/IST-DASLab/gptq) are referenced. + +NOTE: There is one small difference between the official GPTQ implementation and the one here. +In the official implementation, the inputs used to calculate the H matrix of a specific layer, +are obtained after the weight of the previous layers are quantized, which is not easy to be achieved +(The model should be calibrated many times). We relex this condition, that is, the inputs used to +calculate the H matrix are obtained when the weight of the previous layers are NOT quantized. So +the calibration data only needs to be executed once on the model, and all H matrices can be calculated. +This relaxation only results in a slight loss of performance for 4 bit quantization. Of course, we are +also figuring out how to enable this in a user-friendly way. + +``` +@article{frantar-gptq, + title={{GPTQ}: Accurate Post-training Compression for Generative Pretrained Transformers}, + author={Elias Frantar and Saleh Ashkboos and Torsten Hoefler and Dan Alistarh}, + year={2022}, + journal={arXiv preprint arXiv:2210.17323} +} +``` diff --git a/tools/torch_quant/torch_quant/experiment/__init__.py b/tools/torch_quant/torch_quant/experiment/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tools/torch_quant/torch_quant/experiment/gptq.py b/tools/torch_quant/torch_quant/experiment/gptq.py index a9a05432ac1..79a4adfb219 100644 --- a/tools/torch_quant/torch_quant/experiment/gptq.py +++ b/tools/torch_quant/torch_quant/experiment/gptq.py @@ -61,7 +61,7 @@ def __init__(self, layer, observer): self.H = torch.zeros((columns, columns), device=self.device) self.nsamples = 0 - def record(self, x): + def record_h(self, x): x = x.detach().clone() if len(x.shape) == 2: x = x.unsqueeze(0) @@ -87,7 +87,7 @@ def record(self, x): x = math.sqrt(2 / self.nsamples) * x.float() self.H += x.matmul(x.t()) - def quant(self, blocksize=128, percdamp=.01, groupsize=-1): + def quant_weight(self, blocksize=128, percdamp=.01, groupsize=-1): weight = self.layer.weight.data.clone() if isinstance(self.layer, nn.Conv2d): weight = weight.flatten(1) @@ -106,7 +106,6 @@ def quant(self, blocksize=128, percdamp=.01, groupsize=-1): losses = torch.zeros_like(weight) Q = torch.zeros_like(weight) - mask = torch.ones_like(weight) damp = percdamp * torch.mean(torch.diag(H)) diag = torch.arange(self.columns, device=self.device) @@ -117,14 +116,13 @@ def quant(self, blocksize=128, percdamp=.01, groupsize=-1): H = torch.linalg.cholesky(H, upper=True) except Exception: # TODO: should handle this situation - logging.warning(f"Warning: cannot do compression for inverse error") + logging.warning("Warning: cannot do compression for inverse error") if H.isnan().any(): # TODO: should handle this situation - logging.warning(f"Warning: cannot do compression for inverse error") + logging.warning("Warning: cannot do compression for inverse error") hinv = H - hinv_diag = torch.diag(hinv) for i1 in range(0, self.columns, blocksize): i2 = min(i1 + blocksize, self.columns) @@ -132,7 +130,7 @@ def quant(self, blocksize=128, percdamp=.01, groupsize=-1): w1 = weight[:, i1:i2].clone() q1 = torch.zeros_like(w1) - err1 = torch.zeros_like(w1) + total_err = torch.zeros_like(w1) losses1 = torch.zeros_like(w1) hinv1 = hinv[i1:i2, i1:i2] @@ -147,16 +145,16 @@ def quant(self, blocksize=128, percdamp=.01, groupsize=-1): q1[:, i] = q losses1[:, i] = (w - q) ** 2 / d ** 2 - err1 = (w - q) / d - w1[:, i:] -= err1.unsqueeze(1).matmul(hinv1[i, i:].unsqueeze(0)) - err1[:, i] = err1 + err = (w - q) / d + w1[:, i:] -= err.unsqueeze(1).matmul(hinv1[i, i:].unsqueeze(0)) + total_err[:, i] = err Q[:, i1:i2] = q1 losses[:, i1:i2] = losses1 / 2 - weight[:, i2:] -= err1.matmul(hinv[i1:i2, i2:]) + weight[:, i2:] -= total_err.matmul(hinv[i1:i2, i2:]) - torch.cuda.synchronize() + # torch.cuda.synchronize() if is_transformer_conv1d(self.layer): Q = Q.t() @@ -167,10 +165,11 @@ class GPTQModuleWrapper: def __init__(self, module: nn.Module, w_ob_ctr): self.all_layers = {} self.all_handles = [] - def get_hook(name): - def record(_, x): - self.all_layers[name].record(x[0]) - return record + + def get_hook(layer_name): + def record_hook(_, x): + self.all_layers[layer_name].record_h(x[0]) + return record_hook for name, layer in module.named_modules(): if isinstance(layer, tuple(QUANT_LAYERS)): @@ -178,6 +177,13 @@ def record(_, x): handle = layer.register_forward_pre_hook(get_hook(name)) self.all_handles.append(handle) + def quant_module(self): + for _, wrapper in self.all_layers.items(): + wrapper.quant_weight() + + for h in self.all_handles: + h.remove() + class GPTQuantizer: def __init__(self, module_filter: Optional[ModuleFilter] = None, @@ -204,15 +210,17 @@ def calib(self, model: nn.Module, return model def quantize(self, model: nn.Module): - pass + for _, module_wrapper in self.all_module_wrappers.items(): + module_wrapper.quant_module() + return model if __name__ == "__main__": class MyModel(nn.Module): def __init__(self): super().__init__() - self.linear = nn.Linear(3, 4) + self.linear = nn.Linear(2048, 2048) def forward(self, x): return self.linear(x) @@ -221,6 +229,6 @@ def forward(self, x): ss = dict(model.named_modules()) quantizer = GPTQuantizer() calib_model = quantizer.calib(model) - dummy = torch.randn(1, 3) + dummy = torch.randn(1, 2048, 2048) calib_model(dummy) quant_model = quantizer.quantize(model) From 4864fbfcc223c03dad1f34ebe409c51fcf21f649 Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Tue, 28 Mar 2023 14:29:03 +0800 Subject: [PATCH 5/8] refine --- tools/torch_quant/torch_quant/experiment/__init__.py | 12 ++++++++++++ tools/torch_quant/torch_quant/experiment/gptq.py | 11 +++++++++++ 2 files changed, 23 insertions(+) diff --git a/tools/torch_quant/torch_quant/experiment/__init__.py b/tools/torch_quant/torch_quant/experiment/__init__.py index e69de29bb2d..1bd5a3016b8 100644 --- a/tools/torch_quant/torch_quant/experiment/__init__.py +++ b/tools/torch_quant/torch_quant/experiment/__init__.py @@ -0,0 +1,12 @@ +# Copyright 2023 The BladeDISC Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .gptq import GPTQuantizer diff --git a/tools/torch_quant/torch_quant/experiment/gptq.py b/tools/torch_quant/torch_quant/experiment/gptq.py index 79a4adfb219..5146983b1f8 100644 --- a/tools/torch_quant/torch_quant/experiment/gptq.py +++ b/tools/torch_quant/torch_quant/experiment/gptq.py @@ -1,3 +1,14 @@ +# Copyright 2023 The BladeDISC Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging import math from typing import Callable, List, Optional From 8ac54cf7640ffd2467fcbdb2c6f525b85ee6c0d7 Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Wed, 12 Apr 2023 17:45:42 +0800 Subject: [PATCH 6/8] full version of gptq and chatglm example --- tools/torch_quant/examples/chat_glm_gptq.py | 55 +++++ .../torch_quant/experiment/README.md | 14 +- .../torch_quant/experiment/gptq.py | 196 ++++++++++++------ 3 files changed, 190 insertions(+), 75 deletions(-) create mode 100644 tools/torch_quant/examples/chat_glm_gptq.py diff --git a/tools/torch_quant/examples/chat_glm_gptq.py b/tools/torch_quant/examples/chat_glm_gptq.py new file mode 100644 index 00000000000..63dc1573fec --- /dev/null +++ b/tools/torch_quant/examples/chat_glm_gptq.py @@ -0,0 +1,55 @@ + +import logging + +from torch import nn +from torch_quant.experiment import GPTQuantizer +from tqdm import tqdm +from transformers import AutoModel, AutoTokenizer, set_seed + +logging.basicConfig(level=logging.DEBUG) + +from transformers.dynamic_module_utils import get_class_from_dynamic_module + + +def do_inference_with_fixed_seed(model, tokenizer, prompt): + set_seed(42) + response, _ = model.chat(tokenizer, prompt, history=[]) + return response + + +prompt = "晚上睡不着应该怎么办" + +# If use local mode, change it to the folder that contains model files +target_model = "chatglm_6b" + +# The basic block within which all layers are calibrated together +target_block = get_class_from_dynamic_module(target_model, "modeling_chatglm.py", "GLMBlock") + +tokenizer = AutoTokenizer.from_pretrained(target_model, trust_remote_code=True) +model = AutoModel.from_pretrained(target_model, trust_remote_code=True, resume_download=True).half().cuda() + +# Get the output of the original model +response = do_inference_with_fixed_seed(model, tokenizer, prompt) +print(response) + +# Get the GPTQuantizer, the block means that the GLMBlock is calibrated one-by-one +# and the last lm_head (of type nn.Linear) is calibrated alone +quantizer = GPTQuantizer(block=[target_block, nn.Linear]) + +# prepare the model for the quantization process +calib_model = quantizer.calib(model) + +# Since we do not get the graph of the model (e.g. torchscript, fx graph), we must +# do inference on the model once and record the block order +with quantizer.record_order(): + do_inference_with_fixed_seed(model, tokenizer, prompt) + +# Do calibration on the model. In each iter, one block will be quantized using GPTQ and +# you can use other prompts. +for i in tqdm(range(quantizer.calibration_iters)): + with quantizer.start_calib_iter(i): + response, history = model.chat(tokenizer, prompt, history=[]) + +# Get the result of the weight fake-quantized model +response = do_inference_with_fixed_seed(model, tokenizer, prompt) +print(response) diff --git a/tools/torch_quant/torch_quant/experiment/README.md b/tools/torch_quant/torch_quant/experiment/README.md index d0bc5c75dec..e537e915458 100644 --- a/tools/torch_quant/torch_quant/experiment/README.md +++ b/tools/torch_quant/torch_quant/experiment/README.md @@ -2,25 +2,13 @@ This directory contains some advanced quantization algorithms. These quantization algorithms are difficult to implement under the original torch-quant framework (or not necessary, for example, -if you do weight-only quantization, fx graph may be not necessary). So under the premise of ensuring -that the interface is consistent with that in torch-quant, we implement the corresponding quantizer -for each advanced quantization algorithm alone. +if you do weight-only quantization, fx graph may be not necessary). # Supported algorithms ### GPTQ The official [GPTQ codes](https://github.com/IST-DASLab/gptq) are referenced. - -NOTE: There is one small difference between the official GPTQ implementation and the one here. -In the official implementation, the inputs used to calculate the H matrix of a specific layer, -are obtained after the weight of the previous layers are quantized, which is not easy to be achieved -(The model should be calibrated many times). We relex this condition, that is, the inputs used to -calculate the H matrix are obtained when the weight of the previous layers are NOT quantized. So -the calibration data only needs to be executed once on the model, and all H matrices can be calculated. -This relaxation only results in a slight loss of performance for 4 bit quantization. Of course, we are -also figuring out how to enable this in a user-friendly way. - ``` @article{frantar-gptq, title={{GPTQ}: Accurate Post-training Compression for Generative Pretrained Transformers}, diff --git a/tools/torch_quant/torch_quant/experiment/gptq.py b/tools/torch_quant/torch_quant/experiment/gptq.py index 5146983b1f8..e73c8ae20c6 100644 --- a/tools/torch_quant/torch_quant/experiment/gptq.py +++ b/tools/torch_quant/torch_quant/experiment/gptq.py @@ -9,15 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import logging import math from typing import Callable, List, Optional import torch from torch import nn -from torch_quant.module import ModuleFilter from torch_quant.observer import Observer -from torch_quant.quantizer import DEFAULT_W_OB_CTR, Backend, Device, get_default_ctr +from torch_quant.quantizer import (DEFAULT_W_OB_CTR, Backend, Device, + get_default_ctr) LOGGER = logging.getLogger(__name__) @@ -26,7 +27,8 @@ import transformers is_transformers_avail = True except ModuleNotFoundError: - LOGGER.warning("transformers is not installed, so that gptq can not be applied to transformers.Conv1D") + LOGGER.warning("transformers is not installed, " + "so that gptq can not be applied to transformers.Conv1D") is_transformers_avail = False @@ -62,41 +64,44 @@ def fake_quant(self, x): class GPTQLayerWrapper: - def __init__(self, layer, observer): + def __init__(self, layer_name, layer, observer_ctr): super().__init__() + self.layer_name = layer_name self.layer = layer - self.gptq_observer = GPTQObserver(observer) self.device = layer.weight.device + self.gptq_observer = GPTQObserver(observer_ctr().to(self.device)) columns = layer.weight.shape[1] self.columns = columns self.H = torch.zeros((columns, columns), device=self.device) self.nsamples = 0 + self.is_record = True def record_h(self, x): - x = x.detach().clone() - if len(x.shape) == 2: - x = x.unsqueeze(0) - batch = x.shape[0] - if isinstance(self.layer, nn.Linear) or is_transformer_conv1d(self.layer): - if len(x.shape) == 3: - x = x.reshape((-1, x.shape[-1])) - x = x.t() - - if isinstance(self.layer, nn.Conv2d): - unfold = nn.Unfold( - self.layer.kernel_size, - dilation=self.layer.dilation, - padding=self.layer.padding, - stride=self.layer.stride - ) - x = unfold(x) - x = x.permute([1, 0, 2]) - x = x.flatten(1) - - self.H *= self.nsamples / (self.nsamples + batch) - self.nsamples += batch - x = math.sqrt(2 / self.nsamples) * x.float() - self.H += x.matmul(x.t()) + if self.is_record: + x = x.detach().clone() + if len(x.shape) == 2: + x = x.unsqueeze(0) + batch = x.shape[0] + if isinstance(self.layer, nn.Linear) or is_transformer_conv1d(self.layer): + if len(x.shape) == 3: + x = x.reshape((-1, x.shape[-1])) + x = x.t() + + if isinstance(self.layer, nn.Conv2d): + unfold = nn.Unfold( + self.layer.kernel_size, + dilation=self.layer.dilation, + padding=self.layer.padding, + stride=self.layer.stride + ) + x = unfold(x) + x = x.permute([1, 0, 2]) + x = x.flatten(1) + + self.H *= self.nsamples / (self.nsamples + batch) + self.nsamples += batch + x = math.sqrt(2 / self.nsamples) * x.float() + self.H += x.matmul(x.t()) def quant_weight(self, blocksize=128, percdamp=.01, groupsize=-1): weight = self.layer.weight.data.clone() @@ -110,7 +115,6 @@ def quant_weight(self, blocksize=128, percdamp=.01, groupsize=-1): self.gptq_observer.find_quant_info(weight) H = self.H - # del self.H dead = torch.diag(H) == 0 H[dead, dead] = 1 weight[:, dead] = 0 @@ -126,12 +130,12 @@ def quant_weight(self, blocksize=128, percdamp=.01, groupsize=-1): H = torch.cholesky_inverse(H) H = torch.linalg.cholesky(H, upper=True) except Exception: - # TODO: should handle this situation - logging.warning("Warning: cannot do compression for inverse error") + logging.warning(f"Warning: cannot do compression on layer {self.layer_name} because of inverse error") + return if H.isnan().any(): - # TODO: should handle this situation - logging.warning("Warning: cannot do compression for inverse error") + logging.warning(f"Warning: cannot do compression on layer {self.layer_name} because of inverse error") + return hinv = H @@ -165,27 +169,36 @@ def quant_weight(self, blocksize=128, percdamp=.01, groupsize=-1): weight[:, i2:] -= total_err.matmul(hinv[i1:i2, i2:]) - # torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.synchronize() if is_transformer_conv1d(self.layer): Q = Q.t() self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) + del self.H + del self.gptq_observer + if torch.cuda.is_available(): + torch.cuda.empty_cache() class GPTQModuleWrapper: - def __init__(self, module: nn.Module, w_ob_ctr): + def __init__(self, module_name: str, module: nn.Module, w_ob_ctr): self.all_layers = {} self.all_handles = [] + # module order in the whole network + self.order = 0 + self.module_name = module_name def get_hook(layer_name): def record_hook(_, x): self.all_layers[layer_name].record_h(x[0]) return record_hook - for name, layer in module.named_modules(): + for layer_name, layer in module.named_modules(): if isinstance(layer, tuple(QUANT_LAYERS)): - self.all_layers[name] = GPTQLayerWrapper(layer, w_ob_ctr()) - handle = layer.register_forward_pre_hook(get_hook(name)) + full_layer_name = f"{module_name}.{layer_name}" if layer_name else f"{module_name}" + self.all_layers[full_layer_name] = GPTQLayerWrapper(full_layer_name, layer, w_ob_ctr) + handle = layer.register_forward_pre_hook(get_hook(full_layer_name)) self.all_handles.append(handle) def quant_module(self): @@ -195,14 +208,27 @@ def quant_module(self): for h in self.all_handles: h.remove() + def set_order(self, idx): + self.order = idx + + def get_order(self): + return self.order + + def enable(self): + for n, l in self.all_layers.items(): + l.is_record = True + + def disable(self): + for n, l in self.all_layers.items(): + l.is_record = False + class GPTQuantizer: - def __init__(self, module_filter: Optional[ModuleFilter] = None, + def __init__(self, backend: Backend = Backend.DISC, device: Device = Device.GPU, - block: Optional[List[nn.Module]] = None + block: Optional[List[type]] = None ) -> None: - self.module_filter = module_filter self.backend = backend self.device = device self.all_module_wrappers = {} @@ -212,12 +238,19 @@ def calib(self, model: nn.Module, w_ob_ctr: Optional[Callable[..., Observer]] = None): default_w_ob_ctr = get_default_ctr(DEFAULT_W_OB_CTR, self.device, self.backend) w_ob_ctr = w_ob_ctr or default_w_ob_ctr + # GPTQ only quantize the weight of LLMs, so there is no need to # convert it to fx module - for name, module in model.named_modules(): - if isinstance(module, tuple(self.block)): - self.all_module_wrappers[name] = GPTQModuleWrapper(module, w_ob_ctr) - + def wrap_target_module(m, prefix=""): + for name, child in m.named_children(): + new_prefix = f"{prefix}.{name}" if prefix else name + if isinstance(child, tuple(self.block)): + self.all_module_wrappers[name] = GPTQModuleWrapper(new_prefix, child, w_ob_ctr) + LOGGER.debug(f"Calibrate module {new_prefix} as a whole block in GPTQ") + else: + wrap_target_module(child, new_prefix) + + wrap_target_module(model) return model def quantize(self, model: nn.Module): @@ -226,20 +259,59 @@ def quantize(self, model: nn.Module): return model + @property + def calibration_iters(self): + return len(self.all_module_wrappers) -if __name__ == "__main__": - class MyModel(nn.Module): - def __init__(self): - super().__init__() - self.linear = nn.Linear(2048, 2048) - - def forward(self, x): - return self.linear(x) - - model = MyModel() - ss = dict(model.named_modules()) - quantizer = GPTQuantizer() - calib_model = quantizer.calib(model) - dummy = torch.randn(1, 2048, 2048) - calib_model(dummy) - quant_model = quantizer.quantize(model) + @contextlib.contextmanager + def record_order(self): + counter = 0 + record_handles = [] + orders = {} + try: + def get_record_order_hook(module_name): + def record_hook(*args, **kwargs): + nonlocal counter + if module_name not in orders: + orders[module_name] = counter + counter += 1 + return record_hook + + for module_name, module_wrapper in self.all_module_wrappers.items(): + # disable the record + for _, layer_wrapper in module_wrapper.all_layers.items(): + layer_wrapper.is_record = False + + one_layer_wrapper_in_module = list(module_wrapper.all_layers.values())[0] + handles = one_layer_wrapper_in_module.layer.register_forward_pre_hook(get_record_order_hook(module_name)) + record_handles.append(handles) + yield + except Exception as e: + logging.warning(e) + finally: + for module_name, order in orders.items(): + self.all_module_wrappers[module_name].set_order(order) + + for h in record_handles: + h.remove() + + for module_name, module_wrapper in self.all_module_wrappers.items(): + # disable the record + for _, layer_wrapper in module_wrapper.all_layers.items(): + layer_wrapper.is_record = True + + + @contextlib.contextmanager + def start_calib_iter(self, i): + assert i < len(self.all_module_wrappers) + target_module_wrapper = None + try: + for _, module_wrapper in self.all_module_wrappers.items(): + if module_wrapper.get_order() == i: + module_wrapper.enable() + target_module_wrapper = module_wrapper + else: + module_wrapper.disable() + yield + finally: + target_module_wrapper.quant_module() From ded680afa6c05b1e880a98a28ce0708bf9ad493b Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Wed, 12 Apr 2023 17:47:25 +0800 Subject: [PATCH 7/8] fix --- tools/torch_quant/examples/chat_glm_gptq.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tools/torch_quant/examples/chat_glm_gptq.py b/tools/torch_quant/examples/chat_glm_gptq.py index 63dc1573fec..993efc6b8a5 100644 --- a/tools/torch_quant/examples/chat_glm_gptq.py +++ b/tools/torch_quant/examples/chat_glm_gptq.py @@ -1,3 +1,13 @@ +# Copyright 2023 The BladeDISC Authors. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging From 98f1d2f127cfc35c69398c61134a2b8c6d9fad63 Mon Sep 17 00:00:00 2001 From: chenbohua3 Date: Wed, 12 Apr 2023 17:49:59 +0800 Subject: [PATCH 8/8] fix --- tools/torch_quant/examples/chat_glm_gptq.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/torch_quant/examples/chat_glm_gptq.py b/tools/torch_quant/examples/chat_glm_gptq.py index 993efc6b8a5..d9fbf48694b 100644 --- a/tools/torch_quant/examples/chat_glm_gptq.py +++ b/tools/torch_quant/examples/chat_glm_gptq.py @@ -29,8 +29,8 @@ def do_inference_with_fixed_seed(model, tokenizer, prompt): prompt = "晚上睡不着应该怎么办" -# If use local mode, change it to the folder that contains model files -target_model = "chatglm_6b" +# If local mode is used, change it to the folder that contains model files +target_model = "THUDM/chatglm-6b" # The basic block within which all layers are calibrated together target_block = get_class_from_dynamic_module(target_model, "modeling_chatglm.py", "GLMBlock")