mct-nightly 2.2.0.20241111.513__py3-none-any.whl → 2.2.0.20241113.521__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.2.0.20241111.513
3
+ Version: 2.2.0.20241113.521
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=Fw-5L3IgVjoHACxhUUWu7J7Obhrbfc3uR1xU069wk4g,1573
1
+ model_compression_toolkit/__init__.py,sha256=FqZ6XbAbgDSAhK2i7UqlDeDmXsSUCvu0RyBhUeMPDp0,1573
2
2
  model_compression_toolkit/constants.py,sha256=i4wYheBkIdQmsQA-axIpcT3YiSO1USNc-jaNiNE8w6E,3920
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -47,7 +47,7 @@ model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py,sha256
47
47
  model_compression_toolkit/core/common/graph/memory_graph/memory_element.py,sha256=gRmBEFRmyJsNKezQfiwDwQu1cmbGd2wgKCRTH6iw8mw,3961
48
48
  model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py,sha256=gw4av_rzn_3oEAPpD3B7PHZDqnxHMjIESevl6ppPnkk,7175
49
49
  model_compression_toolkit/core/common/hessian/__init__.py,sha256=E7LK3K_1AwMCQokanNc1JODMwUKNOKmwXQiGQ7GO10I,1033
50
- model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=OH8Xadv0ZSD_yoymgSfaNg8tqr4vxUfAbNLCBMRz6pQ,13233
50
+ model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=YynbVHdHH2gPlk1QHXH6GygIkXRZ9qxR14cpgKrHPT0,13238
51
51
  model_compression_toolkit/core/common/hessian/hessian_info_utils.py,sha256=1axmN0tjJSo_7hUr2d2KMv4y1pBi19cqWSQpi4BbdsA,1458
52
52
  model_compression_toolkit/core/common/hessian/hessian_scores_calculator.py,sha256=Pe4uKerx-MeDQPJ7Slr8fvFUHfv02q33w3gbQK5kBKs,4186
53
53
  model_compression_toolkit/core/common/hessian/hessian_scores_request.py,sha256=U2n5fz6fK633HWzIvEuQ7N6dekMqH9-DecOXAgd3v4E,3140
@@ -350,19 +350,19 @@ model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantiz
350
350
  model_compression_toolkit/gptq/__init__.py,sha256=pEgkJvmf05KSw70iLDTz_6LI_2Oi5L8sTN0JsEUpnpk,1445
351
351
  model_compression_toolkit/gptq/runner.py,sha256=La12JTYjWyJW0YW4Al4TP1_Xi4JWBCEKw6FR_JQsxe0,5982
352
352
  model_compression_toolkit/gptq/common/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
353
- model_compression_toolkit/gptq/common/gptq_config.py,sha256=Z6T5B3q4k2Tlr2bBWvC6TAF3d2opyA7ZT_D_mz6D1_0,6297
354
- model_compression_toolkit/gptq/common/gptq_constants.py,sha256=D1x2n4-NdAx6g_1Wc2hwwh4vX9vmx5VnQWN26H107kg,766
353
+ model_compression_toolkit/gptq/common/gptq_config.py,sha256=QwSEZZlC6OpnpoBQoAFfgXTrdBgewgqlgaCV2hoJEso,6143
354
+ model_compression_toolkit/gptq/common/gptq_constants.py,sha256=8HB0yiX75zZ1IKgQUPWpFCM5sS8HAqslws5XrOhxJQ0,750
355
355
  model_compression_toolkit/gptq/common/gptq_framework_implementation.py,sha256=n3mSf4J92kFjekzyGyrJULylI-8Jf5OVWJ5AFoVnEx0,1266
356
356
  model_compression_toolkit/gptq/common/gptq_graph.py,sha256=-bL5HhPcKqV8nj4dZPXc5QmQJbFBel6etrioikP0tEo,3039
