mct-nightly 2.2.0.20240916.525__py3-none-any.whl → 2.2.0.20240918.448__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {mct_nightly-2.2.0.20240916.525.dist-info → mct_nightly-2.2.0.20240918.448.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20240916.525.dist-info → mct_nightly-2.2.0.20240918.448.dist-info}/RECORD +30 -20
  3. {mct_nightly-2.2.0.20240916.525.dist-info → mct_nightly-2.2.0.20240918.448.dist-info}/top_level.txt +1 -0
  4. model_compression_toolkit/__init__.py +1 -1
  5. model_compression_toolkit/core/common/graph/base_node.py +3 -0
  6. model_compression_toolkit/core/common/graph/functional_node.py +1 -1
  7. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +1 -1
  8. model_compression_toolkit/core/keras/reader/node_builder.py +23 -1
  9. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
  10. model_compression_toolkit/core/pytorch/reader/graph_builders.py +13 -4
  11. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +12 -3
  12. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +10 -1
  13. model_compression_toolkit/gptq/__init__.py +17 -5
  14. model_compression_toolkit/gptq/common/gptq_config.py +88 -75
  15. model_compression_toolkit/gptq/pytorch/gptq_training.py +18 -9
  16. model_compression_toolkit/gptq/pytorch/quantization_facade.py +49 -29
  17. model_compression_toolkit/gptq/pytorch/quantizer/gradual_activation_quantization.py +80 -0
  18. model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +10 -10
  19. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +6 -49
  20. model_compression_toolkit/trainable_infrastructure/pytorch/annealing_schedulers.py +39 -0
  21. model_compression_toolkit/trainable_infrastructure/pytorch/util.py +29 -0
  22. tests_pytest/__init__.py +14 -0
  23. tests_pytest/pytorch/__init__.py +14 -0
  24. tests_pytest/pytorch/gptq/__init__.py +14 -0
  25. tests_pytest/pytorch/gptq/test_annealing_cfg.py +40 -0
  26. tests_pytest/pytorch/gptq/test_gradual_act_quantization.py +100 -0
  27. tests_pytest/pytorch/trainable_infrastructure/__init__.py +14 -0
  28. tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py +49 -0
  29. {mct_nightly-2.2.0.20240916.525.dist-info → mct_nightly-2.2.0.20240918.448.dist-info}/LICENSE.md +0 -0
  30. {mct_nightly-2.2.0.20240916.525.dist-info → mct_nightly-2.2.0.20240918.448.dist-info}/WHEEL +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.2.0.20240916.525
3
+ Version: 2.2.0.20240918.448
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=KF313UvQ5VFZNGpEDi7-0bok1wWBTtoHb0ZkfnVhHpY,1573
1
+ model_compression_toolkit/__init__.py,sha256=mbgbcZTqzAzq-hfFdFzcbNZgPkm70zf0uPjjSnCRs4E,1573
2
2
  model_compression_toolkit/constants.py,sha256=i4wYheBkIdQmsQA-axIpcT3YiSO1USNc-jaNiNE8w6E,3920
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -33,9 +33,9 @@ model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=8seu9jBpC7Har
33
33
  model_compression_toolkit/core/common/fusion/layer_fusing.py,sha256=lOubqpc18TslhXZijWUJQAa1c3jIB2S-M-5HK78wJPQ,5548
34
34
  model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
35
35
  model_compression_toolkit/core/common/graph/base_graph.py,sha256=lg5QaBkRbmvM3tGZ0Q34S3m0CbFql3LUv5BaXLe5TG8,37824
36
- model_compression_toolkit/core/common/graph/base_node.py,sha256=Tv_whLIy-Da0DWZIycnvZ2cf2Qa1rCwpcH8kTkkhv2s,31415
36
+ model_compression_toolkit/core/common/graph/base_node.py,sha256=W6xXj3U0vPlSAoEBuw1fZ1E5I1YNaeTcrNum4JDKdj8,31619
37
37
  model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
38
- model_compression_toolkit/core/common/graph/functional_node.py,sha256=J804e0gK_cykxkUZDI0dAB3rZYkhlacORGSoVVVw4No,3962
38
+ model_compression_toolkit/core/common/graph/functional_node.py,sha256=QpO9wjiYWuLzzy84Z6qRhVP6wlMrLnOTYCuNzNvJbNo,3958
39
39
  model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
40
40
  model_compression_toolkit/core/common/graph/graph_searches.py,sha256=2oKuW6L8hP-oL0lFO9PhQFt9fEFgVJwpc1u4fHExAtE,5128
41
41
  model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py,sha256=3el-A7j1oyoo1_9zq3faQp7IeRsFXFCvnrb3zZFXpU0,9803
@@ -163,7 +163,7 @@ model_compression_toolkit/core/keras/back2framework/__init__.py,sha256=rhIiXg_nB
163
163
  model_compression_toolkit/core/keras/back2framework/factory_model_builder.py,sha256=UIQgOOdexycrSKombTMJVvTthR7MlrCihoqM8Kg-rnE,2293
164
164
  model_compression_toolkit/core/keras/back2framework/float_model_builder.py,sha256=9SFHhX-JnkB8PvYIIHRYlReBDI_RkZY9LditzW_ElLk,2444
