mct-nightly 2.2.0.20241111.513__py3-none-any.whl → 2.2.0.20241113.521__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20241111.513.dist-info → mct_nightly-2.2.0.20241113.521.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20241111.513.dist-info → mct_nightly-2.2.0.20241113.521.dist-info}/RECORD +14 -14
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/hessian/hessian_info_service.py +1 -1
- model_compression_toolkit/gptq/common/gptq_config.py +26 -23
- model_compression_toolkit/gptq/common/gptq_constants.py +2 -2
- model_compression_toolkit/gptq/common/gptq_training.py +1 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +1 -1
- model_compression_toolkit/gptq/keras/quantization_facade.py +9 -5
- model_compression_toolkit/gptq/pytorch/gptq_training.py +2 -2
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +26 -21
- {mct_nightly-2.2.0.20241111.513.dist-info → mct_nightly-2.2.0.20241113.521.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20241111.513.dist-info → mct_nightly-2.2.0.20241113.521.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20241111.513.dist-info → mct_nightly-2.2.0.20241113.521.dist-info}/top_level.txt +0 -0
{mct_nightly-2.2.0.20241111.513.dist-info → mct_nightly-2.2.0.20241113.521.dist-info}/RECORD
RENAMED
@@ -1,4 +1,4 @@
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
model_compression_toolkit/__init__.py,sha256=FqZ6XbAbgDSAhK2i7UqlDeDmXsSUCvu0RyBhUeMPDp0,1573
|
2
2
|
model_compression_toolkit/constants.py,sha256=i4wYheBkIdQmsQA-axIpcT3YiSO1USNc-jaNiNE8w6E,3920
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
4
4
|
model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
|
@@ -47,7 +47,7 @@ model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py,sha256
|
|
47
47
|
model_compression_toolkit/core/common/graph/memory_graph/memory_element.py,sha256=gRmBEFRmyJsNKezQfiwDwQu1cmbGd2wgKCRTH6iw8mw,3961
|
48
48
|
model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py,sha256=gw4av_rzn_3oEAPpD3B7PHZDqnxHMjIESevl6ppPnkk,7175
|
49
49
|
model_compression_toolkit/core/common/hessian/__init__.py,sha256=E7LK3K_1AwMCQokanNc1JODMwUKNOKmwXQiGQ7GO10I,1033
|
50
|
-
model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=
|
50
|
+
model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=YynbVHdHH2gPlk1QHXH6GygIkXRZ9qxR14cpgKrHPT0,13238
|
51
51
|
model_compression_toolkit/core/common/hessian/hessian_info_utils.py,sha256=1axmN0tjJSo_7hUr2d2KMv4y1pBi19cqWSQpi4BbdsA,1458
|
52
52
|
model_compression_toolkit/core/common/hessian/hessian_scores_calculator.py,sha256=Pe4uKerx-MeDQPJ7Slr8fvFUHfv02q33w3gbQK5kBKs,4186
|
53
53
|
model_compression_toolkit/core/common/hessian/hessian_scores_request.py,sha256=U2n5fz6fK633HWzIvEuQ7N6dekMqH9-DecOXAgd3v4E,3140
|
@@ -350,19 +350,19 @@ model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantiz
|
|
350
350
|
model_compression_toolkit/gptq/__init__.py,sha256=pEgkJvmf05KSw70iLDTz_6LI_2Oi5L8sTN0JsEUpnpk,1445
|
351
351
|
model_compression_toolkit/gptq/runner.py,sha256=La12JTYjWyJW0YW4Al4TP1_Xi4JWBCEKw6FR_JQsxe0,5982
|
352
352
|
model_compression_toolkit/gptq/common/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
353
|
-
model_compression_toolkit/gptq/common/gptq_config.py,sha256=
|
354
|
-
model_compression_toolkit/gptq/common/gptq_constants.py,sha256=
|
353
|
+
model_compression_toolkit/gptq/common/gptq_config.py,sha256=QwSEZZlC6OpnpoBQoAFfgXTrdBgewgqlgaCV2hoJEso,6143
|
354
|
+
model_compression_toolkit/gptq/common/gptq_constants.py,sha256=8HB0yiX75zZ1IKgQUPWpFCM5sS8HAqslws5XrOhxJQ0,750
|
355
355
|
model_compression_toolkit/gptq/common/gptq_framework_implementation.py,sha256=n3mSf4J92kFjekzyGyrJULylI-8Jf5OVWJ5AFoVnEx0,1266
|
356
356
|
model_compression_toolkit/gptq/common/gptq_graph.py,sha256=-bL5HhPcKqV8nj4dZPXc5QmQJbFBel6etrioikP0tEo,3039
|
357
|
-
model_compression_toolkit/gptq/common/gptq_training.py,sha256=
|
357
|
+
model_compression_toolkit/gptq/common/gptq_training.py,sha256=EnG-17U6kGDgTeMkOJQmRoMs0KUldROss683_Bo5oHQ,13249
|
358
358
|
model_compression_toolkit/gptq/common/gradual_activation_quantization.py,sha256=EgpzMs_aDoB0wQiTagqvcxCTfrgNUuCfdXEXmfNiyb0,3780
|
359
359
|
model_compression_toolkit/gptq/common/regularization_factory.py,sha256=hyunpXepVeHyoAFJw6zNLK-3ZHBmiut3lmNisJN_L3E,2514
|
360
360
|
model_compression_toolkit/gptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
361
361
|
model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCSjq5xk-xGymOwSOqjp39It-CVtGcCTRTf0E_4,1248
|
362
362
|
model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
|
363
|
-
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=
|
363
|
+
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=yBiAod9hbzh2bp4xhVO5szmtCHm6bLUa7-kjUVVwo40,20845
|
364
364
|
model_compression_toolkit/gptq/keras/graph_info.py,sha256=MKIfrRTRH3zCuxCR1g9ZVIFyuSSr0e0sDybqh4LDM7E,4672
|
365
|
-
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=
|
365
|
+
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=e3O835Ol5ML0XuqNsCmoTbnnfs-gEgrSGT1ijUZLX7Q,17102
|
366
366
|
model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
|
367
367
|
model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=Rbl9urzkmACvVxICSEyJ02qFOBxWK0UQWtysFJzBVZw,4899
|
368
368
|
model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
|
@@ -376,9 +376,9 @@ model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py,sha
|
|
376
376
|
model_compression_toolkit/gptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
377
377
|
model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=_07Zx_43bnNokwR5S8phIqeu5-_7_5VBT4DT-FCw7Do,3892
|
378
378
|
model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
|
379
|
-
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=
|
379
|
+
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=iuZJcoG2w-7qjWGntXWTdU2XUuMPy5IwzZbiolThuI4,22145
|
380
380
|
model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=4mVM-VvnBaA64ACVdOe6wTGHdMSa2UTLIUe7nACLcdo,4008
|
381
|
-
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=
|
381
|
+
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=hZFU_ZY-LYcpRZyzzX7NsJievkIYKGdkgBzEoB4rsRQ,16020
|
382
382
|
model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
|
383
383
|
model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=fKg-PNOhGBiL-4eySS9Fyw0GkA76Pq8jT_HbJuJ8iZU,4143
|
384
384
|
model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
|
@@ -558,8 +558,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
558
558
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=bOc-hFL3gdoSM1Th_S2N_-9JJSlPGpZCTx_QLJHS6lg,3388
|
559
559
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
560
560
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
|
561
|
-
mct_nightly-2.2.0.
|
562
|
-
mct_nightly-2.2.0.
|
563
|
-
mct_nightly-2.2.0.
|
564
|
-
mct_nightly-2.2.0.
|
565
|
-
mct_nightly-2.2.0.
|
561
|
+
mct_nightly-2.2.0.20241113.521.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
562
|
+
mct_nightly-2.2.0.20241113.521.dist-info/METADATA,sha256=vVYCluqgh4ApdcolpQzjr1vaNDLvkAqWpZAYn6kLz3I,20830
|
563
|
+
mct_nightly-2.2.0.20241113.521.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
|
564
|
+
mct_nightly-2.2.0.20241113.521.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
565
|
+
mct_nightly-2.2.0.20241113.521.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.2.0.
|
30
|
+
__version__ = "2.2.0.20241113.000521"
|
@@ -204,7 +204,7 @@ class HessianInfoService:
|
|
204
204
|
target_nodes = [n for n in orig_request.target_nodes if n.name in missing]
|
205
205
|
request = request.clone(target_nodes=target_nodes)
|
206
206
|
self._compute_hessians(request, n_iterations, count_by_cache=True)
|
207
|
-
res, missing = self.cache.fetch_hessian(
|
207
|
+
res, missing = self.cache.fetch_hessian(orig_request)
|
208
208
|
assert not missing
|
209
209
|
return res
|
210
210
|
|
@@ -16,8 +16,7 @@ from dataclasses import dataclass, field
|
|
16
16
|
from enum import Enum
|
17
17
|
from typing import Callable, Any, Dict, Optional
|
18
18
|
|
19
|
-
from model_compression_toolkit.constants import
|
20
|
-
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
|
19
|
+
from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE
|
21
20
|
|
22
21
|
|
23
22
|
class RoundingType(Enum):
|
@@ -39,20 +38,26 @@ class GPTQHessianScoresConfig:
|
|
39
38
|
Configuration to use for computing the Hessian-based scores for GPTQ loss metric.
|
40
39
|
|
41
40
|
Args:
|
41
|
+
per_sample (bool): Whether to use per sample attention score.
|
42
42
|
hessians_num_samples (int|None): Number of samples to use for computing the Hessian-based scores.
|
43
43
|
If None, compute Hessian for all images.
|
44
44
|
norm_scores (bool): Whether to normalize the returned scores of the weighted loss function (to get values between 0 and 1).
|
45
45
|
log_norm (bool): Whether to use log normalization for the GPTQ Hessian-based scores.
|
46
46
|
scale_log_norm (bool): Whether to scale the final vector of the Hessian-based scores.
|
47
47
|
hessian_batch_size (int): The Hessian computation batch size. used only if using GPTQ with Hessian-based objective.
|
48
|
-
per_sample (bool): Whether to use per sample attention score.
|
49
48
|
"""
|
50
|
-
|
51
|
-
|
52
|
-
|
49
|
+
per_sample: bool
|
50
|
+
hessians_num_samples: Optional[int]
|
51
|
+
norm_scores: bool = None
|
52
|
+
log_norm: bool = None
|
53
53
|
scale_log_norm: bool = False
|
54
54
|
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE
|
55
|
-
|
55
|
+
|
56
|
+
def __post_init__(self):
|
57
|
+
if self.norm_scores is None:
|
58
|
+
self.norm_scores = not self.per_sample
|
59
|
+
if self.log_norm is None:
|
60
|
+
self.log_norm = not self.per_sample
|
56
61
|
|
57
62
|
|
58
63
|
@dataclass
|
@@ -107,32 +112,30 @@ class GradientPTQConfig:
|
|
107
112
|
|
108
113
|
Args:
|
109
114
|
n_epochs: Number of representative dataset epochs to train.
|
110
|
-
optimizer: Optimizer to use.
|
111
|
-
optimizer_rest: Optimizer to use for bias and quantizer parameters.
|
112
115
|
loss: The loss to use. See 'multiple_tensors_mse_loss' for the expected interface.
|
113
|
-
|
116
|
+
optimizer: Optimizer to use.
|
117
|
+
optimizer_rest: Default optimizer to use for bias and quantizer parameters.
|
114
118
|
train_bias: Whether to update the bias during the training or not.
|
115
|
-
rounding_type: An enum that defines the rounding type.
|
116
|
-
use_hessian_based_weights: Whether to use Hessian-based weights for weighted average loss.
|
117
|
-
optimizer_quantization_parameter: Optimizer to override the rest optimizer for quantizer parameters.
|
118
|
-
optimizer_bias: Optimizer to override the rest optimizer for bias.
|
119
|
-
regularization_factor: A floating point number that defines the regularization factor.
|
120
119
|
hessian_weights_config: A configuration that include all necessary arguments to run a computation of
|
121
120
|
Hessian scores for the GPTQ loss.
|
122
121
|
gradual_activation_quantization_config: A configuration for Gradual Activation Quantization.
|
122
|
+
regularization_factor: A floating point number that defines the regularization factor.
|
123
|
+
rounding_type: An enum that defines the rounding type.
|
124
|
+
optimizer_quantization_parameter: Optimizer to override the rest optimizer for quantizer parameters.
|
125
|
+
optimizer_bias: Optimizer to override the rest optimizer for bias.
|
126
|
+
log_function: Function to log information about the GPTQ process.
|
123
127
|
gptq_quantizer_params_override: A dictionary of parameters to override in GPTQ quantizer instantiation.
|
124
128
|
"""
|
125
129
|
n_epochs: int
|
130
|
+
loss: Callable
|
126
131
|
optimizer: Any
|
127
|
-
optimizer_rest: Any
|
128
|
-
|
129
|
-
|
130
|
-
|
132
|
+
optimizer_rest: Any
|
133
|
+
train_bias: bool
|
134
|
+
hessian_weights_config: Optional[GPTQHessianScoresConfig]
|
135
|
+
gradual_activation_quantization_config: Optional[GradualActivationQuantizationConfig]
|
136
|
+
regularization_factor: float
|
131
137
|
rounding_type: RoundingType = RoundingType.SoftQuantizer
|
132
|
-
use_hessian_based_weights: bool = True
|
133
138
|
optimizer_quantization_parameter: Any = None
|
134
139
|
optimizer_bias: Any = None
|
135
|
-
|
136
|
-
hessian_weights_config: GPTQHessianScoresConfig = field(default_factory=GPTQHessianScoresConfig)
|
137
|
-
gradual_activation_quantization_config: Optional[GradualActivationQuantizationConfig] = None
|
140
|
+
log_function: Callable = None
|
138
141
|
gptq_quantizer_params_override: Dict[str, Any] = field(default_factory=dict)
|
@@ -14,6 +14,7 @@ N_CYCLES = 4
|
|
14
14
|
MIM_TEMP = 0.5
|
15
15
|
MAX_TEMP = 1.0
|
16
16
|
REG_DEFAULT = 0.01
|
17
|
+
REG_DEFAULT_SLA = 10
|
17
18
|
MAX_LSB_CHANGE = 1
|
18
19
|
|
19
20
|
# Soft rounding arguments values
|
@@ -27,6 +28,5 @@ MAX_LSB_STR = 'max_lsbs_change_map'
|
|
27
28
|
# GPTQ learning hyperparameters
|
28
29
|
LR_DEFAULT = 3e-2
|
29
30
|
LR_REST_DEFAULT = 1e-4
|
30
|
-
LR_BIAS_DEFAULT = 1e-
|
31
|
-
LR_QUANTIZATION_PARAM_DEFAULT = 1e-3
|
31
|
+
LR_BIAS_DEFAULT = 1e-4
|
32
32
|
GPTQ_MOMENTUM = 0.9
|
@@ -75,7 +75,7 @@ class GPTQTrainer(ABC):
|
|
75
75
|
fw_info=self.fw_info)
|
76
76
|
|
77
77
|
self.fxp_model, self.gptq_user_info = self.build_gptq_model()
|
78
|
-
if self.gptq_config.
|
78
|
+
if self.gptq_config.hessian_weights_config:
|
79
79
|
if not isinstance(hessian_info_service, HessianInfoService):
|
80
80
|
Logger.critical(f"When using Hessian-based approximations for sensitivity evaluation, "
|
81
81
|
f"an 'HessianInfoService' object must be provided, but received: {hessian_info_service}.") # pragma: no cover
|
@@ -139,7 +139,7 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
139
139
|
|
140
140
|
def _get_compare_points_loss_weights(self):
|
141
141
|
""" Get compare points weights for the distillation loss. """
|
142
|
-
if self.gptq_config.
|
142
|
+
if self.gptq_config.hessian_weights_config:
|
143
143
|
hess_dataloader = data_gen_to_dataloader(self.representative_data_gen_fn,
|
144
144
|
batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size)
|
145
145
|
return self.compute_hessian_based_weights(hess_dataloader)
|
@@ -21,7 +21,7 @@ from model_compression_toolkit.core.common.visualization.tensorboard_writer impo
|
|
21
21
|
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR_DEFAULT, LR_REST_DEFAULT, \
|
22
22
|
LR_BIAS_DEFAULT, GPTQ_MOMENTUM
|
23
23
|
from model_compression_toolkit.logger import Logger
|
24
|
-
from model_compression_toolkit.constants import TENSORFLOW, ACT_HESSIAN_DEFAULT_BATCH_SIZE
|
24
|
+
from model_compression_toolkit.constants import TENSORFLOW, ACT_HESSIAN_DEFAULT_BATCH_SIZE, GPTQ_HESSIAN_NUM_SAMPLES
|
25
25
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
26
26
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
27
27
|
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GPTQHessianScoresConfig, \
|
@@ -117,16 +117,20 @@ if FOUND_TF:
|
|
117
117
|
raise TypeError(f'gradual_activation_quantization argument should be bool or '
|
118
118
|
f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}')
|
119
119
|
|
120
|
-
|
121
|
-
|
120
|
+
hessian_weights_config = None
|
121
|
+
if use_hessian_based_weights:
|
122
|
+
hessian_weights_config = GPTQHessianScoresConfig(per_sample=False,
|
123
|
+
hessians_num_samples=GPTQ_HESSIAN_NUM_SAMPLES,
|
124
|
+
hessian_batch_size=hessian_batch_size)
|
125
|
+
return GradientPTQConfig(n_epochs=n_epochs,
|
126
|
+
optimizer=optimizer,
|
122
127
|
optimizer_rest=optimizer_rest,
|
123
128
|
loss=loss,
|
124
129
|
log_function=log_function,
|
125
130
|
train_bias=True,
|
126
131
|
optimizer_bias=bias_optimizer,
|
127
|
-
use_hessian_based_weights=use_hessian_based_weights,
|
128
132
|
regularization_factor=regularization_factor,
|
129
|
-
hessian_weights_config=
|
133
|
+
hessian_weights_config=hessian_weights_config,
|
130
134
|
gradual_activation_quantization_config=gradual_quant_config)
|
131
135
|
|
132
136
|
|
@@ -116,7 +116,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
116
116
|
trainable_threshold)
|
117
117
|
hessian_cfg = self.gptq_config.hessian_weights_config
|
118
118
|
|
119
|
-
self.use_sample_layer_attention = hessian_cfg.per_sample
|
119
|
+
self.use_sample_layer_attention = hessian_cfg and hessian_cfg.per_sample
|
120
120
|
if self.use_sample_layer_attention:
|
121
121
|
# normalization is currently not supported, make sure the config reflects it.
|
122
122
|
if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm:
|
@@ -178,7 +178,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
178
178
|
dataset = IterableDatasetFromGenerator(data_gen_fn)
|
179
179
|
num_nodes = len(self.compare_points)
|
180
180
|
|
181
|
-
if self.gptq_config.
|
181
|
+
if self.gptq_config.hessian_weights_config:
|
182
182
|
hess_dataloader = DataLoader(dataset, batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size)
|
183
183
|
loss_weights = torch.from_numpy(self.compute_hessian_based_weights(hess_dataloader))
|
184
184
|
else:
|
@@ -15,7 +15,7 @@
|
|
15
15
|
import copy
|
16
16
|
from typing import Callable, Union
|
17
17
|
|
18
|
-
from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, PYTORCH
|
18
|
+
from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, PYTORCH, GPTQ_HESSIAN_NUM_SAMPLES
|
19
19
|
from model_compression_toolkit.core import CoreConfig
|
20
20
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
21
21
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
|
@@ -27,7 +27,7 @@ from model_compression_toolkit.core.runner import core_runner
|
|
27
27
|
from model_compression_toolkit.gptq.common.gptq_config import (
|
28
28
|
GradientPTQConfig, GPTQHessianScoresConfig, GradualActivationQuantizationConfig)
|
29
29
|
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR_DEFAULT, LR_REST_DEFAULT, \
|
30
|
-
LR_BIAS_DEFAULT, GPTQ_MOMENTUM
|
30
|
+
LR_BIAS_DEFAULT, GPTQ_MOMENTUM, REG_DEFAULT_SLA
|
31
31
|
from model_compression_toolkit.gptq.runner import gptq_runner
|
32
32
|
from model_compression_toolkit.logger import Logger
|
33
33
|
from model_compression_toolkit.metadata import create_model_metadata
|
@@ -55,10 +55,10 @@ if FOUND_TORCH:
|
|
55
55
|
loss: Callable = None,
|
56
56
|
log_function: Callable = None,
|
57
57
|
use_hessian_based_weights: bool = True,
|
58
|
-
regularization_factor: float =
|
58
|
+
regularization_factor: float = None,
|
59
59
|
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE,
|
60
|
-
use_hessian_sample_attention: bool =
|
61
|
-
gradual_activation_quantization: Union[bool, GradualActivationQuantizationConfig] =
|
60
|
+
use_hessian_sample_attention: bool = True,
|
61
|
+
gradual_activation_quantization: Union[bool, GradualActivationQuantizationConfig] = True,
|
62
62
|
) -> GradientPTQConfig:
|
63
63
|
"""
|
64
64
|
Create a GradientPTQConfig instance for Pytorch models.
|
@@ -94,25 +94,26 @@ if FOUND_TORCH:
|
|
94
94
|
"""
|
95
95
|
optimizer = optimizer or Adam([torch.Tensor([])], lr=LR_DEFAULT)
|
96
96
|
optimizer_rest = optimizer_rest or Adam([torch.Tensor([])], lr=LR_REST_DEFAULT)
|
97
|
-
|
97
|
+
# TODO this contradicts the docstring for optimizer_rest
|
98
98
|
bias_optimizer = torch.optim.SGD([torch.Tensor([])], lr=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
|
99
99
|
|
100
|
+
if regularization_factor is None:
|
101
|
+
regularization_factor = REG_DEFAULT_SLA if use_hessian_sample_attention else REG_DEFAULT
|
102
|
+
|
103
|
+
loss = loss or multiple_tensors_mse_loss
|
104
|
+
hessian_weights_config = None
|
100
105
|
if use_hessian_sample_attention:
|
101
106
|
if not use_hessian_based_weights: # pragma: no cover
|
102
107
|
raise ValueError('use_hessian_based_weights must be set to True in order to use Sample Layer Attention.')
|
103
108
|
|
104
|
-
hessian_weights_config = GPTQHessianScoresConfig(
|
105
|
-
|
106
|
-
|
107
|
-
log_norm=False,
|
108
|
-
scale_log_norm=False,
|
109
|
-
hessian_batch_size=hessian_batch_size,
|
110
|
-
per_sample=True,
|
111
|
-
)
|
109
|
+
hessian_weights_config = GPTQHessianScoresConfig(per_sample=True,
|
110
|
+
hessians_num_samples=None,
|
111
|
+
hessian_batch_size=hessian_batch_size)
|
112
112
|
loss = loss or sample_layer_attention_loss
|
113
|
-
|
114
|
-
hessian_weights_config = GPTQHessianScoresConfig(
|
115
|
-
|
113
|
+
elif use_hessian_based_weights:
|
114
|
+
hessian_weights_config = GPTQHessianScoresConfig(per_sample=False,
|
115
|
+
hessians_num_samples=GPTQ_HESSIAN_NUM_SAMPLES,
|
116
|
+
hessian_batch_size=hessian_batch_size)
|
116
117
|
|
117
118
|
if isinstance(gradual_activation_quantization, bool):
|
118
119
|
gradual_quant_config = GradualActivationQuantizationConfig() if gradual_activation_quantization else None
|
@@ -122,12 +123,16 @@ if FOUND_TORCH:
|
|
122
123
|
raise TypeError(f'gradual_activation_quantization argument should be bool or '
|
123
124
|
f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}')
|
124
125
|
|
125
|
-
return GradientPTQConfig(n_epochs
|
126
|
-
|
127
|
-
|
126
|
+
return GradientPTQConfig(n_epochs=n_epochs,
|
127
|
+
loss=loss,
|
128
|
+
optimizer=optimizer,
|
129
|
+
optimizer_rest=optimizer_rest,
|
130
|
+
optimizer_bias=bias_optimizer,
|
131
|
+
train_bias=True,
|
128
132
|
regularization_factor=regularization_factor,
|
129
133
|
hessian_weights_config=hessian_weights_config,
|
130
|
-
gradual_activation_quantization_config=gradual_quant_config
|
134
|
+
gradual_activation_quantization_config=gradual_quant_config,
|
135
|
+
log_function=log_function)
|
131
136
|
|
132
137
|
def pytorch_gradient_post_training_quantization(model: Module,
|
133
138
|
representative_data_gen: Callable,
|
{mct_nightly-2.2.0.20241111.513.dist-info → mct_nightly-2.2.0.20241113.521.dist-info}/LICENSE.md
RENAMED
File without changes
|
File without changes
|
{mct_nightly-2.2.0.20241111.513.dist-info → mct_nightly-2.2.0.20241113.521.dist-info}/top_level.txt
RENAMED
File without changes
|