357
- model_compression_toolkit/gptq/common/gptq_training.py,sha256=tt4O8PjSChquzl4c6NojvQWZmvCdTxcMLtmEVIGx1ns,13252
357
+ model_compression_toolkit/gptq/common/gptq_training.py,sha256=EnG-17U6kGDgTeMkOJQmRoMs0KUldROss683_Bo5oHQ,13249
358
358
  model_compression_toolkit/gptq/common/gradual_activation_quantization.py,sha256=EgpzMs_aDoB0wQiTagqvcxCTfrgNUuCfdXEXmfNiyb0,3780
359
359
  model_compression_toolkit/gptq/common/regularization_factory.py,sha256=hyunpXepVeHyoAFJw6zNLK-3ZHBmiut3lmNisJN_L3E,2514
360
360
  model_compression_toolkit/gptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
361
361
  model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCSjq5xk-xGymOwSOqjp39It-CVtGcCTRTf0E_4,1248
362
362
  model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
363
- model_compression_toolkit/gptq/keras/gptq_training.py,sha256=TEWqAU8JZnZVZ-dIkINA0x1NmSrYpEkXTdG835JdKnI,20848
363
+ model_compression_toolkit/gptq/keras/gptq_training.py,sha256=yBiAod9hbzh2bp4xhVO5szmtCHm6bLUa7-kjUVVwo40,20845
364
364
  model_compression_toolkit/gptq/keras/graph_info.py,sha256=MKIfrRTRH3zCuxCR1g9ZVIFyuSSr0e0sDybqh4LDM7E,4672
365
- model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=DhEEpW0rK4JRdk5WQlN-_DOUuzlwOBqpiwTBOySjn2g,16820
365
+ model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=e3O835Ol5ML0XuqNsCmoTbnnfs-gEgrSGT1ijUZLX7Q,17102
366
366
  model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
367
367
  model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=Rbl9urzkmACvVxICSEyJ02qFOBxWK0UQWtysFJzBVZw,4899
368
368
  model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
@@ -376,9 +376,9 @@ model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py,sha
376
376
  model_compression_toolkit/gptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
377
377
  model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=_07Zx_43bnNokwR5S8phIqeu5-_7_5VBT4DT-FCw7Do,3892
378
378
  model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
379
- model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=2KwJFlJj6hFJClsJbC9aaWDAGbZUNDbSx1d-QX4LShc,22132
379
+ model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=iuZJcoG2w-7qjWGntXWTdU2XUuMPy5IwzZbiolThuI4,22145
380
380
  model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=4mVM-VvnBaA64ACVdOe6wTGHdMSa2UTLIUe7nACLcdo,4008
381
- model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=lY7_lNtS1SqaaJ0gc6C7_HO71bBalsxQY37QQlWpu70,15479
381
+ model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=hZFU_ZY-LYcpRZyzzX7NsJievkIYKGdkgBzEoB4rsRQ,16020
382
382
  model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
383
383
  model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=fKg-PNOhGBiL-4eySS9Fyw0GkA76Pq8jT_HbJuJ8iZU,4143
384
384
  model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
@@ -558,8 +558,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
558
558
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=bOc-hFL3gdoSM1Th_S2N_-9JJSlPGpZCTx_QLJHS6lg,3388
559
559
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
560
560
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
561
- mct_nightly-2.2.0.20241111.513.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
562
- mct_nightly-2.2.0.20241111.513.dist-info/METADATA,sha256=McWeuqci7NFAVzdRUjfGJWF-nPncUEACLapRtf-Dx2Y,20830
563
- mct_nightly-2.2.0.20241111.513.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
564
- mct_nightly-2.2.0.20241111.513.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
565
- mct_nightly-2.2.0.20241111.513.dist-info/RECORD,,
561
+ mct_nightly-2.2.0.20241113.521.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
562
+ mct_nightly-2.2.0.20241113.521.dist-info/METADATA,sha256=vVYCluqgh4ApdcolpQzjr1vaNDLvkAqWpZAYn6kLz3I,20830
563
+ mct_nightly-2.2.0.20241113.521.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
564
+ mct_nightly-2.2.0.20241113.521.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
565
+ mct_nightly-2.2.0.20241113.521.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.2.0.20241111.000513"
30
+ __version__ = "2.2.0.20241113.000521"
@@ -204,7 +204,7 @@ class HessianInfoService:
204
204
  target_nodes = [n for n in orig_request.target_nodes if n.name in missing]