165
165
  model_compression_toolkit/core/keras/back2framework/instance_builder.py,sha256=fBj13c6zkVoWX4JJG18_uXPptiEJqXClE_zFbaFB6Q8,4517
166
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py,sha256=XFSSaET4oPWB_cx-Q_c9pDJfWyQ1qXT9JXBl5FJCTa4,18137
166
+ model_compression_toolkit/core/keras/back2framework/keras_model_builder.py,sha256=EyMWjObq8DVG929dY5OquyYGx3kXhgob8XnzmGxmizc,18162
167
167
  model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py,sha256=ygIS1WIiftF1VC3oGhc8N6j7MryKtWgEg8nr50p7f4U,15587
168
168
  model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py,sha256=5wFb4nx_F0Wu4c8pLf6n6OzxOHtpOJ6_3mQsNSXIudU,2481
169
169
  model_compression_toolkit/core/keras/graph_substitutions/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
@@ -205,7 +205,7 @@ model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py,sha256=Up3-sbuA
205
205
  model_compression_toolkit/core/keras/reader/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
206
206
  model_compression_toolkit/core/keras/reader/common.py,sha256=eZWjBcvTDUX7fCWmy1OAH4lYLFTh59_UQ_nP_Gjp4yw,2594
207
207
  model_compression_toolkit/core/keras/reader/connectivity_handler.py,sha256=AgF6qXZOJMeXvc-pBnGY23BJz7wPBx2aTYxHiO8efec,11303
208
- model_compression_toolkit/core/keras/reader/node_builder.py,sha256=2LXL4Vv5nHiRIX9lBpY4nRrJwDm8JhHeybS9V_QtqJQ,14211
208
+ model_compression_toolkit/core/keras/reader/node_builder.py,sha256=fkuzNYTcihtjSOyhfWL7yT30JqPnAQo-JzZLiKtR4Io,15014
209
209
  model_compression_toolkit/core/keras/reader/reader.py,sha256=wS9UQ2wJKnkZYe9JHwQp7ygDr6CRlzrxmIyLDv1Qz6U,8109
210
210
  model_compression_toolkit/core/keras/reader/nested_model/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
211
211
  model_compression_toolkit/core/keras/reader/nested_model/edges_merger.py,sha256=K6KAH9o8KSG6baLmhKoCrYK-i-wb6gRKiZmoijFqEYA,7906
@@ -228,7 +228,7 @@ model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,s
228
228
  model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
229
229
  model_compression_toolkit/core/pytorch/back2framework/instance_builder.py,sha256=BBHBfTqeWm7L3iDyPBpk0jxvj-rBg1QWI23imkjfIl0,1467
230
230
  model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py,sha256=D7lU1r9Uq_7fdNuKk2BMF8ho5GrsY-8gyGN6yYoHaVg,15060
231
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=b3RJ9XpbN2XXlCXEVjxLg3NenmtFfnp_UBRKDIEka8A,18698
231
+ model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=Oyro2qg7Bz8TFoimHtrn3JCwHEO9iCrTMy4HktaYZzg,18937
232
232
  model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py,sha256=qZNNOlNTTV4ZKPG3q5GDXkIVTPUEr8dvxAS_YiMORmg,3456
233
233
  model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
234
234
  model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py,sha256=q2JDw10NKng50ee2i9faGzWZ-IydnR2aOMGSn9RoZmc,5773
@@ -267,7 +267,7 @@ model_compression_toolkit/core/pytorch/quantizer/__init__.py,sha256=Rf1RcYmelmdZ
267
267
  model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py,sha256=D8_CEuFqKAhbUgKaRw7Jlxo0zlqgPTMu6CIIIM4LfS0,7045
268
268
  model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py,sha256=uyeBtNokyDUikk-YkDP_mN_2DX0J5oPm3kSfdSUT2Ck,4420
269
269
  model_compression_toolkit/core/pytorch/reader/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
270
- model_compression_toolkit/core/pytorch/reader/graph_builders.py,sha256=-EGSQOdww-O9x0jT_0ggqz2RcrRuDDaWTKnsWgQyxDI,16114
270
+ model_compression_toolkit/core/pytorch/reader/graph_builders.py,sha256=BvBj9uokKTvX-6d39yA4SKwRQAN8_X4T8l-rPibChJQ,16754
271
271
  model_compression_toolkit/core/pytorch/reader/node_holders.py,sha256=7XNc7-l1MZPJGcOESvtAwfIMxrU6kvt3YjF5B7qOqK4,1048
272
272
  model_compression_toolkit/core/pytorch/reader/reader.py,sha256=GEJE0QX8XJFWbYCkbRBtzttZtmmuoACLx8gw9KyAQCE,6015
273
273
  model_compression_toolkit/core/pytorch/statistics_correction/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
@@ -334,17 +334,17 @@ model_compression_toolkit/exporter/model_wrapper/fw_agnostic/get_inferable_quant
334
334
  model_compression_toolkit/exporter/model_wrapper/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
335
335
  model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py,sha256=SvSGpU0IEUcy6zwChtPm_9lOSNXf4bPN0pwqvVZToik,3929
