compressed-tensors-nightly 0.3.3.20240514__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. compressed_tensors/__init__.py +21 -0
  2. compressed_tensors/base.py +17 -0
  3. compressed_tensors/compressors/__init__.py +22 -0
  4. compressed_tensors/compressors/base.py +59 -0
  5. compressed_tensors/compressors/dense.py +34 -0
  6. compressed_tensors/compressors/helpers.py +137 -0
  7. compressed_tensors/compressors/int_quantized.py +95 -0
  8. compressed_tensors/compressors/model_compressor.py +264 -0
  9. compressed_tensors/compressors/sparse_bitmask.py +239 -0
  10. compressed_tensors/config/__init__.py +18 -0
  11. compressed_tensors/config/base.py +43 -0
  12. compressed_tensors/config/dense.py +36 -0
  13. compressed_tensors/config/sparse_bitmask.py +36 -0
  14. compressed_tensors/quantization/__init__.py +21 -0
  15. compressed_tensors/quantization/lifecycle/__init__.py +23 -0
  16. compressed_tensors/quantization/lifecycle/apply.py +196 -0
  17. compressed_tensors/quantization/lifecycle/calibration.py +51 -0
  18. compressed_tensors/quantization/lifecycle/compressed.py +69 -0
  19. compressed_tensors/quantization/lifecycle/forward.py +333 -0
  20. compressed_tensors/quantization/lifecycle/frozen.py +50 -0
  21. compressed_tensors/quantization/lifecycle/initialize.py +99 -0
  22. compressed_tensors/quantization/observers/__init__.py +21 -0
  23. compressed_tensors/quantization/observers/base.py +130 -0
  24. compressed_tensors/quantization/observers/helpers.py +54 -0
  25. compressed_tensors/quantization/observers/memoryless.py +48 -0
  26. compressed_tensors/quantization/observers/min_max.py +80 -0
  27. compressed_tensors/quantization/quant_args.py +125 -0
  28. compressed_tensors/quantization/quant_config.py +210 -0
  29. compressed_tensors/quantization/quant_scheme.py +39 -0
  30. compressed_tensors/quantization/utils/__init__.py +16 -0
  31. compressed_tensors/quantization/utils/helpers.py +131 -0
  32. compressed_tensors/registry/__init__.py +17 -0
  33. compressed_tensors/registry/registry.py +360 -0
  34. compressed_tensors/utils/__init__.py +16 -0
  35. compressed_tensors/utils/helpers.py +45 -0
  36. compressed_tensors/utils/safetensors_load.py +237 -0
  37. compressed_tensors/version.py +50 -0
  38. compressed_tensors_nightly-0.3.3.20240514.dist-info/LICENSE +201 -0
  39. compressed_tensors_nightly-0.3.3.20240514.dist-info/METADATA +105 -0
  40. compressed_tensors_nightly-0.3.3.20240514.dist-info/RECORD +42 -0
  41. compressed_tensors_nightly-0.3.3.20240514.dist-info/WHEEL +5 -0
  42. compressed_tensors_nightly-0.3.3.20240514.dist-info/top_level.txt +1 -0