205
205
  request = request.clone(target_nodes=target_nodes)
206
206
  self._compute_hessians(request, n_iterations, count_by_cache=True)
207
- res, missing = self.cache.fetch_hessian(request)
207
+ res, missing = self.cache.fetch_hessian(orig_request)
208
208
  assert not missing
209
209
  return res
210
210
 
@@ -16,8 +16,7 @@ from dataclasses import dataclass, field
16
16
  from enum import Enum
17
17
  from typing import Callable, Any, Dict, Optional
18
18
 
19
- from model_compression_toolkit.constants import GPTQ_HESSIAN_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE
20
- from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
19
+ from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE
21
20
 
22
21
 
23
22
  class RoundingType(Enum):
@@ -39,20 +38,26 @@ class GPTQHessianScoresConfig:
39
38
  Configuration to use for computing the Hessian-based scores for GPTQ loss metric.
40
39
 
41
40
  Args:
41
+ per_sample (bool): Whether to use per sample attention score.
42
42
  hessians_num_samples (int|None): Number of samples to use for computing the Hessian-based scores.
43
43
  If None, compute Hessian for all images.
44
44
  norm_scores (bool): Whether to normalize the returned scores of the weighted loss function (to get values between 0 and 1).
45
45
  log_norm (bool): Whether to use log normalization for the GPTQ Hessian-based scores.
46
46
  scale_log_norm (bool): Whether to scale the final vector of the Hessian-based scores.
47
47
  hessian_batch_size (int): The Hessian computation batch size. used only if using GPTQ with Hessian-based objective.
48
- per_sample (bool): Whether to use per sample attention score.
49
48
  """
50
- hessians_num_samples: Optional[int] = GPTQ_HESSIAN_NUM_SAMPLES
51
- norm_scores: bool = True
52
- log_norm: bool = True
49
+ per_sample: bool
50
+ hessians_num_samples: Optional[int]
51
+ norm_scores: bool = None
52
+ log_norm: bool = None
53
53
  scale_log_norm: bool = False
54
54
  hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE
55
- per_sample: bool = False
55
+
56
+ def __post_init__(self):
57
+ if self.norm_scores is None:
58
+ self.norm_scores = not self.per_sample
59
+ if self.log_norm is None:
60
+ self.log_norm = not self.per_sample
56
61
 
57
62
 
58
63
  @dataclass
@@ -107,32 +112,30 @@ class GradientPTQConfig:
107
112
 
108
113
  Args:
109
114
  n_epochs: Number of representative dataset epochs to train.
110
- optimizer: Optimizer to use.
111
- optimizer_rest: Optimizer to use for bias and quantizer parameters.
112
115
  loss: The loss to use. See 'multiple_tensors_mse_loss' for the expected interface.
113
- log_function: Function to log information about the GPTQ process.
116
+ optimizer: Optimizer to use.
117
+ optimizer_rest: Default optimizer to use for bias and quantizer parameters.
114
118
  train_bias: Whether to update the bias during the training or not.
115
- rounding_type: An enum that defines the rounding type.
116
- use_hessian_based_weights: Whether to use Hessian-based weights for weighted average loss.
117
- optimizer_quantization_parameter: Optimizer to override the rest optimizer for quantizer parameters.
118
- optimizer_bias: Optimizer to override the rest optimizer for bias.
119
- regularization_factor: A floating point number that defines the regularization factor.
120
119
  hessian_weights_config: A configuration that include all necessary arguments to run a computation of
121
120
  Hessian scores for the GPTQ loss.
122
121
  gradual_activation_quantization_config: A configuration for Gradual Activation Quantization.
122
+ regularization_factor: A floating point number that defines the regularization factor.
123
+ rounding_type: An enum that defines the rounding type.
124
+ optimizer_quantization_parameter: Optimizer to override the rest optimizer for quantizer parameters.
125
+ optimizer_bias: Optimizer to override the rest optimizer for bias.
126
+ log_function: Function to log information about the GPTQ process.
123
127
  gptq_quantizer_params_override: A dictionary of parameters to override in GPTQ quantizer instantiation.
124
128
  """