336
336
  model_compression_toolkit/exporter/model_wrapper/keras/builder/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
337
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py,sha256=Ov28M0uJ_xZdvl9gk39psoqnBiv9i2irScKUNrEaGug,5536
337
+ model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py,sha256=s59shKmWNtvyGXJu24hxS3jG13PGGsL4jrk1QXTrIxM,6243
338
338
  model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py,sha256=uL6tJWC4s2IWUy8GJVwtMWpwZZioRRztfKyPJHo14xI,9442
339
339
  model_compression_toolkit/exporter/model_wrapper/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
340
340
  model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py,sha256=vQUGbCi8_pGoN8DwQ0IblSeN6L9t6Cr0reZNuCbBpkM,3469
341
341
  model_compression_toolkit/exporter/model_wrapper/pytorch/builder/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
342
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py,sha256=qlPYvgpIEfvwxjjkxUB-lwsGOs7GA5eWoY5xznq7tFg,5395
342
+ model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py,sha256=-4AmWgTG9p8sH2mpns-PaRmvM6J853mrhNc0wt-9ovs,6076
343
343
  model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py,sha256=4sN5z-6BXrTE5Dp2FX_jKO9ty5iZ2r4RM7XvXtDVLSI,9348
344
- model_compression_toolkit/gptq/__init__.py,sha256=YKg-tMj9D4Yd0xW9VRD5EN1J5JrmlRbNEF2fOSgodqA,1228
344
+ model_compression_toolkit/gptq/__init__.py,sha256=pEgkJvmf05KSw70iLDTz_6LI_2Oi5L8sTN0JsEUpnpk,1445
345
345
  model_compression_toolkit/gptq/runner.py,sha256=La12JTYjWyJW0YW4Al4TP1_Xi4JWBCEKw6FR_JQsxe0,5982
346
346
  model_compression_toolkit/gptq/common/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
347
- model_compression_toolkit/gptq/common/gptq_config.py,sha256=U-NiVEedkOsVaFq-iXU2Xcqp99Rgf0f2I3oANdVMhMY,5672
347
+ model_compression_toolkit/gptq/common/gptq_config.py,sha256=xogD4mM2825NXyX7rKWBaKBhBFo31bMUmxECREGgtWc,6132
348
348
  model_compression_toolkit/gptq/common/gptq_constants.py,sha256=QSm6laLkIV0LYmU0BLtmKp3Fi3SqDfbncFQWOGA1cGU,611
349
349
  model_compression_toolkit/gptq/common/gptq_framework_implementation.py,sha256=n3mSf4J92kFjekzyGyrJULylI-8Jf5OVWJ5AFoVnEx0,1266
350
350
  model_compression_toolkit/gptq/common/gptq_graph.py,sha256=-bL5HhPcKqV8nj4dZPXc5QmQJbFBel6etrioikP0tEo,3039
@@ -369,16 +369,17 @@ model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py,sha
369
369
  model_compression_toolkit/gptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
370
370
  model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=kDuWw-6zh17wZpYWh4Xa94rpoodf82DksgjQCnL7nBc,2719
371
371
  model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
372
- model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=-daninmlPGfKsBNPB2C3gT6rK0G5YeyJsuOLA0JlfBU,16633
372
+ model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=bnL4DyPLBz2-pip3RV_jBmExvQKZ4N1vXzQudc1VgMY,17117
373
373
  model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=4mVM-VvnBaA64ACVdOe6wTGHdMSa2UTLIUe7nACLcdo,4008
374
- model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=lw9pOV5SKOw9kqOsfskuUiSH_UGOPRczTMpyzN_WTjY,13953
374
+ model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=Z1xCEDiRWE6xtjVjgVGpgGazuY9l9IhUOPNiRZegLMQ,15408
375
375
  model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
376
376
  model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=fKg-PNOhGBiL-4eySS9Fyw0GkA76Pq8jT_HbJuJ8iZU,4143
377
+ model_compression_toolkit/gptq/pytorch/quantizer/gradual_activation_quantization.py,sha256=nngu2TeXjngkqt_6-wciFmCvo-dbpeh_tJJxBV_cfHk,3686
377
378
  model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
378
379
  model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py,sha256=5EyAzvlU01vLyXmMwY_8dNyb7GwYktXmnrvUON8n8WI,4696
379
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py,sha256=mDWZERLwtDzqWeJUwHMVyGdlS8wPLjJ3NvZiKBP6BNA,1959
380
+ model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py,sha256=lkeEBgAAhC1VHu4DHoqDz8GC7BIU4cU0HIAXFYfgUFU,2098
380
381
  model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
381
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=oO7WgsAHMnWoXNm_gTKAAe-Nd79mGL_m677ai-ui424,4132
382
+ model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=UZwVCpG8WOw7r0-cmPYXNkJYpTZciW66KWtKG004J6Q,2683
382
383
  model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py,sha256=kLVQC1hXzDpP4Jx7AwnA764oGnY5AMEuvUUhAvhz09M,12347
383
384
  model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/uniform_soft_quantizer.py,sha256=FgPSKoV8p8y-gLNz359XdOPD6w_wpDvcJFtTNLWqYb0,9099
