mct-nightly 2.4.0.20250706.701__py3-none-any.whl → 2.4.0.20250707.643__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250707.643.dist-info}/METADATA +1 -1
- {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250707.643.dist-info}/RECORD +36 -38
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/collectors/base_collector.py +4 -1
- model_compression_toolkit/core/common/collectors/mean_collector.py +7 -4
- model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +7 -4
- model_compression_toolkit/core/common/model_collector.py +11 -0
- model_compression_toolkit/core/common/pruning/memory_calculator.py +1 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +22 -87
- model_compression_toolkit/core/common/quantization/quantization_config.py +0 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +23 -17
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +26 -48
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +12 -7
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +4 -14
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +2 -1
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +4 -13
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +3 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +5 -7
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +7 -5
- model_compression_toolkit/core/graph_prep_runner.py +1 -11
- model_compression_toolkit/core/keras/default_framework_info.py +1 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +21 -11
- model_compression_toolkit/core/keras/keras_implementation.py +2 -2
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +8 -0
- model_compression_toolkit/core/pytorch/default_framework_info.py +1 -1
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +9 -1
- model_compression_toolkit/core/quantization_prep_runner.py +2 -2
- model_compression_toolkit/gptq/keras/quantization_facade.py +0 -3
- model_compression_toolkit/ptq/keras/quantization_facade.py +0 -3
- model_compression_toolkit/qat/keras/quantization_facade.py +0 -3
- model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +0 -2
- model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +0 -6
- model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +2 -4
- model_compression_toolkit/core/common/model_validation.py +0 -41
- model_compression_toolkit/core/keras/keras_model_validation.py +0 -37
- {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250707.643.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250707.643.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250707.643.dist-info}/top_level.txt +0 -0
{mct_nightly-2.4.0.20250706.701.dist-info → mct_nightly-2.4.0.20250707.643.dist-info}/RECORD
RENAMED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
mct_nightly-2.4.0.
|
|
2
|
-
model_compression_toolkit/__init__.py,sha256=
|
|
1
|
+
mct_nightly-2.4.0.20250707.643.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
|
2
|
+
model_compression_toolkit/__init__.py,sha256=n67zJIaNzqdkUrpJl9e_iXZ7xM4vgLz7U7d06AHkmTU,1557
|
|
3
3
|
model_compression_toolkit/constants.py,sha256=KNgiNLpsMgSYyXMNEbHXd4bFNerQc1D6HH3vpbUq_Gs,4086
|
|
4
4
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
|
5
5
|
model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
|
|
@@ -7,8 +7,8 @@ model_compression_toolkit/metadata.py,sha256=x_Bk4VpzILdsFax6--CZ3X18qUTP28sbF_A
|
|
|
7
7
|
model_compression_toolkit/verify_packages.py,sha256=l0neIRr8q_QwxmuiTI4vyCMDISDedK0EihjEQUe66tE,1319
|
|
8
8
|
model_compression_toolkit/core/__init__.py,sha256=HNverPpoqEyFKTa7iEdOqqY2P0Gq-7GMejNOi6ZPcQs,2042
|
|
9
9
|
model_compression_toolkit/core/analyzer.py,sha256=5P03LbkFy-mu31TMAiQoIKcsA1-DNz7cTzkGvRaXtbw,3505
|
|
10
|
-
model_compression_toolkit/core/graph_prep_runner.py,sha256=
|
|
11
|
-
model_compression_toolkit/core/quantization_prep_runner.py,sha256=
|
|
10
|
+
model_compression_toolkit/core/graph_prep_runner.py,sha256=XvuR1lNsbaGG3HyROA8nF2n0oPNb1Tnw6rGatewAvE0,10563
|
|
11
|
+
model_compression_toolkit/core/quantization_prep_runner.py,sha256=W6CAJ_Euxrjm5jj__QqfB4xsD8NYx_-XWX22h0ax6B4,6262
|
|
12
12
|
model_compression_toolkit/core/runner.py,sha256=QpiJQmQXK6mWmnygNRdy6I8S48DHV-B0Kmr4TqOKbeA,12418
|
|
13
13
|
model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
|
|
14
14
|
model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
|
|
@@ -16,18 +16,17 @@ model_compression_toolkit/core/common/framework_implementation.py,sha256=jrTupZb
|
|
|
16
16
|
model_compression_toolkit/core/common/framework_info.py,sha256=vPGV28gm-kvNSkkWI6jY3YeKBUYmn6UQ98HVUnl_-tM,5449
|
|
17
17
|
model_compression_toolkit/core/common/memory_computation.py,sha256=ixoSpV5ZYZGyzhre3kQcvR2sNA8KBsPZ3lgbkDnw9Cs,1205
|
|
18
18
|
model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3ug7Y9-lLyV99_FoNHxkGZMgcm0Vkpss,1324
|
|
19
|
-
model_compression_toolkit/core/common/model_collector.py,sha256=
|
|
20
|
-
model_compression_toolkit/core/common/model_validation.py,sha256=HRnYh2uY85yJ7Ijmt4tKRn8bMg60zbBSDRCgK246gUM,1067
|
|
19
|
+
model_compression_toolkit/core/common/model_collector.py,sha256=fgSLqbi1YRyvISJIP9WyZv3IlvuMplywJ-ffUnDWN1Q,13655
|
|
21
20
|
model_compression_toolkit/core/common/node_prior_info.py,sha256=WXX_PrGVG9M9I_REG5ZzFBohwmV4yf356sZnrja_FLo,2832
|
|
22
21
|
model_compression_toolkit/core/common/similarity_analyzer.py,sha256=S3f6WgHyw62dGcxpX51FGKyfebe2zv9ABKbjtGyKRvY,9215
|
|
23
22
|
model_compression_toolkit/core/common/user_info.py,sha256=dSRMnT-oewmdOziIpEuW-s9K7vTSeyUBxT4z9neXurI,1648
|
|
24
23
|
model_compression_toolkit/core/common/back2framework/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
25
24
|
model_compression_toolkit/core/common/back2framework/base_model_builder.py,sha256=yrIxT0ttDi9XViy8Zt8apnMCT8xDyVd5HZp0IttrGGQ,1775
|
|
26
25
|
model_compression_toolkit/core/common/collectors/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
|
27
|
-
model_compression_toolkit/core/common/collectors/base_collector.py,sha256=
|
|
26
|
+
model_compression_toolkit/core/common/collectors/base_collector.py,sha256=n3IIwfWVIl8TO_7MR9jBBxzf2foysWMyrGS-oaCuKiA,2664
|
|
28
27
|
model_compression_toolkit/core/common/collectors/histogram_collector.py,sha256=zra5V06Brpjc1cUNIMVVGqdoqAuro62_hGy2Zm5-XMQ,6754
|
|
29
|
-
model_compression_toolkit/core/common/collectors/mean_collector.py,sha256=
|
|
30
|
-
model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py,sha256=
|
|
28
|
+
model_compression_toolkit/core/common/collectors/mean_collector.py,sha256=7-oCVU76e9ZWIZLiLte-z07FX2Rjmu5Go-KRj7b7oWY,3564
|
|
29
|
+
model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py,sha256=pAdiLXNNm5sfIZ584m2MzJQ0PTNp6MKsrMxJ_onEN1s,5314
|
|
31
30
|
model_compression_toolkit/core/common/collectors/statistics_collector.py,sha256=psijsQZefwjMDH8SU5E18n65HiGtQilPhKr1hhzZX-I,8268
|
|
32
31
|
model_compression_toolkit/core/common/collectors/weighted_histogram_collector.py,sha256=zp3dE7YTqWmkD5QWdRhsl9zD8W6Lr96G1Wjw1g2D3T0,4894
|
|
33
32
|
model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
|
@@ -87,7 +86,7 @@ model_compression_toolkit/core/common/network_editors/node_filters.py,sha256=Pc_
|
|
|
87
86
|
model_compression_toolkit/core/common/pruning/__init__.py,sha256=DGJybkDQtKMSMFoZ-nZ3ZifA8uJ6G_D20wHhKHNlmU0,699
|
|
88
87
|
model_compression_toolkit/core/common/pruning/channels_grouping.py,sha256=-zrq0TsfVE4ooxOcJCsL8H2DBau6vSkEKz1ot-x-Faw,3736
|
|
89
88
|
model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py,sha256=UZekrges7gZv17JFLX_AV2Kv0eBXXarMNInzuOTlyvA,7712
|
|
90
|
-
model_compression_toolkit/core/common/pruning/memory_calculator.py,sha256=
|
|
89
|
+
model_compression_toolkit/core/common/pruning/memory_calculator.py,sha256=7IdX9APlS17omWeF_y0ZFr8S2K7Cp-M0u8rFtfmrx8U,19023
|
|
91
90
|
model_compression_toolkit/core/common/pruning/prune_graph.py,sha256=eGvuqrxyADRSvhKz0R_7lLfIl7bKnn4bryElu3LsVcA,3158
|
|
92
91
|
model_compression_toolkit/core/common/pruning/pruner.py,sha256=Zl0IK0anorzagaSP8qXMN31Dtw5m-Le-JRy2baPLs6M,7262
|
|
93
92
|
model_compression_toolkit/core/common/pruning/pruning_config.py,sha256=fbqERt11FGVeuqPVA6nVbgGDh6Ox9mpEKdxVJT8eG4I,3681
|
|
@@ -107,8 +106,8 @@ model_compression_toolkit/core/common/quantization/candidate_node_quantization_c
|
|
|
107
106
|
model_compression_toolkit/core/common/quantization/core_config.py,sha256=yxCzWqldcHoe8GGxrH0tp99bhrc5jDT7SgZftnMUUBE,2374
|
|
108
107
|
model_compression_toolkit/core/common/quantization/debug_config.py,sha256=uH45Uq3Tp9FIyMynex_WY2_y-Kv8LuPw2XXZydnpW5A,1649
|
|
109
108
|
model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=FyYCYbfkAofEWO2mAvFIppPeq2I10f1ScPNiVa9F7x4,7687
|
|
110
|
-
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=
|
|
111
|
-
model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=
|
|
109
|
+
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=te18Lv-pwAbING6CnURANZZVeThFn3tDmNBBSXwBcuM,21306
|
|
110
|
+
model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=EMhXeY1qkvwlMAY5fpKRDuLEAyRY5yKqe2fOmAD_rVI,4362
|
|
112
111
|
model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=VVq2cKjumlNWucUbaNw8s2J0IbI_vrQ-KR_eQPshGSg,3140
|
|
113
112
|
model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=N005MSvx8UypVpa7XrxNrB2G732n2wHj3RmLyjTgd3I,2728
|
|
114
113
|
model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=WJ-lsT_R_pqjbrMzgcposugACDNz7yZ09vSlltTb78A,3001
|
|
@@ -118,10 +117,10 @@ model_compression_toolkit/core/common/quantization/quantization_params_generatio
|
|
|
118
117
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=RL-PklAjGyC-26anSt8fU07a6pB_LBQFQy9o4e9giN0,8739
|
|
119
118
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/outlier_filter.py,sha256=9gnfJV89jpGwAx8ImJ5E9NjCv3lDtbyulP4OtgWb62M,1772
|
|
120
119
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py,sha256=-cghHF5S11qbjTDRruHlc__uaDoofZHl7QTl8hCeKW0,11141
|
|
121
|
-
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py,sha256=
|
|
122
|
-
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=
|
|
120
|
+
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py,sha256=KoPYo6il5Pa5Y54KUa1GiB5G_hGzN8yMTgLOsC6QElY,7650
|
|
121
|
+
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=bsVCs5KrVKtagHG4VLFzxN2f-YKs0-mhV9VmL57SE9E,7924
|
|
123
122
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py,sha256=Nv_b3DECVjQnlrUet2kbuSvSKVnxcc-gf2zhFb2jSZk,43482
|
|
124
|
-
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py,sha256=
|
|
123
|
+
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py,sha256=saOQMtj1qYqgQoAFjq31p7xtiRDxmanGGdm0DE81_cg,4820
|
|
125
124
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py,sha256=6tRNgWvn-4r8hiSHqND7Qms1Nje1DUR4MR0JeWCNyvI,12531
|
|
126
125
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py,sha256=xiZgCkoIrJ9xsR17x9pSl_sUbiuSta67kf7bQ4quFUI,10804
|
|
127
126
|
model_compression_toolkit/core/common/quantization/quantizers/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
|
@@ -130,22 +129,22 @@ model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers
|
|
|
130
129
|
model_compression_toolkit/core/common/quantization/quantizers/uniform_quantizers.py,sha256=wXExWHf5-0He7L4bpvFpKlx7FG4u3DAfNZiXPpOs_SQ,5521
|
|
131
130
|
model_compression_toolkit/core/common/statistics_correction/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
|
132
131
|
model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py,sha256=oUa1Gv9jIICOoFljTiIaItFjJQPht7CBe-wEr3iBuLQ,4118
|
|
133
|
-
model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py,sha256=
|
|
134
|
-
model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py,sha256=
|
|
132
|
+
model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py,sha256=fvvIH5uWtQbmPJQh_zjE-XuCFP9YrlOEs3mW9ysKhms,3652
|
|
133
|
+
model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py,sha256=7HIoey1tqbTy1tnTF_KF5D-dSNankamBj8kWO9kUQYo,5200
|
|
135
134
|
model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py,sha256=289b2iwzp2hjsgpEZotQKNB2aPKjAZopRaGnbzErHV8,9263
|
|
136
|
-
model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py,sha256=
|
|
137
|
-
model_compression_toolkit/core/common/statistics_correction/statistics_correction.py,sha256=
|
|
135
|
+
model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py,sha256=p4u50-1CPuGTQtMDz5uaZqFSLhl2BLTbCLDQvplCuW4,9552
|
|
136
|
+
model_compression_toolkit/core/common/statistics_correction/statistics_correction.py,sha256=F0xOl97p8Q0h4MXxOKo0kGY4iJIiJou06JbICatg8vQ,5881
|
|
138
137
|
model_compression_toolkit/core/common/substitutions/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
|
139
138
|
model_compression_toolkit/core/common/substitutions/apply_substitutions.py,sha256=k-bifmakHIYZeZS-4T1QpZ1Et6AwAijMRgAKs7hmMKc,1390
|
|
140
139
|
model_compression_toolkit/core/common/substitutions/batchnorm_folding.py,sha256=wLlTT7sqUffKHwOrMG2VV5SktQkkP54l8taW1Fq0mh0,13392
|
|
141
|
-
model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py,sha256=
|
|
140
|
+
model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py,sha256=xa-6ACozqTu4xWXkC-7p6YTZLcLcH_AUYEROTl8USWc,8014
|
|
142
141
|
model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py,sha256=eCbhbAzgXWoVymMLbrupJ1qAcdhZDwkjKeja0fCymnY,9746
|
|
143
142
|
model_compression_toolkit/core/common/substitutions/linear_collapsing.py,sha256=iEtzbWCDXP6EDkTZCtREQ0rpMxhQ2kM9zlcP_0KLq9I,12367
|
|
144
143
|
model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py,sha256=uoauhmncQqUBNvD-qCLIXsIbl_IzrbxSKdxiMig-5W4,2406
|
|
145
144
|
model_compression_toolkit/core/common/substitutions/remove_identity.py,sha256=TKU1TIU52UIkVnl0EZvWnDhLV9nIVZ4hqi-w1i4NXMk,2637
|
|
146
145
|
model_compression_toolkit/core/common/substitutions/residual_collapsing.py,sha256=N82mso5j3EJQlKt9EMHjjEJ67FmdGQeCfN8U5grOFXo,4830
|
|
147
146
|
model_compression_toolkit/core/common/substitutions/scale_equalization.py,sha256=2_NmmBmUBZZwXuF5Od2S919_FgQKYIf-nSyNPawr0e4,9840
|
|
148
|
-
model_compression_toolkit/core/common/substitutions/shift_negative_activation.py,sha256=
|
|
147
|
+
model_compression_toolkit/core/common/substitutions/shift_negative_activation.py,sha256=_TliV1tBxizPsCZVbOhDhg3mSc_m0sY9gVqQMYTwQoc,32739
|
|
149
148
|
model_compression_toolkit/core/common/substitutions/softmax_shift.py,sha256=R-0ZqhYAuZLEFWHvB2UTPm52L6gWHGdRdEnwGxKSeGI,2625
|
|
150
149
|
model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py,sha256=cokiYPZB7504oHTlgZy8u2Xv_S-RK_oDSnGvYRX3JK4,4136
|
|
151
150
|
model_compression_toolkit/core/common/substitutions/weights_activation_split.py,sha256=vafrJ6eA37PrIzOs7uOsiJKIBmAVmNJ-wXsoe332BIw,4683
|
|
@@ -157,9 +156,8 @@ model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7V
|
|
|
157
156
|
model_compression_toolkit/core/keras/constants.py,sha256=dh4elQWt6Q6NYRht5k5RiiOcnLAq1v0MMBCJqMJzzFk,3225
|
|
158
157
|
model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
|
|
159
158
|
model_compression_toolkit/core/keras/data_util.py,sha256=jm54o-SlI1DJ-sEvRuX9OyLN68tEt0VxcqrdIjR98Ag,8366
|
|
160
|
-
model_compression_toolkit/core/keras/default_framework_info.py,sha256=
|
|
161
|
-
model_compression_toolkit/core/keras/keras_implementation.py,sha256=
|
|
162
|
-
model_compression_toolkit/core/keras/keras_model_validation.py,sha256=dMS9cqaYmliyzVu2-MrKx4AIubqz3HW3RY4if2TV6U8,1581
|
|
159
|
+
model_compression_toolkit/core/keras/default_framework_info.py,sha256=wI-M_MsVIL0qOdBr8F-oyrRz4qdc2o6DDB7lGlizddw,6171
|
|
160
|
+
model_compression_toolkit/core/keras/keras_implementation.py,sha256=5oAc0nc21Uzfo94rKUGGpbYVTmEiKA3YGNhbOwD66ng,28550
|
|
163
161
|
model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=k9cwu3S-OUGFaOHxH6cyYS2JjxAYHfBddz0laf6Quds,3311
|
|
164
162
|
model_compression_toolkit/core/keras/resource_utilization_data_facade.py,sha256=xxZlHyruhLuP2iEgMrZhq_AyAGORTqzweVLARFfpaRw,5643
|
|
165
163
|
model_compression_toolkit/core/keras/tf_tensor_numpy.py,sha256=jzD8FGEEa8ZD7w8IpTRdp-Udf1MwOTgjg2XTS1Givic,2696
|
|
@@ -179,7 +177,7 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm
|
|
|
179
177
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/concat_threshold_update.py,sha256=Hl4LEQ_bw_Vpmf3ZqHujYUqVdvTNsPlEMvr9dZhwg2U,2806
|
|
180
178
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py,sha256=vZr8Agj-tFKSX7TM2nZjwbHElJqSIyMAaR7FH-lp4YM,11691
|
|
181
179
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py,sha256=nJO-JUmOK1lLb560KMJgLFwY2IOI2Y3lpzUq7o2f7mQ,5707
|
|
182
|
-
model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py,sha256=
|
|
180
|
+
model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py,sha256=R3minGoiZqhjpex52LmBPGvALez9KKiuZOQzqOB7OK0,6558
|
|
183
181
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py,sha256=AvquvVVVT8-ioeVn-gjqysK4L41L3I7TlNOEDfWjViY,8185
|
|
184
182
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/matmul_substitution.py,sha256=9MZJp4GNTLesWN5uQ5eOQyAHLzLYDAHAjRi-LpNppSc,4257
|
|
185
183
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py,sha256=l9PUREBf4aRwWILiybdteveeUbh7js-i-hLt8Ma0e4c,26771
|
|
@@ -201,7 +199,7 @@ model_compression_toolkit/core/keras/mixed_precision/__init__.py,sha256=sw7LOPN1
|
|
|
201
199
|
model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py,sha256=1p2DlMRmgzBAOUP-NeOzldTemjNLQQ3uf1Rov5iY-l8,5430
|
|
202
200
|
model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py,sha256=GtW0yars8PzqP9uL_vfXrtqHwKiStmOxPng20rYaIjU,6805
|
|
203
201
|
model_compression_toolkit/core/keras/pruning/__init__.py,sha256=3Lkr37Exk9u8811hw8hVqkGcbTQGcLjd3LLuLC3fa_E,698
|
|
204
|
-
model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py,sha256=
|
|
202
|
+
model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py,sha256=tE2-9KqDE2C9Ahs0FzwexCXfO7vviRCLgegn-l5eOTA,11528
|
|
205
203
|
model_compression_toolkit/core/keras/quantization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
|
206
204
|
model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py,sha256=RtQk5r-bZxUs10AFaJ813_rpkDmOwzWPv6zK6LbX4_8,1876
|
|
207
205
|
model_compression_toolkit/core/keras/quantization/fake_quant_builder.py,sha256=vfKwU0AfRH2KztmMF5bxcaBlGdnTePPGZsUqOHzED-U,6854
|
|
@@ -223,7 +221,7 @@ model_compression_toolkit/core/keras/visualization/__init__.py,sha256=mjbqLD-KcG
|
|
|
223
221
|
model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
|
224
222
|
model_compression_toolkit/core/pytorch/constants.py,sha256=Sg0hkUaMe88mI2_pd3KqhVz5ORnA46S1uq9Tj5qhtHc,2828
|
|
225
223
|
model_compression_toolkit/core/pytorch/data_util.py,sha256=YYbT135HhlTt0q6XdD2JX7AS_L92f_uV2rWq2hsJOCA,6325
|
|
226
|
-
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=
|
|
224
|
+
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=HB9JjOBxa0JOLy2bOhDi2Jp2srK8M20g74jOmnUp54I,4305
|
|
227
225
|
model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
|
|
228
226
|
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=cUQOBGwtG_DWpkrUEOcYSwXtNSmQgYVBCTxTpFiF4mo,27213
|
|
229
227
|
model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=5hsp0nl6TewfrKsT133m9Z7DVpTFFftEv6DeZoryDZw,3009
|
|
@@ -272,7 +270,7 @@ model_compression_toolkit/core/pytorch/mixed_precision/__init__.py,sha256=Rf1RcY
|
|
|
272
270
|
model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py,sha256=MTH7WsTpP-cTeMwaqrJPnhV_XdFKO6bySNalTONmr0w,4991
|
|
273
271
|
model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py,sha256=KDnwmbhvhJMfNg1IuTvvzBNEriPQH9bL9dJ5VvWTzpE,6631
|
|
274
272
|
model_compression_toolkit/core/pytorch/pruning/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
|
|
275
|
-
model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py,sha256=
|
|
273
|
+
model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py,sha256=jO-XbB6-ITnkBWMS4GO_s1-V3fhwfe4c183bMfucAsg,13367
|
|
276
274
|
model_compression_toolkit/core/pytorch/quantization/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
|
277
275
|
model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py,sha256=arslrOgJ1l-fScDlp6jNJ-JukKh0uBLcxAzjpDWRw94,1878
|
|
278
276
|
model_compression_toolkit/core/pytorch/quantization/fake_quant_builder.py,sha256=D8_CEuFqKAhbUgKaRw7Jlxo0zlqgPTMu6CIIIM4LfS0,7045
|
|
@@ -368,7 +366,7 @@ model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCS
|
|
|
368
366
|
model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=2hzWzsbuVd5XcL85NM57YeOyHxRY0qMArKn8NvQ1UWw,7643
|
|
369
367
|
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=_QwytOg1RQSg5Gvme089EME4trdTKGKM3JHIgT-b3n0,22841
|
|
370
368
|
model_compression_toolkit/gptq/keras/graph_info.py,sha256=xpjEqiDo4mqW42QGdmyW31n5eWd6HbYyP6EbarN-A8A,4283
|
|
371
|
-
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=
|
|
369
|
+
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=TBMfJAg_yERdd-OGHFzpV0v435QKM3FvIniwKHjw5i4,18493
|
|
372
370
|
model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
|
|
373
371
|
model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=Rbl9urzkmACvVxICSEyJ02qFOBxWK0UQWtysFJzBVZw,4899
|
|
374
372
|
model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
|
|
@@ -403,14 +401,14 @@ model_compression_toolkit/pruning/pytorch/pruning_facade.py,sha256=YxRtJGzD6SjZ4
|
|
|
403
401
|
model_compression_toolkit/ptq/__init__.py,sha256=Z_hkmTh7aLFei1DJKV0oNVUbrv_Q_0CTw-qD85Xf8UM,904
|
|
404
402
|
model_compression_toolkit/ptq/runner.py,sha256=1tVx3Yj5X4ZjTH0REm6fuAmv4QZ4u_vixLsgjBwBzxc,2326
|
|
405
403
|
model_compression_toolkit/ptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
406
|
-
model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=
|
|
404
|
+
model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=QieIYkoVtxWgHILGYeMHvSeIZ3mdbWdTR76MuUJ08-I,11433
|
|
407
405
|
model_compression_toolkit/ptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
408
406
|
model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=RruQVxS4ylBjSH1KMh8ZCV8jk3OvtSrQl24m3Q4xs_8,10065
|
|
409
407
|
model_compression_toolkit/qat/__init__.py,sha256=AaC4KBha4jDW_tyg2SOxZaKh_idIz0gZtDK3_zxs64E,1241
|
|
410
408
|
model_compression_toolkit/qat/common/__init__.py,sha256=6tLZ4R4pYP6QVztLVQC_jik2nES3l4uhML0qUxZrezk,829
|
|
411
409
|
model_compression_toolkit/qat/common/qat_config.py,sha256=QNXj2OcKIJOGvGEGzR2GCifI5Ho7FS7zFc2fkj6PJAc,2750
|
|
412
410
|
model_compression_toolkit/qat/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
413
|
-
model_compression_toolkit/qat/keras/quantization_facade.py,sha256=
|
|
411
|
+
model_compression_toolkit/qat/keras/quantization_facade.py,sha256=kwPnkI8nHrUK650-MyJcLweOj7Q7IbFMbmMpP_lRwTg,17275
|
|
414
412
|
model_compression_toolkit/qat/keras/quantizer/__init__.py,sha256=zmYyCa25_KLCSUCGUDRslh3RCIjcRMxc_oXa54Aui-4,996
|
|
415
413
|
model_compression_toolkit/qat/keras/quantizer/base_keras_qat_weight_quantizer.py,sha256=EbIt4lMlh6cU4awFLMBp0IlZ2zUUp-WtnlW5Wn19FDM,1793
|
|
416
414
|
model_compression_toolkit/qat/keras/quantizer/quant_utils.py,sha256=cBULOgWUodcBO1lHevZggdTevuDYI6tQceV86U2x6DA,2543
|
|
@@ -473,16 +471,16 @@ model_compression_toolkit/trainable_infrastructure/common/__init__.py,sha256=huH
|
|
|
473
471
|
model_compression_toolkit/trainable_infrastructure/common/annealing_schedulers.py,sha256=qm2_wa61nga08Jdcl3RkgTsJ0zyHNjZ_A6I2--oVOig,2455
|
|
474
472
|
model_compression_toolkit/trainable_infrastructure/common/base_trainable_quantizer.py,sha256=IF50ASBUvVrOVqlJ1nHNxZxKXSuCanjhUX0YjMB-rRg,7946
|
|
475
473
|
model_compression_toolkit/trainable_infrastructure/common/constants.py,sha256=HN120boJxAnEXNrLSj-o_s-VX4o6C-1ap_KZ4840sd0,875
|
|
476
|
-
model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py,sha256=
|
|
474
|
+
model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py,sha256=Wv333zLZH0z5N2Uio0beafrUVV-zZcxZvvSKr-fDChY,6839
|
|
477
475
|
model_compression_toolkit/trainable_infrastructure/common/get_quantizers.py,sha256=10edXuhu6F00EcMU7M29AlK7rF_uoLQjMjctrWqK5KU,3346
|
|
478
476
|
model_compression_toolkit/trainable_infrastructure/common/quant_utils.py,sha256=zdiew1jwR7tUKm9XWlHnAPxIZsAdKqbzzC2vH02j5wA,1505
|
|
479
|
-
model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py,sha256=
|
|
477
|
+
model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py,sha256=VLj3YWKJi6MMbEcPq7dDd4DMsXorM6F4Zi7tlFB4YL4,4394
|
|
480
478
|
model_compression_toolkit/trainable_infrastructure/common/training_method.py,sha256=LUoeJkloowhZKuHTiOfzjmSUn2G-4of11-rbnL-h0P4,1194
|
|
481
479
|
model_compression_toolkit/trainable_infrastructure/common/util.py,sha256=oKuWi7E07a8zv5x9auhBugYE2RUQ7ojDh2XCs5koYJY,1090
|
|
482
480
|
model_compression_toolkit/trainable_infrastructure/keras/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
|
483
481
|
model_compression_toolkit/trainable_infrastructure/keras/annealing_schedulers.py,sha256=sISNVxPsdm-Nd95PhoPSJ-2tFpINGlfrU7ZXaCByI-o,1278
|
|
484
482
|
model_compression_toolkit/trainable_infrastructure/keras/base_keras_quantizer.py,sha256=LBc26z8pkpbcdKMTxpNBg5IyChLreHQ1lRgCVjNE37o,4202
|
|
485
|
-
model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py,sha256=
|
|
483
|
+
model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py,sha256=rsyY7mTev5qRI0ncxuBilgYiRMd5GcuZt6qFyADU4H4,4123
|
|
486
484
|
model_compression_toolkit/trainable_infrastructure/keras/load_model.py,sha256=DJHibcLo-UCuHV6UPLeVd7dKmPfkGXEiLqCCqvQrISM,3769
|
|
487
485
|
model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha256=eVB5FSE3OmTLrhfLUcP2knwN1z2_unQLM-xFEGwdafA,5587
|
|
488
486
|
model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=r3CaPd4pyM1GDXU2--9NT3wwvl9H6y3QUrVT9spx5es,4189
|
|
@@ -532,7 +530,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
|
532
530
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=Y0oBl8qPFsdNrK49XczwmVacInJcOPHslVnFBs-iTCc,3742
|
|
533
531
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
|
534
532
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=n0HvWBzkBkUJZlS3WeynhpsRTps2qQkjlq7luliBHNU,9627
|
|
535
|
-
mct_nightly-2.4.0.
|
|
536
|
-
mct_nightly-2.4.0.
|
|
537
|
-
mct_nightly-2.4.0.
|
|
538
|
-
mct_nightly-2.4.0.
|
|
533
|
+
mct_nightly-2.4.0.20250707.643.dist-info/METADATA,sha256=ZQC6URpJO-t-tH6PFUmNg-09N-hh6EnJg-P8ykkesoA,25555
|
|
534
|
+
mct_nightly-2.4.0.20250707.643.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
535
|
+
mct_nightly-2.4.0.20250707.643.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
|
536
|
+
mct_nightly-2.4.0.20250707.643.dist-info/RECORD,,
|
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
|
27
27
|
from model_compression_toolkit import pruning
|
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
|
29
29
|
|
|
30
|
-
__version__ = "2.4.0.
|
|
30
|
+
__version__ = "2.4.0.20250707.000643"
|
|
@@ -13,11 +13,12 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
+
from abc import ABC, abstractmethod
|
|
16
17
|
import numpy as np
|
|
17
18
|
from model_compression_toolkit.logger import Logger
|
|
18
19
|
|
|
19
20
|
|
|
20
|
-
class BaseCollector(
|
|
21
|
+
class BaseCollector(ABC):
|
|
21
22
|
"""
|
|
22
23
|
Base class for statistics collection object.
|
|
23
24
|
"""
|
|
@@ -26,6 +27,7 @@ class BaseCollector(object):
|
|
|
26
27
|
# When manipulation statistics in a granularity they were not collected by, the data is invalid.
|
|
27
28
|
self.is_legal = True
|
|
28
29
|
|
|
30
|
+
@abstractmethod
|
|
29
31
|
def scale(self, scale_factor: np.ndarray):
|
|
30
32
|
"""
|
|
31
33
|
Scale all statistics in collector by some factor.
|
|
@@ -37,6 +39,7 @@ class BaseCollector(object):
|
|
|
37
39
|
raise NotImplemented(
|
|
38
40
|
f'{self.__class__.__name__} needs to implement scale operation for its state.') # pragma: no cover
|
|
39
41
|
|
|
42
|
+
@abstractmethod
|
|
40
43
|
def shift(self, shift_value: np.ndarray):
|
|
41
44
|
"""
|
|
42
45
|
Shift all statistics in collector by some value.
|
|
@@ -87,10 +87,13 @@ class MeanCollector(BaseCollector):
|
|
|
87
87
|
x: Tensor that goes through the mean collector and needs to be considered in the mean computation.
|
|
88
88
|
"""
|
|
89
89
|
self.i += 1 # Update the iteration index
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
90
|
+
if self.axis is None:
|
|
91
|
+
mu = np.mean(np.reshape(x, [1, -1]), axis=-1) # mean per channel for a batch
|
|
92
|
+
else:
|
|
93
|
+
axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
|
|
94
|
+
n = x.shape[axis]
|
|
95
|
+
transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
|
|
96
|
+
mu = np.mean(np.reshape(np.transpose(x, transpose_index), [n, -1]), axis=-1) # mean per channel for a batch
|
|
94
97
|
self.current_sum += mu # sum of all batches
|
|
95
98
|
self.current_mean = self.current_sum / self.i # mean of all batches
|
|
96
99
|
|
|
@@ -130,10 +130,13 @@ class MinMaxPerChannelCollector(BaseCollector):
|
|
|
130
130
|
x: Tensor that goes through the collector and needs to be considered in the min/max computation.
|
|
131
131
|
"""
|
|
132
132
|
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
133
|
+
if self.axis is None:
|
|
134
|
+
x_reshape = np.reshape(x, [1, -1])
|
|
135
|
+
else:
|
|
136
|
+
axis = (len(x.shape) - 1) if self.axis == LAST_AXIS else self.axis
|
|
137
|
+
n = x.shape[axis]
|
|
138
|
+
transpose_index = [axis, *[i for i in range(len(x.shape)) if i != axis]]
|
|
139
|
+
x_reshape = np.reshape(np.transpose(x, transpose_index), [n, -1])
|
|
137
140
|
if self.state is None:
|
|
138
141
|
x_max = np.max(x_reshape, axis=-1)
|
|
139
142
|
x_min = np.min(x_reshape, axis=-1)
|
|
@@ -157,6 +157,17 @@ class ModelCollector:
|
|
|
157
157
|
for n in graph.get_topo_sorted_nodes():
|
|
158
158
|
quant_node_in_fln = n.is_fln_quantization() and graph.fusing_info.is_quantized_node_in_fln(n)
|
|
159
159
|
sc = create_stats_collector_for_node(n, quant_node_in_fln=quant_node_in_fln) # Get static collector for the node
|
|
160
|
+
if isinstance(sc, common.StatsCollector) and (sc.mc.axis is None or sc.mpcc.axis is None):
|
|
161
|
+
# Missing output channel axis info, so try to extract it from previous and next nodes output channel axis.
|
|
162
|
+
possible_output_channel_axis_set = {nn.out_channel_axis for nn in graph.get_next_nodes(n) + graph.get_prev_nodes(n)}
|
|
163
|
+
# Filter out None values.
|
|
164
|
+
possible_output_channel_axis_list = list(filter(lambda x: x is not None, possible_output_channel_axis_set))
|
|
165
|
+
if len(possible_output_channel_axis_list) > 0:
|
|
166
|
+
if len(possible_output_channel_axis_list) > 1:
|
|
167
|
+
Logger.warning(f'Ambiguous input channel data from next nodes for {n.name}.')
|
|
168
|
+
sc.mc.axis = possible_output_channel_axis_list[0]
|
|
169
|
+
sc.mpcc.axis = possible_output_channel_axis_list[0]
|
|
170
|
+
|
|
160
171
|
# If we use bias correction, and the node has kernel weights to quantize, we need to make sure
|
|
161
172
|
# its previous nodes' tensors are consistent with this node.
|
|
162
173
|
if qc.weights_bias_correction and n.kernel_attr is not None and n.is_weights_quantization_enabled(
|
|
@@ -303,7 +303,7 @@ class MemoryCalculator:
|
|
|
303
303
|
num_oc = np.sum(output_mask)
|
|
304
304
|
else:
|
|
305
305
|
# Get the node channel axis from framework info
|
|
306
|
-
channel_axis = node.out_channel_axis
|
|
306
|
+
channel_axis = self.fw_impl.default_output_channel_axis if node.out_channel_axis is None else node.out_channel_axis
|
|
307
307
|
if channel_axis is None:
|
|
308
308
|
Logger.critical(f"The channel axis is undefined. Please ensure the channel axis is explicitly defined for node {node.type} in the framework info.")
|
|
309
309
|
|
|
@@ -18,7 +18,6 @@ from enum import Enum, auto
|
|
|
18
18
|
from model_compression_toolkit.core.common.framework_info import ChannelAxisMapping
|
|
19
19
|
from model_compression_toolkit.logger import Logger
|
|
20
20
|
|
|
21
|
-
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
|
|
22
21
|
from model_compression_toolkit.target_platform_capabilities.constants import POSITIONAL_ATTR
|
|
23
22
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import \
|
|
24
23
|
AttributeQuantizationConfig, OpQuantizationConfig
|
|
@@ -41,6 +40,7 @@ class ActivationQuantizationMode(Enum):
|
|
|
41
40
|
NO_QUANT = auto()
|
|
42
41
|
FLN_NO_QUANT = auto()
|
|
43
42
|
|
|
43
|
+
|
|
44
44
|
class BaseNodeQuantizationConfig(object):
|
|
45
45
|
"""
|
|
46
46
|
Base class for node quantization configuration
|
|
@@ -59,12 +59,11 @@ class BaseNodeQuantizationConfig(object):
|
|
|
59
59
|
kwargs: A dictionary with additional key arguments.
|
|
60
60
|
|
|
61
61
|
"""
|
|
62
|
-
|
|
63
62
|
if hasattr(self, config_parameter_name):
|
|
64
63
|
setattr(self, config_parameter_name, config_parameter_value)
|
|
65
64
|
else:
|
|
66
|
-
|
|
67
|
-
|
|
65
|
+
raise AttributeError(
|
|
66
|
+
f"Parameter {config_parameter_name} could not be found in the node quantization config.")
|
|
68
67
|
|
|
69
68
|
def __repr__(self) -> str:
|
|
70
69
|
"""
|
|
@@ -97,37 +96,9 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
97
96
|
self.signedness = op_cfg.signedness
|
|
98
97
|
|
|
99
98
|
self.activation_quantization_params = {}
|
|
100
|
-
# TODO
|
|
99
|
+
# TODO: computed by compute_activation_bias_correction. Probably shouldnt be here.
|
|
101
100
|
self.activation_bias_correction_term = None
|
|
102
101
|
|
|
103
|
-
# TODO irena remove along with set_qc. Keeping for eq and hash to work without set_qc being called
|
|
104
|
-
self.activation_error_method = None
|
|
105
|
-
self.relu_bound_to_power_of_2 = None
|
|
106
|
-
self.activation_channel_equalization = None
|
|
107
|
-
self.input_scaling = None
|
|
108
|
-
self.min_threshold = None
|
|
109
|
-
self.l_p_value = None
|
|
110
|
-
self.shift_negative_activation_correction = None
|
|
111
|
-
self.z_threshold = None
|
|
112
|
-
self.shift_negative_ratio = None
|
|
113
|
-
self.shift_negative_threshold_recalculation = None
|
|
114
|
-
self.concat_threshold_update = None
|
|
115
|
-
|
|
116
|
-
def set_qc(self, qc: QuantizationConfig):
|
|
117
|
-
""" TODO irena: temporary keep all the attributes as before not to break all code at once.
|
|
118
|
-
Eventually all of them should be removed from here. """
|
|
119
|
-
self.activation_error_method = qc.activation_error_method
|
|
120
|
-
self.relu_bound_to_power_of_2 = qc.relu_bound_to_power_of_2
|
|
121
|
-
self.activation_channel_equalization = qc.activation_channel_equalization
|
|
122
|
-
self.input_scaling = qc.input_scaling
|
|
123
|
-
self.min_threshold = qc.min_threshold
|
|
124
|
-
self.l_p_value = qc.l_p_value
|
|
125
|
-
self.shift_negative_activation_correction = qc.shift_negative_activation_correction
|
|
126
|
-
self.z_threshold = qc.z_threshold
|
|
127
|
-
self.shift_negative_ratio = qc.shift_negative_ratio
|
|
128
|
-
self.shift_negative_threshold_recalculation = qc.shift_negative_threshold_recalculation
|
|
129
|
-
self.concat_threshold_update = qc.concat_threshold_update
|
|
130
|
-
|
|
131
102
|
@property
|
|
132
103
|
def enable_activation_quantization(self):
|
|
133
104
|
return self.quant_mode == ActivationQuantizationMode.QUANT
|
|
@@ -165,32 +136,16 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
165
136
|
if not isinstance(other, NodeActivationQuantizationConfig):
|
|
166
137
|
return False # pragma: no cover
|
|
167
138
|
|
|
168
|
-
return self.
|
|
169
|
-
self.activation_quantization_method == other.activation_quantization_method and \
|
|
139
|
+
return self.activation_quantization_method == other.activation_quantization_method and \
|
|
170
140
|
self.activation_n_bits == other.activation_n_bits and \
|
|
171
141
|
self.quant_mode == other.quant_mode and \
|
|
172
|
-
self.
|
|
173
|
-
self.input_scaling == other.input_scaling and \
|
|
174
|
-
self.min_threshold == other.min_threshold and \
|
|
175
|
-
self.l_p_value == other.l_p_value and \
|
|
176
|
-
self.shift_negative_activation_correction == other.shift_negative_activation_correction and \
|
|
177
|
-
self.z_threshold == other.z_threshold and \
|
|
178
|
-
self.shift_negative_ratio == other.shift_negative_ratio and \
|
|
179
|
-
self.shift_negative_threshold_recalculation == other.shift_negative_threshold_recalculation
|
|
142
|
+
self.signedness == other.signedness
|
|
180
143
|
|
|
181
144
|
def __hash__(self):
|
|
182
|
-
return hash((self.
|
|
183
|
-
self.activation_quantization_method,
|
|
145
|
+
return hash((self.activation_quantization_method,
|
|
184
146
|
self.activation_n_bits,
|
|
185
147
|
self.quant_mode,
|
|
186
|
-
self.
|
|
187
|
-
self.input_scaling,
|
|
188
|
-
self.min_threshold,
|
|
189
|
-
self.l_p_value,
|
|
190
|
-
self.shift_negative_activation_correction,
|
|
191
|
-
self.z_threshold,
|
|
192
|
-
self.shift_negative_ratio,
|
|
193
|
-
self.shift_negative_threshold_recalculation))
|
|
148
|
+
self.signedness))
|
|
194
149
|
|
|
195
150
|
|
|
196
151
|
class WeightsAttrQuantizationConfig:
|
|
@@ -211,16 +166,8 @@ class WeightsAttrQuantizationConfig:
|
|
|
211
166
|
self.weights_n_bits = weights_attr_cfg.weights_n_bits
|
|
212
167
|
self.weights_per_channel_threshold = weights_attr_cfg.weights_per_channel_threshold
|
|
213
168
|
self.enable_weights_quantization = weights_attr_cfg.enable_weights_quantization
|
|
214
|
-
self.weights_quantization_params = {}
|
|
215
169
|
|
|
216
|
-
|
|
217
|
-
self.weights_error_method = None
|
|
218
|
-
self.l_p_value = None
|
|
219
|
-
|
|
220
|
-
def set_qc(self, qc: QuantizationConfig):
|
|
221
|
-
# TODO irena: temporary keep the fields to not break everything at once.
|
|
222
|
-
self.weights_error_method = qc.weights_error_method
|
|
223
|
-
self.l_p_value = qc.l_p_value
|
|
170
|
+
self.weights_quantization_params = {}
|
|
224
171
|
|
|
225
172
|
def set_weights_quantization_param(self,
|
|
226
173
|
weights_params: dict):
|
|
@@ -252,18 +199,14 @@ class WeightsAttrQuantizationConfig:
|
|
|
252
199
|
self.weights_quantization_method == other.weights_quantization_method and \
|
|
253
200
|
self.weights_n_bits == other.weights_n_bits and \
|
|
254
201
|
self.weights_per_channel_threshold == other.weights_per_channel_threshold and \
|
|
255
|
-
self.enable_weights_quantization == other.enable_weights_quantization
|
|
256
|
-
self.weights_error_method == other.weights_error_method and \
|
|
257
|
-
self.l_p_value == other.l_p_value
|
|
202
|
+
self.enable_weights_quantization == other.enable_weights_quantization
|
|
258
203
|
|
|
259
204
|
def __hash__(self):
|
|
260
205
|
return hash((self.weights_channels_axis,
|
|
261
|
-
self.weights_error_method,
|
|
262
206
|
self.weights_quantization_method,
|
|
263
207
|
self.weights_n_bits,
|
|
264
208
|
self.weights_per_channel_threshold,
|
|
265
|
-
self.enable_weights_quantization
|
|
266
|
-
self.l_p_value))
|
|
209
|
+
self.enable_weights_quantization))
|
|
267
210
|
|
|
268
211
|
|
|
269
212
|
class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
@@ -330,16 +273,14 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
330
273
|
|
|
331
274
|
self.attributes_config_mapping[attr] = WeightsAttrQuantizationConfig(weights_attr_cfg=attr_cfg,
|
|
332
275
|
weights_channels_axis=weights_channels_axis)
|
|
333
|
-
# TODO
|
|
334
|
-
|
|
276
|
+
# TODO this is set by batch norm reconstruction substitution when folded batch norms are added back, to mark
|
|
277
|
+
# the nodes that the correction should be applied to (for some nodes it gets disabled) and BNs removed.
|
|
278
|
+
# The actual correction is only computed when it's applied in ptq, so it seems that both substitutions could
|
|
279
|
+
# be unified, and no info need to pass between.
|
|
335
280
|
self.weights_second_moment_correction = None
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
# TODO irena: temporary keep the fields to not break everything at once.
|
|
340
|
-
self.min_threshold = qc.min_threshold
|
|
341
|
-
self.weights_second_moment_correction = qc.weights_second_moment_correction
|
|
342
|
-
self.weights_bias_correction = qc.weights_bias_correction
|
|
281
|
+
# TODO: computed corrected bias is injected to the node config. Probably shouldn't be here. Also it can be
|
|
282
|
+
# computed on the final config, instead of all candidates and then there is no need to save it at all.
|
|
283
|
+
self.bias_corrected = None
|
|
343
284
|
|
|
344
285
|
def get_attr_config(self, attr_name: 'WeightAttrT') -> WeightsAttrQuantizationConfig:
|
|
345
286
|
"""
|
|
@@ -476,8 +417,8 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
476
417
|
if hasattr(attr_cfg, config_parameter_name):
|
|
477
418
|
setattr(attr_cfg, config_parameter_name, config_parameter_value)
|
|
478
419
|
else:
|
|
479
|
-
|
|
480
|
-
|
|
420
|
+
raise AttributeError(f"Parameter {config_parameter_name} could not be found in the node quantization config of "
|
|
421
|
+
f"weights attribute {attr_name}.")
|
|
481
422
|
else: # pragma: no cover
|
|
482
423
|
Logger.critical(f"Weights attribute {attr_name} could not be found to set parameter {config_parameter_name}.")
|
|
483
424
|
|
|
@@ -494,10 +435,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
494
435
|
if not isinstance(other, NodeWeightsQuantizationConfig):
|
|
495
436
|
return False # pragma: no cover
|
|
496
437
|
|
|
497
|
-
return self.
|
|
498
|
-
self.simd_size == other.simd_size and \
|
|
499
|
-
self.weights_second_moment_correction == other.weights_second_moment_correction and \
|
|
500
|
-
self.weights_bias_correction == other.weights_bias_correction and \
|
|
438
|
+
return self.simd_size == other.simd_size and \
|
|
501
439
|
self.attributes_config_mapping.keys() == other.attributes_config_mapping.keys() and \
|
|
502
440
|
all([self.attributes_config_mapping[k] == other.attributes_config_mapping[k]
|
|
503
441
|
for k in self.attributes_config_mapping.keys()]) and \
|
|
@@ -506,9 +444,6 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
506
444
|
for k in self.pos_attributes_config_mapping.keys()])
|
|
507
445
|
|
|
508
446
|
def __hash__(self):
|
|
509
|
-
return hash((self.
|
|
510
|
-
self.simd_size,
|
|
511
|
-
self.weights_second_moment_correction,
|
|
512
|
-
self.weights_bias_correction,
|
|
447
|
+
return hash((self.simd_size,
|
|
513
448
|
frozenset(self.attributes_config_mapping),
|
|
514
449
|
frozenset(self.pos_attributes_config_mapping)))
|
|
@@ -90,7 +90,6 @@ class QuantizationConfig:
|
|
|
90
90
|
shift_negative_activation_correction: bool = True
|
|
91
91
|
activation_channel_equalization: bool = False
|
|
92
92
|
z_threshold: float = math.inf
|
|
93
|
-
min_threshold: float = MIN_THRESHOLD
|
|
94
93
|
l_p_value: int = 2
|
|
95
94
|
linear_collapsing: bool = True
|
|
96
95
|
residual_collapsing: bool = True
|