125
129
  n_epochs: int
130
+ loss: Callable
126
131
  optimizer: Any
127
- optimizer_rest: Any = None
128
- loss: Callable = None
129
- log_function: Callable = None
130
- train_bias: bool = True
132
+ optimizer_rest: Any
133
+ train_bias: bool
134
+ hessian_weights_config: Optional[GPTQHessianScoresConfig]
135
+ gradual_activation_quantization_config: Optional[GradualActivationQuantizationConfig]
136
+ regularization_factor: float
131
137
  rounding_type: RoundingType = RoundingType.SoftQuantizer
132
- use_hessian_based_weights: bool = True
133
138
  optimizer_quantization_parameter: Any = None
134
139
  optimizer_bias: Any = None
135
- regularization_factor: float = REG_DEFAULT
136
- hessian_weights_config: GPTQHessianScoresConfig = field(default_factory=GPTQHessianScoresConfig)
137
- gradual_activation_quantization_config: Optional[GradualActivationQuantizationConfig] = None
140
+ log_function: Callable = None
138
141
  gptq_quantizer_params_override: Dict[str, Any] = field(default_factory=dict)
@@ -14,6 +14,7 @@ N_CYCLES = 4
14
14
  MIM_TEMP = 0.5
15
15
  MAX_TEMP = 1.0
16
16
  REG_DEFAULT = 0.01
17
+ REG_DEFAULT_SLA = 10
17
18
  MAX_LSB_CHANGE = 1
18
19
 
19
20
  # Soft rounding arguments values
@@ -27,6 +28,5 @@ MAX_LSB_STR = 'max_lsbs_change_map'
27
28
  # GPTQ learning hyperparameters
28
29
  LR_DEFAULT = 3e-2
29
30
  LR_REST_DEFAULT = 1e-4
30
- LR_BIAS_DEFAULT = 1e-3
31
- LR_QUANTIZATION_PARAM_DEFAULT = 1e-3
31
+ LR_BIAS_DEFAULT = 1e-4
32
32
  GPTQ_MOMENTUM = 0.9
@@ -75,7 +75,7 @@ class GPTQTrainer(ABC):
75
75
  fw_info=self.fw_info)
76
76
 
77
77
  self.fxp_model, self.gptq_user_info = self.build_gptq_model()
78
- if self.gptq_config.use_hessian_based_weights:
78
+ if self.gptq_config.hessian_weights_config:
79
79
  if not isinstance(hessian_info_service, HessianInfoService):
80
80
  Logger.critical(f"When using Hessian-based approximations for sensitivity evaluation, "
81
81
  f"an 'HessianInfoService' object must be provided, but received: {hessian_info_service}.") # pragma: no cover
@@ -139,7 +139,7 @@ class KerasGPTQTrainer(GPTQTrainer):
139
139
 
140
140
  def _get_compare_points_loss_weights(self):
141
141
  """ Get compare points weights for the distillation loss. """
142
- if self.gptq_config.use_hessian_based_weights:
142
+ if self.gptq_config.hessian_weights_config:
143
143
  hess_dataloader = data_gen_to_dataloader(self.representative_data_gen_fn,
144
144
  batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size)
145
145
  return self.compute_hessian_based_weights(hess_dataloader)
@@ -21,7 +21,7 @@ from model_compression_toolkit.core.common.visualization.tensorboard_writer impo
21
21
  from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR_DEFAULT, LR_REST_DEFAULT, \
22
22
  LR_BIAS_DEFAULT, GPTQ_MOMENTUM