384
385
  model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
@@ -504,8 +505,10 @@ model_compression_toolkit/trainable_infrastructure/keras/load_model.py,sha256=DJ
504
505
  model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha256=eVB5FSE3OmTLrhfLUcP2knwN1z2_unQLM-xFEGwdafA,5587
505
506
  model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
506
507
  model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
508
+ model_compression_toolkit/trainable_infrastructure/pytorch/annealing_schedulers.py,sha256=IdUBpZUcOXHLPp2OhwbO_Kytee3OTVuy2032N-tm694,1686
507
509
  model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=lWc5EG3ptrP85n69EHGKFkIadnrKEBMKnB5YXQ5AmXo,2745
508
510
  model_compression_toolkit/trainable_infrastructure/pytorch/quantizer_utils.py,sha256=1yOXKghUYfw2hmzbqTuNagIXBoM-wR2bP-ul66-mnDw,7767
511
+ model_compression_toolkit/trainable_infrastructure/pytorch/util.py,sha256=4Qv_rkfxaDf0YeLD5I_7cepUk8OFsMNvUTrw9wFp_kU,1082
509
512
  model_compression_toolkit/trainable_infrastructure/pytorch/activation_quantizers/__init__.py,sha256=73CXhqqNTvDpsvlJXclrGJq-vsCUYCI64ILu1y2mtvw,1056
510
513
  model_compression_toolkit/trainable_infrastructure/pytorch/activation_quantizers/base_activation_quantizer.py,sha256=X6E6mewWQot_aAkz3UxW5X0-Fjl_aMMjs3A-Af5eL6w,972
511
514
  model_compression_toolkit/trainable_infrastructure/pytorch/activation_quantizers/lsq/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
@@ -540,8 +543,15 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
540
543
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=bOc-hFL3gdoSM1Th_S2N_-9JJSlPGpZCTx_QLJHS6lg,3388
541
544
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
542
545
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
543
- mct_nightly-2.2.0.20240916.525.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
544
- mct_nightly-2.2.0.20240916.525.dist-info/METADATA,sha256=fgmiM6pS-u3fVCv07c7QyGDsq1SCz_zCQeQiU-rqH0Y,20813
545
- mct_nightly-2.2.0.20240916.525.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
546
- mct_nightly-2.2.0.20240916.525.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
547
- mct_nightly-2.2.0.20240916.525.dist-info/RECORD,,
546
+ tests_pytest/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
547
+ tests_pytest/pytorch/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
548
+ tests_pytest/pytorch/gptq/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
549
+ tests_pytest/pytorch/gptq/test_annealing_cfg.py,sha256=hGC7L6mp3N1ygcJ3OctgS_Fz2JY75q5aswolJkbHkZM,2208
550
+ tests_pytest/pytorch/gptq/test_gradual_act_quantization.py,sha256=tI01aFIUaiCILL5Qn--p1E_rLBUelxLdSY3k52lwcx0,4594
551
+ tests_pytest/pytorch/trainable_infrastructure/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
552
+ tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py,sha256=eNOpSp0GoLxtEdiRypBp8jaujXfdNxBwKh5Rd-P7WLs,1786
553
+ mct_nightly-2.2.0.20240918.448.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
554
+ mct_nightly-2.2.0.20240918.448.dist-info/METADATA,sha256=Atg7fbRWZ1KvrHeOc1jaJ6Gb2VrUdnOAs9gKc_v26VU,20813
555
+ mct_nightly-2.2.0.20240918.448.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
556
+ mct_nightly-2.2.0.20240918.448.dist-info/top_level.txt,sha256=csdfSXhtRnpWYRzjZ-dRLIhOmM2TEdVXUxG05A5fgb8,39
557
+ mct_nightly-2.2.0.20240918.448.dist-info/RECORD,,
@@ -1 +1,2 @@
1
1
  model_compression_toolkit
2
+ tests_pytest
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.2.0.20240916.000525"
30
+ __version__ = "2.2.0.20240918.000448"
@@ -40,6 +40,7 @@ class BaseNode:
40
40
  layer_class: type,
41
41
  reuse: bool = False,
42
42
  reuse_group: str = None,
43
+ inputs_as_list: bool = False,
43
44
  quantization_attr: Dict[str, Any] = None,
44
45
  has_activation: bool = True,
45
46
  is_custom: bool = False
@@ -58,6 +59,7 @@ class BaseNode:
58
59
  layer_class: Class path of the layer this node represents.
59
60
  reuse: Whether this node was duplicated and represents a reused layer.
60
61
  reuse_group: Name of group of nodes from the same reused layer.
62
+ inputs_as_list: Whether to pass the node its input tensors as a list or not when calling the layer.
61
63
  quantization_attr: Attributes the node holds regarding how it should be quantized.
62
64
  has_activation: Whether the node has activations that we might want to quantize.
63
65
  is_custom: Whether the node is custom layer or not.
@@ -71,6 +73,7 @@ class BaseNode:
71
73
  self.layer_class = layer_class
72
74
  self.reuse = reuse
73
75
  self.reuse_group = reuse_group
76
+ self.inputs_as_list = inputs_as_list
74
77
  self.final_weights_quantization_cfg = None