@@ -0,0 +1,69 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import logging
17
+
18
+ import torch
19
+ from compressed_tensors.quantization.lifecycle.forward import quantize
20
+ from compressed_tensors.quantization.quant_config import QuantizationStatus
21
+ from torch.nn import Module
22
+
23
+
24
+ __all__ = [
25
+ "compress_quantized_weights",
26
+ ]
27
+
28
+
29
+ _LOGGER = logging.getLogger(__name__)
30
+
31
+
32
+ def compress_quantized_weights(module: Module):
33
+ """
34
+ Quantizes the module weight representation to use fewer bits in memory
35
+
36
+ apply to full model with `model.apply(compress_quantized_weights)`
37
+
38
+ :param module: module to compress to quantized representation
39
+ """
40
+ scheme = getattr(module, "quantization_scheme", None)
41
+ if not scheme or not scheme.weights:
42
+ # no quantization scheme or weights not quantized, nothing to do
43
+ return
44
+
45
+ if scheme is QuantizationStatus.COMPRESSED:
46
+ # module is already compressed, nothing to do
47
+ return
48
+
49
+ weight = getattr(module, "weight", None)
50
+ scale = getattr(module, "weight_scale", None)
51
+ zero_point = getattr(module, "weight_zero_point", None)
52
+
53
+ if weight is None or scale is None or zero_point is None:
54
+ # no weight, scale, or ZP, nothing to do
55
+
56
+ # mark as compressed here to maintain consistent status throughout the model
57
+ module.quantization_status = QuantizationStatus.COMPRESSED
58
+ return
59
+
60
+ module.weight.requires_grad = False # cannot use auto grad after compression
61
+ module.weight.data = quantize(
62
+ x=weight,
63
+ scale=scale,
64
+ zero_point=zero_point,
65
+ args=scheme.weights,
66
+ dtype=torch.int8,
67
+ )
68
+
69
+ module.quantization_status = QuantizationStatus.COMPRESSED
@@ -0,0 +1,333 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from functools import wraps
16
+ from math import ceil
17
+ from typing import Optional
18
+
19
+ import torch
20
+ from compressed_tensors.quantization.quant_args import (
21
+ QuantizationArgs,
22
+ QuantizationStrategy,
23
+ )
24
+ from compressed_tensors.quantization.quant_config import QuantizationStatus
25
+ from compressed_tensors.quantization.quant_scheme import QuantizationScheme
26
+ from torch.nn import Module
27
+
28
+
29
+ __all__ = [
30
+ "quantize",
31
+ "dequantize",
32
+ "fake_quantize",
33
+ "wrap_module_forward_quantized",
34
+ "maybe_calibrate_or_quantize",
35
+ ]
36
+
37
+
38
+ @torch.no_grad()
39
+ def quantize(
40
+ x: torch.Tensor,
41
+ scale: torch.Tensor,
42
+ zero_point: torch.Tensor,
43
+ args: QuantizationArgs,
44
+ dtype: Optional[torch.dtype] = None,
45
+ ) -> torch.Tensor:
46
+ """
47
+ Quantize the input tensor x using the QuantizationStrategy specified in args.
48
+ Quantization can be done per tensor, channel, token or group. For group
49
+ quantization, the group_size must be divisible by the column size. The input scale
50
+ and zero_points are reshaped to support vectorization (Assumes 1 is the
51
+ channel dimension)
52
+
53
+ :param x: Input tensor
54
+ :param scale: scale tensor
55
+ :param zero_point: zero point tensor
56
+ :param args: quantization args dictating how to quantize x
57
+ :param dtype: optional dtype to cast the quantized output to
58
+ :return: fake quantized tensor
59
+ """
60
+ return _process_quantization(
61
+ x=x,
62
+ scale=scale,
63
+ zero_point=zero_point,
64
+ args=args,
65
+ dtype=dtype,
66
+ do_quantize=True,
67
+ do_dequantize=False,
68
+ )
69
+
70
+
71
+ @torch.no_grad()
72
+ def dequantize(
73
+ x_q: torch.Tensor,
74
+ scale: torch.Tensor,
75
+ zero_point: torch.Tensor,
76
+ args: QuantizationArgs = None,
77
+ ) -> torch.Tensor:
78
+ """
79
+ Dequantize a quantized input tensor x_q based on the strategy specified in args. If
80
+ args is not provided, the strategy will be inferred.
81
+
82
+ :param x: quantized input tensor
83
+ :param scale: scale tensor
84
+ :param zero_point: zero point tensor
85
+ :param args: quantization args used to quantize x_q
86
+ :return: dequantized float tensor
87
+ """
88
+ if args is None:
89
+ if scale.ndim == 0:
90
+ args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR)
91
+ elif scale.ndim == 2:
92
+ args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
93
+ elif scale.ndim == 3:
94
+ group_size = int(x_q.shape[1] / scale.shape[1])
95
+ args = QuantizationArgs(
96
+ strategy=QuantizationStrategy.GROUP, group_size=group_size
97
+ )
98
+ return _process_quantization(
99
+ x=x_q,
100
+ scale=scale,
101
+ zero_point=zero_point,
102
+ args=args,
103
+ do_quantize=False,
104
+ do_dequantize=True,
105
+ )
106
+
107
+
108
+ @torch.no_grad()
109
+ def fake_quantize(
110
+ x: torch.Tensor,
111
+ scale: torch.Tensor,
112
+ zero_point: torch.Tensor,
113
+ args: QuantizationArgs,
114
+ ) -> torch.Tensor:
115
+ """
116
+ Fake quantize the input tensor x by quantizing then dequantizing with
117
+ the QuantizationStrategy specified in args. Quantization can be done per tensor,
118
+ channel, token or group. For group quantization, the group_size must be divisible
119
+ by the column size. The input scale and zero_points are reshaped to support
120
+ vectorization (Assumes 1 is the channel dimension)
121
+
122
+ :param x: Input tensor
123
+ :param scale: scale tensor
124
+ :param zero_point: zero point tensor
125
+ :param args: quantization args dictating how to quantize x
126
+ :return: fake quantized tensor
127
+ """
128
+ return _process_quantization(
129
+ x=x,
130
+ scale=scale,
131
+ zero_point=zero_point,
132
+ args=args,
133
+ do_quantize=True,
134
+ do_dequantize=True,
135
+ )
136
+
137
+
138
+ @torch.no_grad()
139
+ def _process_quantization(
140
+ x: torch.Tensor,
141
+ scale: torch.Tensor,
142
+ zero_point: torch.Tensor,
143
+ args: QuantizationArgs,
144
+ dtype: Optional[torch.dtype] = None,
145
+ do_quantize: bool = True,
146
+ do_dequantize: bool = True,
147
+ ) -> torch.Tensor:
148
+ bit_range = 2**args.num_bits
149
+ q_max = torch.tensor(bit_range / 2 - 1, device=x.device)
150
+ q_min = torch.tensor(-bit_range / 2, device=x.device)
151
+ group_size = args.group_size
152
+
153
+ # group
154
+ if args.strategy == QuantizationStrategy.GROUP:
155
+
156
+ if do_dequantize: # if dequantizing the output should be a fp type
157
+ output = torch.zeros_like(x, dtype=scale.dtype)
158
+ else:
159
+ output_dtype = dtype if dtype is not None else x.dtype
160
+ output = torch.zeros_like(x, dtype=output_dtype)
161
+
162
+ # TODO: vectorize the for loop
163
+ # TODO: fix genetric assumption about the tensor size for computing group
164
+
165
+ # TODO: make validation step for inputs
166
+
167
+ while scale.ndim < 2:
168
+ # pad scale and zero point dims for slicing
169
+ scale = scale.unsqueeze(1)
170
+ zero_point = zero_point.unsqueeze(1)
171
+
172
+ columns = x.shape[1]
173
+ if columns >= group_size:
174
+ if columns % group_size != 0:
175
+ raise ValueError(
176
+ "tesnor column shape must be divisble "
177
+ f"by the given group_size {group_size}"
178
+ )
179
+ for i in range(ceil(columns / group_size)):
180
+ # scale.shape should be [nchan, ndim]
181
+ # sc.shape should be [nchan, 1] after unsqueeze
182
+ sc = scale[:, i].view(-1, 1)
183
+ zp = zero_point[:, i].view(-1, 1)
184
+
185
+ idx = i * group_size
186
+ if do_quantize:
187
+ output[:, idx : (idx + group_size)] = _quantize(
188
+ x[:, idx : (idx + group_size)], sc, zp, q_min, q_max, dtype=dtype
189
+ )
190
+ if do_dequantize:
191
+ input = (
192
+ output[:, idx : (idx + group_size)]
193
+ if do_quantize
194
+ else x[:, idx : (idx + group_size)]
195
+ )
196
+ output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp)
197
+
198
+ # channel-wise
199
+ elif args.strategy == QuantizationStrategy.CHANNEL: # group_size == -1
200
+ if do_quantize:
201
+ output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
202
+ if do_dequantize:
203
+ output = _dequantize(output if do_quantize else x, scale, zero_point)
204
+
205
+ # per-token
206
+ elif args.strategy == QuantizationStrategy.TOKEN:
207
+ # before: scale shape = [num_tokens]
208
+ # after: scale shape = [num_tokens, 1]
209
+ # x.shape = 1, num_tokens, 1]
210
+ # scale gets broadcasted as expected withput having [1, num_tokens, 1] shape
211
+
212
+ scale = scale.unsqueeze(1)
213
+ zero_point = zero_point.unsqueeze(1)
214
+
215
+ if do_quantize:
216
+ output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
217
+ if do_dequantize:
218
+ output = _dequantize(output if do_quantize else x, scale, zero_point)
219
+
220
+ else:
221
+ if do_quantize:
222
+ output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
223
+ if do_dequantize:
224
+ output = _dequantize(output if do_quantize else x, scale, zero_point)
225
+
226
+ return output
227
+
228
+
229
+ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
230
+ # expects a module already initialized and injected with the parameters in
231
+ # initialize_module_for_quantization
232
+ forward_func_orig = module.forward.__func__
233
+
234
+ @wraps(forward_func_orig) # ensures docstring, names, etc are propagated
235
+ def wrapped_forward(self, *args, **kwargs):
236
+ input_ = args[0]
237
+
238
+ if scheme.input_activations is not None:
239
+ # calibrate and (fake) quantize input activations when applicable
240
+ input_ = maybe_calibrate_or_quantize(
241
+ module, input_, "input", scheme.input_activations
242
+ )
243
+
244
+ if scheme.weights is not None:
245
+ # calibrate and (fake) quantize weights when applicable
246
+ unquantized_weight = self.weight.data.clone()
247
+ self.weight.data = maybe_calibrate_or_quantize(
248
+ module, self.weight, "weight", scheme.weights
249
+ )
250
+
251
+ # perform wrapped forward call
252
+ output = forward_func_orig.__get__(module, module.__class__)(
253
+ input_, *args[1:], **kwargs
254
+ )
255
+
256
+ if scheme.output_activations is not None:
257
+ # calibrate and (fake) quantize output activations when applicable
258
+ output = maybe_calibrate_or_quantize(
259
+ module, output, "output", scheme.output_activations
260
+ )
261
+
262
+ # restore back to unquantized_value
263
+ if scheme.weights is not None:
264
+ self.weight.data = unquantized_weight
265
+
266
+ return output
267
+
268
+ # bind wrapped forward to module class so reference to `self` is correct
269
+ bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
270
+ # set forward to wrapped forward
271
+ setattr(module, "forward", bound_wrapped_forward)
272
+
273
+
274
+ def maybe_calibrate_or_quantize(
275
+ module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
276
+ ) -> torch.Tensor:
277
+ # only run quantized for the included stages
278
+ if module.quantization_status not in {
279
+ QuantizationStatus.CALIBRATION,
280
+ QuantizationStatus.FROZEN,
281
+ }:
282
+ return value
283
+
284
+ if args.dynamic:
285
+ # dynamic quantization - get scale and zero point directly from observer
286
+ observer = getattr(module, f"{base_name}_observer")
287
+ scale, zero_point = observer(value)
288
+ else:
289
+ # static quantization - get previous scale and zero point from layer
290
+ scale = getattr(module, f"{base_name}_scale")
291
+ zero_point = getattr(module, f"{base_name}_zero_point")
292
+
293
+ if module.quantization_status == QuantizationStatus.CALIBRATION:
294
+ # calibration mode - get new quant params from observer
295
+ observer = getattr(module, f"{base_name}_observer")
296
+
297
+ updated_scale, updated_zero_point = observer(value)
298
+
299
+ # update scale and zero point
300
+ device = next(module.parameters()).device
301
+ scale.data = updated_scale.to(device)
302
+ zero_point.data = updated_zero_point.to(device)
303
+ return fake_quantize(value, scale, zero_point, args)
304
+
305
+
306
+ @torch.no_grad()
307
+ def _quantize(
308
+ x: torch.Tensor,
309
+ scale: torch.Tensor,
310
+ zero_point: torch.Tensor,
311
+ q_min: torch.Tensor,
312
+ q_max: torch.Tensor,
313
+ dtype: Optional[torch.dtype] = None,
314
+ ) -> torch.Tensor:
315
+ quantized_value = torch.clamp(
316
+ torch.round(x / scale + zero_point),
317
+ q_min,
318
+ q_max,
319
+ )
320
+
321
+ if dtype is not None:
322
+ quantized_value = quantized_value.to(dtype)
323
+
324
+ return quantized_value
325
+
326
+
327
+ @torch.no_grad()
328
+ def _dequantize(
329
+ x_q: torch.Tensor,
330
+ scale: torch.Tensor,
331
+ zero_point: torch.Tensor,
332
+ ) -> torch.Tensor:
333
+ return (x_q - zero_point) * scale
@@ -0,0 +1,50 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from compressed_tensors.quantization.quant_config import QuantizationStatus
17
+ from torch.nn import Module
18
+
19
+
20
+ __all__ = [
21
+ "freeze_module_quantization",
22
+ ]
23
+
24
+
25
+ def freeze_module_quantization(module: Module):
26
+ """
27
+ deletes observers so static quantization is completed.
28
+
29
+ apply to full model with `model.apply(freeze_module_quantization)`
30
+
31
+ :param module: module to freeze quantization for
32
+ """
33
+ scheme = getattr(module, "quantization_scheme", None)
34
+ if not scheme:
35
+ # no quantization scheme nothing to do
36
+ return
37
+
38
+ if module.quantization_status == QuantizationStatus.FROZEN:
39
+ # nothing to do, already frozen
40
+ return
41
+
42
+ # delete observers from module if not dynamic
43
+ if scheme.input_activations and not scheme.input_activations.dynamic:
44
+ delattr(module, "input_observer")
45
+ if scheme.weights and not scheme.weights.dynamic:
46
+ delattr(module, "weight_observer")
47
+ if scheme.output_activations and not scheme.output_activations.dynamic:
48
+ delattr(module, "output_observer")
49
+
50
+ module.quantization_status = QuantizationStatus.FROZEN
@@ -0,0 +1,99 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import logging
17
+ from typing import Optional
18
+
19
+ import torch
20
+ from compressed_tensors.quantization.lifecycle.forward import (
21
+ wrap_module_forward_quantized,
22
+ )
23
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
24
+ from compressed_tensors.quantization.quant_config import QuantizationStatus
25
+ from compressed_tensors.quantization.quant_scheme import QuantizationScheme
26
+ from torch.nn import Module, Parameter
27
+
28
+
29
+ __all__ = [
30
+ "initialize_module_for_quantization",
31
+ ]
32
+
33
+
34
+ _LOGGER = logging.getLogger(__name__)
35
+
36
+
37
+ def initialize_module_for_quantization(
38
+ module: Module,
39
+ scheme: Optional[QuantizationScheme] = None,
40
+ ):
41
+ """
42
+ attaches appropriate scales, zero points, and observers to a layer
43
+ given its target quantization scheme
44
+
45
+ apply to full model with `model.apply(initialize_module_for_quantization)`
46
+
47
+ :param module: module to set for calibration
48
+ :param scheme: scheme to use for quantization. if None is provided,
49
+ will attempt to use scheme stored in the module under `quantization_scheme`,
50
+ if not provided, the layer will be skipped
51
+ """
52
+ scheme = scheme or getattr(module, "quantization_scheme", None)
53
+ if scheme is None:
54
+ # no scheme passed and layer not targeted for quantization - skip
55
+ return
56
+
57
+ if scheme.input_activations is not None:
58
+ _initialize_scale_zero_point_observer(module, "input", scheme.input_activations)
59
+ if scheme.weights is not None:
60
+ if hasattr(module, "weight"):
61
+ _initialize_scale_zero_point_observer(module, "weight", scheme.weights)
62
+ else:
63
+ _LOGGER.warning(
64
+ f"module type {type(module)} targeted for weight quantization but "
65
+ "has no attribute weight, skipping weight quantization "
66
+ f"for {type(module)}"
67
+ )
68
+ if scheme.output_activations is not None:
69
+ _initialize_scale_zero_point_observer(
70
+ module, "output", scheme.output_activations
71
+ )
72
+
73
+ module.quantization_scheme = scheme
74
+ module.quantization_status = QuantizationStatus.INITIALIZED
75
+
76
+ # wrap forward call of module to perform quantized actions based on calltime status
77
+ wrap_module_forward_quantized(module, scheme)
78
+
79
+
80
+ def _initialize_scale_zero_point_observer(
81
+ module: Module, base_name: str, quantization_args: QuantizationArgs
82
+ ):
83
+ # initialize observer module and attach as submodule
84
+ observer = quantization_args.get_observer()
85
+ module.register_module(f"{base_name}_observer", observer)
86
+
87
+ if quantization_args.dynamic:
88
+ return # no need to register a scale and zero point for a dynamic observer
89
+
90
+ device = next(module.parameters()).device
91
+
92
+ # initializes empty scale and zero point parameters for the module
93
+ init_scale = Parameter(torch.empty(0, device=device), requires_grad=False)
94
+ module.register_parameter(f"{base_name}_scale", init_scale)
95
+
96
+ init_zero_point = Parameter(
97
+ torch.empty(0, device=device, dtype=int), requires_grad=False
98
+ )
99
+ module.register_parameter(f"{base_name}_zero_point", init_zero_point)
@@ -0,0 +1,21 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # flake8: noqa
16
+ # isort: skip_file
17
+
18
+ from .helpers import *
19
+ from .base import *
20
+ from .memoryless import *
21
+ from .min_max import *