mct-nightly 2.1.0.20240616.65727__py3-none-any.whl → 2.1.0.20240618.432__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/METADATA +2 -2
  2. {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/RECORD +43 -17
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/graph/functional_node.py +3 -3
  5. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +23 -13
  6. model_compression_toolkit/core/pytorch/constants.py +1 -1
  7. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py +3 -3
  8. model_compression_toolkit/core/pytorch/reader/graph_builders.py +12 -6
  9. model_compression_toolkit/gptq/keras/gptq_training.py +1 -1
  10. model_compression_toolkit/gptq/keras/graph_info.py +1 -1
  11. model_compression_toolkit/gptq/pytorch/gptq_training.py +5 -2
  12. model_compression_toolkit/gptq/pytorch/graph_info.py +2 -1
  13. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py +3 -2
  14. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py +3 -2
  15. model_compression_toolkit/xquant/__init__.py +19 -0
  16. model_compression_toolkit/xquant/common/__init__.py +15 -0
  17. model_compression_toolkit/xquant/common/constants.py +38 -0
  18. model_compression_toolkit/xquant/common/core_report_generator.py +83 -0
  19. model_compression_toolkit/xquant/common/dataset_utils.py +43 -0
  20. model_compression_toolkit/xquant/common/framework_report_utils.py +89 -0
  21. model_compression_toolkit/xquant/common/model_analyzer.py +99 -0
  22. model_compression_toolkit/xquant/common/model_folding_utils.py +104 -0
  23. model_compression_toolkit/xquant/common/similarity_calculator.py +194 -0
  24. model_compression_toolkit/xquant/common/similarity_functions.py +81 -0
  25. model_compression_toolkit/xquant/common/tensorboard_utils.py +101 -0
  26. model_compression_toolkit/xquant/common/xquant_config.py +39 -0
  27. model_compression_toolkit/xquant/keras/__init__.py +15 -0
  28. model_compression_toolkit/xquant/keras/dataset_utils.py +57 -0
  29. model_compression_toolkit/xquant/keras/facade_xquant_report.py +63 -0
  30. model_compression_toolkit/xquant/keras/keras_report_utils.py +60 -0
  31. model_compression_toolkit/xquant/keras/model_analyzer.py +136 -0
  32. model_compression_toolkit/xquant/keras/similarity_functions.py +75 -0
  33. model_compression_toolkit/xquant/keras/tensorboard_utils.py +84 -0
  34. model_compression_toolkit/xquant/pytorch/__init__.py +15 -0
  35. model_compression_toolkit/xquant/pytorch/dataset_utils.py +76 -0
  36. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +62 -0
  37. model_compression_toolkit/xquant/pytorch/model_analyzer.py +132 -0
  38. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +61 -0
  39. model_compression_toolkit/xquant/pytorch/similarity_functions.py +68 -0
  40. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +87 -0
  41. {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/LICENSE.md +0 -0
  42. {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/WHEEL +0 -0
  43. {mct_nightly-2.1.0.20240616.65727.dist-info → mct_nightly-2.1.0.20240618.432.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.1.0.20240616.65727
3
+ Version: 2.1.0.20240618.432
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -14,7 +14,7 @@ Description-Content-Type: text/markdown
14
14
  Requires-Dist: networkx !=2.8.1
15
15
  Requires-Dist: tqdm
16
16
  Requires-Dist: Pillow
17
- Requires-Dist: numpy
17
+ Requires-Dist: numpy <2.0
18
18
  Requires-Dist: opencv-python
19
19
  Requires-Dist: scikit-image
20
20
  Requires-Dist: scikit-learn
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=tiJgaUR3HmSVpDI-Vv2wT3GOa8-4cZ_SfpIM_6GpwV4,1573
1
+ model_compression_toolkit/__init__.py,sha256=VC4_Q3irB2XPx2tSa3QkUZSTQtm_TP2S5WwZu1g2liM,1573
2
2
  model_compression_toolkit/constants.py,sha256=9pVleMwnhlM4QwIL2HcEq42I1uF4rlSw63RUjkxOF4w,3923
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -33,7 +33,7 @@ model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaU
33
33
  model_compression_toolkit/core/common/graph/base_graph.py,sha256=lmIw0srKiwCvz7KWqfwKTxyQHDy3s6rWMIXzFAa1UMo,38326
34
34
  model_compression_toolkit/core/common/graph/base_node.py,sha256=X_0zqHrKYAsmnj9tAKjVYasbFcZD8OHpjdiMj9ugQs0,29436
35
35
  model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
36
- model_compression_toolkit/core/common/graph/functional_node.py,sha256=71_4TrCdqR_r0mtgxmAyqI05iP5YoQQGeSmDgynuzTw,3902
36
+ model_compression_toolkit/core/common/graph/functional_node.py,sha256=_6HsBeLlrpLvXhLPRJswcyDa4z16-O3xzHzGuv46zBc,3897
37
37
  model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
38
38
  model_compression_toolkit/core/common/graph/graph_searches.py,sha256=2oKuW6L8hP-oL0lFO9PhQFt9fEFgVJwpc1u4fHExAtE,5128
39
39
  model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py,sha256=3el-A7j1oyoo1_9zq3faQp7IeRsFXFCvnrb3zZFXpU0,9803
@@ -210,7 +210,7 @@ model_compression_toolkit/core/keras/statistics_correction/__init__.py,sha256=9H
210
210
  model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py,sha256=XNCtT9klMcsO1v5KA3MmCq_WgXOIT5QSzbfTOa9T-04,3060
211
211
  model_compression_toolkit/core/keras/visualization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
212
212
  model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
213
- model_compression_toolkit/core/pytorch/constants.py,sha256=NI-J7REuxn06oEIHsmJ4GqtNC3TbV8xlkJjt5Ar-c4U,2626
213
+ model_compression_toolkit/core/pytorch/constants.py,sha256=AguUnAsNlj41gwuKIP_7nos3FcJHsIAjewLXSQdrDQM,2624
214
214
  model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=r1XyzUFvrjGcJHQM5ETLsMZIG2yHCr9HMjqf0ti9inw,4175
215
215
  model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
216
216
  model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=7CFt1Y3fiDaKkEVvlDd76ZmucCuVp6OZNQwwqJezKbU,27547
@@ -222,7 +222,7 @@ model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,s
222
222
  model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
223
223
  model_compression_toolkit/core/pytorch/back2framework/instance_builder.py,sha256=BBHBfTqeWm7L3iDyPBpk0jxvj-rBg1QWI23imkjfIl0,1467
224
224
  model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py,sha256=D7lU1r9Uq_7fdNuKk2BMF8ho5GrsY-8gyGN6yYoHaVg,15060
225
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=Zw4gi-wjJNV8-qGv79YBWVAHmy27f7iW0c2JGNWAKD0,18199
225
+ model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=Pvpkirt3ziWEXDEspgOhR8ALf-XAZUh-78IkXg9YMWs,18830
226
226
  model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py,sha256=qZNNOlNTTV4ZKPG3q5GDXkIVTPUEr8dvxAS_YiMORmg,3456
227
227
  model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
228
228
  model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py,sha256=q2JDw10NKng50ee2i9faGzWZ-IydnR2aOMGSn9RoZmc,5773
@@ -240,7 +240,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/linear_
240
240
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/multi_head_attention_decomposition.py,sha256=VNg-VgzCxSyqy2J3neEPl6U0SPO8UIVU_T47bGhz4FE,38459
241
241
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py,sha256=q1a3HieQtaOmWG2WGXp6GHYAvxa3CZ9dJUx9dqMAsS8,5695
242
242
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/remove_identity.py,sha256=joHjwiUxccypMHkTy46rI91VyapLn9yJ2YRo5ISnOH4,1987
243
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py,sha256=jOqlelGhADEZiYUEyYj9oJZ5YLXx8jWNUlVTG6Td79Y,4919
243
+ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py,sha256=hAZXzrEinHa-dJHLj39Hy_9Q-13QyO95rtYVSLrhvT8,4915
244
244
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py,sha256=DcJEIkGvBdIMOelNIwaJUZ5UsAHiGnDJPR20I464vWo,2929
245
245
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py,sha256=XFtU9yuBmoZlX0f0mS6otMPWMk-RcWs94XdvvTNhW8Y,3303
246
246
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py,sha256=lOPl5zDU3FoR9WmlxO04Pfi65MimK0gmnuHzQJodQdY,10668
@@ -261,7 +261,7 @@ model_compression_toolkit/core/pytorch/quantizer/__init__.py,sha256=Rf1RcYmelmdZ
261
261
  model_compression_toolkit/core/pytorch/quantizer/fake_quant_builder.py,sha256=D8_CEuFqKAhbUgKaRw7Jlxo0zlqgPTMu6CIIIM4LfS0,7045
262
262
  model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py,sha256=uyeBtNokyDUikk-YkDP_mN_2DX0J5oPm3kSfdSUT2Ck,4420
263
263
  model_compression_toolkit/core/pytorch/reader/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
264
- model_compression_toolkit/core/pytorch/reader/graph_builders.py,sha256=x5n8KHBqvutqS5l5AillA_FQfhf-2ibP813ixK3Gvy8,12627
264
+ model_compression_toolkit/core/pytorch/reader/graph_builders.py,sha256=LiGV-ZqlhxN1evpM-ur2dDVPowhrLwO7JZa7AGPftSk,12913
265
265
  model_compression_toolkit/core/pytorch/reader/node_holders.py,sha256=TaolORuwBZEddWe-q0Mg79Nmswz-Sq3-9-4o8UxFQ50,1028
266
266
  model_compression_toolkit/core/pytorch/reader/reader.py,sha256=GEJE0QX8XJFWbYCkbRBtzttZtmmuoACLx8gw9KyAQCE,6015
267
267
  model_compression_toolkit/core/pytorch/statistics_correction/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
@@ -342,8 +342,8 @@ model_compression_toolkit/gptq/common/gptq_training.py,sha256=efnwgKSGk9wtnirlLR
342
342
  model_compression_toolkit/gptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
343
343
  model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCSjq5xk-xGymOwSOqjp39It-CVtGcCTRTf0E_4,1248
344
344
  model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
345
- model_compression_toolkit/gptq/keras/gptq_training.py,sha256=zyVcEQzdnNsrIz32U1pqqoi08hzxRdJ2CumaPFGwbDM,19123
346
- model_compression_toolkit/gptq/keras/graph_info.py,sha256=5IvgGlJlgOmQYmldjdCBv7tuzAoY0HazatG5Pedrg0Q,4639
345
+ model_compression_toolkit/gptq/keras/gptq_training.py,sha256=RAUZvve-kUMTfXY-aXQWEM4IejaeVedrKejBNrO6szI,19156
346
+ model_compression_toolkit/gptq/keras/graph_info.py,sha256=MKIfrRTRH3zCuxCR1g9ZVIFyuSSr0e0sDybqh4LDM7E,4672
347
347
  model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=SjmBTuSwki4JTPVhxvJMFK9uAsmEm2c6VV11NnM6eEo,15117
348
348
  model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
349
349
  model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=2YU-x4-Q5f6hkUJf0tw6vcwdNwRMHdefrFjhhyHYsvA,4782
@@ -359,8 +359,8 @@ model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py,sha
359
359
  model_compression_toolkit/gptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
360
360
  model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=kDuWw-6zh17wZpYWh4Xa94rpoodf82DksgjQCnL7nBc,2719
361
361
  model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
362
- model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=xkDa62AdIRwv8dEshffALW9Ri66eseEpyUF9taMUKns,16509
363
- model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=yXJzDd24zfGs2_vfMovxD1WSh1RxXoPxN4GztOf3P5c,3967
362
+ model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=2pe_caivE7Fr9zCvmZENKbFTS6AUFbSjHN-TODEhbSY,16631
363
+ model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=4mVM-VvnBaA64ACVdOe6wTGHdMSa2UTLIUe7nACLcdo,4008
364
364
  model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=PqVF1T0unY7V6jB1qUnwBQntLN5lEob83_3NkJE0hG0,13558
365
365
  model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
366
366
  model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=TCA1hAc7raPnrjl06sjFtVM4XUtLtuwAhCGX4U3KGZo,4137
@@ -456,11 +456,11 @@ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_
456
456
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/__init__.py,sha256=gAeebYCKyIXH9-Qwze7FwvTihudzAHk_Qsg94fQbkjQ,717
457
457
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py,sha256=edMH4lM7Bq7FaPAFZLU5UMX-bWSWiaaAIXnQE7lZ7rI,11844
458
458
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_keras.py,sha256=T5YMv-RzgYlzBaagnMO7WnKgbZ7PrOvm29Nn4vUhCHI,6587
459
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py,sha256=-q6Tnn7diPCCoATmLDzJwWwviQcbMMISqgpLu2n42JY,5726
459
+ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py,sha256=HRo0W5l4IJesr_np4ZhXoMk_xfdiV53LgamquQIryJA,5800
460
460
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/__init__.py,sha256=C2kwyDE1-rtukkbNSoKRv9q8Nt2GOCaBbl0BdOr3goA,721
461
461
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py,sha256=HoGjDwoSx2Y4dQua5v1qzzlnSl_HfDMK6bGWuZhPOzQ,11577
462
462
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_keras.py,sha256=LvqUkvpJKXBb9QETcHsmp9OGDwl9KWr457deag8GVuM,6595
463
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py,sha256=4Y2D14rE0SnWIkBTYsVqCryB-gkHU1ZlbdkWF864mPU,5733
463
+ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py,sha256=nP05jqvh6uaj30a3W7zEkJfKtqfP0Nz5bobwRqbYrdM,5807
464
464
  model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
465
465
  model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py,sha256=7KVcuz0LfngRKOsfcvBysxGVb9fqgoAO6MVTl1CmB5c,2082
466
466
  model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py,sha256=UUvUCcTots_sehdRnDfgkaE8WPQ7dPbeuhDF4Qy2nzw,1510
@@ -491,8 +491,34 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
491
491
  model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
492
492
  model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
493
493
  model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=MxylaVFPgN7zBiRBy6WV610EA4scLgRJFbMucKvvNDU,2896
494
- mct_nightly-2.1.0.20240616.65727.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
495
- mct_nightly-2.1.0.20240616.65727.dist-info/METADATA,sha256=PKQDHwJ9J3IdkJTncdUeqbFVeZ5p97EE1uJIFzb8I-E,19723
496
- mct_nightly-2.1.0.20240616.65727.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
497
- mct_nightly-2.1.0.20240616.65727.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
498
- mct_nightly-2.1.0.20240616.65727.dist-info/RECORD,,
494
+ model_compression_toolkit/xquant/__init__.py,sha256=vdmr8sQw3jIBLF9ck7qrskPoXzDKtksHWlMOkU1JUnQ,1003
495
+ model_compression_toolkit/xquant/common/__init__.py,sha256=ycb1Xt7PtixY2Uabr94JGSwBMcct66O8ZMVf3Qa3ud8,719
496
+ model_compression_toolkit/xquant/common/constants.py,sha256=LRh7q0GtyLTSwOc-XL5yNcPKVq68RvKnORYEC4KK-Ss,1513
497
+ model_compression_toolkit/xquant/common/core_report_generator.py,sha256=LQ9QUST9xyvm4B5sp68rjVPnpnxyosn_9jDBcyRciLs,4951
498
+ model_compression_toolkit/xquant/common/dataset_utils.py,sha256=91uXF9UwxdY7BvUT0FNkFm8a69c8oK8Xdl-y7lbuJxk,1649
499
+ model_compression_toolkit/xquant/common/framework_report_utils.py,sha256=3hzTg5xqdcxHnxmxO8B06o5sW8R-NH1Ixa75U0kie-o,3891
500
+ model_compression_toolkit/xquant/common/model_analyzer.py,sha256=T_8OetIQNqR0nkfSatWsEceXSPYpHfYjboBPIyR03-w,3953
501
+ model_compression_toolkit/xquant/common/model_folding_utils.py,sha256=y5Vmc-hJ2rJhzWdM53HdY-PrT5LlspejTUNlXaCrq9Q,4720
502
+ model_compression_toolkit/xquant/common/similarity_calculator.py,sha256=yCs_vlOThLzq7z-u2PkcEErLj7N7qCBPpRa6_5h34J8,10460
503
+ model_compression_toolkit/xquant/common/similarity_functions.py,sha256=Atah1otdX9oUUch2JK-p-e291QHtkP_c4DfLG9WWo1Y,2935
504
+ model_compression_toolkit/xquant/common/tensorboard_utils.py,sha256=YWvTvp7DyZDhybLnjte1Em90lev-NAa-hYp445BX-Y4,4473
505
+ model_compression_toolkit/xquant/common/xquant_config.py,sha256=Qt56cra2tU1PeHlLx_Cqztf5q-ED8MPelhb8coSumFw,1675
506
+ model_compression_toolkit/xquant/keras/__init__.py,sha256=zbtceCVRsi-Gvl_pOmq5laqVqu55vAU1ie2FR2RK1Po,709
507
+ model_compression_toolkit/xquant/keras/dataset_utils.py,sha256=quvVymhvpcPIOneCu5J6K_QAqBHOCIj8IxZxSN2fItA,2258
508
+ model_compression_toolkit/xquant/keras/facade_xquant_report.py,sha256=ZBwu1PwBgMbhQK-GvVCmn8CE6a1joKxZPluNNt9RqSw,3375
509
+ model_compression_toolkit/xquant/keras/keras_report_utils.py,sha256=Yk-VpyNYi5NWKTVYz-alfLK0JvM9CZDwGXBLu6HNJtI,2987
510
+ model_compression_toolkit/xquant/keras/model_analyzer.py,sha256=WXi9BPI9_TzRWn50lM1i-6cwPPRW0p43Shg_xpHFclU,6521
511
+ model_compression_toolkit/xquant/keras/similarity_functions.py,sha256=P2qMJAo94Sz_BCao-bnhEeewKtjeLLDDH2r9luDXJ04,2710
512
+ model_compression_toolkit/xquant/keras/tensorboard_utils.py,sha256=I1JMlSYe8eoYBpvHmc7H08iC9jdwgAWT4O5c7SMFOfc,4230
513
+ model_compression_toolkit/xquant/pytorch/__init__.py,sha256=ycb1Xt7PtixY2Uabr94JGSwBMcct66O8ZMVf3Qa3ud8,719
514
+ model_compression_toolkit/xquant/pytorch/dataset_utils.py,sha256=KFKiFkhIPpEr1ZH5jekZFrgs20VzzKVxSV9YMgH68yI,2894
515
+ model_compression_toolkit/xquant/pytorch/facade_xquant_report.py,sha256=g5uHlFW9vECkTsrgUs8iohbCCQ4_9tPUcoUv1QZH9uI,3146
516
+ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-ihJBLy5Cic-MQiUM_ZGV6SCXoNdscE,5549
517
+ model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=yrZNVRm2IRU7r7R-hjS2lOQ6wvEEvbeunvf2jKoWjXk,3277
518
+ model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
519
+ model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=eyMoXt5o5EnMr6d-rpCwQdX5mAiYiymvbgKv4tf7-a0,4576
520
+ mct_nightly-2.1.0.20240618.432.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
521
+ mct_nightly-2.1.0.20240618.432.dist-info/METADATA,sha256=oQDb0iDkegJzq1J15CZ59NnILw3BnPrgFuRFki8h95Y,19726
522
+ mct_nightly-2.1.0.20240618.432.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
523
+ mct_nightly-2.1.0.20240618.432.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
524
+ mct_nightly-2.1.0.20240618.432.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.1.0.20240616.065727"
30
+ __version__ = "2.1.0.20240618.000432"
@@ -25,7 +25,7 @@ class FunctionalNode(BaseNode):
25
25
  functional_op: Any = None,
26
26
  inputs_as_list: bool = False,
27
27
  has_activation: bool = True,
28
- tensor_input_indices = None):
28
+ tensor_input_allocs = None):
29
29
  """
30
30
  Init a FunctionalNode object.
31
31
 
@@ -44,7 +44,7 @@ class FunctionalNode(BaseNode):
44
44
  functional_op: The op the node implements.
45
45
  inputs_as_list: Whether to pass the node its input tensors as a list or not when calling the layer.
46
46
  has_activation: Whether the node has activations that we might want to quantize.
47
- tensor_input_indices: A list of indices for activation tensors in the node's input tensor list
47
+ tensor_input_allocs: A list of indices for activation tensors in the node's input tensor list
48
48
 
49
49
  """
50
50
 
@@ -63,7 +63,7 @@ class FunctionalNode(BaseNode):
63
63
  self.op_call_args = op_call_args
64
64
  self.functional_op = functional_op
65
65
  self.inputs_as_list = inputs_as_list
66
- self.tensor_input_indices = [] if tensor_input_indices is None else tensor_input_indices
66
+ self.tensor_input_allocs = [] if tensor_input_allocs is None else tensor_input_allocs
67
67
 
68
68
  @property
69
69
  def type(self):
@@ -22,6 +22,7 @@ from networkx import topological_sort
22
22
 
23
23
  from model_compression_toolkit.core import FrameworkInfo
24
24
  from model_compression_toolkit.core import common
25
+ from model_compression_toolkit.logger import Logger
25
26
  from model_compression_toolkit.core.common import BaseNode, Graph
26
27
  from model_compression_toolkit.core.common.back2framework.base_model_builder import BaseModelBuilder
27
28
  from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
@@ -66,8 +67,8 @@ def _build_input_tensors_list(node: BaseNode,
66
67
  return input_tensors
67
68
 
68
69
 
69
- def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List,
70
- tensor_input_indices: List = None) -> List:
70
+ def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List, op_call_kwargs: Dict,
71
+ tensor_input_allocs: List = None) -> Tuple[List, Dict]:
71
72
  """
72
73
  Merge input tensors list with positional weights and op_call_args, according to correct order.
73
74
 
@@ -75,22 +76,30 @@ def _merge_inputs(_node: BaseNode, input_tensors: List, op_call_args: List,
75
76
  _node: The node the inputs are for.
76
77
  input_tensors: activation input tensors to node.
77
78
  op_call_args: framework node call args.
79
+ op_call_kwargs: framework node call kwargs.
80
+ tensor_input_allocs: List of input allocations to node.
78
81
 
79
82
  Returns:
80
83
  Combined list of input_tensors and op_call_args.
81
84
  """
82
- if isinstance(_node, FunctionalNode) and _node.tensor_input_indices:
85
+ if isinstance(_node, FunctionalNode) and _node.tensor_input_allocs:
83
86
  _input_list = op_call_args.copy()
84
- if tensor_input_indices is None:
85
- tensor_input_indices = _node.tensor_input_indices
86
- assert len(tensor_input_indices) == len(input_tensors), \
87
- f'Mismatch between input tensors ({len(tensor_input_indices)}) and indices {len(input_tensors)}'
88
- for i, t in zip(tensor_input_indices, input_tensors):
89
- _input_list.insert(i, t)
87
+ if tensor_input_allocs is None:
88
+ tensor_input_allocs = _node.tensor_input_allocs
89
+ if len(tensor_input_allocs) != len(input_tensors):
90
+ Logger.error(f'Mismatch between input tensors ({len(tensor_input_allocs)}) '
91
+ f'and indices {len(input_tensors)} in node {_node.name}.') # pragma: no cover
92
+ for i, t in zip(tensor_input_allocs, input_tensors):
93
+ # insert input tensors in either args or kwargs, according to tensor_input_allocs
94
+ if isinstance(i, str):
95
+ assert i not in op_call_kwargs
96
+ op_call_kwargs.update({i: t})
97
+ else:
98
+ _input_list.insert(i, t)
90
99
  else:
91
100
  _input_list = input_tensors + op_call_args
92
101
 
93
- return _input_list
102
+ return _input_list, op_call_kwargs
94
103
 
95
104
 
96
105
  def _run_operation(n: BaseNode,
@@ -125,14 +134,15 @@ def _run_operation(n: BaseNode,
125
134
  # list separately, because in FX the tensors are FX objects and fail to_torch_tensor
126
135
  input_tensors = [to_torch_tensor(t, numpy_type=t.dtype) if isinstance(t, np.ndarray) else t
127
136
  for t in input_tensors]
128
- _tensor_input_indices = None
137
+ _tensor_input_allocs = None
129
138
  else:
130
- _tensor_input_indices = [i for i in n.tensor_input_indices if i not in n.weights]
139
+ _tensor_input_allocs = [i for i in n.tensor_input_allocs if i not in n.weights]
131
140
 
132
141
  if isinstance(n, FunctionalNode) and n.inputs_as_list:
133
142
  out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
134
143
  else:
135
- merged_inputs = _merge_inputs(n, input_tensors, op_call_args, tensor_input_indices=_tensor_input_indices)
144
+ merged_inputs, functional_kwargs = _merge_inputs(n, input_tensors, op_call_args, functional_kwargs.copy(),
145
+ tensor_input_allocs=_tensor_input_allocs)
136
146
  out_tensors_of_n_float = op_func(*merged_inputs, **functional_kwargs)
137
147
 
138
148
  # Add a fake quant node if the node has an activation threshold.
@@ -40,7 +40,7 @@ FUNCTIONAL_OP = 'functional_op'
40
40
  OP_CALL_ARGS = 'op_call_args'
41
41
  OP_CALL_KWARGS = 'op_call_kwargs'
42
42
  INPUTS_AS_LIST = 'inputs_as_list'
43
- TENSOR_INPUT_INDICES = 'tensor_input_indices'
43
+ TENSOR_INPUT_ALLOCS = 'tensor_input_allocs'
44
44
  INPLACE = 'inplace'
45
45
  HARDTANH_MIN_VAL = 'min_val'
46
46
  HARDTANH_MAX_VAL = 'max_val'
@@ -65,10 +65,10 @@ class ReshapeWithStaticShapes(common.BaseSubstitution):
65
65
 
66
66
  # When a "reshape" is called with multiple arguments (e.g. x.reshape(-1, channels, height, width)
67
67
  # this substitution converts it x.reshape((-1, channels, height, width)), so need to update the
68
- # tensor_input_indices attribute.
69
- # scalar argument's shape is [1] so remove those indices from tensor_input_indices
68
+ # tensor_input_allocs attribute.
69
+ # scalar argument's shape is [1] so remove those indices from tensor_input_allocs
70
70
  # node.input_shape example: [[1, 32, 4, 32], [1], [1], [1]]
71
- node.tensor_input_indices = node.tensor_input_indices[:sum([i != [1] for i in node.input_shape])]
71
+ node.tensor_input_allocs = node.tensor_input_allocs[:sum([i != [1] for i in node.input_shape])]
72
72
 
73
73
  # modify the node input info
74
74
  node.input_shape = [node.input_shape[0]]
@@ -23,7 +23,7 @@ from model_compression_toolkit.core.common.graph.base_graph import OutTensor
23
23
  from model_compression_toolkit.core.common.graph.edge import Edge
24
24
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
25
25
  from model_compression_toolkit.core.pytorch.constants import OUTPUT, PLACEHOLDER, TENSOR_META, CALL_FUNCTION, TYPE, \
26
- CALL_METHOD, BIAS, FUNCTIONAL_OP, OP_CALL_KWARGS, OP_CALL_ARGS, INPUTS_AS_LIST, TENSOR_INPUT_INDICES, GET_ATTR
26
+ CALL_METHOD, BIAS, FUNCTIONAL_OP, OP_CALL_KWARGS, OP_CALL_ARGS, INPUTS_AS_LIST, TENSOR_INPUT_ALLOCS, GET_ATTR
27
27
  from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
28
28
  from model_compression_toolkit.logger import Logger
29
29
 
@@ -140,7 +140,7 @@ def nodes_builder(model: GraphModule,
140
140
  weights.update({i: consts_dict[input_node]})
141
141
 
142
142
  tensor_meta = input_node.meta
143
- if tensor_meta[TYPE] == torch.Tensor:
143
+ if tensor_meta[TYPE] in [torch.Tensor, torch.nn.parameter.Parameter]:
144
144
  input_shape += [list(tensor_meta[TENSOR_META].shape)]
145
145
  elif tensor_meta[TYPE] == tuple:
146
146
  input_shape += [list(n.shape) for n in tensor_meta[TENSOR_META]]
@@ -159,8 +159,11 @@ def nodes_builder(model: GraphModule,
159
159
 
160
160
  # filter Nodes from framework attributes, we replace these attributes with nx graph nodes
161
161
  framework_attr_filtered = {}
162
+ framework_attr_nodes = {}
162
163
  for k, v in framework_attr.items():
163
- if not isinstance(v, torch.fx.node.Node):
164
+ if isinstance(v, torch.fx.node.Node):
165
+ framework_attr_nodes[k] = v
166
+ else:
164
167
  framework_attr_filtered[k] = v
165
168
  framework_attr = framework_attr_filtered
166
169
 
@@ -177,7 +180,7 @@ def nodes_builder(model: GraphModule,
177
180
  [isinstance(n, torch.fx.node.Node) for n in node.args[0]])
178
181
  inputs_as_list = inputs_as_list1 or (len(node.args) > 0 and isinstance(node.args[0], Node) and
179
182
  node.args[0].op == PLACEHOLDER and node.args[0].meta[TYPE] in (list, tuple))
180
- tensor_input_index = []
183
+ tensor_input_alloc = []
181
184
  op_call_args = list(node.args)
182
185
  if inputs_as_list:
183
186
  op_call_args.pop(0)
@@ -185,7 +188,10 @@ def nodes_builder(model: GraphModule,
185
188
  for in_node in node.all_input_nodes:
186
189
  for i, arg in enumerate(node.args):
187
190
  if arg == in_node:
188
- tensor_input_index.append(i)
191
+ tensor_input_alloc.append(i)
192
+ for k, arg in framework_attr_nodes.items():
193
+ if arg == in_node:
194
+ tensor_input_alloc.append(k)
189
195
 
190
196
  # remove torch.fx.node.Node from inputs to graph_node_type
191
197
  op_call_args = [arg for arg in op_call_args if not isinstance(arg, Node)]
@@ -197,7 +203,7 @@ def nodes_builder(model: GraphModule,
197
203
  OP_CALL_ARGS: op_call_args,
198
204
  OP_CALL_KWARGS: node_kwargs,
199
205
  INPUTS_AS_LIST: inputs_as_list,
200
- TENSOR_INPUT_INDICES: tensor_input_index}
206
+ TENSOR_INPUT_ALLOCS: tensor_input_alloc}
201
207
  else:
202
208
  graph_node_type = BaseNode
203
209
  kwargs = {}
@@ -353,7 +353,7 @@ class KerasGPTQTrainer(GPTQTrainer):
353
353
  node.final_activation_quantization_cfg.set_quant_config_attr(config_attr, config_value)
354
354
  if self.gptq_config.train_bias:
355
355
  use_bias = layer.layer.get_config().get(USE_BIAS)
356
- if use_bias is not None and use_bias:
356
+ if use_bias is not None and use_bias and layer.layer.bias is not None:
357
357
  new_bias = layer.layer.bias.numpy()
358
358
  node.set_weights_by_keys(BIAS, new_bias)
359
359
 
@@ -63,7 +63,7 @@ def get_gptq_trainable_parameters(fxp_model: Model,
63
63
  kernel_ops_attrs = fw_info.kernel_ops_attributes_mapping.get(type(layer.layer))
64
64
  use_bias = kernel_ops_attrs is not None and kernel_ops_attrs[0] is not None \
65
65
  and layer.layer.get_config().get(USE_BIAS)
66
- if use_bias is not None and use_bias:
66
+ if use_bias is not None and use_bias and layer.layer.bias is not None:
67
67
  bias_weights.append([layer.layer.bias])
68
68
 
69
69
  return trainable_weights, bias_weights, trainable_threshold
@@ -299,7 +299,9 @@ class PytorchGPTQTrainer(GPTQTrainer):
299
299
  for config_attr, config_value in activation_quant_config.items():
300
300
  node.final_activation_quantization_cfg.set_quant_config_attr(config_attr, config_value)
301
301
  if self.gptq_config.train_bias and hasattr(layer.layer, BIAS):
302
- node.set_weights_by_keys(BIAS, self.fw_impl.to_numpy(getattr(layer.layer, BIAS)))
302
+ bias = getattr(layer.layer, BIAS)
303
+ if bias is not None:
304
+ node.set_weights_by_keys(BIAS, self.fw_impl.to_numpy(bias))
303
305
 
304
306
  return graph_quant
305
307
 
@@ -316,4 +318,5 @@ class PytorchGPTQTrainer(GPTQTrainer):
316
318
  if isinstance(layer, PytorchQuantizationWrapper):
317
319
  if hasattr(layer.layer, BIAS):
318
320
  bias = getattr(layer.layer, BIAS)
319
- bias.requires_grad = self.gptq_config.train_bias
321
+ if bias is not None:
322
+ bias.requires_grad = self.gptq_config.train_bias
@@ -56,7 +56,8 @@ def get_gptq_trainable_parameters(fxp_model: nn.Module,
56
56
 
57
57
  if add_bias and hasattr(layer.layer, BIAS):
58
58
  bias = getattr(layer.layer, BIAS)
59
- trainable_bias.append(bias)
59
+ if bias is not None:
60
+ trainable_bias.append(bias)
60
61
 
61
62
  return trainable_aux_weights, trainable_bias, trainable_threshold
62
63
 
@@ -18,7 +18,7 @@ import operator
18
18
  import torch
19
19
  from torch import add, sub, mul, div, flatten, reshape, split, unsqueeze, dropout, sigmoid, tanh, chunk, unbind, topk, \
20
20
  gather, equal, transpose, permute, argmax, squeeze
21
- from torch.nn import Conv2d, Linear, ConvTranspose2d
21
+ from torch.nn import Conv2d, Linear, ConvTranspose2d, MaxPool2d
22
22
  from torch.nn import Dropout, Flatten, Hardtanh, Identity
23
23
  from torch.nn import ReLU, ReLU6, PReLU, SiLU, Sigmoid, Tanh, Hardswish, LeakyReLU
24
24
  from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, leaky_relu
@@ -83,7 +83,8 @@ def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel):
83
83
  argmax,
84
84
  gather,
85
85
  topk,
86
- squeeze])
86
+ squeeze,
87
+ MaxPool2d])
87
88
 
88
89
  tp.OperationsSetToLayers("Conv", [Conv2d, ConvTranspose2d],
89
90
  attr_mapping=pytorch_linear_attr_mapping)
@@ -18,7 +18,7 @@ import operator
18
18
  import torch
19
19
  from torch import add, sub, mul, div, flatten, reshape, split, unsqueeze, dropout, sigmoid, tanh, chunk, unbind, topk, \
20
20
  gather, equal, transpose, permute, argmax, squeeze
21
- from torch.nn import Conv2d, Linear, ConvTranspose2d
21
+ from torch.nn import Conv2d, Linear, ConvTranspose2d, MaxPool2d
22
22
  from torch.nn import Dropout, Flatten, Hardtanh, Identity
23
23
  from torch.nn import ReLU, ReLU6, PReLU, SiLU, Sigmoid, Tanh, Hardswish, LeakyReLU
24
24
  from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, leaky_relu
@@ -82,7 +82,8 @@ def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel):
82
82
  argmax,
83
83
  gather,
84
84
  topk,
85
- squeeze])
85
+ squeeze,
86
+ MaxPool2d])
86
87
 
