mct-nightly 2.4.0.20250706.701__py3-none-any.whl → 2.4.0.20250708.612__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250708.612.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250708.612.dist-info}/RECORD +36 -38
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/collectors/base_collector.py +4 -1
  5. model_compression_toolkit/core/common/collectors/mean_collector.py +7 -4
  6. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +7 -4
  7. model_compression_toolkit/core/common/model_collector.py +17 -3
  8. model_compression_toolkit/core/common/pruning/memory_calculator.py +1 -1
  9. model_compression_toolkit/core/common/quantization/node_quantization_config.py +25 -87
  10. model_compression_toolkit/core/common/quantization/quantization_config.py +0 -1
  11. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +26 -17
  12. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +27 -49
  13. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +12 -7
  14. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +4 -14
  15. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -1
  16. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +4 -13
  17. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +3 -3
  18. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +5 -7
  19. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +7 -5
  20. model_compression_toolkit/core/graph_prep_runner.py +1 -11
  21. model_compression_toolkit/core/keras/default_framework_info.py +1 -1
  22. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +21 -11
  23. model_compression_toolkit/core/keras/keras_implementation.py +2 -2
  24. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +8 -0
  25. model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
  26. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +9 -1
  27. model_compression_toolkit/core/quantization_prep_runner.py +2 -2
  28. model_compression_toolkit/gptq/keras/quantization_facade.py +0 -3
  29. model_compression_toolkit/ptq/keras/quantization_facade.py +0 -3
  30. model_compression_toolkit/qat/keras/quantization_facade.py +0 -3
  31. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +0 -2
  32. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +0 -6
  33. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +2 -4
  34. model_compression_toolkit/core/common/model_validation.py +0 -41
  35. model_compression_toolkit/core/keras/keras_model_validation.py +0 -37
  36. {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250708.612.dist-info}/WHEEL +0 -0
  37. {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250708.612.dist-info}/licenses/LICENSE.md +0 -0
  38. {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250708.612.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mct-nightly
3
- Version: 2.4.0.20250706.701
3
+ Version: 2.4.0.20250708.612
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Author-email: ssi-dnn-dev@sony.com
6
6
  Classifier: Programming Language :: Python :: 3
@@ -1,5 +1,5 @@
1
- mct_nightly-2.4.0.20250706.701.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
- model_compression_toolkit/__init__.py,sha256=oMuUjzzPKDPdfhr1AF-A55kyQJBPEDqlrqfoPxw3HKA,1557
1
+ mct_nightly-2.4.0.20250708.612.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
+ model_compression_toolkit/__init__.py,sha256=TrwEn4n4YKNNqTg96Ud45uv2qh_Ob-Q1IzRMrsqjJG4,1557
3
3
  model_compression_toolkit/constants.py,sha256=KNgiNLpsMgSYyXMNEbHXd4bFNerQc1D6HH3vpbUq_Gs,4086
4
4
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
5
5
  model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
@@ -7,8 +7,8 @@ model_compression_toolkit/metadata.py,sha256=x_Bk4VpzILdsFax6--CZ3X18qUTP28sbF_A
7
7
  model_compression_toolkit/verify_packages.py,sha256=l0neIRr8q_QwxmuiTI4vyCMDISDedK0EihjEQUe66tE,1319
8
8
  model_compression_toolkit/core/__init__.py,sha256=HNverPpoqEyFKTa7iEdOqqY2P0Gq-7GMejNOi6ZPcQs,2042
9
9
  model_compression_toolkit/core/analyzer.py,sha256=5P03LbkFy-mu31TMAiQoIKcsA1-DNz7cTzkGvRaXtbw,3505
10
- model_compression_toolkit/core/graph_prep_runner.py,sha256=naZWayASraZ9PgmqCBFgFWWfDV3zLgPaIo6JLbInZc4,11361
11
- model_compression_toolkit/core/quantization_prep_runner.py,sha256=tz91E1BaNc_K0lvVZGB8oS6ya5N4Z5TJLG4pSM3hx30,6229
10
+ model_compression_toolkit/core/graph_prep_runner.py,sha256=XvuR1lNsbaGG3HyROA8nF2n0oPNb1Tnw6rGatewAvE0,10563
11
+ model_compression_toolkit/core/quantization_prep_runner.py,sha256=W6CAJ_Euxrjm5jj__QqfB4xsD8NYx_-XWX22h0ax6B4,6262
12
12
  model_compression_toolkit/core/runner.py,sha256=QpiJQmQXK6mWmnygNRdy6I8S48DHV-B0Kmr4TqOKbeA,12418
13
13
  model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
14
14
  model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
@@ -16,18 +16,17 @@ model_compression_toolkit/core/common/framework_implementation.py,sha256=jrTupZb
16
16
  model_compression_toolkit/core/common/framework_info.py,sha256=vPGV28gm-kvNSkkWI6jY3YeKBUYmn6UQ98HVUnl_-tM,5449
17
17
  model_compression_toolkit/core/common/memory_computation.py,sha256=ixoSpV5ZYZGyzhre3kQcvR2sNA8KBsPZ3lgbkDnw9Cs,1205
18
18
  model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3ug7Y9-lLyV99_FoNHxkGZMgcm0Vkpss,1324
19
- model_compression_toolkit/core/common/model_collector.py,sha256=A1uaGmxqj-392lMtE-F020FHFAyyKDJDdeJeZYtkv3Y,12755
20
- model_compression_toolkit/core/common/model_validation.py,sha256=HRnYh2uY85yJ7Ijmt4tKRn8bMg60zbBSDRCgK246gUM,1067
19
+ model_compression_toolkit/core/common/model_collector.py,sha256=I0O5SoUrwB045AGLleOGYVvZyR7HJx_t6vfbECahfoU,13914
21
20
  model_compression_toolkit/core/common/node_prior_info.py,sha256=WXX_PrGVG9M9I_REG5ZzFBohwmV4yf356sZnrja_FLo,2832
22
21
  model_compression_toolkit/core/common/similarity_analyzer.py,sha256=S3f6WgHyw62dGcxpX51FGKyfebe2zv9ABKbjtGyKRvY,9215
23
22
  model_compression_toolkit/core/common/user_info.py,sha256=dSRMnT-oewmdOziIpEuW-s9K7vTSeyUBxT4z9neXurI,1648
24
23
  model_compression_toolkit/core/common/back2framework/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
25
24
  model_compression_toolkit/core/common/back2framework/base_model_builder.py,sha256=yrIxT0ttDi9XViy8Zt8apnMCT8xDyVd5HZp0IttrGGQ,1775
26
25
  model_compression_toolkit/core/common/collectors/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
27
- model_compression_toolkit/core/common/collectors/base_collector.py,sha256=JoBTX3rRcRnUF3_Azjg848aiJt9drCJ5TsR9RahVI0Y,2591
26
+ model_compression_toolkit/core/common/collectors/base_collector.py,sha256=n3IIwfWVIl8TO_7MR9jBBxzf2foysWMyrGS-oaCuKiA,2664
28
27
  model_compression_toolkit/core/common/collectors/histogram_collector.py,sha256=zra5V06Brpjc1cUNIMVVGqdoqAuro62_hGy2Zm5-XMQ,6754
29
- model_compression_toolkit/core/common/collectors/mean_collector.py,sha256=mjr3U_z7vn8rrqpkHnfErUOflToIYl4ozBVzP2awqDQ,3414
30
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py,sha256=5oKsJEKdVmj4C7fKdHhmrFN5k4G2BaFETpmf_xKNs7s,5207
28
+ model_compression_toolkit/core/common/collectors/mean_collector.py,sha256=7-oCVU76e9ZWIZLiLte-z07FX2Rjmu5Go-KRj7b7oWY,3564
29
+ model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py,sha256=pAdiLXNNm5sfIZ584m2MzJQ0PTNp6MKsrMxJ_onEN1s,5314
31
30
  model_compression_toolkit/core/common/collectors/statistics_collector.py,sha256=psijsQZefwjMDH8SU5E18n65HiGtQilPhKr1hhzZX-I,8268
32
31
  model_compression_toolkit/core/common/collectors/weighted_histogram_collector.py,sha256=zp3dE7YTqWmkD5QWdRhsl9zD8W6Lr96G1Wjw1g2D3T0,4894
33
32
  model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
@@ -87,7 +86,7 @@ model_compression_toolkit/core/common/network_editors/node_filters.py,sha256=Pc_
87
86
  model_compression_toolkit/core/common/pruning/__init__.py,sha256=DGJybkDQtKMSMFoZ-nZ3ZifA8uJ6G_D20wHhKHNlmU0,699
88
87
  model_compression_toolkit/core/common/pruning/channels_grouping.py,sha256=-zrq0TsfVE4ooxOcJCsL8H2DBau6vSkEKz1ot-x-Faw,3736
89
88
  model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py,sha256=UZekrges7gZv17JFLX_AV2Kv0eBXXarMNInzuOTlyvA,7712
90
- model_compression_toolkit/core/common/pruning/memory_calculator.py,sha256=bQtkMuxm9RajIztN4m88ZT0zCeN_bRcm4H2VGqE36lg,18944
89
+ model_compression_toolkit/core/common/pruning/memory_calculator.py,sha256=7IdX9APlS17omWeF_y0ZFr8S2K7Cp-M0u8rFtfmrx8U,19023
91
90
  model_compression_toolkit/core/common/pruning/prune_graph.py,sha256=eGvuqrxyADRSvhKz0R_7lLfIl7bKnn4bryElu3LsVcA,3158
92
91
  model_compression_toolkit/core/common/pruning/pruner.py,sha256=Zl0IK0anorzagaSP8qXMN31Dtw5m-Le-JRy2baPLs6M,7262
93
92
  model_compression_toolkit/core/common/pruning/pruning_config.py,sha256=fbqERt11FGVeuqPVA6nVbgGDh6Ox9mpEKdxVJT8eG4I,3681
@@ -107,8 +106,8 @@ model_compression_toolkit/core/common/quantization/candidate_node_quantization_c
107
106
  model_compression_toolkit/core/common/quantization/core_config.py,sha256=yxCzWqldcHoe8GGxrH0tp99bhrc5jDT7SgZftnMUUBE,2374
108
107
  model_compression_toolkit/core/common/quantization/debug_config.py,sha256=uH45Uq3Tp9FIyMynex_WY2_y-Kv8LuPw2XXZydnpW5A,1649
109
108
  model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=FyYCYbfkAofEWO2mAvFIppPeq2I10f1ScPNiVa9F7x4,7687
110
- model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=fj1ebZgnK6xH-9LIAu93rOEU7siXK86U_VyAtUwu9nA,24869
111
- model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=UkSVW7d1OF_Px9gAjsqqK65aYhIBFWaBO-_IH6_AFfg,4403
109
+ model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=VRY_HxZl9D77eq6oJ61eBnL4QJTG5pVLzJtAcaZATRQ,21636
110
+ model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=EMhXeY1qkvwlMAY5fpKRDuLEAyRY5yKqe2fOmAD_rVI,4362
112
111
  model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=VVq2cKjumlNWucUbaNw8s2J0IbI_vrQ-KR_eQPshGSg,3140
113
112
  model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=N005MSvx8UypVpa7XrxNrB2G732n2wHj3RmLyjTgd3I,2728
114
113
  model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=WJ-lsT_R_pqjbrMzgcposugACDNz7yZ09vSlltTb78A,3001
@@ -118,10 +117,10 @@ model_compression_toolkit/core/common/quantization/quantization_params_generatio
118
117
  model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=RL-PklAjGyC-26anSt8fU07a6pB_LBQFQy9o4e9giN0,8739
119
118
  model_compression_toolkit/core/common/quantization/quantization_params_generation/outlier_filter.py,sha256=9gnfJV89jpGwAx8ImJ5E9NjCv3lDtbyulP4OtgWb62M,1772
120
119
  model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py,sha256=-cghHF5S11qbjTDRruHlc__uaDoofZHl7QTl8hCeKW0,11141
121
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py,sha256=3EAbtLHOgTJIMbGlfAzeki7xxjipAsMyAaVRFXqF228,7243
122
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=27We8-tLL0dkDPYSDlhXe6ZKSO-kw2s5sD4q9I_ADmE,8401
120
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py,sha256=vOFquQ3_h0jyG8EEeHM8M57uIKMZQ7iobUgbvWbiXh4,7798
121
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=TPSkKujdJ6jB7VOVp0kRgMnthmeuaBZgN1HQuJ2pqR0,7951
123
122
  model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py,sha256=Nv_b3DECVjQnlrUet2kbuSvSKVnxcc-gf2zhFb2jSZk,43482
124
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py,sha256=jb9Q2WgjmMc6i8j3TXr850tWCdI0a8598bkTmMYfdAY,4529
123
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py,sha256=saOQMtj1qYqgQoAFjq31p7xtiRDxmanGGdm0DE81_cg,4820
125
124
  model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py,sha256=6tRNgWvn-4r8hiSHqND7Qms1Nje1DUR4MR0JeWCNyvI,12531
126
125
  model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py,sha256=xiZgCkoIrJ9xsR17x9pSl_sUbiuSta67kf7bQ4quFUI,10804
127
126
  model_compression_toolkit/core/common/quantization/quantizers/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
@@ -130,22 +129,22 @@ model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers
130
129
  model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py,sha256=wXExWHf5-0He7L4bpvFpKlx7FG4u3DAfNZiXPpOs_SQ,5521
131
130
  model_compression_toolkit/core/common/statistics_correction/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
132
131
  model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py,sha256=oUa1Gv9jIICOoFljTiIaItFjJQPht7CBe-wEr3iBuLQ,4118
133
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py,sha256=eGd0gaPz1K9tzfQf1UMBeshoydFwwZ4Ha2JKFCJ2eZc,4474
134
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py,sha256=w9VkX0_XyE64zaYJrZqGEtVxaox7MwY-c8Ie1C0f6ZU,5093
132
+ model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py,sha256=fvvIH5uWtQbmPJQh_zjE-XuCFP9YrlOEs3mW9ysKhms,3652
133
+ model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py,sha256=7HIoey1tqbTy1tnTF_KF5D-dSNankamBj8kWO9kUQYo,5200
135
134
  model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py,sha256=289b2iwzp2hjsgpEZotQKNB2aPKjAZopRaGnbzErHV8,9263
136
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py,sha256=08k7sqOLIya7Vvg2WMFdaSzLJ2FsgQlcKk0H_KoFoUg,10068
137
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py,sha256=yB5Kxk74RAzcXxguFRVpvjFSWFrGrqL3JoU1qLst4PQ,5881
135
+ model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py,sha256=p4u50-1CPuGTQtMDz5uaZqFSLhl2BLTbCLDQvplCuW4,9552
136
+ model_compression_toolkit/core/common/statistics_correction/statistics_correction.py,sha256=F0xOl97p8Q0h4MXxOKo0kGY4iJIiJou06JbICatg8vQ,5881
138
137
  model_compression_toolkit/core/common/substitutions/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
139
138
  model_compression_toolkit/core/common/substitutions/apply_substitutions.py,sha256=k-bifmakHIYZeZS-4T1QpZ1Et6AwAijMRgAKs7hmMKc,1390
140
139
  model_compression_toolkit/core/common/substitutions/batchnorm_folding.py,sha256=wLlTT7sqUffKHwOrMG2VV5SktQkkP54l8taW1Fq0mh0,13392
141
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py,sha256=Qe-MYKL2GRQ3PX1Q-zpws5mEW3vrs2h19kjiUZTkKwI,8327
140
+ model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py,sha256=xa-6ACozqTu4xWXkC-7p6YTZLcLcH_AUYEROTl8USWc,8014
142
141
  model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py,sha256=eCbhbAzgXWoVymMLbrupJ1qAcdhZDwkjKeja0fCymnY,9746
143
142
  model_compression_toolkit/core/common/substitutions/linear_collapsing.py,sha256=iEtzbWCDXP6EDkTZCtREQ0rpMxhQ2kM9zlcP_0KLq9I,12367
144
143
  model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py,sha256=uoauhmncQqUBNvD-qCLIXsIbl_IzrbxSKdxiMig-5W4,2406
145
144
  model_compression_toolkit/core/common/substitutions/remove_identity.py,sha256=TKU1TIU52UIkVnl0EZvWnDhLV9nIVZ4hqi-w1i4NXMk,2637
146
145
  model_compression_toolkit/core/common/substitutions/residual_collapsing.py,sha256=N82mso5j3EJQlKt9EMHjjEJ67FmdGQeCfN8U5grOFXo,4830
147
146
  model_compression_toolkit/core/common/substitutions/scale_equalization.py,sha256=2_NmmBmUBZZwXuF5Od2S919_FgQKYIf-nSyNPawr0e4,9840
148
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py,sha256=Q9dQPLIKVtCp23yj-BmQmYkH94OBvAfV-19CYgqWSw0,32572
147
+ model_compression_toolkit/core/common/substitutions/shift_negative_activation.py,sha256=_TliV1tBxizPsCZVbOhDhg3mSc_m0sY9gVqQMYTwQoc,32739
149
148
  model_compression_toolkit/core/common/substitutions/softmax_shift.py,sha256=R-0ZqhYAuZLEFWHvB2UTPm52L6gWHGdRdEnwGxKSeGI,2625
150
149
  model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py,sha256=cokiYPZB7504oHTlgZy8u2Xv_S-RK_oDSnGvYRX3JK4,4136
151
150
  model_compression_toolkit/core/common/substitutions/weights_activation_split.py,sha256=vafrJ6eA37PrIzOs7uOsiJKIBmAVmNJ-wXsoe332BIw,4683
@@ -157,9 +156,8 @@ model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7V
157
156
  model_compression_toolkit/core/keras/constants.py,sha256=dh4elQWt6Q6NYRht5k5RiiOcnLAq1v0MMBCJqMJzzFk,3225
158
157
  model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
159
158
  model_compression_toolkit/core/keras/data_util.py,sha256=jm54o-SlI1DJ-sEvRuX9OyLN68tEt0VxcqrdIjR98Ag,8366
160
- model_compression_toolkit/core/keras/default_framework_info.py,sha256=YhPSp153YcESp1Ho3GyvoEmxf2CpY9rjTnHAfN7Cpns,6175
161
- model_compression_toolkit/core/keras/keras_implementation.py,sha256=x5EOYBrg2chC9-OUlrd0laLpnnHCFhYYAFNKRhVh6aQ,28526
162
- model_compression_toolkit/core/keras/keras_model_validation.py,sha256=dMS9cqaYmliyzVu2-MrKx4AIubqz3HW3RY4if2TV6U8,1581
159
+ model_compression_toolkit/core/keras/default_framework_info.py,sha256=wI-M_MsVIL0qOdBr8F-oyrRz4qdc2o6DDB7lGlizddw,6171
160
+ model_compression_toolkit/core/keras/keras_implementation.py,sha256=5oAc0nc21Uzfo94rKUGGpbYVTmEiKA3YGNhbOwD66ng,28550
163
161
  model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=k9cwu3S-OUGFaOHxH6cyYS2JjxAYHfBddz0laf6Quds,3311
164
162
  model_compression_toolkit/core/keras/resource_utilization_data_facade.py,sha256=xxZlHyruhLuP2iEgMrZhq_AyAGORTqzweVLARFfpaRw,5643
165
163
  model_compression_toolkit/core/keras/tf_tensor_numpy.py,sha256=jzD8FGEEa8ZD7w8IpTRdp-Udf1MwOTgjg2XTS1Givic,2696
@@ -179,7 +177,7 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm
179
177
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/concat_threshold_update.py,sha256=Hl4LEQ_bw_Vpmf3ZqHujYUqVdvTNsPlEMvr9dZhwg2U,2806
180
178
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py,sha256=vZr8Agj-tFKSX7TM2nZjwbHElJqSIyMAaR7FH-lp4YM,11691
181
179
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py,sha256=nJO-JUmOK1lLb560KMJgLFwY2IOI2Y3lpzUq7o2f7mQ,5707
182
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py,sha256=OwHoCLA-upKUnRpyVWrO_E6QmZcxk6-pOKNpiI7kYzI,6044
180
+ model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py,sha256=R3minGoiZqhjpex52LmBPGvALez9KKiuZOQzqOB7OK0,6558
183
181
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py,sha256=AvquvVVVT8-ioeVn-gjqysK4L41L3I7TlNOEDfWjViY,8185
184
182
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py,sha256=9MZJp4GNTLesWN5uQ5eOQyAHLzLYDAHAjRi-LpNppSc,4257
185
183
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py,sha256=l9PUREBf4aRwWILiybdteveeUbh7js-i-hLt8Ma0e4c,26771
@@ -201,7 +199,7 @@ model_compression_toolkit/core/keras/mixed_precision/__init__.py,sha256=sw7LOPN1
201
199
  model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py,sha256=1p2DlMRmgzBAOUP-NeOzldTemjNLQQ3uf1Rov5iY-l8,5430
202
200
  model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py,sha256=GtW0yars8PzqP9uL_vfXrtqHwKiStmOxPng20rYaIjU,6805
203
201
  model_compression_toolkit/core/keras/pruning/__init__.py,sha256=3Lkr37Exk9u8811hw8hVqkGcbTQGcLjd3LLuLC3fa_E,698
204
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py,sha256=gqlssgSMN3TUzHD_Ple02m6rJHfcW9KpF2ZdTKlH4JM,11312
202
+ model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py,sha256=tE2-9KqDE2C9Ahs0FzwexCXfO7vviRCLgegn-l5eOTA,11528
205
203
  model_compression_toolkit/core/keras/quantization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
206
204
  model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py,sha256=RtQk5r-bZxUs10AFaJ813_rpkDmOwzWPv6zK6LbX4_8,1876
207
205
  model_compression_toolkit/core/keras/quantization/fake_quant_builder.py,sha256=vfKwU0AfRH2KztmMF5bxcaBlGdnTePPGZsUqOHzED-U,6854
@@ -223,7 +221,7 @@ model_compression_toolkit/core/keras/visualization/__init__.py,sha256=mjbqLD-KcG
223
221
  model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
224
222
  model_compression_toolkit/core/pytorch/constants.py,sha256=Sg0hkUaMe88mI2_pd3KqhVz5ORnA46S1uq9Tj5qhtHc,2828
225
223
  model_compression_toolkit/core/pytorch/data_util.py,sha256=YYbT135HhlTt0q6XdD2JX7AS_L92f_uV2rWq2hsJOCA,6325
226
- model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=pDUE-rwMhm1V1Y19_gwuZDfDCwKAu1ypBvU6XdURVjQ,4308
224
+ model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=HB9JjOBxa0JOLy2bOhDi2Jp2srK8M20g74jOmnUp54I,4305
227
225
  model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
228
226
  model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=cUQOBGwtG_DWpkrUEOcYSwXtNSmQgYVBCTxTpFiF4mo,27213
229
227
  model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=5hsp0nl6TewfrKsT133m9Z7DVpTFFftEv6DeZoryDZw,3009
@@ -272,7 +270,7 @@ model_compression_toolkit/core/pytorch/mixed_precision/__init__.py,sha256=Rf1RcY
272
270
  model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py,sha256=MTH7WsTpP-cTeMwaqrJPnhV_XdFKO6bySNalTONmr0w,4991
273
271
  model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py,sha256=KDnwmbhvhJMfNg1IuTvvzBNEriPQH9bL9dJ5VvWTzpE,6631
274
272
  model_compression_toolkit/core/pytorch/pruning/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
275
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py,sha256=axcG6BKC8gALjjrgOFpiB8b1VbySUyXZIHmzRxQYDoc,13085
273
+ model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py,sha256=jO-XbB6-ITnkBWMS4GO_s1-V3fhwfe4c183bMfucAsg,13367
276
274
  model_compression_toolkit/core/pytorch/quantization/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
277
275
  model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py,sha256=arslrOgJ1l-fScDlp6jNJ-JukKh0uBLcxAzjpDWRw94,1878
278
276
  model_compression_toolkit/core/pytorch/quantization/fake_quant_builder.py,sha256=D8_CEuFqKAhbUgKaRw7Jlxo0zlqgPTMu6CIIIM4LfS0,7045
@@ -368,7 +366,7 @@ model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCS
368
366
  model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=2hzWzsbuVd5XcL85NM57YeOyHxRY0qMArKn8NvQ1UWw,7643
369
367
  model_compression_toolkit/gptq/keras/gptq_training.py,sha256=_QwytOg1RQSg5Gvme089EME4trdTKGKM3JHIgT-b3n0,22841
370
368
  model_compression_toolkit/gptq/keras/graph_info.py,sha256=xpjEqiDo4mqW42QGdmyW31n5eWd6HbYyP6EbarN-A8A,4283
371
- model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=kA6omL9PoW1hAS2WHN2QoRR1pg2FZQTSyP3qMjDFEJ4,18647
369
+ model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=TBMfJAg_yERdd-OGHFzpV0v435QKM3FvIniwKHjw5i4,18493
372
370
  model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
373
371
  model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=Rbl9urzkmACvVxICSEyJ02qFOBxWK0UQWtysFJzBVZw,4899
374
372
  model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
@@ -403,14 +401,14 @@ model_compression_toolkit/pruning/pytorch/pruning_facade.py,sha256=YxRtJGzD6SjZ4
403
401
  model_compression_toolkit/ptq/__init__.py,sha256=Z_hkmTh7aLFei1DJKV0oNVUbrv_Q_0CTw-qD85Xf8UM,904
404
402
  model_compression_toolkit/ptq/runner.py,sha256=1tVx3Yj5X4ZjTH0REm6fuAmv4QZ4u_vixLsgjBwBzxc,2326
405
403
  model_compression_toolkit/ptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
406
- model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=_Do07apQ091WCOnVkgJcvnOX812AtXlW0HWx6q3SeRE,11587
404
+ model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=QieIYkoVtxWgHILGYeMHvSeIZ3mdbWdTR76MuUJ08-I,11433
407
405
  model_compression_toolkit/ptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
408
406
  model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=RruQVxS4ylBjSH1KMh8ZCV8jk3OvtSrQl24m3Q4xs_8,10065
409
407
  model_compression_toolkit/qat/__init__.py,sha256=AaC4KBha4jDW_tyg2SOxZaKh_idIz0gZtDK3_zxs64E,1241
410
408
  model_compression_toolkit/qat/common/__init__.py,sha256=6tLZ4R4pYP6QVztLVQC_jik2nES3l4uhML0qUxZrezk,829
411
409
  model_compression_toolkit/qat/common/qat_config.py,sha256=QNXj2OcKIJOGvGEGzR2GCifI5Ho7FS7zFc2fkj6PJAc,2750
412
410
  model_compression_toolkit/qat/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
413
- model_compression_toolkit/qat/keras/quantization_facade.py,sha256=V3-hAO9olSrLCDVezmH1WI8sLrg7q9OrPribL6wn7vI,17429
411
+ model_compression_toolkit/qat/keras/quantization_facade.py,sha256=kwPnkI8nHrUK650-MyJcLweOj7Q7IbFMbmMpP_lRwTg,17275
414
412
  model_compression_toolkit/qat/keras/quantizer/__init__.py,sha256=zmYyCa25_KLCSUCGUDRslh3RCIjcRMxc_oXa54Aui-4,996
415
413
  model_compression_toolkit/qat/keras/quantizer/base_keras_qat_weight_quantizer.py,sha256=EbIt4lMlh6cU4awFLMBp0IlZ2zUUp-WtnlW5Wn19FDM,1793
416
414
  model_compression_toolkit/qat/keras/quantizer/quant_utils.py,sha256=cBULOgWUodcBO1lHevZggdTevuDYI6tQceV86U2x6DA,2543
@@ -473,16 +471,16 @@ model_compression_toolkit/trainable_infrastructure/common/__init__.py,sha256=huH
473
471
  model_compression_toolkit/trainable_infrastructure/common/annealing_schedulers.py,sha256=qm2_wa61nga08Jdcl3RkgTsJ0zyHNjZ_A6I2--oVOig,2455
474
472
  model_compression_toolkit/trainable_infrastructure/common/base_trainable_quantizer.py,sha256=IF50ASBUvVrOVqlJ1nHNxZxKXSuCanjhUX0YjMB-rRg,7946
475
473
  model_compression_toolkit/trainable_infrastructure/common/constants.py,sha256=HN120boJxAnEXNrLSj-o_s-VX4o6C-1ap_KZ4840sd0,875
476
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py,sha256=Jxd4IjS_t0FwnA_S_WmZeVbh4VM6Da9ahKGPLp6ZhQo,6983
474
+ model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py,sha256=Wv333zLZH0z5N2Uio0beafrUVV-zZcxZvvSKr-fDChY,6839
477
475
  model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py,sha256=10edXuhu6F00EcMU7M29AlK7rF_uoLQjMjctrWqK5KU,3346
478
476
  model_compression_toolkit/trainable_infrastructure/common/quant_utils.py,sha256=zdiew1jwR7tUKm9XWlHnAPxIZsAdKqbzzC2vH02j5wA,1505
479
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py,sha256=UXeQpLKYus1BuAc6xKkDMq2iLQUR45s6ATJBa7z4el0,4736
477
+ model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py,sha256=VLj3YWKJi6MMbEcPq7dDd4DMsXorM6F4Zi7tlFB4YL4,4394
480
478
  model_compression_toolkit/trainable_infrastructure/common/training_method.py,sha256=LUoeJkloowhZKuHTiOfzjmSUn2G-4of11-rbnL-h0P4,1194
481
479
  model_compression_toolkit/trainable_infrastructure/common/util.py,sha256=oKuWi7E07a8zv5x9auhBugYE2RUQ7ojDh2XCs5koYJY,1090
482
480
  model_compression_toolkit/trainable_infrastructure/keras/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
483
481
  model_compression_toolkit/trainable_infrastructure/keras/annealing_schedulers.py,sha256=sISNVxPsdm-Nd95PhoPSJ-2tFpINGlfrU7ZXaCByI-o,1278
484
482
  model_compression_toolkit/trainable_infrastructure/keras/base_keras_quantizer.py,sha256=LBc26z8pkpbcdKMTxpNBg5IyChLreHQ1lRgCVjNE37o,4202
485
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py,sha256=zgGP7G5jXYFe7_hKw9jC2K0bnknKF3LiXpPpBtx-tVM,4304
483
+ model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py,sha256=rsyY7mTev5qRI0ncxuBilgYiRMd5GcuZt6qFyADU4H4,4123
486
484
  model_compression_toolkit/trainable_infrastructure/keras/load_model.py,sha256=DJHibcLo-UCuHV6UPLeVd7dKmPfkGXEiLqCCqvQrISM,3769
487
485
  model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha256=eVB5FSE3OmTLrhfLUcP2knwN1z2_unQLM-xFEGwdafA,5587
488
486
  model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=r3CaPd4pyM1GDXU2--9NT3wwvl9H6y3QUrVT9spx5es,4189
@@ -532,7 +530,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
532
530
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=Y0oBl8qPFsdNrK49XczwmVacInJcOPHslVnFBs-iTCc,3742
533
531
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
534
532
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=n0HvWBzkBkUJZlS3WeynhpsRTps2qQkjlq7luliBHNU,9627
535
- mct_nightly-2.4.0.20250706.701.dist-info/METADATA,sha256=xTdPfpQvxs-8MOaZMaXgKFzbpdzM6z9JhMkBbkKQRzE,25555
536
- mct_nightly-2.4.0.20250706.701.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
537
- mct_nightly-2.4.0.20250706.701.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
538
- mct_nightly-2.4.0.20250706.701.dist-info/RECORD,,
533
+ mct_nightly-2.4.0.20250708.612.dist-info/METADATA,sha256=DXcxuSC6OdHkGc28g74YdpfdWU7tRPQY-SPtmkfmJvw,25555
534
+ mct_nightly-2.4.0.20250708.612.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
535
+ mct_nightly-2.4.0.20250708.612.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
536
+ mct_nightly-2.4.0.20250708.612.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.4.0.20250706.000701"
30
+ __version__ = "2.4.0.20250708.000612"
@@ -13,11 +13,12 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from abc import ABC, abstractmethod
16
17
  import numpy as np
17
18
  from model_compression_toolkit.logger import Logger
18
19
 
19
20
 
20
- class BaseCollector(object):
21
+ class BaseCollector(ABC):
21
22
  """
22
23
  Base class for statistics collection object.
23
24
  """
@@ -26,6 +27,7 @@ class BaseCollector(object):
26
27
  # When manipulation statistics in a granularity they were not collected by, the data is invalid.
27
28
  self.is_legal = True
28
29
 
30
+ @abstractmethod
29
31
  def scale(self, scale_factor: np.ndarray):
30
32
  """
31
33
  Scale all statistics in collector by some factor.
@@ -37,6 +39,7 @@ class BaseCollector(object):
37
39
  raise NotImplemented(
38
40
  f'{self.__class__.__name__} needs to implement scale operation for its state.') # pragma: no cover
39
41
 
42
+ @abstractmethod
40
43
  def shift(self, shift_value: np.ndarray):
41
44
  """
42
45
  Shift all statistics in collector by some value.
@@ -87,10 +87,13 @@ class MeanCollector(BaseCollector):
87
87
  x: Tensor that goes through the mean collector and needs to be considered in the mean computation.
88
88
  """
89
89
  self.i += 1 # Update the iteration index
90
- axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
91
- n = x.shape[axis]
92
- transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
93
- mu = np.mean(np.reshape(np.transpose(x, transpose_index), [n, -1]), axis=-1) # mean per channel for a batch
90
+ if self.axis is None:
91
+ mu = np.mean(np.reshape(x, [1, -1]), axis=-1) # mean per channel for a batch
92
+ else:
93
+ axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
94
+ n = x.shape[axis]
95
+ transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
96
+ mu = np.mean(np.reshape(np.transpose(x, transpose_index), [n, -1]), axis=-1) # mean per channel for a batch
94
97
  self.current_sum += mu # sum of all batches
95
98
  self.current_mean = self.current_sum / self.i # mean of all batches
96
99
 
@@ -130,10 +130,13 @@ class MinMaxPerChannelCollector(BaseCollector):
130
130
  x: Tensor that goes through the collector and needs to be considered in the min/max computation.
131
131
  """
132
132
 
133
- axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
134
- n = x.shape[axis]
135
- transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
136
- x_reshape = np.reshape(np.transpose(x, transpose_index), [n, -1])
133
+ if self.axis is None:
134
+ x_reshape = np.reshape(x, [1, -1])
135
+ else:
136
+ axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
137
+ n = x.shape[axis]
138
+ transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
139
+ x_reshape = np.reshape(np.transpose(x, transpose_index), [n, -1])
137
140
  if self.state is None:
138
141
  x_max = np.max(x_reshape, axis=-1)
139
142
  x_min = np.min(x_reshape, axis=-1)
@@ -57,19 +57,21 @@ def create_stats_collector_for_node(node: common.BaseNode,
57
57
 
58
58
 
59
59
  def create_tensor2node(graph: common.Graph,
60
- node: common.BaseNode):
60
+ node: common.BaseNode,
61
+ next_node_output_channel_axis: int):
61
62
  """
62
63
  Force statistic collector creation and assignment for a node.
63
64
  Args:
64
65
  graph: Graph of the node (for retrieving the current tensor).
65
66
  node: Node to create a tensor for.
67
+ next_node_output_channel_axis: channel output axis of next node.
66
68
 
67
69
  """
68
70
  current_sc = graph.get_out_stats_collector(node)
69
71
  is_list_nostat_collectors = isinstance(current_sc, list) and len(
70
72
  [sc for sc in current_sc if not isinstance(sc, common.NoStatsCollector)]) == 0
71
73
  if isinstance(current_sc, common.NoStatsCollector) or current_sc is None or is_list_nostat_collectors:
72
- stats_collector = common.StatsCollector(node.out_channel_axis)
74
+ stats_collector = common.StatsCollector(next_node_output_channel_axis if node.out_channel_axis is None else node.out_channel_axis)
73
75
  graph.set_out_stats_collector_to_node(node, stats_collector)
74
76
 
75
77
 
@@ -157,6 +159,17 @@ class ModelCollector:
157
159
  for n in graph.get_topo_sorted_nodes():
158
160
  quant_node_in_fln = n.is_fln_quantization() and graph.fusing_info.is_quantized_node_in_fln(n)
159
161
  sc = create_stats_collector_for_node(n, quant_node_in_fln=quant_node_in_fln) # Get static collector for the node
162
+ if isinstance(sc, common.StatsCollector) and (sc.mc.axis is None or sc.mpcc.axis is None):
163
+ # Missing output channel axis info, so try to extract it from previous and next nodes output channel axis.
164
+ possible_output_channel_axis_set = {nn.out_channel_axis for nn in graph.get_next_nodes(n) + graph.get_prev_nodes(n)}
165
+ # Filter out None values.
166
+ possible_output_channel_axis_list = list(filter(lambda x: x is not None, possible_output_channel_axis_set))
167
+ if len(possible_output_channel_axis_list) > 0:
168
+ if len(possible_output_channel_axis_list) > 1:
169
+ Logger.warning(f'Ambiguous input channel data from next nodes for {n.name}.')
170
+ sc.mc.axis = possible_output_channel_axis_list[0]
171
+ sc.mpcc.axis = possible_output_channel_axis_list[0]
172
+
160
173
  # If we use bias correction, and the node has kernel weights to quantize, we need to make sure
161
174
  # its previous nodes' tensors are consistent with this node.
162
175
  if qc.weights_bias_correction and n.kernel_attr is not None and n.is_weights_quantization_enabled(
@@ -164,7 +177,8 @@ class ModelCollector:
164
177
  for ie in graph.incoming_edges(n):
165
178
  input_node = ie.source_node
166
179
  create_tensor2node(graph,
167
- input_node)
180
+ input_node,
181
+ n.out_channel_axis)
168
182
  if sc is not None:
169
183
  graph.set_out_stats_collector_to_node(n, sc)
170
184
 
@@ -303,7 +303,7 @@ class MemoryCalculator:
303
303
  num_oc = np.sum(output_mask)
304
304
  else:
305
305
  # Get the node channel axis from framework info
306
- channel_axis = node.out_channel_axis
306
+ channel_axis = self.fw_impl.default_output_channel_axis if node.out_channel_axis is None else node.out_channel_axis
307
307
  if channel_axis is None:
308
308
  Logger.critical(f"The channel axis is undefined. Please ensure the channel axis is explicitly defined for node {node.type} in the framework info.")
309
309
 
@@ -18,7 +18,6 @@ from enum import Enum, auto
18
18
  from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
19
19
  from model_compression_toolkit.logger import Logger
20
20
 
21
- from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
22
21
  from model_compression_toolkit.target_platform_capabilities.constants import POSITIONAL_ATTR
23
22
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import \
24
23
  AttributeQuantizationConfig, OpQuantizationConfig
@@ -41,6 +40,7 @@ class ActivationQuantizationMode(Enum):
41
40
  NO_QUANT = auto()
42
41
  FLN_NO_QUANT = auto()
43
42
 
43
+
44
44
  class BaseNodeQuantizationConfig(object):
45
45
  """
46
46
  Base class for node quantization configuration
@@ -59,12 +59,11 @@ class BaseNodeQuantizationConfig(object):
59
59
  kwargs: A dictionary with additional key arguments.
60
60
 
61
61
  """
62
-
63
62
  if hasattr(self, config_parameter_name):
64
63
  setattr(self, config_parameter_name, config_parameter_value)
65
64
  else:
66
- Logger.warning(f"Parameter {config_parameter_name} could not be found in the node quantization config and "
67
- f"was not updated!")
65
+ raise AttributeError(
66
+ f"Parameter {config_parameter_name} could not be found in the node quantization config.")
68
67
 
69
68
  def __repr__(self) -> str:
70
69
  """
@@ -97,36 +96,11 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
97
96
  self.signedness = op_cfg.signedness
98
97
 
99
98
  self.activation_quantization_params = {}
100
- # TODO irena: computed by compute_activation_bias_correction. shouldnt really be here
99
+ # TODO: computed by compute_activation_bias_correction. Probably shouldnt be here.
101
100
  self.activation_bias_correction_term = None
102
-
103
- # TODO irena remove along with set_qc. Keeping for eq and hash to work without set_qc being called
104
- self.activation_error_method = None
105
- self.relu_bound_to_power_of_2 = None
106
- self.activation_channel_equalization = None
107
- self.input_scaling = None
108
- self.min_threshold = None
109
- self.l_p_value = None
110
- self.shift_negative_activation_correction = None
101
+ # Z-threshold is a global param from QuantizationConfig, however it can be overridden per node by NetworkEditor.
102
+ # Since activation qparams are re-computed in several places, it's easier to keep it here and update it once.
111
103
  self.z_threshold = None
112
- self.shift_negative_ratio = None
113
- self.shift_negative_threshold_recalculation = None
114
- self.concat_threshold_update = None
115
-
116
- def set_qc(self, qc: QuantizationConfig):
117
- """ TODO irena: temporary keep all the attributes as before not to break all code at once.
118
- Eventually all of them should be removed from here. """
119
- self.activation_error_method = qc.activation_error_method
120
- self.relu_bound_to_power_of_2 = qc.relu_bound_to_power_of_2
121
- self.activation_channel_equalization = qc.activation_channel_equalization
122
- self.input_scaling = qc.input_scaling
123
- self.min_threshold = qc.min_threshold
124
- self.l_p_value = qc.l_p_value
125
- self.shift_negative_activation_correction = qc.shift_negative_activation_correction
126
- self.z_threshold = qc.z_threshold
127
- self.shift_negative_ratio = qc.shift_negative_ratio
128
- self.shift_negative_threshold_recalculation = qc.shift_negative_threshold_recalculation
129
- self.concat_threshold_update = qc.concat_threshold_update
130
104
 
131
105
  @property
132
106
  def enable_activation_quantization(self):
@@ -148,7 +122,7 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
148
122
  activation_params: Dictionary that contains weight quantization params.
149
123
 
150
124
  """
151
- assert self.quant_mode == ActivationQuantizationMode.QUANT
125
+ assert self.quant_mode == ActivationQuantizationMode.QUANT or self.quant_mode == ActivationQuantizationMode.FLN_QUANT
152
126
  for param_name, param_value in activation_params.items():
153
127
  self.activation_quantization_params[param_name] = param_value
154
128
 
@@ -165,32 +139,16 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
165
139
  if not isinstance(other, NodeActivationQuantizationConfig):
166
140
  return False # pragma: no cover
167
141
 
168
- return self.activation_error_method == other.activation_error_method and \
169
- self.activation_quantization_method == other.activation_quantization_method and \
142
+ return self.activation_quantization_method == other.activation_quantization_method and \
170
143
  self.activation_n_bits == other.activation_n_bits and \
171
144
  self.quant_mode == other.quant_mode and \
172
- self.activation_channel_equalization == other.activation_channel_equalization and \
173
- self.input_scaling == other.input_scaling and \
174
- self.min_threshold == other.min_threshold and \
175
- self.l_p_value == other.l_p_value and \
176
- self.shift_negative_activation_correction == other.shift_negative_activation_correction and \
177
- self.z_threshold == other.z_threshold and \
178
- self.shift_negative_ratio == other.shift_negative_ratio and \
179
- self.shift_negative_threshold_recalculation == other.shift_negative_threshold_recalculation
145
+ self.signedness == other.signedness
180
146
 
181
147
  def __hash__(self):
182
- return hash((self.activation_error_method,
183
- self.activation_quantization_method,
148
+ return hash((self.activation_quantization_method,
184
149
  self.activation_n_bits,
185
150
  self.quant_mode,
186
- self.activation_channel_equalization,
187
- self.input_scaling,
188
- self.min_threshold,
189
- self.l_p_value,
190
- self.shift_negative_activation_correction,
191
- self.z_threshold,
192
- self.shift_negative_ratio,
193
- self.shift_negative_threshold_recalculation))
151
+ self.signedness))
194
152
 
195
153
 
196
154
  class WeightsAttrQuantizationConfig:
@@ -211,16 +169,8 @@ class WeightsAttrQuantizationConfig:
211
169
  self.weights_n_bits = weights_attr_cfg.weights_n_bits
212
170
  self.weights_per_channel_threshold = weights_attr_cfg.weights_per_channel_threshold
213
171
  self.enable_weights_quantization = weights_attr_cfg.enable_weights_quantization
214
- self.weights_quantization_params = {}
215
172
 
216
- # TODO irena remove along with set_qc. Keeping for eq and hash to work without set_qc being called
217
- self.weights_error_method = None
218
- self.l_p_value = None
219
-
220
- def set_qc(self, qc: QuantizationConfig):
221
- # TODO irena: temporary keep the fields to not break everything at once.
222
- self.weights_error_method = qc.weights_error_method
223
- self.l_p_value = qc.l_p_value
173
+ self.weights_quantization_params = {}
224
174
 
225
175
  def set_weights_quantization_param(self,
226
176
  weights_params: dict):
@@ -252,18 +202,14 @@ class WeightsAttrQuantizationConfig:
252
202
  self.weights_quantization_method == other.weights_quantization_method and \
253
203
  self.weights_n_bits == other.weights_n_bits and \
254
204
  self.weights_per_channel_threshold == other.weights_per_channel_threshold and \
255
- self.enable_weights_quantization == other.enable_weights_quantization and \
256
- self.weights_error_method == other.weights_error_method and \
257
- self.l_p_value == other.l_p_value
205
+ self.enable_weights_quantization == other.enable_weights_quantization
258
206
 
259
207
  def __hash__(self):
260
208
  return hash((self.weights_channels_axis,
261
- self.weights_error_method,
262
209
  self.weights_quantization_method,
263
210
  self.weights_n_bits,
264
211
  self.weights_per_channel_threshold,
265
- self.enable_weights_quantization,
266
- self.l_p_value))
212
+ self.enable_weights_quantization))
267
213
 
