mct-nightly 2.4.0.20250706.701__py3-none-any.whl → 2.4.0.20250708.612__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250708.612.dist-info}/METADATA +1 -1
- {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250708.612.dist-info}/RECORD +36 -38
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +4 -1
- model_compression_toolkit/core/common/collectors/mean_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +7 -4
- model_compression_toolkit/core/common/model_collector.py +17 -3
- model_compression_toolkit/core/common/pruning/memory_calculator.py +1 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +25 -87
- model_compression_toolkit/core/common/quantization/quantization_config.py +0 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +26 -17
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +27 -49
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +12 -7
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +4 -14
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -1
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +4 -13
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +5 -7
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +7 -5
- model_compression_toolkit/core/graph_prep_runner.py +1 -11
- model_compression_toolkit/core/keras/default_framework_info.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +21 -11
- model_compression_toolkit/core/keras/keras_implementation.py +2 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +8 -0
- model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +9 -1
- model_compression_toolkit/core/quantization_prep_runner.py +2 -2
- model_compression_toolkit/gptq/keras/quantization_facade.py +0 -3
- model_compression_toolkit/ptq/keras/quantization_facade.py +0 -3
- model_compression_toolkit/qat/keras/quantization_facade.py +0 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +0 -2
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +0 -6
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +2 -4
- model_compression_toolkit/core/common/model_validation.py +0 -41
- model_compression_toolkit/core/keras/keras_model_validation.py +0 -37
- {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250708.612.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250708.612.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250708.612.dist-info}/top_level.txt +0 -0
{mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250708.612.dist-info}/RECORD
RENAMED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
mct_nightly-2.4.0.
|
|
2
|
-
model_compression_toolkit/__init__.py,sha256=
|
|
1
|
+
mct_nightly-2.4.0.20250708.612.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
|
2
|
+
model_compression_toolkit/__init__.py,sha256=TrwEn4n4YKNNqTg96Ud45uv2qh_Ob-Q1IzRMrsqjJG4,1557
|
|
3
3
|
model_compression_toolkit/constants.py,sha256=KNgiNLpsMgSYyXMNEbHXd4bFNerQc1D6HH3vpbUq_Gs,4086
|
|
4
4
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
|
5
5
|
model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
|
|
@@ -7,8 +7,8 @@ model_compression_toolkit/metadata.py,sha256=x_Bk4VpzILdsFax6--CZ3X18qUTP28sbF_A
|
|
|
7
7
|
model_compression_toolkit/verify_packages.py,sha256=l0neIRr8q_QwxmuiTI4vyCMDISDedK0EihjEQUe66tE,1319
|
|
8
8
|
model_compression_toolkit/core/__init__.py,sha256=HNverPpoqEyFKTa7iEdOqqY2P0Gq-7GMejNOi6ZPcQs,2042
|
|
9
9
|
model_compression_toolkit/core/analyzer.py,sha256=5P03LbkFy-mu31TMAiQoIKcsA1-DNz7cTzkGvRaXtbw,3505
|
|
10
|
-
model_compression_toolkit/core/graph_prep_runner.py,sha256=
|
|
11
|
-
model_compression_toolkit/core/quantization_prep_runner.py,sha256=
|
|
10
|
+
model_compression_toolkit/core/graph_prep_runner.py,sha256=XvuR1lNsbaGG3HyROA8nF2n0oPNb1Tnw6rGatewAvE0,10563
|
|
11
|
+
model_compression_toolkit/core/quantization_prep_runner.py,sha256=W6CAJ_Euxrjm5jj__QqfB4xsD8NYx_-XWX22h0ax6B4,6262
|
|
12
12
|
model_compression_toolkit/core/runner.py,sha256=QpiJQmQXK6mWmnygNRdy6I8S48DHV-B0Kmr4TqOKbeA,12418
|
|
13
13
|
model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
|
|
14
14
|
model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
|
|
@@ -16,18 +16,17 @@ model_compression_toolkit/core/common/framework_implementation.py,sha256=jrTupZb
|
|
|
16
16
|
model_compression_toolkit/core/common/framework_info.py,sha256=vPGV28gm-kvNSkkWI6jY3YeKBUYmn6UQ98HVUnl_-tM,5449
|
|
17
17
|
model_compression_toolkit/core/common/memory_computation.py,sha256=ixoSpV5ZYZGyzhre3kQcvR2sNA8KBsPZ3lgbkDnw9Cs,1205
|
|
18
18
|
model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3ug7Y9-lLyV99_FoNHxkGZMgcm0Vkpss,1324
|
|
19
|
-
model_compression_toolkit/core/common/model_collector.py,sha256=
|
|
20
|
-
model_compression_toolkit/core/common/model_validation.py,sha256=HRnYh2uY85yJ7Ijmt4tKRn8bMg60zbBSDRCgK246gUM,1067
|
|
19
|
+
model_compression_toolkit/core/common/model_collector.py,sha256=I0O5SoUrwB045AGLleOGYVvZyR7HJx_t6vfbECahfoU,13914
|
|
21
20
|
model_compression_toolkit/core/common/node_prior_info.py,sha256=WXX_PrGVG9M9I_REG5ZzFBohwmV4yf356sZnrja_FLo,2832
|
|
22
21
|
model_compression_toolkit/core/common/similarity_analyzer.py,sha256=S3f6WgHyw62dGcxpX51FGKyfebe2zv9ABKbjtGyKRvY,9215
|
|
23
22
|
model_compression_toolkit/core/common/user_info.py,sha256=dSRMnT-oewmdOziIpEuW-s9K7vTSeyUBxT4z9neXurI,1648
|
|
24
23
|
model_compression_toolkit/core/common/back2framework/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
25
24
|
model_compression_toolkit/core/common/back2framework/base_model_builder.py,sha256=yrIxT0ttDi9XViy8Zt8apnMCT8xDyVd5HZp0IttrGGQ,1775
|
|
26
25
|
model_compression_toolkit/core/common/collectors/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
|
27
|
-
model_compression_toolkit/core/common/collectors/base_collector.py,sha256=
|
|
26
|
+
model_compression_toolkit/core/common/collectors/base_collector.py,sha256=n3IIwfWVIl8TO_7MR9jBBxzf2foysWMyrGS-oaCuKiA,2664
|
|
28
27
|
model_compression_toolkit/core/common/collectors/histogram_collector.py,sha256=zra5V06Brpjc1cUNIMVVGqdoqAuro62_hGy2Zm5-XMQ,6754
|
|
29
|
-
model_compression_toolkit/core/common/collectors/mean_collector.py,sha256=
|
|
30
|
-
model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py,sha256=
|
|
28
|
+
model_compression_toolkit/core/common/collectors/mean_collector.py,sha256=7-oCVU76e9ZWIZLiLte-z07FX2Rjmu5Go-KRj7b7oWY,3564
|
|
29
|
+
model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py,sha256=pAdiLXNNm5sfIZ584m2MzJQ0PTNp6MKsrMxJ_onEN1s,5314
|
|
31
30
|
model_compression_toolkit/core/common/collectors/statistics_collector.py,sha256=psijsQZefwjMDH8SU5E18n65HiGtQilPhKr1hhzZX-I,8268
|
|
32
31
|
model_compression_toolkit/core/common/collectors/weighted_histogram_collector.py,sha256=zp3dE7YTqWmkD5QWdRhsl9zD8W6Lr96G1Wjw1g2D3T0,4894
|
|
33
32
|
model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
|
@@ -87,7 +86,7 @@ model_compression_toolkit/core/common/network_editors/node_filters.py,sha256=Pc_
|
|
|
87
86
|
model_compression_toolkit/core/common/pruning/__init__.py,sha256=DGJybkDQtKMSMFoZ-nZ3ZifA8uJ6G_D20wHhKHNlmU0,699
|
|
88
87
|
model_compression_toolkit/core/common/pruning/channels_grouping.py,sha256=-zrq0TsfVE4ooxOcJCsL8H2DBau6vSkEKz1ot-x-Faw,3736
|
|
89
88
|
model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py,sha256=UZekrges7gZv17JFLX_AV2Kv0eBXXarMNInzuOTlyvA,7712
|
|
90
|
-
model_compression_toolkit/core/common/pruning/memory_calculator.py,sha256=
|
|
89
|
+
model_compression_toolkit/core/common/pruning/memory_calculator.py,sha256=7IdX9APlS17omWeF_y0ZFr8S2K7Cp-M0u8rFtfmrx8U,19023
|
|
91
90
|
model_compression_toolkit/core/common/pruning/prune_graph.py,sha256=eGvuqrxyADRSvhKz0R_7lLfIl7bKnn4bryElu3LsVcA,3158
|
|
92
91
|
model_compression_toolkit/core/common/pruning/pruner.py,sha256=Zl0IK0anorzagaSP8qXMN31Dtw5m-Le-JRy2baPLs6M,7262
|
|
93
92
|
model_compression_toolkit/core/common/pruning/pruning_config.py,sha256=fbqERt11FGVeuqPVA6nVbgGDh6Ox9mpEKdxVJT8eG4I,3681
|
|
@@ -107,8 +106,8 @@ model_compression_toolkit/core/common/quantization/candidate_node_quantization_c
|
|
|
107
106
|
model_compression_toolkit/core/common/quantization/core_config.py,sha256=yxCzWqldcHoe8GGxrH0tp99bhrc5jDT7SgZftnMUUBE,2374
|
|
108
107
|
model_compression_toolkit/core/common/quantization/debug_config.py,sha256=uH45Uq3Tp9FIyMynex_WY2_y-Kv8LuPw2XXZydnpW5A,1649
|
|
109
108
|
model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=FyYCYbfkAofEWO2mAvFIppPeq2I10f1ScPNiVa9F7x4,7687
|
|
110
|
-
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=
|
|
111
|
-
model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=
|
|
109
|
+
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=VRY_HxZl9D77eq6oJ61eBnL4QJTG5pVLzJtAcaZATRQ,21636
|
|
110
|
+
model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=EMhXeY1qkvwlMAY5fpKRDuLEAyRY5yKqe2fOmAD_rVI,4362
|
|
112
111
|
model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=VVq2cKjumlNWucUbaNw8s2J0IbI_vrQ-KR_eQPshGSg,3140
|
|
113
112
|
model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=N005MSvx8UypVpa7XrxNrB2G732n2wHj3RmLyjTgd3I,2728
|
|
114
113
|
model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=WJ-lsT_R_pqjbrMzgcposugACDNz7yZ09vSlltTb78A,3001
|
|
@@ -118,10 +117,10 @@ model_compression_toolkit/core/common/quantization/quantization_params_generatio
|
|
|
118
117
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=RL-PklAjGyC-26anSt8fU07a6pB_LBQFQy9o4e9giN0,8739
|
|
119
118
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/outlier_filter.py,sha256=9gnfJV89jpGwAx8ImJ5E9NjCv3lDtbyulP4OtgWb62M,1772
|
|
120
119
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py,sha256=-cghHF5S11qbjTDRruHlc__uaDoofZHl7QTl8hCeKW0,11141
|
|
121
|
-
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py,sha256=
|
|
122
|
-
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=
|
|
120
|
+
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py,sha256=vOFquQ3_h0jyG8EEeHM8M57uIKMZQ7iobUgbvWbiXh4,7798
|
|
121
|
+
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=TPSkKujdJ6jB7VOVp0kRgMnthmeuaBZgN1HQuJ2pqR0,7951
|
|
123
122
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py,sha256=Nv_b3DECVjQnlrUet2kbuSvSKVnxcc-gf2zhFb2jSZk,43482
|
|
124
|
-
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py,sha256=
|
|
123
|
+
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py,sha256=saOQMtj1qYqgQoAFjq31p7xtiRDxmanGGdm0DE81_cg,4820
|
|
125
124
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py,sha256=6tRNgWvn-4r8hiSHqND7Qms1Nje1DUR4MR0JeWCNyvI,12531
|
|
126
125
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py,sha256=xiZgCkoIrJ9xsR17x9pSl_sUbiuSta67kf7bQ4quFUI,10804
|
|
127
126
|
model_compression_toolkit/core/common/quantization/quantizers/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
|
@@ -130,22 +129,22 @@ model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers
|
|
|
130
129
|
model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py,sha256=wXExWHf5-0He7L4bpvFpKlx7FG4u3DAfNZiXPpOs_SQ,5521
|
|
131
130
|
model_compression_toolkit/core/common/statistics_correction/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
|
132
131
|
model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py,sha256=oUa1Gv9jIICOoFljTiIaItFjJQPht7CBe-wEr3iBuLQ,4118
|
|
133
|
-
model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py,sha256=
|
|
134
|
-
model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py,sha256=
|
|
132
|
+
model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py,sha256=fvvIH5uWtQbmPJQh_zjE-XuCFP9YrlOEs3mW9ysKhms,3652
|
|
133
|
+
model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py,sha256=7HIoey1tqbTy1tnTF_KF5D-dSNankamBj8kWO9kUQYo,5200
|
|
135
134
|
model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py,sha256=289b2iwzp2hjsgpEZotQKNB2aPKjAZopRaGnbzErHV8,9263
|
|
136
|
-
model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py,sha256=
|
|
137
|
-
model_compression_toolkit/core/common/statistics_correction/statistics_correction.py,sha256=
|
|
135
|
+
model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py,sha256=p4u50-1CPuGTQtMDz5uaZqFSLhl2BLTbCLDQvplCuW4,9552
|
|
136
|
+
model_compression_toolkit/core/common/statistics_correction/statistics_correction.py,sha256=F0xOl97p8Q0h4MXxOKo0kGY4iJIiJou06JbICatg8vQ,5881
|
|
138
137
|
model_compression_toolkit/core/common/substitutions/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
|
139
138
|
model_compression_toolkit/core/common/substitutions/apply_substitutions.py,sha256=k-bifmakHIYZeZS-4T1QpZ1Et6AwAijMRgAKs7hmMKc,1390
|
|
140
139
|
model_compression_toolkit/core/common/substitutions/batchnorm_folding.py,sha256=wLlTT7sqUffKHwOrMG2VV5SktQkkP54l8taW1Fq0mh0,13392
|
|
141
|
-
model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py,sha256=
|
|
140
|
+
model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py,sha256=xa-6ACozqTu4xWXkC-7p6YTZLcLcH_AUYEROTl8USWc,8014
|
|
142
141
|
model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py,sha256=eCbhbAzgXWoVymMLbrupJ1qAcdhZDwkjKeja0fCymnY,9746
|
|
143
142
|
model_compression_toolkit/core/common/substitutions/linear_collapsing.py,sha256=iEtzbWCDXP6EDkTZCtREQ0rpMxhQ2kM9zlcP_0KLq9I,12367
|
|
144
143
|
model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py,sha256=uoauhmncQqUBNvD-qCLIXsIbl_IzrbxSKdxiMig-5W4,2406
|
|
145
144
|
model_compression_toolkit/core/common/substitutions/remove_identity.py,sha256=TKU1TIU52UIkVnl0EZvWnDhLV9nIVZ4hqi-w1i4NXMk,2637
|
|
146
145
|
model_compression_toolkit/core/common/substitutions/residual_collapsing.py,sha256=N82mso5j3EJQlKt9EMHjjEJ67FmdGQeCfN8U5grOFXo,4830
|
|
147
146
|
model_compression_toolkit/core/common/substitutions/scale_equalization.py,sha256=2_NmmBmUBZZwXuF5Od2S919_FgQKYIf-nSyNPawr0e4,9840
|
|
148
|
-
model_compression_toolkit/core/common/substitutions/shift_negative_activation.py,sha256=
|
|
147
|
+
model_compression_toolkit/core/common/substitutions/shift_negative_activation.py,sha256=_TliV1tBxizPsCZVbOhDhg3mSc_m0sY9gVqQMYTwQoc,32739
|
|
149
148
|
model_compression_toolkit/core/common/substitutions/softmax_shift.py,sha256=R-0ZqhYAuZLEFWHvB2UTPm52L6gWHGdRdEnwGxKSeGI,2625
|
|
150
149
|
model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py,sha256=cokiYPZB7504oHTlgZy8u2Xv_S-RK_oDSnGvYRX3JK4,4136
|
|
151
150
|
model_compression_toolkit/core/common/substitutions/weights_activation_split.py,sha256=vafrJ6eA37PrIzOs7uOsiJKIBmAVmNJ-wXsoe332BIw,4683
|
|
@@ -157,9 +156,8 @@ model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7V
|
|
|
157
156
|
model_compression_toolkit/core/keras/constants.py,sha256=dh4elQWt6Q6NYRht5k5RiiOcnLAq1v0MMBCJqMJzzFk,3225
|
|
158
157
|
model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
|
|
159
158
|
model_compression_toolkit/core/keras/data_util.py,sha256=jm54o-SlI1DJ-sEvRuX9OyLN68tEt0VxcqrdIjR98Ag,8366
|
|
160
|
-
model_compression_toolkit/core/keras/default_framework_info.py,sha256=
|
|
161
|
-
model_compression_toolkit/core/keras/keras_implementation.py,sha256=
|
|
162
|
-
model_compression_toolkit/core/keras/keras_model_validation.py,sha256=dMS9cqaYmliyzVu2-MrKx4AIubqz3HW3RY4if2TV6U8,1581
|
|
159
|
+
model_compression_toolkit/core/keras/default_framework_info.py,sha256=wI-M_MsVIL0qOdBr8F-oyrRz4qdc2o6DDB7lGlizddw,6171
|
|
160
|
+
model_compression_toolkit/core/keras/keras_implementation.py,sha256=5oAc0nc21Uzfo94rKUGGpbYVTmEiKA3YGNhbOwD66ng,28550
|
|
163
161
|
model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=k9cwu3S-OUGFaOHxH6cyYS2JjxAYHfBddz0laf6Quds,3311
|
|
164
162
|
model_compression_toolkit/core/keras/resource_utilization_data_facade.py,sha256=xxZlHyruhLuP2iEgMrZhq_AyAGORTqzweVLARFfpaRw,5643
|
|
165
163
|
model_compression_toolkit/core/keras/tf_tensor_numpy.py,sha256=jzD8FGEEa8ZD7w8IpTRdp-Udf1MwOTgjg2XTS1Givic,2696
|
|
@@ -179,7 +177,7 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm
|
|
|
179
177
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/concat_threshold_update.py,sha256=Hl4LEQ_bw_Vpmf3ZqHujYUqVdvTNsPlEMvr9dZhwg2U,2806
|
|
180
178
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py,sha256=vZr8Agj-tFKSX7TM2nZjwbHElJqSIyMAaR7FH-lp4YM,11691
|
|
181
179
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py,sha256=nJO-JUmOK1lLb560KMJgLFwY2IOI2Y3lpzUq7o2f7mQ,5707
|
|
182
|
-
model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py,sha256=
|
|
180
|
+
model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py,sha256=R3minGoiZqhjpex52LmBPGvALez9KKiuZOQzqOB7OK0,6558
|
|
183
181
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py,sha256=AvquvVVVT8-ioeVn-gjqysK4L41L3I7TlNOEDfWjViY,8185
|
|
184
182
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py,sha256=9MZJp4GNTLesWN5uQ5eOQyAHLzLYDAHAjRi-LpNppSc,4257
|
|
185
183
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py,sha256=l9PUREBf4aRwWILiybdteveeUbh7js-i-hLt8Ma0e4c,26771
|
|
@@ -201,7 +199,7 @@ model_compression_toolkit/core/keras/mixed_precision/__init__.py,sha256=sw7LOPN1
|
|
|
201
199
|
model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py,sha256=1p2DlMRmgzBAOUP-NeOzldTemjNLQQ3uf1Rov5iY-l8,5430
|
|
202
200
|
model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py,sha256=GtW0yars8PzqP9uL_vfXrtqHwKiStmOxPng20rYaIjU,6805
|
|
203
201
|
model_compression_toolkit/core/keras/pruning/__init__.py,sha256=3Lkr37Exk9u8811hw8hVqkGcbTQGcLjd3LLuLC3fa_E,698
|
|
204
|
-
model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py,sha256=
|
|
202
|
+
model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py,sha256=tE2-9KqDE2C9Ahs0FzwexCXfO7vviRCLgegn-l5eOTA,11528
|
|
205
203
|
model_compression_toolkit/core/keras/quantization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
|
206
204
|
model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py,sha256=RtQk5r-bZxUs10AFaJ813_rpkDmOwzWPv6zK6LbX4_8,1876
|
|
207
205
|
model_compression_toolkit/core/keras/quantization/fake_quant_builder.py,sha256=vfKwU0AfRH2KztmMF5bxcaBlGdnTePPGZsUqOHzED-U,6854
|
|
@@ -223,7 +221,7 @@ model_compression_toolkit/core/keras/visualization/__init__.py,sha256=mjbqLD-KcG
|
|
|
223
221
|
model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
|
224
222
|
model_compression_toolkit/core/pytorch/constants.py,sha256=Sg0hkUaMe88mI2_pd3KqhVz5ORnA46S1uq9Tj5qhtHc,2828
|
|
225
223
|
model_compression_toolkit/core/pytorch/data_util.py,sha256=YYbT135HhlTt0q6XdD2JX7AS_L92f_uV2rWq2hsJOCA,6325
|
|
226
|
-
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=
|
|
224
|
+
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=HB9JjOBxa0JOLy2bOhDi2Jp2srK8M20g74jOmnUp54I,4305
|
|
227
225
|
model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
|
|
228
226
|
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=cUQOBGwtG_DWpkrUEOcYSwXtNSmQgYVBCTxTpFiF4mo,27213
|
|
229
227
|
model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=5hsp0nl6TewfrKsT133m9Z7DVpTFFftEv6DeZoryDZw,3009
|
|
@@ -272,7 +270,7 @@ model_compression_toolkit/core/pytorch/mixed_precision/__init__.py,sha256=Rf1RcY
|
|
|
272
270
|
model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py,sha256=MTH7WsTpP-cTeMwaqrJPnhV_XdFKO6bySNalTONmr0w,4991
|
|
273
271
|
model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py,sha256=KDnwmbhvhJMfNg1IuTvvzBNEriPQH9bL9dJ5VvWTzpE,6631
|
|
274
272
|
model_compression_toolkit/core/pytorch/pruning/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
|
|
275
|
-
model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py,sha256=
|
|
273
|
+
model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py,sha256=jO-XbB6-ITnkBWMS4GO_s1-V3fhwfe4c183bMfucAsg,13367
|
|
276
274
|
model_compression_toolkit/core/pytorch/quantization/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
|
277
275
|
model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py,sha256=arslrOgJ1l-fScDlp6jNJ-JukKh0uBLcxAzjpDWRw94,1878
|
|
278
276
|
model_compression_toolkit/core/pytorch/quantization/fake_quant_builder.py,sha256=D8_CEuFqKAhbUgKaRw7Jlxo0zlqgPTMu6CIIIM4LfS0,7045
|
|
@@ -368,7 +366,7 @@ model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCS
|
|
|
368
366
|
model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=2hzWzsbuVd5XcL85NM57YeOyHxRY0qMArKn8NvQ1UWw,7643
|
|
369
367
|
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=_QwytOg1RQSg5Gvme089EME4trdTKGKM3JHIgT-b3n0,22841
|
|
370
368
|
model_compression_toolkit/gptq/keras/graph_info.py,sha256=xpjEqiDo4mqW42QGdmyW31n5eWd6HbYyP6EbarN-A8A,4283
|
|
371
|
-
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=
|
|
369
|
+
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=TBMfJAg_yERdd-OGHFzpV0v435QKM3FvIniwKHjw5i4,18493
|
|
372
370
|
model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
|
|
373
371
|
model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=Rbl9urzkmACvVxICSEyJ02qFOBxWK0UQWtysFJzBVZw,4899
|
|
374
372
|
model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
|
|
@@ -403,14 +401,14 @@ model_compression_toolkit/pruning/pytorch/pruning_facade.py,sha256=YxRtJGzD6SjZ4
|
|
|
403
401
|
model_compression_toolkit/ptq/__init__.py,sha256=Z_hkmTh7aLFei1DJKV0oNVUbrv_Q_0CTw-qD85Xf8UM,904
|
|
404
402
|
model_compression_toolkit/ptq/runner.py,sha256=1tVx3Yj5X4ZjTH0REm6fuAmv4QZ4u_vixLsgjBwBzxc,2326
|
|
405
403
|
model_compression_toolkit/ptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
406
|
-
model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=
|
|
404
|
+
model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=QieIYkoVtxWgHILGYeMHvSeIZ3mdbWdTR76MuUJ08-I,11433
|
|
407
405
|
model_compression_toolkit/ptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
408
406
|
model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=RruQVxS4ylBjSH1KMh8ZCV8jk3OvtSrQl24m3Q4xs_8,10065
|
|
409
407
|
model_compression_toolkit/qat/__init__.py,sha256=AaC4KBha4jDW_tyg2SOxZaKh_idIz0gZtDK3_zxs64E,1241
|
|
410
408
|
model_compression_toolkit/qat/common/__init__.py,sha256=6tLZ4R4pYP6QVztLVQC_jik2nES3l4uhML0qUxZrezk,829
|
|
411
409
|
model_compression_toolkit/qat/common/qat_config.py,sha256=QNXj2OcKIJOGvGEGzR2GCifI5Ho7FS7zFc2fkj6PJAc,2750
|
|
412
410
|
model_compression_toolkit/qat/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
413
|
-
model_compression_toolkit/qat/keras/quantization_facade.py,sha256=
|
|
411
|
+
model_compression_toolkit/qat/keras/quantization_facade.py,sha256=kwPnkI8nHrUK650-MyJcLweOj7Q7IbFMbmMpP_lRwTg,17275
|
|
414
412
|
model_compression_toolkit/qat/keras/quantizer/__init__.py,sha256=zmYyCa25_KLCSUCGUDRslh3RCIjcRMxc_oXa54Aui-4,996
|
|
415
413
|
model_compression_toolkit/qat/keras/quantizer/base_keras_qat_weight_quantizer.py,sha256=EbIt4lMlh6cU4awFLMBp0IlZ2zUUp-WtnlW5Wn19FDM,1793
|
|
416
414
|
model_compression_toolkit/qat/keras/quantizer/quant_utils.py,sha256=cBULOgWUodcBO1lHevZggdTevuDYI6tQceV86U2x6DA,2543
|
|
@@ -473,16 +471,16 @@ model_compression_toolkit/trainable_infrastructure/common/__init__.py,sha256=huH
|
|
|
473
471
|
model_compression_toolkit/trainable_infrastructure/common/annealing_schedulers.py,sha256=qm2_wa61nga08Jdcl3RkgTsJ0zyHNjZ_A6I2--oVOig,2455
|
|
474
472
|
model_compression_toolkit/trainable_infrastructure/common/base_trainable_quantizer.py,sha256=IF50ASBUvVrOVqlJ1nHNxZxKXSuCanjhUX0YjMB-rRg,7946
|
|
475
473
|
model_compression_toolkit/trainable_infrastructure/common/constants.py,sha256=HN120boJxAnEXNrLSj-o_s-VX4o6C-1ap_KZ4840sd0,875
|
|
476
|
-
model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py,sha256=
|
|
474
|
+
model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py,sha256=Wv333zLZH0z5N2Uio0beafrUVV-zZcxZvvSKr-fDChY,6839
|
|
477
475
|
model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py,sha256=10edXuhu6F00EcMU7M29AlK7rF_uoLQjMjctrWqK5KU,3346
|
|
478
476
|
model_compression_toolkit/trainable_infrastructure/common/quant_utils.py,sha256=zdiew1jwR7tUKm9XWlHnAPxIZsAdKqbzzC2vH02j5wA,1505
|
|
479
|
-
model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py,sha256=
|
|
477
|
+
model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py,sha256=VLj3YWKJi6MMbEcPq7dDd4DMsXorM6F4Zi7tlFB4YL4,4394
|
|
480
478
|
model_compression_toolkit/trainable_infrastructure/common/training_method.py,sha256=LUoeJkloowhZKuHTiOfzjmSUn2G-4of11-rbnL-h0P4,1194
|
|
481
479
|
model_compression_toolkit/trainable_infrastructure/common/util.py,sha256=oKuWi7E07a8zv5x9auhBugYE2RUQ7ojDh2XCs5koYJY,1090
|
|
482
480
|
model_compression_toolkit/trainable_infrastructure/keras/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
|
483
481
|
model_compression_toolkit/trainable_infrastructure/keras/annealing_schedulers.py,sha256=sISNVxPsdm-Nd95PhoPSJ-2tFpINGlfrU7ZXaCByI-o,1278
|
|
484
482
|
model_compression_toolkit/trainable_infrastructure/keras/base_keras_quantizer.py,sha256=LBc26z8pkpbcdKMTxpNBg5IyChLreHQ1lRgCVjNE37o,4202
|
|
485
|
-
model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py,sha256=
|
|
483
|
+
model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py,sha256=rsyY7mTev5qRI0ncxuBilgYiRMd5GcuZt6qFyADU4H4,4123
|
|
486
484
|
model_compression_toolkit/trainable_infrastructure/keras/load_model.py,sha256=DJHibcLo-UCuHV6UPLeVd7dKmPfkGXEiLqCCqvQrISM,3769
|
|
487
485
|
model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha256=eVB5FSE3OmTLrhfLUcP2knwN1z2_unQLM-xFEGwdafA,5587
|
|
488
486
|
model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=r3CaPd4pyM1GDXU2--9NT3wwvl9H6y3QUrVT9spx5es,4189
|
|
@@ -532,7 +530,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
|
532
530
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=Y0oBl8qPFsdNrK49XczwmVacInJcOPHslVnFBs-iTCc,3742
|
|
533
531
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
|
534
532
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=n0HvWBzkBkUJZlS3WeynhpsRTps2qQkjlq7luliBHNU,9627
|
|
535
|
-
mct_nightly-2.4.0.
|
|
536
|
-
mct_nightly-2.4.0.
|
|
537
|
-
mct_nightly-2.4.0.
|
|
538
|
-
mct_nightly-2.4.0.
|
|
533
|
+
mct_nightly-2.4.0.20250708.612.dist-info/METADATA,sha256=DXcxuSC6OdHkGc28g74YdpfdWU7tRPQY-SPtmkfmJvw,25555
|
|
534
|
+
mct_nightly-2.4.0.20250708.612.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
535
|
+
mct_nightly-2.4.0.20250708.612.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
|
536
|
+
mct_nightly-2.4.0.20250708.612.dist-info/RECORD,,
|
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
|
27
27
|
from model_compression_toolkit import pruning
|
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
|
29
29
|
|
|
30
|
-
__version__ = "2.4.0.
|
|
30
|
+
__version__ = "2.4.0.20250708.000612"
|
|
@@ -13,11 +13,12 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
+
from abc import ABC, abstractmethod
|
|
16
17
|
import numpy as np
|
|
17
18
|
from model_compression_toolkit.logger import Logger
|
|
18
19
|
|
|
19
20
|
|
|
20
|
-
class BaseCollector(
|
|
21
|
+
class BaseCollector(ABC):
|
|
21
22
|
"""
|
|
22
23
|
Base class for statistics collection object.
|
|
23
24
|
"""
|
|
@@ -26,6 +27,7 @@ class BaseCollector(object):
|
|
|
26
27
|
# When manipulation statistics in a granularity they were not collected by, the data is invalid.
|
|
27
28
|
self.is_legal = True
|
|
28
29
|
|
|
30
|
+
@abstractmethod
|
|
29
31
|
def scale(self, scale_factor: np.ndarray):
|
|
30
32
|
"""
|
|
31
33
|
Scale all statistics in collector by some factor.
|
|
@@ -37,6 +39,7 @@ class BaseCollector(object):
|
|
|
37
39
|
raise NotImplemented(
|
|
38
40
|
f'{self.__class__.__name__} needs to implement scale operation for its state.') # pragma: no cover
|
|
39
41
|
|
|
42
|
+
@abstractmethod
|
|
40
43
|
def shift(self, shift_value: np.ndarray):
|
|
41
44
|
"""
|
|
42
45
|
Shift all statistics in collector by some value.
|
|
@@ -87,10 +87,13 @@ class MeanCollector(BaseCollector):
|
|
|
87
87
|
x: Tensor that goes through the mean collector and needs to be considered in the mean computation.
|
|
88
88
|
"""
|
|
89
89
|
self.i += 1 # Update the iteration index
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
90
|
+
if self.axis is None:
|
|
91
|
+
mu = np.mean(np.reshape(x, [1, -1]), axis=-1) # mean per channel for a batch
|
|
92
|
+
else:
|
|
93
|
+
axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
|
|
94
|
+
n = x.shape[axis]
|
|
95
|
+
transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
|
|
96
|
+
mu = np.mean(np.reshape(np.transpose(x, transpose_index), [n, -1]), axis=-1) # mean per channel for a batch
|
|
94
97
|
self.current_sum += mu # sum of all batches
|
|
95
98
|
self.current_mean = self.current_sum / self.i # mean of all batches
|
|
96
99
|
|
|
@@ -130,10 +130,13 @@ class MinMaxPerChannelCollector(BaseCollector):
|
|
|
130
130
|
x: Tensor that goes through the collector and needs to be considered in the min/max computation.
|
|
131
131
|
"""
|
|
132
132
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
133
|
+
if self.axis is None:
|
|
134
|
+
x_reshape = np.reshape(x, [1, -1])
|
|
135
|
+
else:
|
|
136
|
+
axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
|
|
137
|
+
n = x.shape[axis]
|
|
138
|
+
transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
|
|
139
|
+
x_reshape = np.reshape(np.transpose(x, transpose_index), [n, -1])
|
|
137
140
|
if self.state is None:
|
|
138
141
|
x_max = np.max(x_reshape, axis=-1)
|
|
139
142
|
x_min = np.min(x_reshape, axis=-1)
|
|
@@ -57,19 +57,21 @@ def create_stats_collector_for_node(node: common.BaseNode,
|
|
|
57
57
|
|
|
58
58
|
|
|
59
59
|
def create_tensor2node(graph: common.Graph,
|
|
60
|
-
node: common.BaseNode
|
|
60
|
+
node: common.BaseNode,
|
|
61
|
+
next_node_output_channel_axis: int):
|
|
61
62
|
"""
|
|
62
63
|
Force statistic collector creation and assignment for a node.
|
|
63
64
|
Args:
|
|
64
65
|
graph: Graph of the node (for retrieving the current tensor).
|
|
65
66
|
node: Node to create a tensor for.
|
|
67
|
+
next_node_output_channel_axis: channel output axis of next node.
|
|
66
68
|
|
|
67
69
|
"""
|
|
68
70
|
current_sc = graph.get_out_stats_collector(node)
|
|
69
71
|
is_list_nostat_collectors = isinstance(current_sc, list) and len(
|
|
70
72
|
[sc for sc in current_sc if not isinstance(sc, common.NoStatsCollector)]) == 0
|
|
71
73
|
if isinstance(current_sc, common.NoStatsCollector) or current_sc is None or is_list_nostat_collectors:
|
|
72
|
-
stats_collector = common.StatsCollector(node.out_channel_axis)
|
|
74
|
+
stats_collector = common.StatsCollector(next_node_output_channel_axis if node.out_channel_axis is None else node.out_channel_axis)
|
|
73
75
|
graph.set_out_stats_collector_to_node(node, stats_collector)
|
|
74
76
|
|
|
75
77
|
|
|
@@ -157,6 +159,17 @@ class ModelCollector:
|
|
|
157
159
|
for n in graph.get_topo_sorted_nodes():
|
|
158
160
|
quant_node_in_fln = n.is_fln_quantization() and graph.fusing_info.is_quantized_node_in_fln(n)
|
|
159
161
|
sc = create_stats_collector_for_node(n, quant_node_in_fln=quant_node_in_fln) # Get static collector for the node
|
|
162
|
+
if isinstance(sc, common.StatsCollector) and (sc.mc.axis is None or sc.mpcc.axis is None):
|
|
163
|
+
# Missing output channel axis info, so try to extract it from previous and next nodes output channel axis.
|
|
164
|
+
possible_output_channel_axis_set = {nn.out_channel_axis for nn in graph.get_next_nodes(n) + graph.get_prev_nodes(n)}
|
|
165
|
+
# Filter out None values.
|
|
166
|
+
possible_output_channel_axis_list = list(filter(lambda x: x is not None, possible_output_channel_axis_set))
|
|
167
|
+
if len(possible_output_channel_axis_list) > 0:
|
|
168
|
+
if len(possible_output_channel_axis_list) > 1:
|
|
169
|
+
Logger.warning(f'Ambiguous input channel data from next nodes for {n.name}.')
|
|
170
|
+
sc.mc.axis = possible_output_channel_axis_list[0]
|
|
171
|
+
sc.mpcc.axis = possible_output_channel_axis_list[0]
|
|
172
|
+
|
|
160
173
|
# If we use bias correction, and the node has kernel weights to quantize, we need to make sure
|
|
161
174
|
# its previous nodes' tensors are consistent with this node.
|
|
162
175
|
if qc.weights_bias_correction and n.kernel_attr is not None and n.is_weights_quantization_enabled(
|
|
@@ -164,7 +177,8 @@ class ModelCollector:
|
|
|
164
177
|
for ie in graph.incoming_edges(n):
|
|
165
178
|
input_node = ie.source_node
|
|
166
179
|
create_tensor2node(graph,
|
|
167
|
-
input_node
|
|
180
|
+
input_node,
|
|
181
|
+
n.out_channel_axis)
|
|
168
182
|
if sc is not None:
|
|
169
183
|
graph.set_out_stats_collector_to_node(n, sc)
|
|
170
184
|
|
|
@@ -303,7 +303,7 @@ class MemoryCalculator:
|
|
|
303
303
|
num_oc = np.sum(output_mask)
|
|
304
304
|
else:
|
|
305
305
|
# Get the node channel axis from framework info
|
|
306
|
-
channel_axis = node.out_channel_axis
|
|
306
|
+
channel_axis = self.fw_impl.default_output_channel_axis if node.out_channel_axis is None else node.out_channel_axis
|
|
307
307
|
if channel_axis is None:
|
|
308
308
|
Logger.critical(f"The channel axis is undefined. Please ensure the channel axis is explicitly defined for node {node.type} in the framework info.")
|
|
309
309
|
|
|
@@ -18,7 +18,6 @@ from enum import Enum, auto
|
|
|
18
18
|
from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
|
|
19
19
|
from model_compression_toolkit.logger import Logger
|
|
20
20
|
|
|
21
|
-
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
|
22
21
|
from model_compression_toolkit.target_platform_capabilities.constants import POSITIONAL_ATTR
|
|
23
22
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import \
|
|
24
23
|
AttributeQuantizationConfig, OpQuantizationConfig
|
|
@@ -41,6 +40,7 @@ class ActivationQuantizationMode(Enum):
|
|
|
41
40
|
NO_QUANT = auto()
|
|
42
41
|
FLN_NO_QUANT = auto()
|
|
43
42
|
|
|
43
|
+
|
|
44
44
|
class BaseNodeQuantizationConfig(object):
|
|
45
45
|
"""
|
|
46
46
|
Base class for node quantization configuration
|
|
@@ -59,12 +59,11 @@ class BaseNodeQuantizationConfig(object):
|
|
|
59
59
|
kwargs: A dictionary with additional key arguments.
|
|
60
60
|
|
|
61
61
|
"""
|
|
62
|
-
|
|
63
62
|
if hasattr(self, config_parameter_name):
|
|
64
63
|
setattr(self, config_parameter_name, config_parameter_value)
|
|
65
64
|
else:
|
|
66
|
-
|
|
67
|
-
|
|
65
|
+
raise AttributeError(
|
|
66
|
+
f"Parameter {config_parameter_name} could not be found in the node quantization config.")
|
|
68
67
|
|
|
69
68
|
def __repr__(self) -> str:
|
|
70
69
|
"""
|
|
@@ -97,36 +96,11 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
97
96
|
self.signedness = op_cfg.signedness
|
|
98
97
|
|
|
99
98
|
self.activation_quantization_params = {}
|
|
100
|
-
# TODO
|
|
99
|
+
# TODO: computed by compute_activation_bias_correction. Probably shouldnt be here.
|
|
101
100
|
self.activation_bias_correction_term = None
|
|
102
|
-
|
|
103
|
-
#
|
|
104
|
-
self.activation_error_method = None
|
|
105
|
-
self.relu_bound_to_power_of_2 = None
|
|
106
|
-
self.activation_channel_equalization = None
|
|
107
|
-
self.input_scaling = None
|
|
108
|
-
self.min_threshold = None
|
|
109
|
-
self.l_p_value = None
|
|
110
|
-
self.shift_negative_activation_correction = None
|
|
101
|
+
# Z-threshold is a global param from QuantizationConfig, however it can be overridden per node by NetworkEditor.
|
|
102
|
+
# Since activation qparams are re-computed in several places, it's easier to keep it here and update it once.
|
|
111
103
|
self.z_threshold = None
|
|
112
|
-
self.shift_negative_ratio = None
|
|
113
|
-
self.shift_negative_threshold_recalculation = None
|
|
114
|
-
self.concat_threshold_update = None
|
|
115
|
-
|
|
116
|
-
def set_qc(self, qc: QuantizationConfig):
|
|
117
|
-
""" TODO irena: temporary keep all the attributes as before not to break all code at once.
|
|
118
|
-
Eventually all of them should be removed from here. """
|
|
119
|
-
self.activation_error_method = qc.activation_error_method
|
|
120
|
-
self.relu_bound_to_power_of_2 = qc.relu_bound_to_power_of_2
|
|
121
|
-
self.activation_channel_equalization = qc.activation_channel_equalization
|
|
122
|
-
self.input_scaling = qc.input_scaling
|
|
123
|
-
self.min_threshold = qc.min_threshold
|
|
124
|
-
self.l_p_value = qc.l_p_value
|
|
125
|
-
self.shift_negative_activation_correction = qc.shift_negative_activation_correction
|
|
126
|
-
self.z_threshold = qc.z_threshold
|
|
127
|
-
self.shift_negative_ratio = qc.shift_negative_ratio
|
|
128
|
-
self.shift_negative_threshold_recalculation = qc.shift_negative_threshold_recalculation
|
|
129
|
-
self.concat_threshold_update = qc.concat_threshold_update
|
|
130
104
|
|
|
131
105
|
@property
|
|
132
106
|
def enable_activation_quantization(self):
|
|
@@ -148,7 +122,7 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
148
122
|
activation_params: Dictionary that contains weight quantization params.
|
|
149
123
|
|
|
150
124
|
"""
|
|
151
|
-
assert self.quant_mode == ActivationQuantizationMode.QUANT
|
|
125
|
+
assert self.quant_mode == ActivationQuantizationMode.QUANT or self.quant_mode == ActivationQuantizationMode.FLN_QUANT
|
|
152
126
|
for param_name, param_value in activation_params.items():
|
|
153
127
|
self.activation_quantization_params[param_name] = param_value
|
|
154
128
|
|
|
@@ -165,32 +139,16 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
165
139
|
if not isinstance(other, NodeActivationQuantizationConfig):
|
|
166
140
|
return False # pragma: no cover
|
|
167
141
|
|
|
168
|
-
return self.
|
|
169
|
-
self.activation_quantization_method == other.activation_quantization_method and \
|
|
142
|
+
return self.activation_quantization_method == other.activation_quantization_method and \
|
|
170
143
|
self.activation_n_bits == other.activation_n_bits and \
|
|
171
144
|
self.quant_mode == other.quant_mode and \
|
|
172
|
-
self.
|
|
173
|
-
self.input_scaling == other.input_scaling and \
|
|
174
|
-
self.min_threshold == other.min_threshold and \
|
|
175
|
-
self.l_p_value == other.l_p_value and \
|
|
176
|
-
self.shift_negative_activation_correction == other.shift_negative_activation_correction and \
|
|
177
|
-
self.z_threshold == other.z_threshold and \
|
|
178
|
-
self.shift_negative_ratio == other.shift_negative_ratio and \
|
|
179
|
-
self.shift_negative_threshold_recalculation == other.shift_negative_threshold_recalculation
|
|
145
|
+
self.signedness == other.signedness
|
|
180
146
|
|
|
181
147
|
def __hash__(self):
|
|
182
|
-
return hash((self.
|
|
183
|
-
self.activation_quantization_method,
|
|
148
|
+
return hash((self.activation_quantization_method,
|
|
184
149
|
self.activation_n_bits,
|
|
185
150
|
self.quant_mode,
|
|
186
|
-
self.
|
|
187
|
-
self.input_scaling,
|
|
188
|
-
self.min_threshold,
|
|
189
|
-
self.l_p_value,
|
|
190
|
-
self.shift_negative_activation_correction,
|
|
191
|
-
self.z_threshold,
|
|
192
|
-
self.shift_negative_ratio,
|
|
193
|
-
self.shift_negative_threshold_recalculation))
|
|
151
|
+
self.signedness))
|
|
194
152
|
|
|
195
153
|
|
|
196
154
|
class WeightsAttrQuantizationConfig:
|
|
@@ -211,16 +169,8 @@ class WeightsAttrQuantizationConfig:
|
|
|
211
169
|
self.weights_n_bits = weights_attr_cfg.weights_n_bits
|
|
212
170
|
self.weights_per_channel_threshold = weights_attr_cfg.weights_per_channel_threshold
|
|
213
171
|
self.enable_weights_quantization = weights_attr_cfg.enable_weights_quantization
|
|
214
|
-
self.weights_quantization_params = {}
|
|
215
172
|
|
|
216
|
-
|
|
217
|
-
self.weights_error_method = None
|
|
218
|
-
self.l_p_value = None
|
|
219
|
-
|
|
220
|
-
def set_qc(self, qc: QuantizationConfig):
|
|
221
|
-
# TODO irena: temporary keep the fields to not break everything at once.
|
|
222
|
-
self.weights_error_method = qc.weights_error_method
|
|
223
|
-
self.l_p_value = qc.l_p_value
|
|
173
|
+
self.weights_quantization_params = {}
|
|
224
174
|
|
|
225
175
|
def set_weights_quantization_param(self,
|
|
226
176
|
weights_params: dict):
|
|
@@ -252,18 +202,14 @@ class WeightsAttrQuantizationConfig:
|
|
|
252
202
|
self.weights_quantization_method == other.weights_quantization_method and \
|
|
253
203
|
self.weights_n_bits == other.weights_n_bits and \
|
|
254
204
|
self.weights_per_channel_threshold == other.weights_per_channel_threshold and \
|
|
255
|
-
self.enable_weights_quantization == other.enable_weights_quantization
|
|
256
|
-
self.weights_error_method == other.weights_error_method and \
|
|
257
|
-
self.l_p_value == other.l_p_value
|
|
205
|
+
self.enable_weights_quantization == other.enable_weights_quantization
|
|
258
206
|
|
|
259
207
|
def __hash__(self):
|
|
260
208
|
return hash((self.weights_channels_axis,
|
|
261
|
-
self.weights_error_method,
|
|
262
209
|
self.weights_quantization_method,
|
|
263
210
|
self.weights_n_bits,
|
|
264
211
|
self.weights_per_channel_threshold,
|
|
265
|
-
self.enable_weights_quantization
|
|
266
|
-
self.l_p_value))
|
|
212
|
+
self.enable_weights_quantization))
|
|
267
213
|
|
|
268
214
|
|
|
269
215
|
class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
@@ -330,16 +276,14 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
330
276
|
|
|
331
277
|
self.attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(weights_attr_cfg=attr_cfg,
|
|
332
278
|
weights_channels_axis=weights_channels_axis)
|
|
333
|
-
# TODO
|
|
334
|
-
|
|
279
|
+
# TODO this is set by batch norm reconstruction substitution when folded batch norms are added back, to mark
|
|
280
|
+
# the nodes that the correction should be applied to (for some nodes it gets disabled) and BNs removed.
|
|
281
|
+
# The actual correction is only computed when it's applied in ptq, so it seems that both substitutions could
|
|
282
|
+
# be unified, and no info need to pass between.
|
|
335
283
|
self.weights_second_moment_correction = None
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
# TODO irena: temporary keep the fields to not break everything at once.
|
|
340
|
-
self.min_threshold = qc.min_threshold
|
|
341
|
-
self.weights_second_moment_correction = qc.weights_second_moment_correction
|
|
342
|
-
self.weights_bias_correction = qc.weights_bias_correction
|
|
284
|
+
# TODO: computed corrected bias is injected to the node config. Probably shouldn't be here. Also it can be
|
|
285
|
+
# computed on the final config, instead of all candidates and then there is no need to save it at all.
|
|
286
|
+
self.bias_corrected = None
|
|
343
287
|
|
|
344
288
|
def get_attr_config(self, attr_name: 'WeightAttrT') -> WeightsAttrQuantizationConfig:
|
|
345
289
|
"""
|
|
@@ -476,8 +420,8 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
476
420
|
if hasattr(attr_cfg, config_parameter_name):
|
|
477
421
|
setattr(attr_cfg, config_parameter_name, config_parameter_value)
|
|
478
422
|
else:
|
|
479
|
-
|
|
480
|
-
|
|
423
|
+
raise AttributeError(f"Parameter {config_parameter_name} could not be found in the node quantization config of "
|
|
424
|
+
f"weights attribute {attr_name}.")
|
|
481
425
|
else: # pragma: no cover
|
|
482
426
|
Logger.critical(f"Weights attribute {attr_name} could not be found to set parameter {config_parameter_name}.")
|
|
483
427
|
|
|
@@ -494,10 +438,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
494
438
|
if not isinstance(other, NodeWeightsQuantizationConfig):
|
|
495
439
|
return False # pragma: no cover
|
|
496
440
|
|
|
497
|
-
return self.
|
|
498
|
-
self.simd_size == other.simd_size and \
|
|
499
|
-
self.weights_second_moment_correction == other.weights_second_moment_correction and \
|
|
500
|
-
self.weights_bias_correction == other.weights_bias_correction and \
|
|
441
|
+
return self.simd_size == other.simd_size and \
|
|
501
442
|
self.attributes_config_mapping.keys() == other.attributes_config_mapping.keys() and \
|
|
502
443
|
all([self.attributes_config_mapping[k] == other.attributes_config_mapping[k]
|
|
503
444
|
for k in self.attributes_config_mapping.keys()]) and \
|
|
@@ -506,9 +447,6 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
506
447
|
for k in self.pos_attributes_config_mapping.keys()])
|
|
507
448
|
|
|
508
449
|
def __hash__(self):
|
|
509
|
-
return hash((self.
|
|
510
|
-
self.simd_size,
|
|
511
|
-
self.weights_second_moment_correction,
|
|
512
|
-
self.weights_bias_correction,
|
|
450
|
+
return hash((self.simd_size,
|
|
513
451
|
frozenset(self.attributes_config_mapping),
|
|
514
452
|
frozenset(self.pos_attributes_config_mapping)))
|
|
@@ -90,7 +90,6 @@ class QuantizationConfig:
|
|
|
90
90
|
shift_negative_activation_correction: bool = True
|
|
91
91
|
activation_channel_equalization: bool = False
|
|
92
92
|
z_threshold: float = math.inf
|
|
93
|
-
min_threshold: float = MIN_THRESHOLD
|
|
94
93
|
l_p_value: int = 2
|
|
95
94
|
linear_collapsing: bool = True
|
|
96
95
|
residual_collapsing: bool = True
|