87
88
  tp.OperationsSetToLayers("Conv", [Conv2d, ConvTranspose2d],
88
89
  attr_mapping=pytorch_linear_attr_mapping)
@@ -0,0 +1,19 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from model_compression_toolkit.xquant.common.xquant_config import XQuantConfig
17
+ from model_compression_toolkit.xquant.keras.facade_xquant_report import xquant_report_keras_experimental
18
+ from model_compression_toolkit.xquant.pytorch.facade_xquant_report import xquant_report_pytorch_experimental
19
+
@@ -0,0 +1,15 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ # #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ # #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ # #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
@@ -0,0 +1,38 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # Default similarity metric names:
17
+ CS_SIMILARITY_METRIC_NAME = 'cs'
18
+ SQNR_SIMILARITY_METRIC_NAME = 'sqnr'
19
+ MSE_SIMILARITY_METRIC_NAME = 'mse'
20
+
21
+ # Report components names:
22
+ OUTPUT_SIMILARITY_METRICS_REPR = 'output_similarity_metrics_repr'
23
+ OUTPUT_SIMILARITY_METRICS_VAL = 'output_similarity_metrics_val'
24
+ INTERMEDIATE_SIMILARITY_METRICS_REPR = 'intermediate_similarity_metrics_repr'
25
+ INTERMEDIATE_SIMILARITY_METRICS_VAL = 'intermediate_similarity_metrics_val'
26
+
27
+ # Graph attribute names:
28
+ XQUANT_REPR = 'xquant_repr'
29
+ XQUANT_VAL = 'xquant_val'
30
+
31
+ # Report file name:
32
+ REPORT_FILENAME = 'quant_report.json'
33
+
34
+ # Tag to use in tensorboard for the graph we plot:
35
+ TENSORBOARD_DEFAULT_TAG = 'xquant'
36
+
37
+ # When extracting the activations of a model we hold the output using a dedicated key:
38
+ MODEL_OUTPUT_KEY = 'model_output_key'
@@ -0,0 +1,83 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from tqdm import tqdm
16
+ from typing import Callable, Any, Dict
17
+
18
+ from model_compression_toolkit.core.common.model_collector import ModelCollector
19
+ from model_compression_toolkit.xquant import XQuantConfig
20
+ from model_compression_toolkit.xquant.common.constants import OUTPUT_SIMILARITY_METRICS_REPR, OUTPUT_SIMILARITY_METRICS_VAL, INTERMEDIATE_SIMILARITY_METRICS_REPR, \
21
+ INTERMEDIATE_SIMILARITY_METRICS_VAL
22
+ from model_compression_toolkit.xquant.common.framework_report_utils import FrameworkReportUtils
23
+
24
+
25
+ def core_report_generator(float_model: Any,
26
+ quantized_model: Any,
27
+ repr_dataset: Callable,
28
+ validation_dataset: Callable,
29
+ fw_report_utils: FrameworkReportUtils,
30
+ xquant_config: XQuantConfig) -> Dict[str, Any]:
31
+ """
32
+ Generate report in tensorboard with a graph of the quantized model and similarity metrics that
33
+ have been measured when comparing to the float model (or any other two models).
34
+ The report also contains histograms that are collected on the baseline model (usually, the float
35
+ model).
36
+
37
+ Args:
38
+ float_model (Any): The original floating-point model.
39
+ quantized_model (Any): The model after quantization.
40
+ repr_dataset (Callable): Representative dataset used for similarity metrics computation.
41
+ validation_dataset (Callable): Validation dataset used for similarity metrics computation.
42
+ fw_report_utils (FrameworkReportUtils): Utilities for generating framework-specific reports.
43
+ xquant_config (XQuantConfig): Configuration settings for explainable quantization.
44
+
45
+ Returns:
46
+ Dict[str, Any]: A dictionary containing the collected similarity metrics and report data.
47
+ """
48
+ # Collect histograms on the float model.
49
+ float_graph = fw_report_utils.model_folding_utils.create_float_folded_graph(float_model, repr_dataset)
50
+ mi = ModelCollector(float_graph, fw_report_utils.fw_impl, fw_report_utils.fw_info)
51
+ for _data in tqdm(repr_dataset(), desc="Collecting Histograms"):
52
+ mi.infer(_data)
53
+
54
+ # Collect histograms and add them to Tensorboard.
55
+ fw_report_utils.tb_utils.add_histograms_to_tensorboard(graph=float_graph)
56
+
57
+ # Compute similarity metrics on representative dataset and validation set.
58
+ repr_similarity = fw_report_utils.similarity_calculator.compute_similarity_metrics(float_model=float_model,
59
+ quantized_model=quantized_model,
60
+ dataset=repr_dataset,
61
+ custom_similarity_metrics=xquant_config.custom_similarity_metrics)
62
+ val_similarity = fw_report_utils.similarity_calculator.compute_similarity_metrics(float_model=float_model,
63
+ quantized_model=quantized_model,
64
+ dataset=validation_dataset,
65
+ custom_similarity_metrics=xquant_config.custom_similarity_metrics,
66
+ is_validation=True)
67
+ similarity_metrics = {
68
+ OUTPUT_SIMILARITY_METRICS_REPR: repr_similarity[0],
69
+ OUTPUT_SIMILARITY_METRICS_VAL: val_similarity[0],
70
+ INTERMEDIATE_SIMILARITY_METRICS_REPR: repr_similarity[1],
71
+ INTERMEDIATE_SIMILARITY_METRICS_VAL: val_similarity[1]
72
+ }
73
+
74
+ # Add a graph of the quantized model with the similarity metrics to TensorBoard for visualization.
75
+ fw_report_utils.tb_utils.add_graph_to_tensorboard(quantized_model,
76
+ similarity_metrics,
77
+ repr_dataset)
78
+
79
+ # Save data to a json file.
80
+ fw_report_utils.dump_report_to_json(report_dir=xquant_config.report_dir,
81
+ collected_data=similarity_metrics)
82
+
83
+ return similarity_metrics