mct-nightly 2.2.0.20240916.525__py3-none-any.whl → 2.2.0.20240918.448__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20240916.525.dist-info → mct_nightly-2.2.0.20240918.448.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20240916.525.dist-info → mct_nightly-2.2.0.20240918.448.dist-info}/RECORD +30 -20
- {mct_nightly-2.2.0.20240916.525.dist-info → mct_nightly-2.2.0.20240918.448.dist-info}/top_level.txt +1 -0
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/base_node.py +3 -0
- model_compression_toolkit/core/common/graph/functional_node.py +1 -1
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -1
- model_compression_toolkit/core/keras/reader/node_builder.py +23 -1
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
- model_compression_toolkit/core/pytorch/reader/graph_builders.py +13 -4
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +12 -3
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +10 -1
- model_compression_toolkit/gptq/__init__.py +17 -5
- model_compression_toolkit/gptq/common/gptq_config.py +88 -75
- model_compression_toolkit/gptq/pytorch/gptq_training.py +18 -9
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +49 -29
- model_compression_toolkit/gptq/pytorch/quantizer/gradual_activation_quantization.py +80 -0
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +10 -10
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +6 -49
- model_compression_toolkit/trainable_infrastructure/pytorch/annealing_schedulers.py +39 -0
- model_compression_toolkit/trainable_infrastructure/pytorch/util.py +29 -0
- tests_pytest/__init__.py +14 -0
- tests_pytest/pytorch/__init__.py +14 -0
- tests_pytest/pytorch/gptq/__init__.py +14 -0
- tests_pytest/pytorch/gptq/test_annealing_cfg.py +40 -0
- tests_pytest/pytorch/gptq/test_gradual_act_quantization.py +100 -0
- tests_pytest/pytorch/trainable_infrastructure/__init__.py +14 -0
- tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py +49 -0
- {mct_nightly-2.2.0.20240916.525.dist-info → mct_nightly-2.2.0.20240918.448.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20240916.525.dist-info → mct_nightly-2.2.0.20240918.448.dist-info}/WHEEL +0 -0
{mct_nightly-2.2.0.20240916.525.dist-info → mct_nightly-2.2.0.20240918.448.dist-info}/RECORD
RENAMED
@@ -1,4 +1,4 @@
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
model_compression_toolkit/__init__.py,sha256=mbgbcZTqzAzq-hfFdFzcbNZgPkm70zf0uPjjSnCRs4E,1573
|
2
2
|
model_compression_toolkit/constants.py,sha256=i4wYheBkIdQmsQA-axIpcT3YiSO1USNc-jaNiNE8w6E,3920
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
4
4
|
model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
|
@@ -33,9 +33,9 @@ model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=8seu9jBpC7Har
|
|
33
33
|
model_compression_toolkit/core/common/fusion/layer_fusing.py,sha256=lOubqpc18TslhXZijWUJQAa1c3jIB2S-M-5HK78wJPQ,5548
|
34
34
|
model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
|
35
35
|
model_compression_toolkit/core/common/graph/base_graph.py,sha256=lg5QaBkRbmvM3tGZ0Q34S3m0CbFql3LUv5BaXLe5TG8,37824
|
36
|
-
model_compression_toolkit/core/common/graph/base_node.py,sha256=
|
36
|
+
model_compression_toolkit/core/common/graph/base_node.py,sha256=W6xXj3U0vPlSAoEBuw1fZ1E5I1YNaeTcrNum4JDKdj8,31619
|
37
37
|
model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
|
38
|
-
model_compression_toolkit/core/common/graph/functional_node.py,sha256=
|
38
|
+
model_compression_toolkit/core/common/graph/functional_node.py,sha256=QpO9wjiYWuLzzy84Z6qRhVP6wlMrLnOTYCuNzNvJbNo,3958
|
39
39
|
model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
|
40
40
|
model_compression_toolkit/core/common/graph/graph_searches.py,sha256=2oKuW6L8hP-oL0lFO9PhQFt9fEFgVJwpc1u4fHExAtE,5128
|
41
41
|
model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py,sha256=3el-A7j1oyoo1_9zq3faQp7IeRsFXFCvnrb3zZFXpU0,9803
|
@@ -163,7 +163,7 @@ model_compression_toolkit/core/keras/back2framework/__init__.py,sha256=rhIiXg_nB
|
|
163
163
|
model_compression_toolkit/core/keras/back2framework/factory_model_builder.py,sha256=UIQgOOdexycrSKombTMJVvTthR7MlrCihoqM8Kg-rnE,2293
|
164
164
|
model_compression_toolkit/core/keras/back2framework/float_model_builder.py,sha256=9SFHhX-JnkB8PvYIIHRYlReBDI_RkZY9LditzW_ElLk,2444
|
165
165
|
model_compression_toolkit/core/keras/back2framework/instance_builder.py,sha256=fBj13c6zkVoWX4JJG18_uXPptiEJqXClE_zFbaFB6Q8,4517
|
166
|
-
model_compression_toolkit/core/keras/back2framework/keras_model_builder.py,sha256=
|
166
|
+
model_compression_toolkit/core/keras/back2framework/keras_model_builder.py,sha256=EyMWjObq8DVG929dY5OquyYGx3kXhgob8XnzmGxmizc,18162
|
167
167
|
model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py,sha256=ygIS1WIiftF1VC3oGhc8N6j7MryKtWgEg8nr50p7f4U,15587
|
168
168
|
model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py,sha256=5wFb4nx_F0Wu4c8pLf6n6OzxOHtpOJ6_3mQsNSXIudU,2481
|
169
169
|
model_compression_toolkit/core/keras/graph_substitutions/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
@@ -205,7 +205,7 @@ model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py,sha256=Up3-sbuA
|
|
205
205
|
model_compression_toolkit/core/keras/reader/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
206
206
|
model_compression_toolkit/core/keras/reader/common.py,sha256=eZWjBcvTDUX7fCWmy1OAH4lYLFTh59_UQ_nP_Gjp4yw,2594
|
207
207
|
model_compression_toolkit/core/keras/reader/connectivity_handler.py,sha256=AgF6qXZOJMeXvc-pBnGY23BJz7wPBx2aTYxHiO8efec,11303
|
208
|
-
model_compression_toolkit/core/keras/reader/node_builder.py,sha256=
|
208
|
+
model_compression_toolkit/core/keras/reader/node_builder.py,sha256=fkuzNYTcihtjSOyhfWL7yT30JqPnAQo-JzZLiKtR4Io,15014
|
209
209
|
model_compression_toolkit/core/keras/reader/reader.py,sha256=wS9UQ2wJKnkZYe9JHwQp7ygDr6CRlzrxmIyLDv1Qz6U,8109
|
210
210
|
model_compression_toolkit/core/keras/reader/nested_model/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
211
211
|
model_compression_toolkit/core/keras/reader/nested_model/edges_merger.py,sha256=K6KAH9o8KSG6baLmhKoCrYK-i-wb6gRKiZmoijFqEYA,7906
|
@@ -228,7 +228,7 @@ model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,s
|
|
228
228
|
model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
|
229
229
|
model_compression_toolkit/core/pytorch/back2framework/instance_builder.py,sha256=BBHBfTqeWm7L3iDyPBpk0jxvj-rBg1QWI23imkjfIl0,1467
|
230
230
|
model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py,sha256=D7lU1r9Uq_7fdNuKk2BMF8ho5GrsY-8gyGN6yYoHaVg,15060
|
231
|
-
model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=
|
231
|
+
model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=Oyro2qg7Bz8TFoimHtrn3JCwHEO9iCrTMy4HktaYZzg,18937
|
232
232
|
model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py,sha256=qZNNOlNTTV4ZKPG3q5GDXkIVTPUEr8dvxAS_YiMORmg,3456
|
233
233
|
model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
234
234
|
model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py,sha256=q2JDw10NKng50ee2i9faGzWZ-IydnR2aOMGSn9RoZmc,5773
|
@@ -267,7 +267,7 @@ model_compression_toolkit/core/pytorch/quantizer/__init__.py,sha256=Rf1RcYmelmdZ
|
|
267
267
|
model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py,sha256=D8_CEuFqKAhbUgKaRw7Jlxo0zlqgPTMu6CIIIM4LfS0,7045
|
268
268
|
model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py,sha256=uyeBtNokyDUikk-YkDP_mN_2DX0J5oPm3kSfdSUT2Ck,4420
|
269
269
|
model_compression_toolkit/core/pytorch/reader/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
270
|
-
model_compression_toolkit/core/pytorch/reader/graph_builders.py,sha256
|
270
|
+
model_compression_toolkit/core/pytorch/reader/graph_builders.py,sha256=BvBj9uokKTvX-6d39yA4SKwRQAN8_X4T8l-rPibChJQ,16754
|
271
271
|
model_compression_toolkit/core/pytorch/reader/node_holders.py,sha256=7XNc7-l1MZPJGcOESvtAwfIMxrU6kvt3YjF5B7qOqK4,1048
|
272
272
|
model_compression_toolkit/core/pytorch/reader/reader.py,sha256=GEJE0QX8XJFWbYCkbRBtzttZtmmuoACLx8gw9KyAQCE,6015
|
273
273
|
model_compression_toolkit/core/pytorch/statistics_correction/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
@@ -334,17 +334,17 @@ model_compression_toolkit/exporter/model_wrapper/fw_agnostic/get_inferable_quant
|
|
334
334
|
model_compression_toolkit/exporter/model_wrapper/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
335
335
|
model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py,sha256=SvSGpU0IEUcy6zwChtPm_9lOSNXf4bPN0pwqvVZToik,3929
|
336
336
|
model_compression_toolkit/exporter/model_wrapper/keras/builder/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
337
|
-
model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py,sha256=
|
337
|
+
model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py,sha256=s59shKmWNtvyGXJu24hxS3jG13PGGsL4jrk1QXTrIxM,6243
|
338
338
|
model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py,sha256=uL6tJWC4s2IWUy8GJVwtMWpwZZioRRztfKyPJHo14xI,9442
|
339
339
|
model_compression_toolkit/exporter/model_wrapper/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
340
340
|
model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py,sha256=vQUGbCi8_pGoN8DwQ0IblSeN6L9t6Cr0reZNuCbBpkM,3469
|
341
341
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
342
|
-
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py,sha256
|
342
|
+
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py,sha256=-4AmWgTG9p8sH2mpns-PaRmvM6J853mrhNc0wt-9ovs,6076
|
343
343
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py,sha256=4sN5z-6BXrTE5Dp2FX_jKO9ty5iZ2r4RM7XvXtDVLSI,9348
|
344
|
-
model_compression_toolkit/gptq/__init__.py,sha256=
|
344
|
+
model_compression_toolkit/gptq/__init__.py,sha256=pEgkJvmf05KSw70iLDTz_6LI_2Oi5L8sTN0JsEUpnpk,1445
|
345
345
|
model_compression_toolkit/gptq/runner.py,sha256=La12JTYjWyJW0YW4Al4TP1_Xi4JWBCEKw6FR_JQsxe0,5982
|
346
346
|
model_compression_toolkit/gptq/common/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
347
|
-
model_compression_toolkit/gptq/common/gptq_config.py,sha256=
|
347
|
+
model_compression_toolkit/gptq/common/gptq_config.py,sha256=xogD4mM2825NXyX7rKWBaKBhBFo31bMUmxECREGgtWc,6132
|
348
348
|
model_compression_toolkit/gptq/common/gptq_constants.py,sha256=QSm6laLkIV0LYmU0BLtmKp3Fi3SqDfbncFQWOGA1cGU,611
|
349
349
|
model_compression_toolkit/gptq/common/gptq_framework_implementation.py,sha256=n3mSf4J92kFjekzyGyrJULylI-8Jf5OVWJ5AFoVnEx0,1266
|
350
350
|
model_compression_toolkit/gptq/common/gptq_graph.py,sha256=-bL5HhPcKqV8nj4dZPXc5QmQJbFBel6etrioikP0tEo,3039
|
@@ -369,16 +369,17 @@ model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py,sha
|
|
369
369
|
model_compression_toolkit/gptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
370
370
|
model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=kDuWw-6zh17wZpYWh4Xa94rpoodf82DksgjQCnL7nBc,2719
|
371
371
|
model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
|
372
|
-
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256
|
372
|
+
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=bnL4DyPLBz2-pip3RV_jBmExvQKZ4N1vXzQudc1VgMY,17117
|
373
373
|
model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=4mVM-VvnBaA64ACVdOe6wTGHdMSa2UTLIUe7nACLcdo,4008
|
374
|
-
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=
|
374
|
+
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=Z1xCEDiRWE6xtjVjgVGpgGazuY9l9IhUOPNiRZegLMQ,15408
|
375
375
|
model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
|
376
376
|
model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=fKg-PNOhGBiL-4eySS9Fyw0GkA76Pq8jT_HbJuJ8iZU,4143
|
377
|
+
model_compression_toolkit/gptq/pytorch/quantizer/gradual_activation_quantization.py,sha256=nngu2TeXjngkqt_6-wciFmCvo-dbpeh_tJJxBV_cfHk,3686
|
377
378
|
model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
|
378
379
|
model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py,sha256=5EyAzvlU01vLyXmMwY_8dNyb7GwYktXmnrvUON8n8WI,4696
|
379
|
-
model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py,sha256=
|
380
|
+
model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py,sha256=lkeEBgAAhC1VHu4DHoqDz8GC7BIU4cU0HIAXFYfgUFU,2098
|
380
381
|
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
|
381
|
-
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=
|
382
|
+
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=UZwVCpG8WOw7r0-cmPYXNkJYpTZciW66KWtKG004J6Q,2683
|
382
383
|
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py,sha256=kLVQC1hXzDpP4Jx7AwnA764oGnY5AMEuvUUhAvhz09M,12347
|
383
384
|
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py,sha256=FgPSKoV8p8y-gLNz359XdOPD6w_wpDvcJFtTNLWqYb0,9099
|
384
385
|
model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
@@ -504,8 +505,10 @@ model_compression_toolkit/trainable_infrastructure/keras/load_model.py,sha256=DJ
|
|
504
505
|
model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha256=eVB5FSE3OmTLrhfLUcP2knwN1z2_unQLM-xFEGwdafA,5587
|
505
506
|
model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
|
506
507
|
model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
508
|
+
model_compression_toolkit/trainable_infrastructure/pytorch/annealing_schedulers.py,sha256=IdUBpZUcOXHLPp2OhwbO_Kytee3OTVuy2032N-tm694,1686
|
507
509
|
model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=lWc5EG3ptrP85n69EHGKFkIadnrKEBMKnB5YXQ5AmXo,2745
|
508
510
|
model_compression_toolkit/trainable_infrastructure/pytorch/quantizer_utils.py,sha256=1yOXKghUYfw2hmzbqTuNagIXBoM-wR2bP-ul66-mnDw,7767
|
511
|
+
model_compression_toolkit/trainable_infrastructure/pytorch/util.py,sha256=4Qv_rkfxaDf0YeLD5I_7cepUk8OFsMNvUTrw9wFp_kU,1082
|
509
512
|
model_compression_toolkit/trainable_infrastructure/pytorch/activation_quantizers/__init__.py,sha256=73CXhqqNTvDpsvlJXclrGJq-vsCUYCI64ILu1y2mtvw,1056
|
510
513
|
model_compression_toolkit/trainable_infrastructure/pytorch/activation_quantizers/base_activation_quantizer.py,sha256=X6E6mewWQot_aAkz3UxW5X0-Fjl_aMMjs3A-Af5eL6w,972
|
511
514
|
model_compression_toolkit/trainable_infrastructure/pytorch/activation_quantizers/lsq/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
|
@@ -540,8 +543,15 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
540
543
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=bOc-hFL3gdoSM1Th_S2N_-9JJSlPGpZCTx_QLJHS6lg,3388
|
541
544
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
542
545
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
546
|
+
tests_pytest/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
|
547
|
+
tests_pytest/pytorch/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
|
548
|
+
tests_pytest/pytorch/gptq/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
|
549
|
+
tests_pytest/pytorch/gptq/test_annealing_cfg.py,sha256=hGC7L6mp3N1ygcJ3OctgS_Fz2JY75q5aswolJkbHkZM,2208
|
550
|
+
tests_pytest/pytorch/gptq/test_gradual_act_quantization.py,sha256=tI01aFIUaiCILL5Qn--p1E_rLBUelxLdSY3k52lwcx0,4594
|
551
|
+
tests_pytest/pytorch/trainable_infrastructure/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
|
552
|
+
tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py,sha256=eNOpSp0GoLxtEdiRypBp8jaujXfdNxBwKh5Rd-P7WLs,1786
|
553
|
+
mct_nightly-2.2.0.20240918.448.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
554
|
+
mct_nightly-2.2.0.20240918.448.dist-info/METADATA,sha256=Atg7fbRWZ1KvrHeOc1jaJ6Gb2VrUdnOAs9gKc_v26VU,20813
|
555
|
+
mct_nightly-2.2.0.20240918.448.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
556
|
+
mct_nightly-2.2.0.20240918.448.dist-info/top_level.txt,sha256=csdfSXhtRnpWYRzjZ-dRLIhOmM2TEdVXUxG05A5fgb8,39
|
557
|
+
mct_nightly-2.2.0.20240918.448.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.2.0.
|
30
|
+
__version__ = "2.2.0.20240918.000448"
|
@@ -40,6 +40,7 @@ class BaseNode:
|
|
40
40
|
layer_class: type,
|
41
41
|
reuse: bool = False,
|
42
42
|
reuse_group: str = None,
|
43
|
+
inputs_as_list: bool = False,
|
43
44
|
quantization_attr: Dict[str, Any] = None,
|
44
45
|
has_activation: bool = True,
|
45
46
|
is_custom: bool = False
|
@@ -58,6 +59,7 @@ class BaseNode:
|
|
58
59
|
layer_class: Class path of the layer this node represents.
|
59
60
|
reuse: Whether this node was duplicated and represents a reused layer.
|
60
61
|
reuse_group: Name of group of nodes from the same reused layer.
|
62
|
+
inputs_as_list: Whether to pass the node its input tensors as a list or not when calling the layer.
|
61
63
|
quantization_attr: Attributes the node holds regarding how it should be quantized.
|
62
64
|
has_activation: Whether the node has activations that we might want to quantize.
|
63
65
|
is_custom: Whether the node is custom layer or not.
|
@@ -71,6 +73,7 @@ class BaseNode:
|
|
71
73
|
self.layer_class = layer_class
|
72
74
|
self.reuse = reuse
|
73
75
|
self.reuse_group = reuse_group
|
76
|
+
self.inputs_as_list = inputs_as_list
|
74
77
|
self.final_weights_quantization_cfg = None
|
75
78
|
self.final_activation_quantization_cfg = None
|
76
79
|
self.candidates_quantization_cfg = None
|
@@ -55,13 +55,13 @@ class FunctionalNode(BaseNode):
|
|
55
55
|
layer_class,
|
56
56
|
reuse,
|
57
57
|
reuse_group,
|
58
|
+
inputs_as_list,
|
58
59
|
quantization_attr,
|
59
60
|
has_activation=has_activation)
|
60
61
|
|
61
62
|
self.op_call_kwargs = op_call_kwargs
|
62
63
|
self.op_call_args = list(op_call_args)
|
63
64
|
self.functional_op = functional_op
|
64
|
-
self.inputs_as_list = inputs_as_list
|
65
65
|
self.tensor_input_allocs = [] if tensor_input_allocs is None else tensor_input_allocs
|
66
66
|
|
67
67
|
@property
|
@@ -308,7 +308,7 @@ class KerasModelBuilder(BaseModelBuilder):
|
|
308
308
|
else:
|
309
309
|
# If operator expects a single input tensor, it cannot be a list as it should
|
310
310
|
# have a dtype field.
|
311
|
-
if len(input_tensors) == 1:
|
311
|
+
if len(input_tensors) == 1 and not n.inputs_as_list:
|
312
312
|
input_tensors = input_tensors[0]
|
313
313
|
out_tensors_of_n_float = op_func(input_tensors)
|
314
314
|
|
@@ -30,10 +30,12 @@ if version.parse(tf.__version__) >= version.parse("2.13"):
|
|
30
30
|
from keras.src.layers.core import TFOpLambda, SlicingOpLambda
|
31
31
|
from keras.src.engine.keras_tensor import KerasTensor
|
32
32
|
from keras.src.engine.node import Node as KerasNode
|
33
|
+
from keras.src.layers.merging.base_merge import _Merge
|
33
34
|
else:
|
34
35
|
from keras.layers.core import TFOpLambda, SlicingOpLambda
|
35
36
|
from keras.engine.keras_tensor import KerasTensor
|
36
37
|
from keras.engine.node import Node as KerasNode
|
38
|
+
from keras.layers.merging.base_merge import _Merge
|
37
39
|
|
38
40
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
39
41
|
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
@@ -287,6 +289,7 @@ def build_node(node: KerasNode,
|
|
287
289
|
for i, arg in enumerate(op_call_args[0]):
|
288
290
|
if is_const(arg):
|
289
291
|
weights.update({i: to_numpy(arg, is_single_tensor=True)})
|
292
|
+
inputs_as_list = __is_node_inputs_a_list(op_call_args, keras_layer)
|
290
293
|
|
291
294
|
node = BaseNode(node_name,
|
292
295
|
layer_config,
|
@@ -296,6 +299,7 @@ def build_node(node: KerasNode,
|
|
296
299
|
layer_class,
|
297
300
|
is_reused,
|
298
301
|
reuse_group,
|
302
|
+
inputs_as_list,
|
299
303
|
is_custom=is_keras_custom_layer(layer_class))
|
300
304
|
|
301
305
|
node_name_to_node[node_name] = node
|
@@ -316,6 +320,24 @@ def __is_functional_inputs_a_list(op_call_args: Any, keras_layer: Any) -> bool:
|
|
316
320
|
"""
|
317
321
|
|
318
322
|
return (keras_layer.symbol in
|
319
|
-
[TFOpLambda(tf.concat).symbol, TFOpLambda(tf.stack).symbol,TFOpLambda(tf.add_n).symbol] and
|
323
|
+
[TFOpLambda(tf.concat).symbol, TFOpLambda(tf.stack).symbol, TFOpLambda(tf.add_n).symbol] and
|
320
324
|
len(op_call_args) > 0 and
|
321
325
|
isinstance(op_call_args[0], list))
|
326
|
+
|
327
|
+
|
328
|
+
def __is_node_inputs_a_list(op_call_args: Any, keras_layer: Any) -> bool:
|
329
|
+
"""
|
330
|
+
Check whether the input tensors should be passed as a list or not. This is relevant
|
331
|
+
only for layers that inherit from _Merge such as Concatenate and Add.
|
332
|
+
|
333
|
+
Args:
|
334
|
+
op_call_args: Arguments list to check.
|
335
|
+
keras_layer: Keras layer.
|
336
|
+
|
337
|
+
Returns:
|
338
|
+
Whether the input tensors should be passed as a list or not.
|
339
|
+
"""
|
340
|
+
|
341
|
+
return (isinstance(keras_layer, _Merge) and
|
342
|
+
len(op_call_args) > 0 and
|
343
|
+
isinstance(op_call_args[0], (list, tuple)))
|
@@ -139,7 +139,11 @@ def _run_operation(n: BaseNode,
|
|
139
139
|
_tensor_input_allocs = None
|
140
140
|
|
141
141
|
if isinstance(n, FunctionalNode) and n.inputs_as_list:
|
142
|
-
|
142
|
+
if isinstance(op_func, PytorchQuantizationWrapper):
|
143
|
+
# in wrapped nodes, the op args & kwargs are already in the PytorchQuantizationWrapper.
|
144
|
+
out_tensors_of_n_float = op_func(*input_tensors)
|
145
|
+
else:
|
146
|
+
out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
|
143
147
|
else:
|
144
148
|
merged_inputs, functional_kwargs = _merge_inputs(n, input_tensors, op_call_args, functional_kwargs.copy(),
|
145
149
|
tensor_input_allocs=_tensor_input_allocs)
|
@@ -232,10 +232,19 @@ def nodes_builder(model: GraphModule,
|
|
232
232
|
|
233
233
|
# Add constants to weights dictionary.
|
234
234
|
if node.op != PLACEHOLDER:
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
235
|
+
if len(node.args) and isinstance(node.args[0], (list, tuple)):
|
236
|
+
# handle weights in nodes with list input. Especially when there's a duplicate of a tensor
|
237
|
+
# in the input list (e.g. torch.concat([const1, x, const2, x, const3], 1)).
|
238
|
+
for input_node in node.all_input_nodes:
|
239
|
+
for i, input_arg in enumerate(node.args[0]):
|
240
|
+
if input_node is input_arg and input_node in consts_dict:
|
241
|
+
used_consts.add(input_node)
|
242
|
+
weights.update({i: consts_dict[input_node]})
|
243
|
+
else:
|
244
|
+
for i, input_node in enumerate(node.all_input_nodes):
|
245
|
+
if input_node in consts_dict:
|
246
|
+
used_consts.add(input_node)
|
247
|
+
weights.update({i: consts_dict[input_node]})
|
239
248
|
|
240
249
|
# Extract input and output shapes of the node.
|
241
250
|
input_shape, output_shape = _extract_input_and_output_shapes(node)
|
model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py
CHANGED
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from typing import Tuple, Callable
|
16
|
+
from typing import Tuple, Callable, Union
|
17
17
|
from model_compression_toolkit.core import common
|
18
18
|
from model_compression_toolkit.core.common import Graph
|
19
19
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
@@ -25,10 +25,12 @@ if FOUND_TF:
|
|
25
25
|
import tensorflow as tf
|
26
26
|
from tensorflow.keras.layers import Layer
|
27
27
|
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
|
28
|
+
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
28
29
|
from mct_quantizers import KerasQuantizationWrapper
|
29
30
|
from mct_quantizers import KerasActivationQuantizationHolder
|
31
|
+
from mct_quantizers.common.constants import OP_CALL_ARGS, OP_CALL_KWARGS
|
30
32
|
|
31
|
-
def _get_wrapper(node: common.BaseNode,
|
33
|
+
def _get_wrapper(node: Union[common.BaseNode, FunctionalNode],
|
32
34
|
layer: Layer,
|
33
35
|
fw_impl=None) -> Layer:
|
34
36
|
"""
|
@@ -45,9 +47,16 @@ if FOUND_TF:
|
|
45
47
|
# for positional weights we need to extract the weight's value.
|
46
48
|
weights_values = {attr: node.get_weights_by_keys(attr)
|
47
49
|
for attr in weights_quantizers if isinstance(attr, int)}
|
50
|
+
# When wrapping functional nodes, need to set call args\kwargs in wrapper, because they
|
51
|
+
# are used during wrapper call method.
|
52
|
+
func_node_kwargs = {OP_CALL_ARGS: node.op_call_args,
|
53
|
+
OP_CALL_KWARGS: node.op_call_kwargs
|
54
|
+
} if isinstance(node, FunctionalNode) else {}
|
48
55
|
return KerasQuantizationWrapper(layer,
|
49
56
|
weights_quantizers,
|
50
|
-
weights_values
|
57
|
+
weights_values,
|
58
|
+
is_inputs_as_list=node.inputs_as_list,
|
59
|
+
**func_node_kwargs)
|
51
60
|
return layer
|
52
61
|
|
53
62
|
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py
CHANGED
@@ -24,7 +24,9 @@ import model_compression_toolkit.core as C
|
|
24
24
|
if FOUND_TORCH:
|
25
25
|
import torch
|
26
26
|
from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder
|
27
|
+
from mct_quantizers.common.constants import OP_CALL_ARGS, OP_CALL_KWARGS
|
27
28
|
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
|
29
|
+
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
28
30
|
|
29
31
|
|
30
32
|
def fully_quantized_wrapper(node: common.BaseNode,
|
@@ -46,7 +48,14 @@ if FOUND_TORCH:
|
|
46
48
|
# for positional weights we need to extract the weight's value.
|
47
49
|
weights_values = {attr: fw_impl.to_tensor(node.get_weights_by_keys(attr))
|
48
50
|
for attr in weight_quantizers if isinstance(attr, int)}
|
49
|
-
|
51
|
+
# When wrapping functional nodes, need to set call args\kwargs in wrapper, because they
|
52
|
+
# are used during wrapper call method.
|
53
|
+
func_node_kwargs = {OP_CALL_ARGS: node.op_call_args,
|
54
|
+
OP_CALL_KWARGS: node.op_call_kwargs
|
55
|
+
} if isinstance(node, FunctionalNode) else {}
|
56
|
+
return PytorchQuantizationWrapper(module, weight_quantizers, weights_values,
|
57
|
+
is_inputs_as_list=node.inputs_as_list,
|
58
|
+
**func_node_kwargs)
|
50
59
|
return module
|
51
60
|
|
52
61
|
|
@@ -13,8 +13,20 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from model_compression_toolkit.gptq.common.gptq_config import
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
16
|
+
from model_compression_toolkit.gptq.common.gptq_config import (
|
17
|
+
GradientPTQConfig,
|
18
|
+
RoundingType,
|
19
|
+
GPTQHessianScoresConfig,
|
20
|
+
GradualActivationQuantizationConfig,
|
21
|
+
QFractionLinearAnnealingConfig
|
22
|
+
)
|
23
|
+
|
24
|
+
from model_compression_toolkit.verify_packages import FOUND_TF, FOUND_TORCH
|
25
|
+
|
26
|
+
if FOUND_TF:
|
27
|
+
from model_compression_toolkit.gptq.keras.quantization_facade import keras_gradient_post_training_quantization
|
28
|
+
from model_compression_toolkit.gptq.keras.quantization_facade import get_keras_gptq_config
|
29
|
+
|
30
|
+
if FOUND_TORCH:
|
31
|
+
from model_compression_toolkit.gptq.pytorch.quantization_facade import pytorch_gradient_post_training_quantization
|
32
|
+
from model_compression_toolkit.gptq.pytorch.quantization_facade import get_pytorch_gptq_config
|
@@ -12,8 +12,9 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
from dataclasses import dataclass, field
|
15
16
|
from enum import Enum
|
16
|
-
from typing import Callable, Any, Dict
|
17
|
+
from typing import Callable, Any, Dict, Optional
|
17
18
|
|
18
19
|
from model_compression_toolkit.constants import GPTQ_HESSIAN_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE
|
19
20
|
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
|
@@ -32,91 +33,103 @@ class RoundingType(Enum):
|
|
32
33
|
SoftQuantizer = 1
|
33
34
|
|
34
35
|
|
36
|
+
@dataclass
|
35
37
|
class GPTQHessianScoresConfig:
|
36
38
|
"""
|
37
39
|
Configuration to use for computing the Hessian-based scores for GPTQ loss metric.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
hessians_num_samples (int): Number of samples to use for computing the Hessian-based scores.
|
43
|
+
norm_scores (bool): Whether to normalize the returned scores of the weighted loss function (to get values between 0 and 1).
|
44
|
+
log_norm (bool): Whether to use log normalization for the GPTQ Hessian-based scores.
|
45
|
+
scale_log_norm (bool): Whether to scale the final vector of the Hessian-based scores.
|
46
|
+
hessian_batch_size (int): The Hessian computation batch size. used only if using GPTQ with Hessian-based objective.
|
38
47
|
"""
|
48
|
+
hessians_num_samples: int = GPTQ_HESSIAN_NUM_SAMPLES
|
49
|
+
norm_scores: bool = True
|
50
|
+
log_norm: bool = True
|
51
|
+
scale_log_norm: bool = False
|
52
|
+
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE
|
39
53
|
|
40
|
-
def __init__(self,
|
41
|
-
hessians_num_samples: int = GPTQ_HESSIAN_NUM_SAMPLES,
|
42
|
-
norm_scores: bool = True,
|
43
|
-
log_norm: bool = True,
|
44
|
-
scale_log_norm: bool = False,
|
45
|
-
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE):
|
46
54
|
|
47
|
-
|
48
|
-
|
55
|
+
@dataclass
|
56
|
+
class QFractionLinearAnnealingConfig:
|
57
|
+
"""
|
58
|
+
Config for the quantized fraction linear scheduler of Gradual Activation Quantization.
|
49
59
|
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
60
|
+
Args:
|
61
|
+
initial_q_fraction: initial quantized fraction
|
62
|
+
target_q_fraction: target quantized fraction
|
63
|
+
start_step: gradient step to begin annealing
|
64
|
+
end_step: gradient step to complete annealing. None means last step.
|
65
|
+
"""
|
66
|
+
initial_q_fraction: float
|
67
|
+
target_q_fraction: float
|
68
|
+
start_step: int
|
69
|
+
end_step: Optional[int]
|
57
70
|
|
58
|
-
|
59
|
-
self.
|
60
|
-
|
61
|
-
|
62
|
-
self.
|
71
|
+
def __post_init__(self):
|
72
|
+
if not (0 <= self.initial_q_fraction < self.target_q_fraction <= 1):
|
73
|
+
raise ValueError(f'Expected 0 <= initial_q_fraction < target_q_fraction <= 1, received initial_q_fraction '
|
74
|
+
f'{self.initial_q_fraction} and target_q_fraction {self.target_q_fraction}.')
|
75
|
+
if self.start_step < 0:
|
76
|
+
raise ValueError(f'Expected start_step >= 0. received {self.start_step}.')
|
77
|
+
if self.end_step is not None and self.end_step <= self.start_step:
|
78
|
+
raise ValueError('Expected start_step < end_step, '
|
79
|
+
'received end_step {self.end_step} and start_step {self.start_stap}.')
|
63
80
|
|
64
81
|
|
65
|
-
|
66
|
-
|
67
|
-
Configuration
|
68
|
-
|
69
|
-
|
70
|
-
n_epochs: int,
|
71
|
-
optimizer: Any,
|
72
|
-
optimizer_rest: Any = None,
|
73
|
-
loss: Callable = None,
|
74
|
-
log_function: Callable = None,
|
75
|
-
train_bias: bool = True,
|
76
|
-
rounding_type: RoundingType = RoundingType.SoftQuantizer,
|
77
|
-
use_hessian_based_weights: bool = True,
|
78
|
-
optimizer_quantization_parameter: Any = None,
|
79
|
-
optimizer_bias: Any = None,
|
80
|
-
regularization_factor: float = REG_DEFAULT,
|
81
|
-
hessian_weights_config: GPTQHessianScoresConfig = GPTQHessianScoresConfig(),
|
82
|
-
gptq_quantizer_params_override: Dict[str, Any] = None):
|
83
|
-
"""
|
84
|
-
Initialize a GradientPTQConfig.
|
82
|
+
@dataclass
|
83
|
+
class GradualActivationQuantizationConfig:
|
84
|
+
""" Configuration for Gradual Activation Quantization.
|
85
|
+
|
86
|
+
By default, the quantized fraction increases linearly from 0 to 1 throughout the training.
|
85
87
|
|
86
88
|
Args:
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
|
97
|
-
optimizer_quantization_parameter (Any): Optimizer to override the rest optimizer for quantizer parameters.
|
98
|
-
optimizer_bias (Any): Optimizer to override the rest optimizer for bias.
|
99
|
-
regularization_factor (float): A floating point number that defines the regularization factor.
|
100
|
-
hessian_weights_config (GPTQHessianScoresConfig): A configuration that include all necessary arguments to run a computation of Hessian scores for the GPTQ loss.
|
101
|
-
gptq_quantizer_params_override (dict): A dictionary of parameters to override in GPTQ quantizer instantiation. Defaults to None (no parameters).
|
102
|
-
|
103
|
-
"""
|
104
|
-
|
105
|
-
self.n_epochs = n_epochs
|
106
|
-
self.optimizer = optimizer
|
107
|
-
self.optimizer_rest = optimizer_rest
|
108
|
-
self.loss = loss
|
109
|
-
self.log_function = log_function
|
110
|
-
self.train_bias = train_bias
|
111
|
-
|
112
|
-
self.rounding_type = rounding_type
|
113
|
-
self.use_hessian_based_weights = use_hessian_based_weights
|
114
|
-
self.optimizer_quantization_parameter = optimizer_quantization_parameter
|
115
|
-
self.optimizer_bias = optimizer_bias
|
116
|
-
self.regularization_factor = regularization_factor
|
117
|
-
self.hessian_weights_config = hessian_weights_config
|
118
|
-
|
119
|
-
self.gptq_quantizer_params_override = {} if gptq_quantizer_params_override is None \
|
120
|
-
else gptq_quantizer_params_override
|
89
|
+
q_fraction_scheduler_policy: config for the scheduling of the quantized fraction.
|
90
|
+
Only linear annealing is currently supported.
|
91
|
+
"""
|
92
|
+
q_fraction_scheduler_policy: QFractionLinearAnnealingConfig = field(
|
93
|
+
default_factory=lambda: QFractionLinearAnnealingConfig(initial_q_fraction=0,
|
94
|
+
target_q_fraction=1,
|
95
|
+
start_step=0,
|
96
|
+
end_step=None)
|
97
|
+
)
|
121
98
|
|
122
99
|
|
100
|
+
@dataclass
|
101
|
+
class GradientPTQConfig:
|
102
|
+
"""
|
103
|
+
Configuration to use for quantization with GradientPTQ.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
n_epochs: Number of representative dataset epochs to train.
|
107
|
+
optimizer: Optimizer to use.
|
108
|
+
optimizer_rest: Optimizer to use for bias and quantizer parameters.
|
109
|
+
loss: The loss to use. See 'multiple_tensors_mse_loss' for the expected interface.
|
110
|
+
log_function: Function to log information about the GPTQ process.
|
111
|
+
train_bias: Whether to update the bias during the training or not.
|
112
|
+
rounding_type: An enum that defines the rounding type.
|
113
|
+
use_hessian_based_weights: Whether to use Hessian-based weights for weighted average loss.
|
114
|
+
optimizer_quantization_parameter: Optimizer to override the rest optimizer for quantizer parameters.
|
115
|
+
optimizer_bias: Optimizer to override the rest optimizer for bias.
|
116
|
+
regularization_factor: A floating point number that defines the regularization factor.
|
117
|
+
hessian_weights_config: A configuration that include all necessary arguments to run a computation of
|
118
|
+
Hessian scores for the GPTQ loss.
|
119
|
+
gradual_activation_quantization_config: A configuration for Gradual Activation Quantization.
|
120
|
+
gptq_quantizer_params_override: A dictionary of parameters to override in GPTQ quantizer instantiation.
|
121
|
+
"""
|
122
|
+
n_epochs: int
|
123
|
+
optimizer: Any
|
124
|
+
optimizer_rest: Any = None
|
125
|
+
loss: Callable = None
|
126
|
+
log_function: Callable = None
|
127
|
+
train_bias: bool = True
|
128
|
+
rounding_type: RoundingType = RoundingType.SoftQuantizer
|
129
|
+
use_hessian_based_weights: bool = True
|
130
|
+
optimizer_quantization_parameter: Any = None
|
131
|
+
optimizer_bias: Any = None
|
132
|
+
regularization_factor: float = REG_DEFAULT
|
133
|
+
hessian_weights_config: GPTQHessianScoresConfig = field(default_factory=GPTQHessianScoresConfig)
|
134
|
+
gradual_activation_quantization_config: Optional[GradualActivationQuantizationConfig] = None
|
135
|
+
gptq_quantizer_params_override: Dict[str, Any] = field(default_factory=dict)
|