mct-nightly 1.8.0.27022023.post430__py3-none-any.whl → 1.8.0.27032023.post403__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. {mct_nightly-1.8.0.27022023.post430.dist-info → mct_nightly-1.8.0.27032023.post403.dist-info}/METADATA +7 -7
  2. {mct_nightly-1.8.0.27022023.post430.dist-info → mct_nightly-1.8.0.27032023.post403.dist-info}/RECORD +65 -59
  3. {mct_nightly-1.8.0.27022023.post430.dist-info → mct_nightly-1.8.0.27032023.post403.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +9 -15
  5. model_compression_toolkit/core/common/logger.py +10 -2
  6. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +6 -1
  7. model_compression_toolkit/core/keras/quantization_facade.py +1 -1
  8. model_compression_toolkit/core/pytorch/constants.py +4 -0
  9. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py +4 -10
  10. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +16 -2
  11. model_compression_toolkit/exporter/__init__.py +5 -0
  12. model_compression_toolkit/exporter/model_exporter/__init__.py +0 -3
  13. model_compression_toolkit/exporter/model_exporter/tflite/fakely_quant_tflite_exporter.py +1 -1
  14. model_compression_toolkit/exporter/model_wrapper/__init__.py +4 -8
  15. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +45 -39
  16. model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py +39 -24
  17. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +50 -42
  18. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +43 -36
  19. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py +24 -5
  20. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +25 -18
  21. model_compression_toolkit/gptq/__init__.py +6 -0
  22. model_compression_toolkit/gptq/common/gptq_config.py +60 -106
  23. model_compression_toolkit/gptq/common/gptq_constants.py +0 -7
  24. model_compression_toolkit/gptq/common/gptq_training.py +28 -38
  25. model_compression_toolkit/gptq/keras/gptq_training.py +10 -28
  26. model_compression_toolkit/gptq/keras/graph_info.py +8 -33
  27. model_compression_toolkit/gptq/keras/quantization_facade.py +6 -12
  28. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +0 -1
  29. model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +2 -2
  30. model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +45 -0
  31. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +112 -0
  32. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py +22 -128
  33. model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py +11 -41
  34. model_compression_toolkit/gptq/pytorch/gptq_training.py +12 -4
  35. model_compression_toolkit/gptq/pytorch/graph_info.py +9 -6
  36. model_compression_toolkit/gptq/pytorch/quantization_facade.py +9 -22
  37. model_compression_toolkit/gptq/pytorch/quantizer/__init__.py +3 -1
  38. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +0 -20
  39. model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py +10 -1
  40. model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +2 -2
  41. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +45 -0
  42. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/__init__.py +14 -0
  43. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +115 -0
  44. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py +236 -0
  45. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py +196 -0
  46. model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py +9 -31
  47. model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py +30 -37
  48. model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py +27 -36
  49. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py +21 -21
  50. model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py +25 -26
  51. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_all_subclasses.py +1 -2
  52. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/common/get_quantizers.py +1 -1
  53. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantize_wrapper.py +12 -0
  54. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/__init__.py +4 -0
  55. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/keras/quantizers/constants.py +1 -0
  56. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantize_wrapper.py +12 -0
  57. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/__init__.py +6 -0
  58. model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/pytorch/quantizers/constants.py +3 -0
  59. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/base_trainable_quantizer.py +53 -2
  60. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/common/get_quantizers.py +2 -1
  61. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/keras/base_keras_quantizer.py +22 -4
  62. model_compression_toolkit/quantizers_infrastructure/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +24 -3
  63. model_compression_toolkit/gptq/common/gptq_quantizer_config.py +0 -93
  64. {mct_nightly-1.8.0.27022023.post430.dist-info → mct_nightly-1.8.0.27032023.post403.dist-info}/LICENSE.md +0 -0
  65. {mct_nightly-1.8.0.27022023.post430.dist-info → mct_nightly-1.8.0.27032023.post403.dist-info}/top_level.txt +0 -0
  66. /model_compression_toolkit/quantizers_infrastructure/inferable_infrastructure/{common → pytorch/quantizers/activation_inferable_quantizers}/activation_lut_pot_inferable_quantizer.py +0 -0
@@ -12,13 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import torch
15
16
  import torch.nn as nn
16
17
  from typing import List
17
-
18
18
  from model_compression_toolkit.core.pytorch.constants import BIAS
19
19
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
20
20
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
21
21
  from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
22
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
22
23
 
23
24
 
24
25
  def get_gptq_trainable_parameters(fxp_model: nn.Module,
@@ -38,24 +39,26 @@ def get_gptq_trainable_parameters(fxp_model: nn.Module,
38
39
  trainable_aux_weights = nn.ParameterList()
39
40
  trainable_threshold = nn.ParameterList()
40
41
  trainable_bias = nn.ParameterList()
41
- trainable_temperature = nn.ParameterList()
42
42
 
43
43
  for layer in fxp_model.modules():
44
44
  if isinstance(layer, PytorchQuantizationWrapper):
45
45
  kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
46
46
  fw_info=DEFAULT_PYTORCH_INFO)
47
47
 
48
- trainable_aux_weights.extend(layer.weights_quantizers[kernel_attribute].get_aux_variable())
49
- trainable_threshold.extend(layer.weights_quantizers[kernel_attribute].get_quantization_variable())
48
+ # collect trainable weights per quantizer
49
+ quantizer_trainable_weights = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.WEIGHTS)
50
+ quantizer_trainable_threshold = layer.weights_quantizers[kernel_attribute].get_trainable_variables(VariableGroup.QPARAMS)
51
+ trainable_aux_weights.extend(quantizer_trainable_weights)
52
+ trainable_threshold.extend(quantizer_trainable_threshold)
50
53
 
51
54
  if add_bias and hasattr(layer.layer, BIAS):
52
55
  bias = getattr(layer.layer, BIAS)
53
56
  trainable_bias.append(bias)
54
57
 
55
- return trainable_aux_weights, trainable_bias, trainable_threshold, trainable_temperature
58
+ return trainable_aux_weights, trainable_bias, trainable_threshold
56
59
 
57
60
 
58
- def get_weights_for_loss(fxp_model: nn.Module) -> [List, List]:
61
+ def get_weights_for_loss(fxp_model: nn.Module) -> [List[nn.Parameter], List[torch.Tensor]]:
59
62
  """
60
63
  Get all float and quantized kernels for the GPTQ loss
61
64
 
@@ -17,14 +17,15 @@ from model_compression_toolkit.core import common
17
17
  from model_compression_toolkit.core.common.constants import FOUND_TORCH
18
18
  from model_compression_toolkit.core.common import Logger
19
19
  from model_compression_toolkit.core.common.constants import PYTORCH
20
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2, RoundingType
20
+ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfigV2
21
21
  from model_compression_toolkit.core.common.target_platform import TargetPlatformCapabilities
22
22
  from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
23
23
  from model_compression_toolkit.core.runner import core_runner, _init_tensorboard_writer
24
+ from model_compression_toolkit.gptq.keras.quantization_facade import GPTQ_MOMENTUM
24
25
  from model_compression_toolkit.gptq.runner import gptq_runner
25
26
  from model_compression_toolkit.core.exporter import export_model
26
27
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
27
- from model_compression_toolkit import CoreConfig, GPTQQuantizerConfig
28
+ from model_compression_toolkit import CoreConfig
28
29
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
29
30
  MixedPrecisionQuantizationConfigV2
30
31
 
@@ -71,33 +72,19 @@ if FOUND_TORCH:
71
72
  Import MCT and Create a GradientPTQConfigV2 to run for 5 epochs:
72
73
 
73
74
  >>> import model_compression_toolkit as mct
74
- >>> gptq_conf = mct.get_pytorch_gptq_config(n_epochs=5)
75
+ >>> gptq_conf = mct.gptq.get_pytorch_gptq_config(n_epochs=5)
75
76
 
76
77
  Other PyTorch optimizers can be passed with dummy params:
77
78
 
78
79
  >>> import torch
79
- >>> gptq_conf = mct.get_pytorch_gptq_config(n_epochs=3, optimizer=torch.optim.Adam([torch.Tensor(1)]))
80
+ >>> gptq_conf = mct.gptq.get_pytorch_gptq_config(n_epochs=3, optimizer=torch.optim.Adam([torch.Tensor(1)]))
80
81
 
81
82
  The configuration can be passed to :func:`~model_compression_toolkit.pytorch_post_training_quantization` in order to quantize a pytorch model using gptq.
82
83
 
83
84
  """
84
- bias_optimizer = Adam([torch.Tensor([])], lr=LR_BIAS_DEFAULT)
85
- optimizer_quantization_parameter = Adam([torch.Tensor([])], lr=LR_QUANTIZATION_PARAM_DEFAULT)
86
- # TODO: Once implementing Soft Quantizer for GPTQ in Pytorch:
87
- # - change default quantization_parameters_learning to True.
88
- # - remove explicit rounding_type and quantizer_config (and let it use the default GradientPTQConfig).
89
- return GradientPTQConfigV2(n_epochs,
90
- optimizer,
91
- optimizer_rest=optimizer_rest,
92
- loss=loss,
93
- log_function=log_function,
94
- train_bias=True,
95
- optimizer_quantization_parameter=optimizer_quantization_parameter,
96
- optimizer_bias=bias_optimizer,
97
- rounding_type=RoundingType.STE,
98
- quantizer_config=GPTQQuantizerConfig(),
99
- quantization_parameters_learning=False,
100
- )
85
+ bias_optimizer = torch.optim.SGD([torch.Tensor([])], lr=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
86
+ return GradientPTQConfigV2(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
87
+ log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer)
101
88
 
102
89
 
103
90
  def pytorch_gradient_post_training_quantization_experimental(model: Module,
@@ -159,7 +146,7 @@ if FOUND_TORCH:
159
146
 
160
147
  Pass the module, the representative dataset generator and the configuration (optional) to get a quantized module
161
148
 
162
- >>> quantized_module, quantization_info = mct.pytorch_gradient_post_training_quantization_experimental(module, repr_datagen, core_config=config, gptq_config=gptq_conf)
149
+ >>> quantized_module, quantization_info = mct.gptq.pytorch_gradient_post_training_quantization_experimental(module, repr_datagen, core_config=config, gptq_config=gptq_conf)
163
150
 
164
151
  """
165
152
 
@@ -13,4 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import model_compression_toolkit.gptq.pytorch.quantizer.ste_rounding.symmetric_ste
16
+ import model_compression_toolkit.gptq.pytorch.quantizer.ste_rounding.symmetric_ste
17
+ import model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.symmetric_soft_quantizer
18
+ import model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.uniform_soft_quantizer
@@ -71,26 +71,6 @@ if FOUND_TORCH:
71
71
 
72
72
  return weights, quant_config, {}
73
73
 
74
- def get_aux_variable(self) -> List[Tensor]:
75
- """
76
- This function return a list with the quantizer's quantization auxiliary variables.
77
-
78
- Returns: A list with the quantization auxiliary variables.
79
-
80
- """
81
-
82
- return [] # pragma: no cover
83
-
84
- def get_quantization_variable(self) -> List[Tensor]:
85
- """
86
- This function return a list with the quantizer's quantization parameters variables.
87
-
88
- Returns: A list with the quantization parameters.
89
-
90
- """
91
-
92
- return [] # pragma: no cover
93
-
94
74
  @abstractmethod
95
75
  def get_quant_config(self):
96
76
  """
@@ -30,11 +30,20 @@ def calculate_delta(max_tensor: torch.Tensor,
30
30
  num_bits: int,
31
31
  signed: bool) -> torch.Tensor:
32
32
  """
33
- Compute the step size for the quantization.
33
+ Compute the step size for the symmetric quantization.
34
34
  """
35
35
  return max_tensor / (2 ** (num_bits - int(signed)))
36
36
 
37
37
 
38
+ def calculate_delta_uniform(min_tensor: torch.Tensor,
39
+ max_tensor: torch.Tensor,
40
+ num_bits: int) -> torch.Tensor:
41
+ """
42
+ Compute the step size for the uniform quantization.
43
+ """
44
+ return (max_tensor-min_tensor) / (2 ** num_bits - 1)
45
+
46
+
38
47
  def ste_ceil(x: torch.Tensor) -> torch.Tensor:
39
48
  """
40
49
  Return the ceil values of a tensor.
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from typing import List, Dict, Tuple
16
16
 
17
- from model_compression_toolkit import GradientPTQConfigV2
17
+ from model_compression_toolkit.gptq import GradientPTQConfigV2
18
18
  from model_compression_toolkit.core import common
19
19
  from model_compression_toolkit.core.pytorch.constants import KERNEL
20
20
  from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizer import \
@@ -59,7 +59,7 @@ def quantization_builder(n: common.BaseNode,
59
59
  quant_method=quant_method,
60
60
  quantizer_base_class=BasePytorchGPTQTrainableQuantizer)
61
61
  weights_quantizers.update({KERNEL: quantizer_class(get_trainable_quantizer_weights_config(n),
62
- **gptq_config.get_extended_quantizer_parametes())})
62
+ **gptq_config.gptq_quantizer_params_override)})
63
63
  activation_quantizers = []
64
64
  if n.is_activation_quantization_enabled():
65
65
  quant_method = n.final_activation_quantization_cfg.activation_quantization_method
@@ -0,0 +1,45 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import Callable
16
+
17
+ from model_compression_toolkit.gptq import RoundingType, GradientPTQConfigV2, GradientPTQConfig
18
+ from model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.soft_quantizer_reg import \
19
+ SoftQuantizerRegularization
20
+
21
+
22
+ def get_regularization(gptq_config: GradientPTQConfig, representative_data_gen: Callable) -> Callable:
23
+ """
24
+ Returns a function that computes the regularization term for GPTQ training based on the given
25
+ rounding type in the GPTQ configuration.
26
+
27
+ Args:
28
+ gptq_config: A GPTQ configuration.
29
+ representative_data_gen: Dataset used for the GPTQ training.
30
+
31
+ Returns: A function for computing the regularization. If there is no regularization function defined for the given
32
+ rounding type, then it returns a function that just returns 0.
33
+
34
+ """
35
+ if gptq_config.rounding_type == RoundingType.SoftQuantizer:
36
+ # dry run on the representative dataset to count number of batches
37
+ num_batches = 0
38
+ for _ in representative_data_gen():
39
+ num_batches += 1
40
+
41
+ n_epochs = GradientPTQConfigV2.from_v1(n_ptq_iter=num_batches, config_v1=gptq_config).n_epochs if \
42
+ not type(gptq_config) == GradientPTQConfigV2 else gptq_config.n_epochs
43
+ return SoftQuantizerRegularization(total_gradient_steps=num_batches * n_epochs)
44
+ else:
45
+ return lambda m, e_reg: 0
@@ -0,0 +1,14 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,115 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from typing import List
16
+
17
+ import torch
18
+ import numpy as np
19
+ from torch import nn
20
+
21
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
22
+ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
23
+ from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
24
+ from model_compression_toolkit.quantizers_infrastructure import PytorchQuantizationWrapper
25
+
26
+
27
+ class LinearTempDecay:
28
+ """
29
+ Annealing process for the soft quantizer regularization temperature term.
30
+ """
31
+
32
+ def __init__(self, t_max: int, rel_start_decay: float = 0.2, start_b: int = 20, end_b: int = 2):
33
+ """
34
+ Initializes a LinearTempDecay object.
35
+
36
+ Args:
37
+ t_max: maximal time step.
38
+ rel_start_decay: Decay step size at the beginning of the process.
39
+ start_b: Starting value of the regularization term.
40
+ end_b: Target value of the regularization term.
41
+ """
42
+
43
+ self.t_max = t_max
44
+ self.start_decay = rel_start_decay * t_max
45
+ self.start_b = start_b
46
+ self.end_b = end_b
47
+
48
+ def __call__(self, t: float) -> float:
49
+ """
50
+ Cosine annealing scheduler for soft quantizer regularization temperature term.
51
+
52
+ Args:
53
+ t: The current time step.
54
+
55
+ Returns: Scheduled temperature.
56
+ """
57
+
58
+ is_before_start_decay = (t < self.start_decay)
59
+
60
+ rel_t = (t - self.start_decay) / (self.t_max - self.start_decay)
61
+
62
+ return self.start_b * is_before_start_decay + \
63
+ (1 - is_before_start_decay) * \
64
+ (self.end_b + (self.start_b - self.end_b) * torch.maximum(to_torch_tensor(np.array([0.0])),
65
+ to_torch_tensor(np.array((1 - rel_t)))))
66
+
67
+
68
+ class SoftQuantizerRegularization:
69
+ """
70
+ A class to handle the computation of soft quantizer regularization for GPTQ training.
71
+ """
72
+
73
+ def __init__(self, total_gradient_steps: int):
74
+ """
75
+ Initializes the regularization computation object with a LinearDecay object.
76
+
77
+ Args:
78
+ total_gradient_steps: The number of gradient steps during optimization.
79
+ """
80
+
81
+ # Initializing the temperature decay according to the number of expected gradient steps
82
+ self.linear_decay = LinearTempDecay(total_gradient_steps)
83
+
84
+ self.count_iter = 0
85
+
86
+ def __call__(self, model: nn.Module, entropy_reg: float):
87
+ """
88
+ Returns the soft quantizer regularization value for SoftRounding.
89
+
90
+ Args:
91
+ model: A model to be quantized with SoftRounding.
92
+ entropy_reg: Entropy value to scale the quantizer regularization.
93
+
94
+ Returns: Regularization value.
95
+ """
96
+
97
+ soft_reg_aux: List[torch.Tensor] = []
98
+ for layer in model.modules():
99
+ if isinstance(layer, PytorchQuantizationWrapper):
100
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
101
+ fw_info=DEFAULT_PYTORCH_INFO)
102
+
103
+ st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
104
+ b = self.linear_decay(self.count_iter)
105
+
106
+ soft_reg_aux.append((1 - torch.pow(torch.abs(st - .5) * 2, b)).sum())
107
+
108
+ reg = 0
109
+
110
+ for sq in soft_reg_aux:
111
+ reg += sq
112
+
113
+ self.count_iter += 1
114
+
115
+ return entropy_reg * reg
@@ -0,0 +1,236 @@
1
+ # Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import torch
16
+ import torch.nn as nn
17
+ from typing import Dict
18
+ import numpy as np
19
+
20
+ from model_compression_toolkit.core.common import max_power_of_two
21
+ from model_compression_toolkit import quantizers_infrastructure as qi
22
+ from model_compression_toolkit.core.common.target_platform import QuantizationMethod
23
+ from model_compression_toolkit.gptq.common.gptq_config import RoundingType
24
+ from model_compression_toolkit.gptq.pytorch.quantizer.base_pytorch_gptq_quantizer import \
25
+ BasePytorchGPTQTrainableQuantizer
26
+ from model_compression_toolkit.core.pytorch.utils import to_torch_tensor, torch_tensor_to_numpy
27
+ from model_compression_toolkit.gptq.pytorch.quantizer import quant_utils as qutils
28
+ from model_compression_toolkit.gptq.common.gptq_constants import PTQ_THRESHOLD, SCALE_PTQ, \
29
+ SOFT_ROUNDING_GAMMA, SOFT_ROUNDING_ZETA, AUXVAR
30
+ from model_compression_toolkit.core.common.constants import THRESHOLD, MIN_THRESHOLD
31
+ from model_compression_toolkit.quantizers_infrastructure import TrainableQuantizerWeightsConfig
32
+ from model_compression_toolkit.quantizers_infrastructure.inferable_infrastructure.common.base_inferable_quantizer import mark_quantizer
33
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.quant_utils import \
34
+ get_threshold_reshape_shape
35
+ from model_compression_toolkit.quantizers_infrastructure.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
36
+
37
+
38
+ def soft_rounding_symmetric_quantizer(input_tensor: torch.Tensor,
39
+ auxvar_tensor: torch.Tensor,
40
+ threshold_tensor: torch.Tensor,
41
+ num_bits: int,
42
+ signed: bool,
43
+ power_of_two: bool) -> torch.Tensor:
44
+ """
45
+ Quantize a tensor symmetrically for GPTQ quantizers.
46
+
47
+ Args:
48
+ input_tensor: Tensor to quantize. values of this tensor are not changed during gptq.
49
+ auxvar_tensor: Tensor that manifests the bit shift of the quantized weights due to gptq training.
50
+ threshold_tensor: Tensor with values to compute the threshold.
51
+ num_bits: Num of bits to use.
52
+ signed: Signedness of the quantization range.
53
+ power_of_two: Whether the threshold should be constrained or not.
54
+
55
+ Returns:
56
+ A quantized tensor.
57
+ """
58
+
59
+ if power_of_two:
60
+ threshold_tensor = qutils.power_of_two_max(threshold_tensor)
61
+ delta = qutils.calculate_delta(threshold_tensor, num_bits, signed)
62
+ with torch.no_grad():
63
+ input_tensor_int = torch.floor(input_tensor / delta)
64
+ tensor_q = input_tensor_int + auxvar_tensor
65
+ int_threshold = 2 ** (num_bits - int(signed))
66
+ return delta * qutils.ste_clip(tensor_q,
67
+ min_val=-int(signed) * int_threshold,
68
+ max_val=int_threshold - 1)
69
+
70
+
71
+ @mark_quantizer(quantization_target=qi.QuantizationTarget.Weights,
72
+ quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
73
+ quantizer_type=RoundingType.SoftQuantizer)
74
+ class SymmetricSoftRoundingGPTQ(BasePytorchGPTQTrainableQuantizer):
75
+ """
76
+ Trainable symmetric quantizer to optimize the rounding of the quantized values using a soft quantization method.
77
+ """
78
+
79
+ def __init__(self,
80
+ quantization_config: TrainableQuantizerWeightsConfig,
81
+ quantization_parameter_learning: bool = False):
82
+ """
83
+ Construct a Pytorch model that utilize a fake weight quantizer of soft-quantizer for symmetric quantizer.
84
+
85
+ Args:
86
+ quantization_config: Trainable weights quantizer config.
87
+ quantization_parameter_learning (Bool): Whether to learn the threshold or not
88
+ """
89
+
90
+ super().__init__(quantization_config)
91
+ self.num_bits = quantization_config.weights_n_bits
92
+ self.per_channel = quantization_config.weights_per_channel_threshold
93
+
94
+ threshold_values = quantization_config.weights_quantization_params[THRESHOLD]
95
+ self.threshold_shape = np.asarray(threshold_values).shape
96
+ self.threshold_values = np.reshape(np.asarray(threshold_values), [-1]) if self.per_channel else float(
97
+ threshold_values)
98
+
99
+ self.quantization_axis = quantization_config.weights_channels_axis
100
+ self.power_of_two = quantization_config.weights_quantization_method == QuantizationMethod.POWER_OF_TWO
101
+ self.quantization_parameter_learning = quantization_parameter_learning
102
+
103
+ # gamma and zeta are stretch parameters for computing the rectified sigmoind function.
104
+ # See: https://arxiv.org/pdf/2004.10568.pdf
105
+ self.gamma = SOFT_ROUNDING_GAMMA
106
+ self.zeta = SOFT_ROUNDING_ZETA
107
+
108
+ self.quantizer_parameters = {}
109
+
110
+ def initialize_quantization(self,
111
+ tensor_shape: torch.Size,
112
+ name: str,
113
+ layer: qi.PytorchQuantizationWrapper):
114
+ """
115
+ Add quantizer parameters to the quantizer parameters dictionary
116
+
117
+ Args:
118
+ tensor_shape: tensor shape of the quantized tensor.
119
+ name: Tensor name.
120
+ layer: Layer to quantize.
121
+ """
122
+
123
+ if self.per_channel:
124
+ threshold_tensor = to_torch_tensor(self.threshold_values)
125
+ else:
126
+ threshold_tensor = torch.tensor(self.threshold_values)
127
+ layer.register_parameter(f"{name}_{PTQ_THRESHOLD}",
128
+ nn.Parameter(threshold_tensor, requires_grad=False))
129
+
130
+ w = layer.layer.weight
131
+ delta = qutils.calculate_delta(threshold_tensor.reshape(self.threshold_shape), self.num_bits, signed=True)
132
+ w_clipped_normed = torch.clip(w / delta, -2**(self.num_bits-1), 2**(self.num_bits-1)-1)
133
+ rest = w_clipped_normed - torch.floor(w_clipped_normed) # rest of rounding [0, 1)
134
+ # Note that (rest - self.gamma) can't be zero since rest is positive and gamma is negative, so the division
135
+ # is safe
136
+ alpha = -torch.log((self.zeta - self.gamma) / (rest - self.gamma) - 1) # => sigmoid(alpha) = rest
137
+
138
+ layer.register_parameter(f"{name}_{AUXVAR}", nn.Parameter(alpha, requires_grad=True))
139
+
140
+ # save the quantizer added parameters for later calculations
141
+ self.add_quantizer_variable(PTQ_THRESHOLD, layer.get_parameter(f"{name}_{PTQ_THRESHOLD}"), VariableGroup.QPARAMS)
142
+ self.add_quantizer_variable(AUXVAR, layer.get_parameter(f"{name}_{AUXVAR}"), VariableGroup.WEIGHTS)
143
+
144
+ if self.quantization_parameter_learning:
145
+ layer.register_parameter(f"{name}_{SCALE_PTQ}",
146
+ nn.Parameter(to_torch_tensor(torch.ones_like(torch.Tensor(self.threshold_values))),
147
+ requires_grad=True))
148
+ self.add_quantizer_variable(SCALE_PTQ, layer.get_parameter(f"{name}_{SCALE_PTQ}"), VariableGroup.QPARAMS)
149
+
150
+ def get_soft_targets(self) -> torch.Tensor:
151
+ """
152
+ Computes the rectified sigmoid function for the quantization target parameters.
153
+
154
+ Returns:
155
+ A tensor with the soft rounding targets values.
156
+
157
+ """
158
+ scaled_sigmoid = torch.sigmoid(self.get_quantizer_variable(AUXVAR)) * (self.zeta - self.gamma) + self.gamma
159
+ return torch.clip(scaled_sigmoid, min=0, max=1)
160
+
161
+ def get_quant_config(self) -> Dict[str, np.ndarray]:
162
+ """
163
+ Returns the config used to edit NodeQuantizationConfig after GPTQ retraining
164
+
165
+ Returns:
166
+ A dictionary of attributes the quantize_config retraining has changed during GPTQ retraining.
167
+ Keys must match NodeQuantizationConfig attributes
168
+
169
+ """
170
+ old_threshold = torch_tensor_to_numpy(self.get_quantizer_variable(PTQ_THRESHOLD))
171
+ old_threshold = np.resize(old_threshold, self.threshold_shape)
172
+ if self.power_of_two:
173
+ old_threshold = max_power_of_two(old_threshold, MIN_THRESHOLD)
174
+ else:
175
+ if self.quantization_parameter_learning:
176
+ scale = torch.reshape(self.get_quantizer_variable(SCALE_PTQ), self.threshold_shape)
177
+ old_threshold = old_threshold * torch_tensor_to_numpy(scale)
178
+ old_threshold = old_threshold.reshape(self.threshold_shape)
179
+ return {THRESHOLD: old_threshold}
180
+
181
+ def __call__(self,
182
+ inputs: nn.Parameter,
183
+ training: bool) -> torch.Tensor:
184
+ """
185
+ Quantize a tensor.
186
+
187
+ Args:
188
+ inputs: Input tensor to quantize.
189
+ training: whether in training mode or not
190
+
191
+ Returns:
192
+ quantized tensor
193
+ """
194
+ auxvar = self.get_quantizer_variable(AUXVAR)
195
+ ptq_threshold_tensor = self.get_quantizer_variable(PTQ_THRESHOLD)
196
+
197
+ #####################################################
198
+ # Soft Rounding
199
+ #####################################################
200
+ aux_var = self.get_soft_targets()
201
+ if not training:
202
+ aux_var = (aux_var >= 0.5).to(auxvar.dtype)
203
+
204
+ if self.per_channel:
205
+ reshape_shape = get_threshold_reshape_shape(inputs.shape,
206
+ quant_axis=self.quantization_axis,
207
+ quant_axis_dim=-1)
208
+
209
+ ##########################################################
210
+ # Calculate soft rounding targets and optimized threshold
211
+ ##########################################################
212
+ ptq_threshold_tensor_hat = torch.reshape(ptq_threshold_tensor, reshape_shape)
213
+
214
+ #####################################################
215
+ # Quantized Input
216
+ #####################################################
217
+ q_tensor = soft_rounding_symmetric_quantizer(input_tensor=inputs,
218
+ auxvar_tensor=aux_var,
219
+ threshold_tensor=ptq_threshold_tensor_hat,
220
+ num_bits=self.num_bits,
221
+ signed=True,
222
+ power_of_two=self.power_of_two)
223
+
224
+ if self.quantization_parameter_learning and not self.power_of_two:
225
+ scale = torch.reshape(self.get_quantizer_variable(SCALE_PTQ), reshape_shape)
226
+ q_tensor *= scale
227
+
228
+ else:
229
+ q_tensor = soft_rounding_symmetric_quantizer(input_tensor=inputs,
230
+ auxvar_tensor=aux_var,
231
+ threshold_tensor=ptq_threshold_tensor,
232
+ num_bits=self.num_bits,
233
+ signed=True,
234
+ power_of_two=self.power_of_two)
235
+
236
+ return q_tensor