23
23
  from model_compression_toolkit.logger import Logger
24
- from model_compression_toolkit.constants import TENSORFLOW, ACT_HESSIAN_DEFAULT_BATCH_SIZE
24
+ from model_compression_toolkit.constants import TENSORFLOW, ACT_HESSIAN_DEFAULT_BATCH_SIZE, GPTQ_HESSIAN_NUM_SAMPLES
25
25
  from model_compression_toolkit.verify_packages import FOUND_TF
26
26
  from model_compression_toolkit.core.common.user_info import UserInformation
27
27
  from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GPTQHessianScoresConfig, \
@@ -117,16 +117,20 @@ if FOUND_TF:
117
117
  raise TypeError(f'gradual_activation_quantization argument should be bool or '
118
118
  f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}')
119
119
 
120
- return GradientPTQConfig(n_epochs,
121
- optimizer,
120
+ hessian_weights_config = None
121
+ if use_hessian_based_weights:
122
+ hessian_weights_config = GPTQHessianScoresConfig(per_sample=False,
123
+ hessians_num_samples=GPTQ_HESSIAN_NUM_SAMPLES,
124
+ hessian_batch_size=hessian_batch_size)
125
+ return GradientPTQConfig(n_epochs=n_epochs,
126
+ optimizer=optimizer,
122
127
  optimizer_rest=optimizer_rest,
123
128
  loss=loss,
124
129
  log_function=log_function,
125
130
  train_bias=True,
126
131
  optimizer_bias=bias_optimizer,
127
- use_hessian_based_weights=use_hessian_based_weights,
128
132
  regularization_factor=regularization_factor,
129
- hessian_weights_config=GPTQHessianScoresConfig(hessian_batch_size=hessian_batch_size),
133
+ hessian_weights_config=hessian_weights_config,
130
134
  gradual_activation_quantization_config=gradual_quant_config)
131
135
 
132
136
 
@@ -116,7 +116,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
116
116
  trainable_threshold)
117
117
  hessian_cfg = self.gptq_config.hessian_weights_config
118
118
 
119
- self.use_sample_layer_attention = hessian_cfg.per_sample
119
+ self.use_sample_layer_attention = hessian_cfg and hessian_cfg.per_sample
120
120
  if self.use_sample_layer_attention:
121
121
  # normalization is currently not supported, make sure the config reflects it.
122
122
  if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm:
@@ -178,7 +178,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
178
178
  dataset = IterableDatasetFromGenerator(data_gen_fn)
179
179
  num_nodes = len(self.compare_points)
180
180
 
181
- if self.gptq_config.use_hessian_based_weights:
181
+ if self.gptq_config.hessian_weights_config:
182
182
  hess_dataloader = DataLoader(dataset, batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size)
183
183
  loss_weights = torch.from_numpy(self.compute_hessian_based_weights(hess_dataloader))
184
184
  else:
@@ -15,7 +15,7 @@
15
15
  import copy
16
16
  from typing import Callable, Union
17
17
 
18
- from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, PYTORCH
18
+ from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, PYTORCH, GPTQ_HESSIAN_NUM_SAMPLES
19
19
  from model_compression_toolkit.core import CoreConfig
20
20
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
21
21
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
@@ -27,7 +27,7 @@ from model_compression_toolkit.core.runner import core_runner
27
27
  from model_compression_toolkit.gptq.common.gptq_config import (
28
28
  GradientPTQConfig, GPTQHessianScoresConfig, GradualActivationQuantizationConfig)
29
29
  from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR_DEFAULT, LR_REST_DEFAULT, \
30
- LR_BIAS_DEFAULT, GPTQ_MOMENTUM
30
+ LR_BIAS_DEFAULT, GPTQ_MOMENTUM, REG_DEFAULT_SLA
31
31
  from model_compression_toolkit.gptq.runner import gptq_runner
32
32
  from model_compression_toolkit.logger import Logger
33
33
  from model_compression_toolkit.metadata import create_model_metadata
