mct-nightly 1.11.0.20240305.post352__py3-none-any.whl → 1.11.0.20240306.post426__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.11.0.20240305.post352.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/METADATA +4 -4
- {mct_nightly-1.11.0.20240305.post352.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/RECORD +19 -19
- model_compression_toolkit/gptq/__init__.py +1 -1
- model_compression_toolkit/gptq/common/gptq_config.py +5 -72
- model_compression_toolkit/gptq/keras/gptq_training.py +2 -2
- model_compression_toolkit/gptq/keras/quantization_facade.py +15 -29
- model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py +3 -3
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +2 -4
- model_compression_toolkit/gptq/pytorch/gptq_training.py +2 -2
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +11 -28
- model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py +3 -3
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +2 -4
- model_compression_toolkit/gptq/runner.py +3 -3
- model_compression_toolkit/ptq/__init__.py +2 -2
- model_compression_toolkit/ptq/keras/quantization_facade.py +9 -23
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +9 -24
- {mct_nightly-1.11.0.20240305.post352.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.11.0.20240305.post352.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/WHEEL +0 -0
- {mct_nightly-1.11.0.20240305.post352.dist-info → mct_nightly-1.11.0.20240306.post426.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: mct-nightly
|
|
3
|
-
Version: 1.11.0.
|
|
3
|
+
Version: 1.11.0.20240306.post426
|
|
4
4
|
Summary: A Model Compression Toolkit for neural networks
|
|
5
5
|
Home-page: UNKNOWN
|
|
6
6
|
License: UNKNOWN
|
|
@@ -173,10 +173,9 @@ This pruning technique is designed to compress models for specific hardware arch
|
|
|
173
173
|
taking into account the target platform's Single Instruction, Multiple Data (SIMD) capabilities.
|
|
174
174
|
By pruning groups of channels (SIMD groups), our approach not only reduces model size
|
|
175
175
|
and complexity, but ensures that better utilization of channels is in line with the SIMD architecture
|
|
176
|
-
for a target KPI of weights memory footprint.
|
|
176
|
+
for a target KPI of weights memory footprint.
|
|
177
177
|
[Keras API](https://sony.github.io/model_optimization/docs/api/experimental_api_docs/methods/keras_pruning_experimental.html)
|
|
178
|
-
[Pytorch API](https://
|
|
179
|
-
|
|
178
|
+
[Pytorch API](https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/pruning/pytorch/pruning_facade.py#L43)
|
|
180
179
|
|
|
181
180
|
#### Results
|
|
182
181
|
|
|
@@ -209,3 +208,4 @@ MCT aims at keeping a more up-to-date fork and welcomes contributions from anyon
|
|
|
209
208
|
|
|
210
209
|
[4] Gordon, O., Habi, H. V., & Netzer, A., 2023. [EPTQ: Enhanced Post-Training Quantization via Label-Free Hessian. arXiv preprint](https://arxiv.org/abs/2309.11531)
|
|
211
210
|
|
|
211
|
+
|
|
@@ -328,10 +328,10 @@ model_compression_toolkit/exporter/model_wrapper/pytorch/builder/__init__.py,sha
|
|
|
328
328
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py,sha256=SJ5fetbUMkmB0tkHkmVhMrLksh7eqMQJLFuMD08ZKWM,3921
|
|
329
329
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py,sha256=y7VfnsSbiKhEt5I_W7GHsuRNYLW-HgYmAiVy_rg65SI,8724
|
|
330
330
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizers.py,sha256=hinP-wtyxZyoW860GdJAk6M3iPjmwwPXQTUxd56yhq8,2086
|
|
331
|
-
model_compression_toolkit/gptq/__init__.py,sha256=
|
|
332
|
-
model_compression_toolkit/gptq/runner.py,sha256=
|
|
331
|
+
model_compression_toolkit/gptq/__init__.py,sha256=YfSaHsSmijG0cUUDv6_lqmiQR9wsw7kjltoHx7ybNMM,1247
|
|
332
|
+
model_compression_toolkit/gptq/runner.py,sha256=MIg-oBtR1nbHkexySdCJD_XfjRoHSknLotmGBMuD5qM,5924
|
|
333
333
|
model_compression_toolkit/gptq/common/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
334
|
-
model_compression_toolkit/gptq/common/gptq_config.py,sha256=
|
|
334
|
+
model_compression_toolkit/gptq/common/gptq_config.py,sha256=eRtAmBKoq4w8D5LM075HJjotC831oGJxG-q4lCIc9Uc,5496
|
|
335
335
|
model_compression_toolkit/gptq/common/gptq_constants.py,sha256=QSm6laLkIV0LYmU0BLtmKp3Fi3SqDfbncFQWOGA1cGU,611
|
|
336
336
|
model_compression_toolkit/gptq/common/gptq_framework_implementation.py,sha256=n3mSf4J92kFjekzyGyrJULylI-8Jf5OVWJ5AFoVnEx0,1266
|
|
337
337
|
model_compression_toolkit/gptq/common/gptq_graph.py,sha256=LfxpkMJb87h1NF1q4HoC88wA_0MW-B820alpiuZpZFo,2826
|
|
@@ -339,14 +339,14 @@ model_compression_toolkit/gptq/common/gptq_training.py,sha256=WMeizYxnb7qNZ_-oVD
|
|
|
339
339
|
model_compression_toolkit/gptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
340
340
|
model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCSjq5xk-xGymOwSOqjp39It-CVtGcCTRTf0E_4,1248
|
|
341
341
|
model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
|
|
342
|
-
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=
|
|
342
|
+
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=pd3_5N8uXV6PipjEV6HxzKbnsS9jRUnUVcl05U-XyKk,17701
|
|
343
343
|
model_compression_toolkit/gptq/keras/graph_info.py,sha256=FIGqzJbG6GkdHenvdMu-tGTjp4j9BewdF_spmWCb4Mo,4627
|
|
344
|
-
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=
|
|
344
|
+
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=Tyzd-fOpjkkA1XMXCqIJeuteIk1q4HcCzUiZQXbFU2Y,14283
|
|
345
345
|
model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
|
|
346
346
|
model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=8NrJBftKFbMAF_jYaAbLP6GBwpCv3Ln1NKURaV75zko,4770
|
|
347
347
|
model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
|
|
348
|
-
model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py,sha256=
|
|
349
|
-
model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py,sha256=
|
|
348
|
+
model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py,sha256=SyBQcKIUUXt4yQnfCUlFB9uCrqkM_RnLRsge3MmdxPE,4423
|
|
349
|
+
model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py,sha256=7kvQQz2zHTRkIzJpsOPe8PWtfsOpcGZ2hjVIxbc-qJo,1906
|
|
350
350
|
model_compression_toolkit/gptq/keras/quantizer/soft_rounding/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
|
351
351
|
model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=qUuMKysUpjWYjNbchFuyb_UFwzV1HL7R3Y7o0Z5rf60,4016
|
|
352
352
|
model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py,sha256=BBSDWLmeywjSM5N6oJkMgcuo7zrXTesB4zLwRGG8QB0,12159
|
|
@@ -356,14 +356,14 @@ model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py,sha
|
|
|
356
356
|
model_compression_toolkit/gptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
357
357
|
model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=kDuWw-6zh17wZpYWh4Xa94rpoodf82DksgjQCnL7nBc,2719
|
|
358
358
|
model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
|
|
359
|
-
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=
|
|
359
|
+
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=SBTwdIxeNR1_8RJjysSuGxaKJ3cB2cbwjfa7Ah1BeY4,15074
|
|
360
360
|
model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=-0GDC2cr-XXS7cTFTnDflJivGN7VaPnzVPsxCE-vZNU,3955
|
|
361
|
-
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=
|
|
361
|
+
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=zrHs-cQua7GlTW6k_DmaLODvfDIaSrfMlkt6RyNgjYU,12278
|
|
362
362
|
model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
|
|
363
363
|
model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=Zb-P0yRyZHHBlDvUBdRwxDpdduEJyJp6OT9pfKFF5ks,4171
|
|
364
364
|
model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
|
|
365
|
-
model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py,sha256=
|
|
366
|
-
model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py,sha256
|
|
365
|
+
model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py,sha256=5-RCtQWeD8v_s4Ba9ULqdzm6ISkwAyYA8VCaPeKc_k4,4283
|
|
366
|
+
model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py,sha256=-6fn6U6y2HZXluOfShYLeFKiuiDMVvsF64OTUDCrne4,1908
|
|
367
367
|
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
|
|
368
368
|
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=oO7WgsAHMnWoXNm_gTKAAe-Nd79mGL_m677ai-ui424,4132
|
|
369
369
|
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py,sha256=kLVQC1hXzDpP4Jx7AwnA764oGnY5AMEuvUUhAvhz09M,12347
|
|
@@ -375,12 +375,12 @@ model_compression_toolkit/pruning/keras/__init__.py,sha256=3Lkr37Exk9u8811hw8hVq
|
|
|
375
375
|
model_compression_toolkit/pruning/keras/pruning_facade.py,sha256=PHKZYBHrVyR348-a6gw44NrV8Ra9iaeFJ0WbWYpzX8k,8020
|
|
376
376
|
model_compression_toolkit/pruning/pytorch/__init__.py,sha256=pKAdbTCFM_2BrZXUtTIw0ouKotrWwUDF_hP3rPwCM2k,696
|
|
377
377
|
model_compression_toolkit/pruning/pytorch/pruning_facade.py,sha256=hO13kbCIyoAgVW6hRoJCUpHr6ThOuqz1Vkg-cftJY5k,8906
|
|
378
|
-
model_compression_toolkit/ptq/__init__.py,sha256=
|
|
378
|
+
model_compression_toolkit/ptq/__init__.py,sha256=Z_hkmTh7aLFei1DJKV0oNVUbrv_Q_0CTw-qD85Xf8UM,904
|
|
379
379
|
model_compression_toolkit/ptq/runner.py,sha256=_c1dSjlPPpsx59Vbg1buhG9bZq__OORz1VlPkwjJzoc,2552
|
|
380
380
|
model_compression_toolkit/ptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
381
|
-
model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=
|
|
381
|
+
model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=ergUI8RDA2h4_SHU05x2pYJatt-U-fZUrShdHJDLo_o,8844
|
|
382
382
|
model_compression_toolkit/ptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
383
|
-
model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=
|
|
383
|
+
model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=WKzokgg_gGcEHipVH26shneiAiTdSa7d_UUQKoS8ALY,7438
|
|
384
384
|
model_compression_toolkit/qat/__init__.py,sha256=BYKgH1NwB9fqF1TszULQ5tDfLI-GqgZV5sao-lDN9EM,1091
|
|
385
385
|
model_compression_toolkit/qat/common/__init__.py,sha256=6tLZ4R4pYP6QVztLVQC_jik2nES3l4uhML0qUxZrezk,829
|
|
386
386
|
model_compression_toolkit/qat/common/qat_config.py,sha256=kbSxFL6_u28furq5mW_75STWDmyX4clPt-seJAnX3IQ,3445
|
|
@@ -472,8 +472,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
|
|
|
472
472
|
model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
|
|
473
473
|
model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
|
474
474
|
model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=SbvRlIdE32PEBsINt1bhSqvrKL_zbM9V-aeSkOn-sw4,3083
|
|
475
|
-
mct_nightly-1.11.0.
|
|
476
|
-
mct_nightly-1.11.0.
|
|
477
|
-
mct_nightly-1.11.0.
|
|
478
|
-
mct_nightly-1.11.0.
|
|
479
|
-
mct_nightly-1.11.0.
|
|
475
|
+
mct_nightly-1.11.0.20240306.post426.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
|
476
|
+
mct_nightly-1.11.0.20240306.post426.dist-info/METADATA,sha256=VmLquWXvg3pTvowSn7Wl_OvRbk762Yz0WSw1ohQYBfM,17379
|
|
477
|
+
mct_nightly-1.11.0.20240306.post426.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
|
478
|
+
mct_nightly-1.11.0.20240306.post426.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
|
479
|
+
mct_nightly-1.11.0.20240306.post426.dist-info/RECORD,,
|
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType,
|
|
16
|
+
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType, GradientPTQConfig, GPTQHessianScoresConfig
|
|
17
17
|
from model_compression_toolkit.gptq.keras.quantization_facade import keras_gradient_post_training_quantization
|
|
18
18
|
from model_compression_toolkit.gptq.keras.quantization_facade import get_keras_gptq_config
|
|
19
19
|
from model_compression_toolkit.gptq.pytorch.quantization_facade import pytorch_gradient_post_training_quantization
|
|
@@ -61,8 +61,8 @@ class GradientPTQConfig:
|
|
|
61
61
|
"""
|
|
62
62
|
Configuration to use for quantization with GradientPTQ.
|
|
63
63
|
"""
|
|
64
|
-
|
|
65
|
-
|
|
64
|
+
def __init__(self,
|
|
65
|
+
n_epochs: int,
|
|
66
66
|
optimizer: Any,
|
|
67
67
|
optimizer_rest: Any = None,
|
|
68
68
|
loss: Callable = None,
|
|
@@ -79,7 +79,7 @@ class GradientPTQConfig:
|
|
|
79
79
|
Initialize a GradientPTQConfig.
|
|
80
80
|
|
|
81
81
|
Args:
|
|
82
|
-
|
|
82
|
+
n_epochs (int): Number of representative dataset epochs to train.
|
|
83
83
|
optimizer (Any): Optimizer to use.
|
|
84
84
|
optimizer_rest (Any): Optimizer to use for bias and quantizer parameters.
|
|
85
85
|
loss (Callable): The loss to use. should accept 6 lists of tensors. 1st list of quantized tensors, the 2nd list is the float tensors,
|
|
@@ -96,7 +96,8 @@ class GradientPTQConfig:
|
|
|
96
96
|
gptq_quantizer_params_override (dict): A dictionary of parameters to override in GPTQ quantizer instantiation. Defaults to None (no parameters).
|
|
97
97
|
|
|
98
98
|
"""
|
|
99
|
-
|
|
99
|
+
|
|
100
|
+
self.n_epochs = n_epochs
|
|
100
101
|
self.optimizer = optimizer
|
|
101
102
|
self.optimizer_rest = optimizer_rest
|
|
102
103
|
self.loss = loss
|
|
@@ -114,71 +115,3 @@ class GradientPTQConfig:
|
|
|
114
115
|
else gptq_quantizer_params_override
|
|
115
116
|
|
|
116
117
|
|
|
117
|
-
class GradientPTQConfigV2(GradientPTQConfig):
|
|
118
|
-
"""
|
|
119
|
-
Configuration to use for quantization with GradientPTQV2.
|
|
120
|
-
"""
|
|
121
|
-
def __init__(self, n_epochs: int,
|
|
122
|
-
optimizer: Any,
|
|
123
|
-
optimizer_rest: Any = None,
|
|
124
|
-
loss: Callable = None,
|
|
125
|
-
log_function: Callable = None,
|
|
126
|
-
train_bias: bool = True,
|
|
127
|
-
rounding_type: RoundingType = RoundingType.SoftQuantizer,
|
|
128
|
-
use_hessian_based_weights: bool = True,
|
|
129
|
-
optimizer_quantization_parameter: Any = None,
|
|
130
|
-
optimizer_bias: Any = None,
|
|
131
|
-
regularization_factor: float = REG_DEFAULT,
|
|
132
|
-
hessian_weights_config: GPTQHessianScoresConfig = GPTQHessianScoresConfig(),
|
|
133
|
-
gptq_quantizer_params_override: Dict[str, Any] = None):
|
|
134
|
-
"""
|
|
135
|
-
Initialize a GradientPTQConfigV2.
|
|
136
|
-
|
|
137
|
-
Args:
|
|
138
|
-
n_epochs (int): Number of representative dataset epochs to train.
|
|
139
|
-
optimizer (Any): Optimizer to use.
|
|
140
|
-
optimizer_rest (Any): Optimizer to use for bias and quantizer parameters.
|
|
141
|
-
loss (Callable): The loss to use. should accept 6 lists of tensors. 1st list of quantized tensors, the 2nd list is the float tensors,
|
|
142
|
-
the 3rd is a list of quantized weights, the 4th is a list of float weights, the 5th and 6th lists are the mean and std of the tensors
|
|
143
|
-
accordingly. see example in multiple_tensors_mse_loss
|
|
144
|
-
log_function (Callable): Function to log information about the GPTQ process.
|
|
145
|
-
train_bias (bool): Whether to update the bias during the training or not.
|
|
146
|
-
rounding_type (RoundingType): An enum that defines the rounding type.
|
|
147
|
-
use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
|
|
148
|
-
optimizer_quantization_parameter (Any): Optimizer to override the rest optimizer for quantizer parameters.
|
|
149
|
-
optimizer_bias (Any): Optimizer to override the rest optimizerfor bias.
|
|
150
|
-
regularization_factor (float): A floating point number that defines the regularization factor.
|
|
151
|
-
hessian_weights_config (GPTQHessianScoresConfig): A configuration that include all necessary arguments to run a computation of Hessian scores for the GPTQ loss.
|
|
152
|
-
gptq_quantizer_params_override (dict): A dictionary of parameters to override in GPTQ quantizer instantiation. Defaults to None (no parameters).
|
|
153
|
-
|
|
154
|
-
"""
|
|
155
|
-
|
|
156
|
-
super().__init__(n_iter=None,
|
|
157
|
-
optimizer=optimizer,
|
|
158
|
-
optimizer_rest=optimizer_rest,
|
|
159
|
-
loss=loss,
|
|
160
|
-
log_function=log_function,
|
|
161
|
-
train_bias=train_bias,
|
|
162
|
-
rounding_type=rounding_type,
|
|
163
|
-
use_hessian_based_weights=use_hessian_based_weights,
|
|
164
|
-
optimizer_quantization_parameter=optimizer_quantization_parameter,
|
|
165
|
-
optimizer_bias=optimizer_bias,
|
|
166
|
-
regularization_factor=regularization_factor,
|
|
167
|
-
hessian_weights_config=hessian_weights_config,
|
|
168
|
-
gptq_quantizer_params_override=gptq_quantizer_params_override)
|
|
169
|
-
self.n_epochs = n_epochs
|
|
170
|
-
|
|
171
|
-
@classmethod
|
|
172
|
-
def from_v1(cls, n_ptq_iter: int, config_v1: GradientPTQConfig):
|
|
173
|
-
"""
|
|
174
|
-
Initialize a GradientPTQConfigV2 from GradientPTQConfig instance.
|
|
175
|
-
|
|
176
|
-
Args:
|
|
177
|
-
n_ptq_iter (int): Number of PTQ calibration iters (length of representative dataset).
|
|
178
|
-
config_v1 (GradientPTQConfig): A GPTQ config to convert to V2.
|
|
179
|
-
|
|
180
|
-
"""
|
|
181
|
-
n_epochs = int(round(config_v1.n_iter) / n_ptq_iter)
|
|
182
|
-
v1_params = config_v1.__dict__
|
|
183
|
-
v1_params = {k: v for k, v in v1_params.items() if k != 'n_iter'}
|
|
184
|
-
return cls(n_epochs, **v1_params)
|
|
@@ -37,7 +37,7 @@ else:
|
|
|
37
37
|
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
|
|
38
38
|
from model_compression_toolkit.core import common
|
|
39
39
|
from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
|
|
40
|
-
from model_compression_toolkit.gptq.common.gptq_config import
|
|
40
|
+
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
41
41
|
from model_compression_toolkit.core.common import Graph
|
|
42
42
|
from model_compression_toolkit.gptq.keras.graph_info import get_weights_for_loss, get_gptq_trainable_parameters
|
|
43
43
|
from model_compression_toolkit.gptq.keras.quantizer.regularization_factory import get_regularization
|
|
@@ -56,7 +56,7 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
56
56
|
def __init__(self,
|
|
57
57
|
graph_float: Graph,
|
|
58
58
|
graph_quant: Graph,
|
|
59
|
-
gptq_config:
|
|
59
|
+
gptq_config: GradientPTQConfig,
|
|
60
60
|
fw_impl: FrameworkImplementation,
|
|
61
61
|
fw_info: FrameworkInfo,
|
|
62
62
|
representative_data_gen: Callable,
|
|
@@ -21,7 +21,7 @@ from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
|
|
|
21
21
|
from model_compression_toolkit.logger import Logger
|
|
22
22
|
from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
|
|
23
23
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
|
24
|
-
from model_compression_toolkit.gptq.common.gptq_config import
|
|
24
|
+
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
25
25
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
26
26
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
27
27
|
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
|
|
@@ -66,7 +66,7 @@ if FOUND_TF:
|
|
|
66
66
|
loss: Callable = GPTQMultipleTensorsLoss(),
|
|
67
67
|
log_function: Callable = None,
|
|
68
68
|
use_hessian_based_weights: bool = True,
|
|
69
|
-
regularization_factor: float = REG_DEFAULT) ->
|
|
69
|
+
regularization_factor: float = REG_DEFAULT) -> GradientPTQConfig:
|
|
70
70
|
"""
|
|
71
71
|
Create a GradientPTQConfigV2 instance for Keras models.
|
|
72
72
|
|
|
@@ -102,26 +102,25 @@ if FOUND_TF:
|
|
|
102
102
|
"""
|
|
103
103
|
bias_optimizer = tf.keras.optimizers.SGD(learning_rate=LR_BIAS_DEFAULT,
|
|
104
104
|
momentum=GPTQ_MOMENTUM)
|
|
105
|
-
return
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
105
|
+
return GradientPTQConfig(n_epochs,
|
|
106
|
+
optimizer,
|
|
107
|
+
optimizer_rest=optimizer_rest,
|
|
108
|
+
loss=loss,
|
|
109
|
+
log_function=log_function,
|
|
110
|
+
train_bias=True,
|
|
111
|
+
optimizer_bias=bias_optimizer,
|
|
112
|
+
use_hessian_based_weights=use_hessian_based_weights,
|
|
113
|
+
regularization_factor=regularization_factor)
|
|
114
114
|
|
|
115
115
|
|
|
116
116
|
def keras_gradient_post_training_quantization(in_model: Model,
|
|
117
117
|
representative_data_gen: Callable,
|
|
118
|
-
gptq_config:
|
|
118
|
+
gptq_config: GradientPTQConfig,
|
|
119
119
|
gptq_representative_data_gen: Callable = None,
|
|
120
120
|
target_kpi: KPI = None,
|
|
121
121
|
core_config: CoreConfig = CoreConfig(),
|
|
122
122
|
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
|
|
123
|
-
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC,
|
|
124
|
-
new_experimental_exporter: bool = True) -> Tuple[Model, UserInformation]:
|
|
123
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
|
|
125
124
|
"""
|
|
126
125
|
Quantize a trained Keras model using post-training quantization. The model is quantized using a
|
|
127
126
|
symmetric constraint quantization thresholds (power of two).
|
|
@@ -141,13 +140,12 @@ if FOUND_TF:
|
|
|
141
140
|
Args:
|
|
142
141
|
in_model (Model): Keras model to quantize.
|
|
143
142
|
representative_data_gen (Callable): Dataset used for calibration.
|
|
144
|
-
gptq_config (
|
|
143
|
+
gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
|
|
145
144
|
gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
|
|
146
145
|
target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
|
|
147
146
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
|
|
148
147
|
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/keras/default_framework_info.py>`_
|
|
149
148
|
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
|
|
150
|
-
new_experimental_exporter (bool): Whether to wrap the quantized model using quantization information or not. Enabled by default. Experimental and subject to future changes.
|
|
151
149
|
|
|
152
150
|
Returns:
|
|
153
151
|
|
|
@@ -232,19 +230,7 @@ if FOUND_TF:
|
|
|
232
230
|
if core_config.debug_config.analyze_similarity:
|
|
233
231
|
analyzer_model_quantization(representative_data_gen, tb_w, tg_gptq, fw_impl, fw_info)
|
|
234
232
|
|
|
235
|
-
|
|
236
|
-
Logger.warning('Using new experimental wrapped and ready for export models. To '
|
|
237
|
-
'disable it, please set new_experimental_exporter to False when '
|
|
238
|
-
'calling keras_gradient_post_training_quantization. '
|
|
239
|
-
'If you encounter an issue please file a bug.')
|
|
240
|
-
|
|
241
|
-
return get_exportable_keras_model(tg_gptq)
|
|
242
|
-
|
|
243
|
-
return export_model(tg_gptq,
|
|
244
|
-
fw_info,
|
|
245
|
-
fw_impl,
|
|
246
|
-
tb_w,
|
|
247
|
-
bit_widths_config)
|
|
233
|
+
return get_exportable_keras_model(tg_gptq)
|
|
248
234
|
|
|
249
235
|
else:
|
|
250
236
|
# If tensorflow is not installed,
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Dict, List, Tuple
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.gptq import
|
|
17
|
+
from model_compression_toolkit.gptq import GradientPTQConfig
|
|
18
18
|
from model_compression_toolkit.core import common
|
|
19
19
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
|
20
20
|
from model_compression_toolkit.exporter.model_wrapper.keras.builder.node_to_quantizer import \
|
|
@@ -33,7 +33,7 @@ from model_compression_toolkit.trainable_infrastructure.common.get_quantizers im
|
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
def quantization_builder(n: common.BaseNode,
|
|
36
|
-
gptq_config:
|
|
36
|
+
gptq_config: GradientPTQConfig
|
|
37
37
|
) -> Tuple[Dict[str, BaseKerasGPTQTrainableQuantizer], List[BaseKerasInferableQuantizer]]:
|
|
38
38
|
"""
|
|
39
39
|
Build quantizers for a node according to its quantization configuration and
|
|
@@ -41,7 +41,7 @@ def quantization_builder(n: common.BaseNode,
|
|
|
41
41
|
|
|
42
42
|
Args:
|
|
43
43
|
n: Node to build its QuantizeConfig.
|
|
44
|
-
gptq_config (
|
|
44
|
+
gptq_config (GradientPTQConfig): GradientPTQConfigV2 configuration.
|
|
45
45
|
|
|
46
46
|
Returns:
|
|
47
47
|
A dictionary which maps the weights kernel attribute to a quantizer for GPTQ training.
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Callable
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.gptq import RoundingType,
|
|
17
|
+
from model_compression_toolkit.gptq import RoundingType, GradientPTQConfig, GradientPTQConfig
|
|
18
18
|
from model_compression_toolkit.gptq.keras.quantizer.soft_rounding.soft_quantizer_reg import \
|
|
19
19
|
SoftQuantizerRegularization
|
|
20
20
|
|
|
@@ -38,8 +38,6 @@ def get_regularization(gptq_config: GradientPTQConfig, representative_data_gen:
|
|
|
38
38
|
for _ in representative_data_gen():
|
|
39
39
|
num_batches += 1
|
|
40
40
|
|
|
41
|
-
|
|
42
|
-
not type(gptq_config) == GradientPTQConfigV2 else gptq_config.n_epochs
|
|
43
|
-
return SoftQuantizerRegularization(total_gradient_steps=num_batches * n_epochs)
|
|
41
|
+
return SoftQuantizerRegularization(total_gradient_steps=num_batches * gptq_config.n_epochs)
|
|
44
42
|
else:
|
|
45
43
|
return lambda m, e_reg: 0
|
|
@@ -25,7 +25,7 @@ from model_compression_toolkit.logger import Logger
|
|
|
25
25
|
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
|
|
26
26
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
|
27
27
|
from model_compression_toolkit.gptq.common.gptq_training import GPTQTrainer
|
|
28
|
-
from model_compression_toolkit.gptq.common.gptq_config import
|
|
28
|
+
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
29
29
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
30
30
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
31
31
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
@@ -46,7 +46,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
46
46
|
def __init__(self,
|
|
47
47
|
graph_float: Graph,
|
|
48
48
|
graph_quant: Graph,
|
|
49
|
-
gptq_config:
|
|
49
|
+
gptq_config: GradientPTQConfig,
|
|
50
50
|
fw_impl: FrameworkImplementation,
|
|
51
51
|
fw_info: FrameworkInfo,
|
|
52
52
|
representative_data_gen: Callable,
|
|
@@ -19,7 +19,7 @@ from model_compression_toolkit.core.common.visualization.tensorboard_writer impo
|
|
|
19
19
|
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
|
|
20
20
|
from model_compression_toolkit.logger import Logger
|
|
21
21
|
from model_compression_toolkit.constants import PYTORCH
|
|
22
|
-
from model_compression_toolkit.gptq.common.gptq_config import
|
|
22
|
+
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
23
23
|
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
|
|
24
24
|
from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
|
|
25
25
|
from model_compression_toolkit.core.runner import core_runner
|
|
@@ -54,7 +54,7 @@ if FOUND_TORCH:
|
|
|
54
54
|
loss: Callable = multiple_tensors_mse_loss,
|
|
55
55
|
log_function: Callable = None,
|
|
56
56
|
use_hessian_based_weights: bool = True,
|
|
57
|
-
regularization_factor: float = REG_DEFAULT) ->
|
|
57
|
+
regularization_factor: float = REG_DEFAULT) -> GradientPTQConfig:
|
|
58
58
|
"""
|
|
59
59
|
Create a GradientPTQConfigV2 instance for Pytorch models.
|
|
60
60
|
|
|
@@ -86,21 +86,19 @@ if FOUND_TORCH:
|
|
|
86
86
|
|
|
87
87
|
"""
|
|
88
88
|
bias_optimizer = torch.optim.SGD([torch.Tensor([])], lr=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
|
|
89
|
-
return
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
89
|
+
return GradientPTQConfig(n_epochs, optimizer, optimizer_rest=optimizer_rest, loss=loss,
|
|
90
|
+
log_function=log_function, train_bias=True, optimizer_bias=bias_optimizer,
|
|
91
|
+
use_hessian_based_weights=use_hessian_based_weights,
|
|
92
|
+
regularization_factor=regularization_factor)
|
|
93
93
|
|
|
94
94
|
|
|
95
95
|
def pytorch_gradient_post_training_quantization(model: Module,
|
|
96
96
|
representative_data_gen: Callable,
|
|
97
97
|
target_kpi: KPI = None,
|
|
98
98
|
core_config: CoreConfig = CoreConfig(),
|
|
99
|
-
gptq_config:
|
|
99
|
+
gptq_config: GradientPTQConfig = None,
|
|
100
100
|
gptq_representative_data_gen: Callable = None,
|
|
101
|
-
target_platform_capabilities: TargetPlatformCapabilities =
|
|
102
|
-
DEFAULT_PYTORCH_TPC,
|
|
103
|
-
new_experimental_exporter: bool = True):
|
|
101
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
|
|
104
102
|
"""
|
|
105
103
|
Quantize a trained Pytorch module using post-training quantization.
|
|
106
104
|
By default, the module is quantized using a symmetric constraint quantization thresholds
|
|
@@ -122,10 +120,9 @@ if FOUND_TORCH:
|
|
|
122
120
|
representative_data_gen (Callable): Dataset used for calibration.
|
|
123
121
|
target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
|
|
124
122
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
|
|
125
|
-
gptq_config (
|
|
123
|
+
gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
|
|
126
124
|
gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
|
|
127
125
|
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
|
|
128
|
-
new_experimental_exporter (bool): Whether to wrap the quantized model using quantization information or not. Enabled by default. Experimental and subject to future changes.
|
|
129
126
|
|
|
130
127
|
Returns:
|
|
131
128
|
A quantized module and information the user may need to handle the quantized module.
|
|
@@ -194,22 +191,8 @@ if FOUND_TORCH:
|
|
|
194
191
|
if core_config.debug_config.analyze_similarity:
|
|
195
192
|
analyzer_model_quantization(representative_data_gen, tb_w, graph_gptq, fw_impl, DEFAULT_PYTORCH_INFO)
|
|
196
193
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
# ---------------------- #
|
|
200
|
-
if new_experimental_exporter:
|
|
201
|
-
Logger.warning('Using new experimental wrapped and ready for export models. To '
|
|
202
|
-
'disable it, please set new_experimental_exporter to False when '
|
|
203
|
-
'calling pytorch_gradient_post_training_quantization_experimental. '
|
|
204
|
-
'If you encounter an issue please file a bug.')
|
|
205
|
-
|
|
206
|
-
return get_exportable_pytorch_model(graph_gptq)
|
|
207
|
-
|
|
208
|
-
return export_model(graph_gptq,
|
|
209
|
-
DEFAULT_PYTORCH_INFO,
|
|
210
|
-
fw_impl,
|
|
211
|
-
tb_w,
|
|
212
|
-
bit_widths_config)
|
|
194
|
+
return get_exportable_pytorch_model(graph_gptq)
|
|
195
|
+
|
|
213
196
|
|
|
214
197
|
else:
|
|
215
198
|
# If torch is not installed,
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import List, Dict, Tuple
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.gptq import
|
|
17
|
+
from model_compression_toolkit.gptq import GradientPTQConfig
|
|
18
18
|
from model_compression_toolkit.core import common
|
|
19
19
|
from model_compression_toolkit.core.pytorch.constants import KERNEL
|
|
20
20
|
from model_compression_toolkit.exporter.model_wrapper.pytorch.builder.node_to_quantizer import \
|
|
@@ -34,7 +34,7 @@ from model_compression_toolkit.trainable_infrastructure.common.get_quantizers im
|
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
def quantization_builder(n: common.BaseNode,
|
|
37
|
-
gptq_config:
|
|
37
|
+
gptq_config: GradientPTQConfig,
|
|
38
38
|
) -> Tuple[Dict[str, BasePytorchQATTrainableQuantizer],
|
|
39
39
|
List[BasePyTorchInferableQuantizer]]:
|
|
40
40
|
"""
|
|
@@ -43,7 +43,7 @@ def quantization_builder(n: common.BaseNode,
|
|
|
43
43
|
|
|
44
44
|
Args:
|
|
45
45
|
n: Node to build its QuantizeConfig.
|
|
46
|
-
gptq_config (
|
|
46
|
+
gptq_config (GradientPTQConfig): GradientPTQConfigV2 configuration.
|
|
47
47
|
|
|
48
48
|
Returns:
|
|
49
49
|
A dictionary which maps the weights kernel attribute to a quantizer for GPTQ training.
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
from typing import Callable
|
|
16
16
|
|
|
17
|
-
from model_compression_toolkit.gptq import RoundingType,
|
|
17
|
+
from model_compression_toolkit.gptq import RoundingType, GradientPTQConfig, GradientPTQConfig
|
|
18
18
|
from model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.soft_quantizer_reg import \
|
|
19
19
|
SoftQuantizerRegularization
|
|
20
20
|
|
|
@@ -38,8 +38,6 @@ def get_regularization(gptq_config: GradientPTQConfig, representative_data_gen:
|
|
|
38
38
|
for _ in representative_data_gen():
|
|
39
39
|
num_batches += 1
|
|
40
40
|
|
|
41
|
-
|
|
42
|
-
not type(gptq_config) == GradientPTQConfigV2 else gptq_config.n_epochs
|
|
43
|
-
return SoftQuantizerRegularization(total_gradient_steps=num_batches * n_epochs)
|
|
41
|
+
return SoftQuantizerRegularization(total_gradient_steps=num_batches * gptq_config.n_epochs)
|
|
44
42
|
else:
|
|
45
43
|
return lambda m, e_reg: 0
|
|
@@ -20,7 +20,7 @@ from model_compression_toolkit.core import common
|
|
|
20
20
|
from model_compression_toolkit.core.common.hessian import HessianInfoService
|
|
21
21
|
from model_compression_toolkit.core.common.statistics_correction.statistics_correction import \
|
|
22
22
|
apply_statistics_correction
|
|
23
|
-
from model_compression_toolkit.gptq.common.gptq_config import
|
|
23
|
+
from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
24
24
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
25
25
|
from model_compression_toolkit.core.common import FrameworkInfo
|
|
26
26
|
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
@@ -32,7 +32,7 @@ from model_compression_toolkit.core.common.statistics_correction.apply_bias_corr
|
|
|
32
32
|
from model_compression_toolkit.logger import Logger
|
|
33
33
|
|
|
34
34
|
|
|
35
|
-
def _apply_gptq(gptq_config:
|
|
35
|
+
def _apply_gptq(gptq_config: GradientPTQConfig,
|
|
36
36
|
representative_data_gen: Callable,
|
|
37
37
|
tb_w: TensorboardWriter,
|
|
38
38
|
tg: Graph,
|
|
@@ -74,7 +74,7 @@ def _apply_gptq(gptq_config: GradientPTQConfigV2,
|
|
|
74
74
|
|
|
75
75
|
def gptq_runner(tg: Graph,
|
|
76
76
|
core_config: CoreConfig,
|
|
77
|
-
gptq_config:
|
|
77
|
+
gptq_config: GradientPTQConfig,
|
|
78
78
|
representative_data_gen: Callable,
|
|
79
79
|
gptq_representative_data_gen: Callable,
|
|
80
80
|
fw_info: FrameworkInfo,
|
|
@@ -13,5 +13,5 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from model_compression_toolkit.ptq.pytorch.quantization_facade import
|
|
17
|
-
from model_compression_toolkit.ptq.keras.quantization_facade import
|
|
16
|
+
from model_compression_toolkit.ptq.pytorch.quantization_facade import pytorch_post_training_quantization
|
|
17
|
+
from model_compression_toolkit.ptq.keras.quantization_facade import keras_post_training_quantization
|
|
@@ -40,12 +40,11 @@ if FOUND_TF:
|
|
|
40
40
|
DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
|
|
41
41
|
|
|
42
42
|
|
|
43
|
-
def
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
new_experimental_exporter: bool = True):
|
|
43
|
+
def keras_post_training_quantization(in_model: Model,
|
|
44
|
+
representative_data_gen: Callable,
|
|
45
|
+
target_kpi: KPI = None,
|
|
46
|
+
core_config: CoreConfig = CoreConfig(),
|
|
47
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
|
|
49
48
|
"""
|
|
50
49
|
Quantize a trained Keras model using post-training quantization. The model is quantized using a
|
|
51
50
|
symmetric constraint quantization thresholds (power of two).
|
|
@@ -65,7 +64,6 @@ if FOUND_TF:
|
|
|
65
64
|
target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
|
|
66
65
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
|
|
67
66
|
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
|
|
68
|
-
new_experimental_exporter (bool): Whether to wrap the quantized model using quantization information or not. Enabled by default. Experimental and subject to future changes.
|
|
69
67
|
|
|
70
68
|
Returns:
|
|
71
69
|
|
|
@@ -111,7 +109,7 @@ if FOUND_TF:
|
|
|
111
109
|
Pass the model, the representative dataset generator, the configuration and the target KPI to get a
|
|
112
110
|
quantized model:
|
|
113
111
|
|
|
114
|
-
>>> quantized_model, quantization_info = mct.ptq.
|
|
112
|
+
>>> quantized_model, quantization_info = mct.ptq.keras_post_training_quantization(model, repr_datagen, kpi, core_config=config)
|
|
115
113
|
|
|
116
114
|
For more configuration options, please take a look at our `API documentation <https://sony.github.io/model_optimization/api/api_docs/modules/mixed_precision_quantization_config.html>`_.
|
|
117
115
|
|
|
@@ -150,26 +148,14 @@ if FOUND_TF:
|
|
|
150
148
|
fw_impl,
|
|
151
149
|
fw_info)
|
|
152
150
|
|
|
153
|
-
|
|
154
|
-
Logger.warning('Using new experimental wrapped and ready for export models. To '
|
|
155
|
-
'disable it, please set new_experimental_exporter to False when '
|
|
156
|
-
'calling keras_post_training_quantization_experimental. '
|
|
157
|
-
'If you encounter an issue please file a bug.')
|
|
158
|
-
|
|
159
|
-
return get_exportable_keras_model(tg)
|
|
160
|
-
|
|
161
|
-
return export_model(tg,
|
|
162
|
-
fw_info,
|
|
163
|
-
fw_impl,
|
|
164
|
-
tb_w,
|
|
165
|
-
bit_widths_config)
|
|
151
|
+
return get_exportable_keras_model(tg)
|
|
166
152
|
|
|
167
153
|
|
|
168
154
|
|
|
169
155
|
else:
|
|
170
156
|
# If tensorflow is not installed,
|
|
171
157
|
# we raise an exception when trying to use these functions.
|
|
172
|
-
def
|
|
158
|
+
def keras_post_training_quantization(*args, **kwargs):
|
|
173
159
|
Logger.critical('Installing tensorflow is mandatory '
|
|
174
|
-
'when using
|
|
160
|
+
'when using keras_post_training_quantization. '
|
|
175
161
|
'Could not find Tensorflow package.') # pragma: no cover
|
|
@@ -39,12 +39,11 @@ if FOUND_TORCH:
|
|
|
39
39
|
|
|
40
40
|
DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
|
|
41
41
|
|
|
42
|
-
def
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
new_experimental_exporter: bool = True):
|
|
42
|
+
def pytorch_post_training_quantization(in_module: Module,
|
|
43
|
+
representative_data_gen: Callable,
|
|
44
|
+
target_kpi: KPI = None,
|
|
45
|
+
core_config: CoreConfig = CoreConfig(),
|
|
46
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
|
|
48
47
|
"""
|
|
49
48
|
Quantize a trained Pytorch module using post-training quantization.
|
|
50
49
|
By default, the module is quantized using a symmetric constraint quantization thresholds
|
|
@@ -64,7 +63,6 @@ if FOUND_TORCH:
|
|
|
64
63
|
target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
|
|
65
64
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
|
|
66
65
|
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
|
|
67
|
-
new_experimental_exporter (bool): Whether to wrap the quantized model using quantization information or not. Enabled by default. Experimental and subject to future changes.
|
|
68
66
|
|
|
69
67
|
Returns:
|
|
70
68
|
A quantized module and information the user may need to handle the quantized module.
|
|
@@ -89,7 +87,7 @@ if FOUND_TORCH:
|
|
|
89
87
|
Set number of clibration iterations to 1:
|
|
90
88
|
|
|
91
89
|
>>> import model_compression_toolkit as mct
|
|
92
|
-
>>> quantized_module, quantization_info = mct.ptq.
|
|
90
|
+
>>> quantized_module, quantization_info = mct.ptq.pytorch_post_training_quantization(module, repr_datagen)
|
|
93
91
|
|
|
94
92
|
"""
|
|
95
93
|
|
|
@@ -123,26 +121,13 @@ if FOUND_TORCH:
|
|
|
123
121
|
fw_impl,
|
|
124
122
|
DEFAULT_PYTORCH_INFO)
|
|
125
123
|
|
|
126
|
-
|
|
127
|
-
Logger.warning('Using new experimental wrapped and ready for export models. To '
|
|
128
|
-
'disable it, please set new_experimental_exporter to False when '
|
|
129
|
-
'calling pytorch_post_training_quantization_experimental. '
|
|
130
|
-
'If you encounter an issue please file a bug.')
|
|
124
|
+
return get_exportable_pytorch_model(tg)
|
|
131
125
|
|
|
132
|
-
return get_exportable_pytorch_model(tg)
|
|
133
|
-
|
|
134
|
-
quantized_model, user_info = export_model(tg,
|
|
135
|
-
DEFAULT_PYTORCH_INFO,
|
|
136
|
-
fw_impl,
|
|
137
|
-
tb_w,
|
|
138
|
-
bit_widths_config)
|
|
139
|
-
|
|
140
|
-
return quantized_model, user_info
|
|
141
126
|
|
|
142
127
|
else:
|
|
143
128
|
# If torch is not installed,
|
|
144
129
|
# we raise an exception when trying to use these functions.
|
|
145
|
-
def
|
|
130
|
+
def pytorch_post_training_quantization(*args, **kwargs):
|
|
146
131
|
Logger.critical('Installing Pytorch is mandatory '
|
|
147
|
-
'when using
|
|
132
|
+
'when using pytorch_post_training_quantization. '
|
|
148
133
|
'Could not find the torch package.') # pragma: no cover
|
|
File without changes
|
|
File without changes
|
|
File without changes
|