75
78
  self.final_activation_quantization_cfg = None
76
79
  self.candidates_quantization_cfg = None
@@ -55,13 +55,13 @@ class FunctionalNode(BaseNode):
55
55
  layer_class,
56
56
  reuse,
57
57
  reuse_group,
58
+ inputs_as_list,
58
59
  quantization_attr,
59
60
  has_activation=has_activation)
60
61
 
61
62
  self.op_call_kwargs = op_call_kwargs
62
63
  self.op_call_args = list(op_call_args)
63
64
  self.functional_op = functional_op
64
- self.inputs_as_list = inputs_as_list
65
65
  self.tensor_input_allocs = [] if tensor_input_allocs is None else tensor_input_allocs
66
66
 
67
67
  @property
@@ -308,7 +308,7 @@ class KerasModelBuilder(BaseModelBuilder):
308
308
  else:
309
309
  # If operator expects a single input tensor, it cannot be a list as it should
310
310
  # have a dtype field.
311
- if len(input_tensors) == 1:
311
+ if len(input_tensors) == 1 and not n.inputs_as_list:
312
312
  input_tensors = input_tensors[0]
313
313
  out_tensors_of_n_float = op_func(input_tensors)
314
314
 
@@ -30,10 +30,12 @@ if version.parse(tf.__version__) >= version.parse("2.13"):
30
30
  from keras.src.layers.core import TFOpLambda, SlicingOpLambda
31
31
  from keras.src.engine.keras_tensor import KerasTensor
32
32
  from keras.src.engine.node import Node as KerasNode
33
+ from keras.src.layers.merging.base_merge import _Merge
33
34
  else:
34
35
  from keras.layers.core import TFOpLambda, SlicingOpLambda
35
36
  from keras.engine.keras_tensor import KerasTensor
36
37
  from keras.engine.node import Node as KerasNode
38
+ from keras.layers.merging.base_merge import _Merge
37
39
 