268
214
 
269
215
  class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
@@ -330,16 +276,14 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
330
276
 
331
277
  self.attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(weights_attr_cfg=attr_cfg,
332
278
  weights_channels_axis=weights_channels_axis)
333
- # TODO irena remove along with set_qc. Keeping for eq and hash to work without set_qc being called
334
- self.min_threshold = None
279
+ # TODO this is set by batch norm reconstruction substitution when folded batch norms are added back, to mark
280
+ # the nodes that the correction should be applied to (for some nodes it gets disabled) and BNs removed.
281
+ # The actual correction is only computed when it's applied in ptq, so it seems that both substitutions could
282
+ # be unified, and no info need to pass between.
335
283
  self.weights_second_moment_correction = None
336
- self.weights_bias_correction = None
337
-
338
- def set_qc(self, qc: QuantizationConfig):
339
- # TODO irena: temporary keep the fields to not break everything at once.
340
- self.min_threshold = qc.min_threshold
341
- self.weights_second_moment_correction = qc.weights_second_moment_correction
342
- self.weights_bias_correction = qc.weights_bias_correction
284
+ # TODO: computed corrected bias is injected to the node config. Probably shouldn't be here. Also it can be
285
+ # computed on the final config, instead of all candidates and then there is no need to save it at all.
286
+ self.bias_corrected = None
343
287
 