@@ -55,10 +55,10 @@ if FOUND_TORCH:
55
55
  loss: Callable = None,
56
56
  log_function: Callable = None,
57
57
  use_hessian_based_weights: bool = True,
58
- regularization_factor: float = REG_DEFAULT,
58
+ regularization_factor: float = None,
59
59
  hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE,
60
- use_hessian_sample_attention: bool = False,
61
- gradual_activation_quantization: Union[bool, GradualActivationQuantizationConfig] = False,
60
+ use_hessian_sample_attention: bool = True,
61
+ gradual_activation_quantization: Union[bool, GradualActivationQuantizationConfig] = True,
62
62
  ) -> GradientPTQConfig:
63
63
  """
64
64
  Create a GradientPTQConfig instance for Pytorch models.
@@ -94,25 +94,26 @@ if FOUND_TORCH:
94
94
  """
95
95
  optimizer = optimizer or Adam([torch.Tensor([])], lr=LR_DEFAULT)
96
96
  optimizer_rest = optimizer_rest or Adam([torch.Tensor([])], lr=LR_REST_DEFAULT)
97
-
97
+ # TODO this contradicts the docstring for optimizer_rest
98
98
  bias_optimizer = torch.optim.SGD([torch.Tensor([])], lr=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
99
99
 
100
+ if regularization_factor is None:
101
+ regularization_factor = REG_DEFAULT_SLA if use_hessian_sample_attention else REG_DEFAULT
102
+
103
+ loss = loss or multiple_tensors_mse_loss
104
+ hessian_weights_config = None
100
105
  if use_hessian_sample_attention:
101
106
  if not use_hessian_based_weights: # pragma: no cover
102
107
  raise ValueError('use_hessian_based_weights must be set to True in order to use Sample Layer Attention.')
103
108
 
104
- hessian_weights_config = GPTQHessianScoresConfig(
105
- hessians_num_samples=None,
106
- norm_scores=False,
107
- log_norm=False,
108
- scale_log_norm=False,
109
- hessian_batch_size=hessian_batch_size,
110
- per_sample=True,
111
- )
109
+ hessian_weights_config = GPTQHessianScoresConfig(per_sample=True,
110
+ hessians_num_samples=None,
111
+ hessian_batch_size=hessian_batch_size)
112
112
  loss = loss or sample_layer_attention_loss
113
- else:
114
- hessian_weights_config = GPTQHessianScoresConfig(hessian_batch_size=hessian_batch_size)
115
- loss = loss or multiple_tensors_mse_loss
113
+ elif use_hessian_based_weights:
114
+ hessian_weights_config = GPTQHessianScoresConfig(per_sample=False,
115
+ hessians_num_samples=GPTQ_HESSIAN_NUM_SAMPLES,
116
+ hessian_batch_size=hessian_batch_size)
116
117
 
117
118
  if isinstance(gradual_activation_quantization, bool):
118
119
  gradual_quant_config = GradualActivationQuantizationConfig() if gradual_activation_quantization else None
@@ -122,12 +123,16 @@ if FOUND_TORCH:
122
123
  raise TypeError(f'gradual_activation_quantization argument should be bool or '
123
124
  f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}')
124
125
 
125
- return GradientPTQConfig(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
126
- log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer,
127
- use_hessian_based_weights=use_hessian_based_weights,
126
+ return GradientPTQConfig(n_epochs=n_epochs,
127
+ loss=loss,
128
+ optimizer=optimizer,
129
+ optimizer_rest=optimizer_rest,
130
+ optimizer_bias=bias_optimizer,
131
+ train_bias=True,
128
132
  regularization_factor=regularization_factor,
129
133
  hessian_weights_config=hessian_weights_config,
130
- gradual_activation_quantization_config=gradual_quant_config)
134
+ gradual_activation_quantization_config=gradual_quant_config,
135
+ log_function=log_function)
131
136
 
132
137
  def pytorch_gradient_post_training_quantization(model: Module,
133
138
  representative_data_gen: Callable,