38
40
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
39
41
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
@@ -287,6 +289,7 @@ def build_node(node: KerasNode,
287
289
  for i, arg in enumerate(op_call_args[0]):
288
290
  if is_const(arg):
289
291
  weights.update({i: to_numpy(arg, is_single_tensor=True)})
292
+ inputs_as_list = __is_node_inputs_a_list(op_call_args, keras_layer)
290
293
 
291
294
  node = BaseNode(node_name,
292
295
  layer_config,
@@ -296,6 +299,7 @@ def build_node(node: KerasNode,
296
299
  layer_class,
297
300
  is_reused,
298
301
  reuse_group,
302
+ inputs_as_list,
299
303
  is_custom=is_keras_custom_layer(layer_class))
300
304
 
301
305
  node_name_to_node[node_name] = node
@@ -316,6 +320,24 @@ def __is_functional_inputs_a_list(op_call_args: Any, keras_layer: Any) -> bool:
316
320
  """
317
321
 
318
322
  return (keras_layer.symbol in
319
- [TFOpLambda(tf.concat).symbol, TFOpLambda(tf.stack).symbol,TFOpLambda(tf.add_n).symbol] and
323
+ [TFOpLambda(tf.concat).symbol, TFOpLambda(tf.stack).symbol, TFOpLambda(tf.add_n).symbol] and
320
324
  len(op_call_args) > 0 and
321
325
  isinstance(op_call_args[0], list))
326
+
327
+
328
+ def __is_node_inputs_a_list(op_call_args: Any, keras_layer: Any) -> bool:
329
+ """
330
+ Check whether the input tensors should be passed as a list or not. This is relevant
331
+ only for layers that inherit from _Merge such as Concatenate and Add.
332
+
333
+ Args:
334
+ op_call_args: Arguments list to check.
335
+ keras_layer: Keras layer.
336
+
337
+ Returns:
338
+ Whether the input tensors should be passed as a list or not.
339
+ """
340
+
341
+ return (isinstance(keras_layer, _Merge) and
342
+ len(op_call_args) > 0 and
343
+ isinstance(op_call_args[0], (list, tuple)))
@@ -139,7 +139,11 @@ def _run_operation(n: BaseNode,
139
139
  _tensor_input_allocs = None
140
140
 
141
141
  if isinstance(n, FunctionalNode) and n.inputs_as_list:
142
- out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
142
+ if isinstance(op_func, PytorchQuantizationWrapper):
143
+ # in wrapped nodes, the op args & kwargs are already in the PytorchQuantizationWrapper.
144
+ out_tensors_of_n_float = op_func(*input_tensors)
145
+ else:
146
+ out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
143
147
  else:
144
148
  merged_inputs, functional_kwargs = _merge_inputs(n, input_tensors, op_call_args, functional_kwargs.copy(),
145
149
  tensor_input_allocs=_tensor_input_allocs)
@@ -232,10 +232,19 @@ def nodes_builder(model: GraphModule,
232
232
 
233
233
  # Add constants to weights dictionary.
234
234
  if node.op != PLACEHOLDER:
235
- for i, input_node in enumerate(node.all_input_nodes):
236
- if input_node in consts_dict:
237
- used_consts.add(input_node)
238
- weights.update({i: consts_dict[input_node]})
235
+ if len(node.args) and isinstance(node.args[0], (list, tuple)):
236
+ # handle weights in nodes with list input. Especially when there's a duplicate of a tensor
237
+ # in the input list (e.g. torch.concat([const1, x, const2, x, const3], 1)).
238
+ for input_node in node.all_input_nodes:
239
+ for i, input_arg in enumerate(node.args[0]):
240
+ if input_node is input_arg and input_node in consts_dict:
241
+ used_consts.add(input_node)
242
+ weights.update({i: consts_dict[input_node]})
243
+ else:
244
+ for i, input_node in enumerate(node.all_input_nodes):
245
+ if input_node in consts_dict:
246
+ used_consts.add(input_node)
247
+ weights.update({i: consts_dict[input_node]})
239
248
 
240
249
  # Extract input and output shapes of the node.
241
250
  input_shape, output_shape = _extract_input_and_output_shapes(node)
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Tuple, Callable
16
+ from typing import Tuple, Callable, Union
17
17
  from model_compression_toolkit.core import common
18
18
  from model_compression_toolkit.core.common import Graph
19
19
  from model_compression_toolkit.verify_packages import FOUND_TF
@@ -25,10 +25,12 @@ if FOUND_TF:
25
25
  import tensorflow as tf
26
26
  from tensorflow.keras.layers import Layer
27
27
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
28
+ from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
28
29
  from mct_quantizers import KerasQuantizationWrapper
29
30
  from mct_quantizers import KerasActivationQuantizationHolder
31
+ from mct_quantizers.common.constants import OP_CALL_ARGS, OP_CALL_KWARGS
30
32
 
31
- def _get_wrapper(node: common.BaseNode,
33
+ def _get_wrapper(node: Union[common.BaseNode, FunctionalNode],
32
34
  layer: Layer,
33
35
  fw_impl=None) -> Layer:
34
36
  """
@@ -45,9 +47,16 @@ if FOUND_TF:
45
47
  # for positional weights we need to extract the weight's value.
46
48
  weights_values = {attr: node.get_weights_by_keys(attr)
47
49
  for attr in weights_quantizers if isinstance(attr, int)}
50
+ # When wrapping functional nodes, need to set call args\kwargs in wrapper, because they
51
+ # are used during wrapper call method.
52
+ func_node_kwargs = {OP_CALL_ARGS: node.op_call_args,
53
+ OP_CALL_KWARGS: node.op_call_kwargs
54
+ } if isinstance(node, FunctionalNode) else {}
48
55
  return KerasQuantizationWrapper(layer,
49
56
  weights_quantizers,
50
- weights_values)
57
+ weights_values,
58
+ is_inputs_as_list=node.inputs_as_list,
59
+ **func_node_kwargs)
51
60
  return layer
52
61
 
53
62
 
@@ -24,7 +24,9 @@ import model_compression_toolkit.core as C
24
24
  if FOUND_TORCH:
25
25
  import torch
26
26
  from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder
27
+ from mct_quantizers.common.constants import OP_CALL_ARGS, OP_CALL_KWARGS
27
28
  from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
29
+ from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
28
30
 
29
31
 
30
32
  def fully_quantized_wrapper(node: common.BaseNode,
@@ -46,7 +48,14 @@ if FOUND_TORCH:
46
48
  # for positional weights we need to extract the weight's value.
47
49
  weights_values = {attr: fw_impl.to_tensor(node.get_weights_by_keys(attr))
48
50
  for attr in weight_quantizers if isinstance(attr, int)}
49
- return PytorchQuantizationWrapper(module, weight_quantizers, weights_values)
51
+ # When wrapping functional nodes, need to set call args\kwargs in wrapper, because they
52
+ # are used during wrapper call method.
53
+ func_node_kwargs = {OP_CALL_ARGS: node.op_call_args,
54
+ OP_CALL_KWARGS: node.op_call_kwargs
55
+ } if isinstance(node, FunctionalNode) else {}
56
+ return PytorchQuantizationWrapper(module, weight_quantizers, weights_values,
57
+ is_inputs_as_list=node.inputs_as_list,
58
+ **func_node_kwargs)
50
59
  return module
51
60
 
52
61
 
@@ -13,8 +13,20 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType, GPTQHessianScoresConfig
17
- from model_compression_toolkit.gptq.keras.quantization_facade import keras_gradient_post_training_quantization
18
- from model_compression_toolkit.gptq.keras.quantization_facade import get_keras_gptq_config
19
- from model_compression_toolkit.gptq.pytorch.quantization_facade import pytorch_gradient_post_training_quantization
20
- from model_compression_toolkit.gptq.pytorch.quantization_facade import get_pytorch_gptq_config
16
+ from model_compression_toolkit.gptq.common.gptq_config import (
17
+ GradientPTQConfig,
18
+ RoundingType,
19
+ GPTQHessianScoresConfig,
20
+ GradualActivationQuantizationConfig,
21
+ QFractionLinearAnnealingConfig
22
+ )
23
+
24
+ from model_compression_toolkit.verify_packages import FOUND_TF, FOUND_TORCH
25
+
26
+ if FOUND_TF:
27
+ from model_compression_toolkit.gptq.keras.quantization_facade import keras_gradient_post_training_quantization
28
+ from model_compression_toolkit.gptq.keras.quantization_facade import get_keras_gptq_config
29
+
30
+ if FOUND_TORCH:
31
+ from model_compression_toolkit.gptq.pytorch.quantization_facade import pytorch_gradient_post_training_quantization
32
+ from model_compression_toolkit.gptq.pytorch.quantization_facade import get_pytorch_gptq_config
@@ -12,8 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from dataclasses import dataclass, field
15
16
  from enum import Enum
16
- from typing import Callable, Any, Dict
17
+ from typing import Callable, Any, Dict, Optional
17
18
 
18
19
  from model_compression_toolkit.constants import GPTQ_HESSIAN_NUM_SAMPLES, ACT_HESSIAN_DEFAULT_BATCH_SIZE
19
20
  from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
@@ -32,91 +33,103 @@ class RoundingType(Enum):
32
33
  SoftQuantizer = 1
33
34
 
34
35
 
36
+ @dataclass
35
37
  class GPTQHessianScoresConfig:
36
38
  """
37
39
  Configuration to use for computing the Hessian-based scores for GPTQ loss metric.
40
+
41
+ Args:
42
+ hessians_num_samples (int): Number of samples to use for computing the Hessian-based scores.
43
+ norm_scores (bool): Whether to normalize the returned scores of the weighted loss function (to get values between 0 and 1).
44
+ log_norm (bool): Whether to use log normalization for the GPTQ Hessian-based scores.
45
+ scale_log_norm (bool): Whether to scale the final vector of the Hessian-based scores.
46
+ hessian_batch_size (int): The Hessian computation batch size. used only if using GPTQ with Hessian-based objective.
38
47
  """
48
+ hessians_num_samples: int = GPTQ_HESSIAN_NUM_SAMPLES
49
+ norm_scores: bool = True
50
+ log_norm: bool = True
51
+ scale_log_norm: bool = False
52
+ hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE
39
53
 
40
- def __init__(self,
41
- hessians_num_samples: int = GPTQ_HESSIAN_NUM_SAMPLES,
42
- norm_scores: bool = True,
43
- log_norm: bool = True,
44
- scale_log_norm: bool = False,
45
- hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE):
46
54
 
47
- """
48
- Initialize a GPTQHessianWeightsConfig.
55
+ @dataclass
56
+ class QFractionLinearAnnealingConfig:
57
+ """
58
+ Config for the quantized fraction linear scheduler of Gradual Activation Quantization.
49
59
 
50
- Args:
51
- hessians_num_samples (int): Number of samples to use for computing the Hessian-based scores.
52
- norm_scores (bool): Whether to normalize the returned scores of the weighted loss function (to get values between 0 and 1).
53
- log_norm (bool): Whether to use log normalization for the GPTQ Hessian-based scores.
54
- scale_log_norm (bool): Whether to scale the final vector of the Hessian-based scores.
55
- hessian_batch_size (int): The Hessian computation batch size. used only if using GPTQ with Hessian-based objective.
56
- """
60
+ Args:
61
+ initial_q_fraction: initial quantized fraction
62
+ target_q_fraction: target quantized fraction
63
+ start_step: gradient step to begin annealing
64
+ end_step: gradient step to complete annealing. None means last step.
65
+ """
66
+ initial_q_fraction: float
67
+ target_q_fraction: float
68
+ start_step: int
69
+ end_step: Optional[int]
57
70
 
58
- self.hessians_num_samples = hessians_num_samples
59
- self.norm_scores = norm_scores
60
- self.log_norm = log_norm
61
- self.scale_log_norm = scale_log_norm
62
- self.hessian_batch_size = hessian_batch_size
71
+ def __post_init__(self):
72
+ if not (0 <= self.initial_q_fraction < self.target_q_fraction <= 1):
73
+ raise ValueError(f'Expected 0 <= initial_q_fraction < target_q_fraction <= 1, received initial_q_fraction '
74
+ f'{self.initial_q_fraction} and target_q_fraction {self.target_q_fraction}.')
75
+ if self.start_step < 0:
76
+ raise ValueError(f'Expected start_step >= 0. received {self.start_step}.')
77
+ if self.end_step is not None and self.end_step <= self.start_step:
78
+ raise ValueError('Expected start_step < end_step, '
79
+ 'received end_step {self.end_step} and start_step {self.start_stap}.')
63
80
 
64
81
 
65
- class GradientPTQConfig:
66
- """
67
- Configuration to use for quantization with GradientPTQ.
68
- """
69
- def __init__(self,
70
- n_epochs: int,
71
- optimizer: Any,
72
- optimizer_rest: Any = None,
73
- loss: Callable = None,
74
- log_function: Callable = None,
75
- train_bias: bool = True,
76
- rounding_type: RoundingType = RoundingType.SoftQuantizer,
77
- use_hessian_based_weights: bool = True,
78
- optimizer_quantization_parameter: Any = None,
79
- optimizer_bias: Any = None,
80
- regularization_factor: float = REG_DEFAULT,
81
- hessian_weights_config: GPTQHessianScoresConfig = GPTQHessianScoresConfig(),
82
- gptq_quantizer_params_override: Dict[str, Any] = None):
83
- """
84
- Initialize a GradientPTQConfig.
82
+ @dataclass
83
+ class GradualActivationQuantizationConfig:
84
+ """ Configuration for Gradual Activation Quantization.
85
+
86
+ By default, the quantized fraction increases linearly from 0 to 1 throughout the training.
85
87
 
86
88
  Args:
87
- n_epochs (int): Number of representative dataset epochs to train.
88
- optimizer (Any): Optimizer to use.
89
- optimizer_rest (Any): Optimizer to use for bias and quantizer parameters.
90
- loss (Callable): The loss to use. should accept 6 lists of tensors. 1st list of quantized tensors, the 2nd list is the float tensors,
91
- the 3rd is a list of quantized weights, the 4th is a list of float weights, the 5th and 6th lists are the mean and std of the tensors
92
- accordingly. see example in multiple_tensors_mse_loss
93
- log_function (Callable): Function to log information about the GPTQ process.
94
- train_bias (bool): Whether to update the bias during the training or not.
95
- rounding_type (RoundingType): An enum that defines the rounding type.
96
- use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
97
- optimizer_quantization_parameter (Any): Optimizer to override the rest optimizer for quantizer parameters.
98
- optimizer_bias (Any): Optimizer to override the rest optimizer for bias.
99
- regularization_factor (float): A floating point number that defines the regularization factor.
100
- hessian_weights_config (GPTQHessianScoresConfig): A configuration that include all necessary arguments to run a computation of Hessian scores for the GPTQ loss.
101
- gptq_quantizer_params_override (dict): A dictionary of parameters to override in GPTQ quantizer instantiation. Defaults to None (no parameters).
102
-
103
- """
104
-
105
- self.n_epochs = n_epochs
106
- self.optimizer = optimizer
107
- self.optimizer_rest = optimizer_rest
108
- self.loss = loss
109
- self.log_function = log_function
110
- self.train_bias = train_bias
111
-
112
- self.rounding_type = rounding_type
113
- self.use_hessian_based_weights = use_hessian_based_weights
114
- self.optimizer_quantization_parameter = optimizer_quantization_parameter
115
- self.optimizer_bias = optimizer_bias
116
- self.regularization_factor = regularization_factor
117
- self.hessian_weights_config = hessian_weights_config
118
-
119
- self.gptq_quantizer_params_override = {} if gptq_quantizer_params_override is None \
120
- else gptq_quantizer_params_override
89
+ q_fraction_scheduler_policy: config for the scheduling of the quantized fraction.
90
+ Only linear annealing is currently supported.
91
+ """
92
+ q_fraction_scheduler_policy: QFractionLinearAnnealingConfig = field(
93
+ default_factory=lambda: QFractionLinearAnnealingConfig(initial_q_fraction=0,
94
+ target_q_fraction=1,
95
+ start_step=0,
96
+ end_step=None)
97
+ )
121
98
 
122
99
 
100
+ @dataclass
101
+ class GradientPTQConfig:
102
+ """
103
+ Configuration to use for quantization with GradientPTQ.
104
+
105
+ Args:
106
+ n_epochs: Number of representative dataset epochs to train.
107
+ optimizer: Optimizer to use.
108
+ optimizer_rest: Optimizer to use for bias and quantizer parameters.
109
+ loss: The loss to use. See 'multiple_tensors_mse_loss' for the expected interface.
110
+ log_function: Function to log information about the GPTQ process.
111
+ train_bias: Whether to update the bias during the training or not.
112
+ rounding_type: An enum that defines the rounding type.
113
+ use_hessian_based_weights: Whether to use Hessian-based weights for weighted average loss.
114
+ optimizer_quantization_parameter: Optimizer to override the rest optimizer for quantizer parameters.
115
+ optimizer_bias: Optimizer to override the rest optimizer for bias.
116
+ regularization_factor: A floating point number that defines the regularization factor.
117
+ hessian_weights_config: A configuration that include all necessary arguments to run a computation of
118
+ Hessian scores for the GPTQ loss.
119
+ gradual_activation_quantization_config: A configuration for Gradual Activation Quantization.
120
+ gptq_quantizer_params_override: A dictionary of parameters to override in GPTQ quantizer instantiation.
121
+ """
122
+ n_epochs: int
123
+ optimizer: Any
124
+ optimizer_rest: Any = None
125
+ loss: Callable = None
126
+ log_function: Callable = None
127
+ train_bias: bool = True
128
+ rounding_type: RoundingType = RoundingType.SoftQuantizer
129
+ use_hessian_based_weights: bool = True
130
+ optimizer_quantization_parameter: Any = None
131
+ optimizer_bias: Any = None
132
+ regularization_factor: float = REG_DEFAULT
133
+ hessian_weights_config: GPTQHessianScoresConfig = field(default_factory=GPTQHessianScoresConfig)
134
+ gradual_activation_quantization_config: Optional[GradualActivationQuantizationConfig] = None
135
+ gptq_quantizer_params_override: Dict[str, Any] = field(default_factory=dict)