344
288
  def get_attr_config(self, attr_name: 'WeightAttrT') -> WeightsAttrQuantizationConfig:
345
289
  """
@@ -476,8 +420,8 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
476
420
  if hasattr(attr_cfg, config_parameter_name):
477
421
  setattr(attr_cfg, config_parameter_name, config_parameter_value)
478
422
  else:
479
- Logger.warning(f"Parameter {config_parameter_name} could not be found in the node quantization config of "
480
- f"weights attribute {attr_name} and was not updated!")
423
+ raise AttributeError(f"Parameter {config_parameter_name} could not be found in the node quantization config of "
424
+ f"weights attribute {attr_name}.")
481
425
  else: # pragma: no cover
482
426
  Logger.critical(f"Weights attribute {attr_name} could not be found to set parameter {config_parameter_name}.")
483
427
 
@@ -494,10 +438,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
494
438
  if not isinstance(other, NodeWeightsQuantizationConfig):
495
439
  return False # pragma: no cover
496
440
 
497
- return self.min_threshold == other.min_threshold and \
498
- self.simd_size == other.simd_size and \
499
- self.weights_second_moment_correction == other.weights_second_moment_correction and \
500
- self.weights_bias_correction == other.weights_bias_correction and \
441
+ return self.simd_size == other.simd_size and \
501
442
  self.attributes_config_mapping.keys() == other.attributes_config_mapping.keys() and \
502
443
  all([self.attributes_config_mapping[k] == other.attributes_config_mapping[k]
503
444
  for k in self.attributes_config_mapping.keys()]) and \
@@ -506,9 +447,6 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
506
447
  for k in self.pos_attributes_config_mapping.keys()])
507
448
 
508
449
  def __hash__(self):
509
- return hash((self.min_threshold,
510
- self.simd_size,
511
- self.weights_second_moment_correction,
512
- self.weights_bias_correction,
450
+ return hash((self.simd_size,
513
451
  frozenset(self.attributes_config_mapping),
514
452
  frozenset(self.pos_attributes_config_mapping)))
@@ -90,7 +90,6 @@ class QuantizationConfig:
90
90
  shift_negative_activation_correction: bool = True
91
91
  activation_channel_equalization: bool = False
92
92
  z_threshold: float = math.inf
93
- min_threshold: float = MIN_THRESHOLD
94
93
  l_p_value: int = 2
95
94
  linear_collapsing: bool = True
96
95
  residual_collapsing: bool = True