mct-nightly 2.3.0.20250408.522__py3-none-any.whl → 2.3.0.20250410.526__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {mct_nightly-2.3.0.20250408.522.dist-info → mct_nightly-2.3.0.20250410.526.dist-info}/METADATA +2 -1
  2. {mct_nightly-2.3.0.20250408.522.dist-info → mct_nightly-2.3.0.20250410.526.dist-info}/RECORD +33 -33
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/framework_implementation.py +11 -0
  5. model_compression_toolkit/core/common/fusion/fusing_info.py +4 -5
  6. model_compression_toolkit/core/common/graph/base_graph.py +2 -1
  7. model_compression_toolkit/core/common/graph/base_node.py +15 -19
  8. model_compression_toolkit/core/common/hessian/hessian_scores_calculator.py +4 -2
  9. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +2 -2
  10. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +2 -1
  11. model_compression_toolkit/core/common/quantization/debug_config.py +2 -0
  12. model_compression_toolkit/core/common/quantization/node_quantization_config.py +31 -6
  13. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +5 -3
  14. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +1 -2
  15. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +3 -2
  16. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +3 -2
  17. model_compression_toolkit/core/keras/keras_implementation.py +14 -1
  18. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
  19. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py +0 -1
  20. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
  21. model_compression_toolkit/core/pytorch/pytorch_implementation.py +16 -3
  22. model_compression_toolkit/core/pytorch/reader/reader.py +28 -7
  23. model_compression_toolkit/core/pytorch/utils.py +2 -2
  24. model_compression_toolkit/gptq/keras/quantization_facade.py +6 -2
  25. model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -2
  26. model_compression_toolkit/ptq/keras/quantization_facade.py +7 -2
  27. model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -2
  28. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py +2 -10
  29. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +2 -1
  30. model_compression_toolkit/verify_packages.py +0 -1
  31. {mct_nightly-2.3.0.20250408.522.dist-info → mct_nightly-2.3.0.20250410.526.dist-info}/WHEEL +0 -0
  32. {mct_nightly-2.3.0.20250408.522.dist-info → mct_nightly-2.3.0.20250410.526.dist-info}/licenses/LICENSE.md +0 -0
  33. {mct_nightly-2.3.0.20250408.522.dist-info → mct_nightly-2.3.0.20250410.526.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mct-nightly
3
- Version: 2.3.0.20250408.522
3
+ Version: 2.3.0.20250410.526
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: License :: OSI Approved :: Apache Software License
@@ -22,6 +22,7 @@ Requires-Dist: scipy
22
22
  Requires-Dist: protobuf
23
23
  Requires-Dist: mct-quantizers-nightly
24
24
  Requires-Dist: pydantic<2.0
25
+ Requires-Dist: sony-custom-layers-dev==0.4.0.dev6
25
26
  Dynamic: classifier
26
27
  Dynamic: description
27
28
  Dynamic: description-content-type
@@ -1,10 +1,10 @@
1
- mct_nightly-2.3.0.20250408.522.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
- model_compression_toolkit/__init__.py,sha256=Nxtw-bh_Op7j74mOVtyUSxB8W47zj0_P8k7LJLmHpwU,1557
1
+ mct_nightly-2.3.0.20250410.526.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
+ model_compression_toolkit/__init__.py,sha256=hH7K5n9ZDkBMIcujI3umeAh3pUxoyxZu2pWu83zoGgk,1557
3
3
  model_compression_toolkit/constants.py,sha256=2ltuH-gdaLZoZV4CPUgKjC3S9ojz2z4OTVdenyVEypU,3912
4
4
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
5
5
  model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
6
6
  model_compression_toolkit/metadata.py,sha256=x_Bk4VpzILdsFax6--CZ3X18qUTP28sbF_AhoQW8dNc,4003
7
- model_compression_toolkit/verify_packages.py,sha256=TlS-K1EP-QsghqWUW7SDPkAJiUf7ryw4tvhFDe6rCUk,1405
7
+ model_compression_toolkit/verify_packages.py,sha256=l0neIRr8q_QwxmuiTI4vyCMDISDedK0EihjEQUe66tE,1319
8
8
  model_compression_toolkit/core/__init__.py,sha256=8a0wUNBKwTdJGDk_Ho6WQAXjGuCqQZG1FUxxJlAV8L8,2096
9
9
  model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
10
10
  model_compression_toolkit/core/graph_prep_runner.py,sha256=C6eUTd-fcgxk0LUbt51gFZwmyDDDEB8-9Q4kr9ujYvI,11555
@@ -12,7 +12,7 @@ model_compression_toolkit/core/quantization_prep_runner.py,sha256=DPevqQ8brkdut8
12
12
  model_compression_toolkit/core/runner.py,sha256=_r6cieb7Ur2BeHQK5XxTZHogjyA0utybvIVbH06CBHY,13056
13
13
  model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
14
14
  model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
15
- model_compression_toolkit/core/common/framework_implementation.py,sha256=s3yiqnbWkwfnAB1sSal_KAuqVg27rLhAJ2O8LHUbSHE,22494
15
+ model_compression_toolkit/core/common/framework_implementation.py,sha256=L88uv_sfYM_56FSmxXP--emjv01_lk7IPqOI7QBZEt0,22939
16
16
  model_compression_toolkit/core/common/framework_info.py,sha256=RWeZfQOPiBroU2v4AeZoquVunNtZ4UORjOr2aRAPu8o,6279
17
17
  model_compression_toolkit/core/common/memory_computation.py,sha256=ixoSpV5ZYZGyzhre3kQcvR2sNA8KBsPZ3lgbkDnw9Cs,1205
18
18
  model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3ug7Y9-lLyV99_FoNHxkGZMgcm0Vkpss,1324
@@ -31,11 +31,11 @@ model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.p
31
31
  model_compression_toolkit/core/common/collectors/statistics_collector.py,sha256=psijsQZefwjMDH8SU5E18n65HiGtQilPhKr1hhzZX-I,8268
32
32
  model_compression_toolkit/core/common/collectors/weighted_histogram_collector.py,sha256=zp3dE7YTqWmkD5QWdRhsl9zD8W6Lr96G1Wjw1g2D3T0,4894
33
33
  model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
34
- model_compression_toolkit/core/common/fusion/fusing_info.py,sha256=LfzVS9B6r2KCwf8rcCUdepEQhWkt287SoXfwoudpfFo,15496
34
+ model_compression_toolkit/core/common/fusion/fusing_info.py,sha256=W8qZejLwbm-lkvNF3GepNL3ypO10vFRxOxbq-o_rt_I,15479
35
35
  model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=F0AaAUBpJ9JjHMB5H2LD9pdwTSWJK-Kqm9dQmGHX1Jo,7368
36
36
  model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
37
- model_compression_toolkit/core/common/graph/base_graph.py,sha256=hedhjVula5rPv0vN0CLBDtPYM8SH3cM6FAL62aFfF7U,41767
38
- model_compression_toolkit/core/common/graph/base_node.py,sha256=CJu8_r80MGVnYmlAUGOGKGRsD9xShMyaRNb3VMeRC0s,34523
37
+ model_compression_toolkit/core/common/graph/base_graph.py,sha256=3OhaMHW01okwFY4mSy0ERFCJk8AZPDs8bCKAmjvmJEI,41893
38
+ model_compression_toolkit/core/common/graph/base_node.py,sha256=Yl6GdjnP_Rt9w1lQUm00CJI0JUAffQF7wr6mur_YfbA,34124
39
39
  model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
40
40
  model_compression_toolkit/core/common/graph/functional_node.py,sha256=GH5wStmw8SoAj5IdT_-ItN1Meo_P5NUTt_5bgJC4fak,3935
41
41
  model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
@@ -51,7 +51,7 @@ model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py,sha256=
51
51
  model_compression_toolkit/core/common/hessian/__init__.py,sha256=E7LK3K_1AwMCQokanNc1JODMwUKNOKmwXQiGQ7GO10I,1033
52
52
  model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=8NDC_WLe3ZnY_v3e_Vz_lseF22lrbvhFmArihpeWfuI,14291
53
53
  model_compression_toolkit/core/common/hessian/hessian_info_utils.py,sha256=1axmN0tjJSo_7hUr2d2KMv4y1pBi19cqWSQpi4BbdsA,1458
54
- model_compression_toolkit/core/common/hessian/hessian_scores_calculator.py,sha256=NHC2WQcTK4MLOuKlmELR8XoDTt_h8KwvpNy2o94azrI,4238
54
+ model_compression_toolkit/core/common/hessian/hessian_scores_calculator.py,sha256=wqKPfAJgXiV7zD2DufbOU5HcOLi-44Fv9PWdVgFMGaw,4354
55
55
  model_compression_toolkit/core/common/hessian/hessian_scores_request.py,sha256=ZNdwDzW7QF2A-w1Ye4P2xn5erTQnoTXk5z_b17HDGH4,3391
56
56
  model_compression_toolkit/core/common/matchers/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
57
57
  model_compression_toolkit/core/common/matchers/base_graph_filter.py,sha256=mTk54z0mIbFmPOb4h0xfLtLDookcFyNh8H0pIN5js_M,3091
@@ -70,12 +70,12 @@ model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantizati
70
70
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256=2Pp4hiYvGW2I9YhloDxQNT0sZRg3TDp9CXObloF8IFU,4971
71
71
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=GGrp7QngrWvWtPN8cQnL4IEbNwcVRc-hAUqfnxjjMmk,5998
72
72
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=NBzzhkVI407S9cIiw7t7nsP3MrkOdSnweKQdPBXb8to,38180
73
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=gsigifJ-ykWNafF4t7UMEC_-nd6YPERAk1_z0kT-Y88,27172
73
+ model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=4bkM8pYKvk18cxHbx973Dz6qWrNT0MRm44cuk__qVaI,27297
74
74
  model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=P8QtKgFXtt5b2RoubzI5OGlCfbEfZsAirjyrkFzK26A,2846
75
75
  model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=fk7PWiZ6Na5O_Z_dymk_UfDCTqW_X_4EROU7DZknQnc,9444
76
76
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
77
77
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=PKkhc5q8pEPnNLXwo3U56EOCfYnPXIvPs0LlCGZOoKU,4426
78
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=MP4Q5lThvEIhfa1iBajQQM3nCUNgK-2yseqQQ8Rgiog,40624
78
+ model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=cjFnpDvxZDE4K2sgt26DhosA2XqhxHDs0eW5Qe7AwAQ,40668
79
79
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=QQwtl08DiDxUOQGpYPnek_RlZjWm1Ky7tL2ESHXMK78,4050
80
80
  model_compression_toolkit/core/common/mixed_precision/search_methods/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
81
81
  model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=TaK5NqVdmygsHw9_x5JsJ-BPvlbKA9cRyTno1R8gbnU,7269
@@ -104,15 +104,15 @@ model_compression_toolkit/core/common/quantization/__init__.py,sha256=sw7LOPN1bM
104
104
  model_compression_toolkit/core/common/quantization/bit_width_config.py,sha256=0HA3CIZW-ZrA55ra-yJXRvAYnoR8i1SjpbnMDKcWYNQ,12819
105
105
  model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py,sha256=lyWPvnoX8BmulhLKR20r5gT2_Yan7P40d8EcgDhErPk,4905
106
106
  model_compression_toolkit/core/common/quantization/core_config.py,sha256=yxCzWqldcHoe8GGxrH0tp99bhrc5jDT7SgZftnMUUBE,2374
107
- model_compression_toolkit/core/common/quantization/debug_config.py,sha256=zJP2W9apUPX9RstpPWWK71wr9xJsg7j-s7lGV4_bQdc,1510
107
+ model_compression_toolkit/core/common/quantization/debug_config.py,sha256=uH45Uq3Tp9FIyMynex_WY2_y-Kv8LuPw2XXZydnpW5A,1649
108
108
  model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=n2A8pO7_DMMae4o69U0I00iW6mzeRlRfKHDxlQUBBuI,7204
109
- model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=tAVQlDp7Zt9xncFFU39JCjDBarACRwz-Do_-6pUGMB0,28530
109
+ model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=gL3XUm85FBLvtF60jmWkPxITOBw7cs66scNtC7QHW-M,29471
110
110
  model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=UkSVW7d1OF_Px9gAjsqqK65aYhIBFWaBO-_IH6_AFfg,4403
111
111
  model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=HfBkSiRTOf9mNF-TNQHTCCs3xSg66F20no0O6vl5v1Y,2154
112
112
  model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=7eG7dl1TcbdnHwgmvyjarxLs0o6Lw_9VAjXAm4rsiBk,3791
113
113
  model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=N005MSvx8UypVpa7XrxNrB2G732n2wHj3RmLyjTgd3I,2728
114
114
  model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=cdzGNWfT4MRogIU8ehs0tr3lVjnzAI-jeoS9b4TwVBo,2854
115
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=u0pVJawyUTgatn2L8qMNBac2Cut3HSPZSytBGDuBB0k,21341
115
+ model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=3jyOBaRFwoZQsiyB1nF7ayox1XSo6jf2fDc9V30wKkg,21431
116
116
  model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=eCDGwsWYLU6z7qbEVb4TozMW_nd5VEP_iCJ6PcvyEPw,1486
117
117
  model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=_m-XkEMJMHf0gYwVIXAoHVjdRa2NXt_gYdwBlw76ZR8,24031
118
118
  model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=RL-PklAjGyC-26anSt8fU07a6pB_LBQFQy9o4e9giN0,8739
@@ -132,20 +132,20 @@ model_compression_toolkit/core/common/statistics_correction/__init__.py,sha256=s
132
132
  model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py,sha256=b05ZwQ2CwG0Q-yqs9A1uHfP8o17aGEZFCeJNP1p4IWk,4450
133
133
  model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py,sha256=b5clhUWGoDaQLn2pDCeYkV0FomVebcKS8pMXtQTTzIg,4679
134
134
  model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py,sha256=C_nwhhitTd1pCto0nHZPn3fjIMOeDD7VIciumTR3s6k,5641
135
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py,sha256=ov9-WYktWKqRquibwyARR81QVT9TfPWAoTTfnKOQSd0,9273
135
+ model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py,sha256=F8kK8yoYCGeTdXUsHGcM3T2tRdjSlcWg3UToGtovNOs,9196
136
136
  model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py,sha256=LaGhYES7HgIDf9Bi2KAG_mBzAWuum0J6AGmAFPC8wwo,10478
137
137
  model_compression_toolkit/core/common/statistics_correction/statistics_correction.py,sha256=E0ZA4edimJwpHh9twI5gafcoJ9fX5F1JX2QUOkUOKEw,6250
138
138
  model_compression_toolkit/core/common/substitutions/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
139
139
  model_compression_toolkit/core/common/substitutions/apply_substitutions.py,sha256=k-bifmakHIYZeZS-4T1QpZ1Et6AwAijMRgAKs7hmMKc,1390
140
140
  model_compression_toolkit/core/common/substitutions/batchnorm_folding.py,sha256=wLlTT7sqUffKHwOrMG2VV5SktQkkP54l8taW1Fq0mh0,13392
141
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py,sha256=1389z4NbTKIHYGr-FB-fV1YP1Gcfta0tOu60DwfNVlI,8452
141
+ model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py,sha256=kz1Xg2OMNXyRbCW3K-wfZpbv6jmLShJjHYUoziOUNv4,8496
142
142
  model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py,sha256=dWJpVfomF4Ppeeor3VzS23TXHyBm85QI7snyLOYP_ko,9972
143
143
  model_compression_toolkit/core/common/substitutions/linear_collapsing.py,sha256=iEtzbWCDXP6EDkTZCtREQ0rpMxhQ2kM9zlcP_0KLq9I,12367
144
144
  model_compression_toolkit/core/common/substitutions/linear_collapsing_substitution.py,sha256=uoauhmncQqUBNvD-qCLIXsIbl_IzrbxSKdxiMig-5W4,2406
145
145
  model_compression_toolkit/core/common/substitutions/remove_identity.py,sha256=TKU1TIU52UIkVnl0EZvWnDhLV9nIVZ4hqi-w1i4NXMk,2637
146
146
  model_compression_toolkit/core/common/substitutions/residual_collapsing.py,sha256=N82mso5j3EJQlKt9EMHjjEJ67FmdGQeCfN8U5grOFXo,4830
147
147
  model_compression_toolkit/core/common/substitutions/scale_equalization.py,sha256=p57u25qdW2pimxzGwgMXEBV4S-LzXuTVAlIM7830WfU,10966
148
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py,sha256=oiiN16OqDrax4FPP5VeyTz0rhb0-eZJACKznTBlKkio,30013
148
+ model_compression_toolkit/core/common/substitutions/shift_negative_activation.py,sha256=1utreR5CkJYwaJS1LOCQi_EYkedsnxXzyJCnJ-ZeqQ0,30057
149
149
  model_compression_toolkit/core/common/substitutions/softmax_shift.py,sha256=R-0ZqhYAuZLEFWHvB2UTPm52L6gWHGdRdEnwGxKSeGI,2625
150
150
  model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py,sha256=w43dRmaG96a8SNECgghxoFCTSoZ-vUb33dXGm2PbomE,4251
151
151
  model_compression_toolkit/core/common/substitutions/weights_activation_split.py,sha256=gt07lXRUvYunJKiwv_w20zfXhcplSW4oT2C1dqiNNXc,4719
@@ -158,7 +158,7 @@ model_compression_toolkit/core/keras/constants.py,sha256=dh4elQWt6Q6NYRht5k5RiiO
158
158
  model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
159
159
  model_compression_toolkit/core/keras/data_util.py,sha256=jm54o-SlI1DJ-sEvRuX9OyLN68tEt0VxcqrdIjR98Ag,8366
160
160
  model_compression_toolkit/core/keras/default_framework_info.py,sha256=IGEHKH3IcmpRfyHuEBJTpEXu2-TDFfqQzpm8kHuj8QY,4974
161
- model_compression_toolkit/core/keras/keras_implementation.py,sha256=e9cVe_TJ_6h4OPgnAVX9T9wgvNDdGh5y_4Hprxa6Ths,32104
161
+ model_compression_toolkit/core/keras/keras_implementation.py,sha256=_15BrSGTRSSp_8ayuo2x-hdKanew1xuIPSumP46IGSA,32545
162
162
  model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
163
163
  model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=HUmzEXDQ8LGX7uOYSRiLZ2TNbYxLX9J9IeAa6QYlifg,3927
164
164
  model_compression_toolkit/core/keras/resource_utilization_data_facade.py,sha256=XBCmUrHy_fNQCfSjnXCpwuEtc7cda4hXySuiIzhFGqc,5696
@@ -198,7 +198,7 @@ model_compression_toolkit/core/keras/hessian/activation_hessian_scores_calculato
198
198
  model_compression_toolkit/core/keras/hessian/hessian_scores_calculator_keras.py,sha256=1o7X9GXSfpEmuB5ee2AaBQ2sN2xzX4-smbrq_0qOGRU,4454
199
199
  model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py,sha256=Rl6NNGkHMV0ioEM5bbM4XX7yHDqG6mMp4ifN2VQBDxE,12168
200
200
  model_compression_toolkit/core/keras/mixed_precision/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
201
- model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py,sha256=aW8wR13fK6P6xzbU9XGU60IO1yYzXSo_Hk4qeq486kg,5137
201
+ model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py,sha256=WFwPtCcXR3qY86OML_jyzasvdd2DGhy4-GveAGpDOt0,5075
202
202
  model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py,sha256=38Lvwux9L35oT6muck6_FH7nDdH2N8_kuGDMj4-QNpE,6647
203
203
  model_compression_toolkit/core/keras/pruning/__init__.py,sha256=3Lkr37Exk9u8811hw8hVqkGcbTQGcLjd3LLuLC3fa_E,698
204
204
  model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py,sha256=EJkblZ4OAjI5l29GKsUraam5Jn58Sogld47_rFFyr3k,12777
@@ -224,10 +224,10 @@ model_compression_toolkit/core/pytorch/constants.py,sha256=Sg0hkUaMe88mI2_pd3Kqh
224
224
  model_compression_toolkit/core/pytorch/data_util.py,sha256=YYbT135HhlTt0q6XdD2JX7AS_L92f_uV2rWq2hsJOCA,6325
225
225
  model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=NLdmiig5a2EBxutJeDHjp8px4g_2EKt3zmntmK-NrT4,4309
226
226
  model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
227
- model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=QBCKYimTbHGFmXGz84Ioni5C9qKntp9FMEBLMUrIKkY,30771
227
+ model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=c_QFo4e7t6b21CDakGhjVpqy5aXFxxqkdJ-s54HEOfs,31207
228
228
  model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=2LDQ7qupglHQ7o1Am7LWdfYVacfQnl-aW2N6l9det1w,3264
229
229
  model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=aIHl-dTAC4ISnWSKLD99c-1W3827vfRGyLjMBib-l3s,5618
230
- model_compression_toolkit/core/pytorch/utils.py,sha256=7VbgcLwtQvdEEc_AJgSOQ3U3KRKCICFPaBirN1fIQxg,3940
230
+ model_compression_toolkit/core/pytorch/utils.py,sha256=xNVE7YMtHupLEimIJcxmfcMGM4XKB9I1v0-K8lDeLB8,3936
231
231
  model_compression_toolkit/core/pytorch/back2framework/__init__.py,sha256=H_WixgN0elVWf3exgGYsi58imPoYDj5eYPeh6x4yfug,813
232
232
  model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,sha256=bwppTPRs6gL96nm7qPiKrNcBj4Krr0yEsOWjRF0aXmQ,2339
233
233
  model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
@@ -257,7 +257,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/remove_
257
257
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py,sha256=hAZXzrEinHa-dJHLj39Hy_9Q-13QyO95rtYVSLrhvT8,4915
258
258
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py,sha256=DcJEIkGvBdIMOelNIwaJUZ5UsAHiGnDJPR20I464vWo,2929
259
259
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py,sha256=XFtU9yuBmoZlX0f0mS6otMPWMk-RcWs94XdvvTNhW8Y,3303
260
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py,sha256=WG7MyYTP5JhMZHYxj4PB-7TTuvUDjFQScG4_Ce1mQDY,12476
260
+ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py,sha256=D1hxN3pZ5-_FLJSS30ZJUo-v8TqUWFcMjhMijFa9aSo,12407
261
261
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py,sha256=3WCLvPyx7tVkM0rwYhYq-gntCzW9R_DcImR1ucKlPac,10772
262
262
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/softmax_shift.py,sha256=05lV4pIL3hJkZl4JQPV4wk_EFD0eYLG5b8cdzvZk4P8,1588
263
263
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/transform_function_call_method.py,sha256=EC9Dvp-_UlpDWnipnf8ds65wh_Y-T8pXAFIwRScWpiY,2044
@@ -268,7 +268,7 @@ model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calcula
268
268
  model_compression_toolkit/core/pytorch/hessian/hessian_scores_calculator_pytorch.py,sha256=8f_XlM8ZFVQPNGr1iECr1hv8QusYDrNU_vTkLQZE9RU,2477
269
269
  model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py,sha256=UzWxWDbr8koKZatEcPn8RCb0Zjm_7fKTvIGb98sp18k,8487
270
270
  model_compression_toolkit/core/pytorch/mixed_precision/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
271
- model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py,sha256=aEjqqj96iK_G_ebXEiJ8kcHLJWs9NFUevSJTipLux1s,4815
271
+ model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py,sha256=mcY_KOQgABIqGIMh0x6mNxaKp7SFNbkEIYavR2X7SQ4,4754
272
272
  model_compression_toolkit/core/pytorch/mixed_precision/configurable_weights_quantizer.py,sha256=zp1Xp75IDf9LN5YGO2UzeDbms_6ICQ_pSE1ORQr-SA8,6281
273
273
  model_compression_toolkit/core/pytorch/pruning/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
274
274
  model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py,sha256=VfEEVwWEXKpVlZFnr7N6mvEjcpq85ROLg05ZvXfD1Pg,14764
@@ -278,7 +278,7 @@ model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py,sha256=uyeBtN
278
278
  model_compression_toolkit/core/pytorch/reader/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
279
279
  model_compression_toolkit/core/pytorch/reader/graph_builders.py,sha256=ZASzWbYYojFYIx-ynqMTkg6mCpTrJg2oWYT-xXki4Mw,19763
280
280
  model_compression_toolkit/core/pytorch/reader/node_holders.py,sha256=7XNc7-l1MZPJGcOESvtAwfIMxrU6kvt3YjF5B7qOqK4,1048
281
- model_compression_toolkit/core/pytorch/reader/reader.py,sha256=Me6nqJpmQBg13dXYiUsmfYr148BYySBZqxHRDba5Tuk,6228
281
+ model_compression_toolkit/core/pytorch/reader/reader.py,sha256=OKlSkGXI-5fKULPEcBnGM6dxwUlWGQEq7ZWdUIhovMU,7440
282
282
  model_compression_toolkit/core/pytorch/statistics_correction/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
283
283
  model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py,sha256=VgU24J3jf7QComHH7jonOXSkg6mO4TOch3uFkOthZvM,3261
284
284
  model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py,sha256=N-9QaEaQYUsIoya9Lc0ZDoMZ0fkiT2gFpOd4zXHKP34,3096
@@ -366,7 +366,7 @@ model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCS
366
366
  model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=2hzWzsbuVd5XcL85NM57YeOyHxRY0qMArKn8NvQ1UWw,7643
367
367
  model_compression_toolkit/gptq/keras/gptq_training.py,sha256=km9tcuugOkRvprGXQZrsq_GPtA3-7Du_-rnbR_Gyups,23228
368
368
  model_compression_toolkit/gptq/keras/graph_info.py,sha256=zwoeHX67nJJ5-zYLjzvMXS9TLsy9BsizARbZiDVjVSA,4473
369
- model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=onQSR1YPjQ6IZdqzeeqFMs3IeBT-nWLbI0yXuOkdpKs,18827
369
+ model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=-goXDz-ACJ4QQH55XTA5n4eGVRXcYAWtqJ4dq6tWq1o,18927
370
370
  model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
371
371
  model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=Rbl9urzkmACvVxICSEyJ02qFOBxWK0UQWtysFJzBVZw,4899
372
372
  model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
@@ -382,7 +382,7 @@ model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=_07Zx_43bnNokwR5S8phI
382
382
  model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
383
383
  model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=WtehnyiYXdUXf8-uNpV0mdsalF7YF7eKnL7tcFrzZoE,19549
384
384
  model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=4mVM-VvnBaA64ACVdOe6wTGHdMSa2UTLIUe7nACLcdo,4008
385
- model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=ciBrdTZqTNFw-5RleEAM6o5GJq5zNhym2GmAmf6U0_I,17179
385
+ model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=V_T3EbFiHO3SkN0kvppsEB9IFW8Q68_GMtUf3xjHnXU,17445
386
386
  model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
387
387
  model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=fKg-PNOhGBiL-4eySS9Fyw0GkA76Pq8jT_HbJuJ8iZU,4143
388
388
  model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
@@ -401,9 +401,9 @@ model_compression_toolkit/pruning/pytorch/pruning_facade.py,sha256=FmUQvT0T247Xa
401
401
  model_compression_toolkit/ptq/__init__.py,sha256=Z_hkmTh7aLFei1DJKV0oNVUbrv_Q_0CTw-qD85Xf8UM,904
402
402
  model_compression_toolkit/ptq/runner.py,sha256=_c1dSjlPPpsx59Vbg1buhG9bZq__OORz1VlPkwjJzoc,2552
403
403
  model_compression_toolkit/ptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
404
- model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=f8sa46eUNHmeaVs3huhZv14DHm5j1X-VInCYdI7nXAY,11567
404
+ model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=QAQ7Pegk26fARDQg2ZNzcYY8aYKmb2hnUY8FiAdcuy0,11824
405
405
  model_compression_toolkit/ptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
406
- model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=p5FwojKaybYdsOUVI7qBNa7R8Nge3EXdu38Jf2jHr84,10021
406
+ model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=Du3CBhp7HXam-GSkv9VPcBoaIBydKjdXsnhFjsemT3E,10282
407
407
  model_compression_toolkit/qat/__init__.py,sha256=AaC4KBha4jDW_tyg2SOxZaKh_idIz0gZtDK3_zxs64E,1241
408
408
  model_compression_toolkit/qat/common/__init__.py,sha256=6tLZ4R4pYP6QVztLVQC_jik2nES3l4uhML0qUxZrezk,829
409
409
  model_compression_toolkit/qat/common/qat_config.py,sha256=xtfVSoyELGXynHNrw86dB9FU3Inu0zwehc3wLrh7JvY,2918
@@ -441,8 +441,8 @@ model_compression_toolkit/target_platform_capabilities/schema/v1.py,sha256=4CGpW
441
441
  model_compression_toolkit/target_platform_capabilities/schema/v2.py,sha256=yg0ZrsaqaS69lmDvxRrz636CRARzx_eZbokTMVHNEXc,4555
442
442
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/__init__.py,sha256=XjNws3zoiJkeH4ixKqrLA5xBvpv5rq31qX7wYQjNpZM,1447
443
443
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2fw.py,sha256=HJ8uc3PFfyxg-WpVXPBg4mGaox8Z9bRqtQNbRfIyAk4,3745
444
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py,sha256=mxc3DBbUi-HDFgSx8Nmnyxr8SIdbx8lmtcRMsQl1BLE,7578
445
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py,sha256=8spnpqxVUv8WF9-PTukOLvJAFiNi01wNowUVIDqSj5I,6321
444
+ model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py,sha256=Ehwpd_sL6zxmJFpJugOdN9uNxNX05nijvOCilNfHnFs,7162
445
+ model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py,sha256=RY7STxTqYG1umFJEbWFRuGXk32eGi1iYuDFKgyVFo-8,6408
446
446
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attribute_filter.py,sha256=jfhszvuD2Fyy6W2KjlLzXBQKFzTqGAaDZeFVr4-ONQw,8776
447
447
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/current_tpc.py,sha256=_kFG0USYa6yzvLsi82_Vusv_KR8Hi7J1u680pPXECuo,2192
448
448
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py,sha256=UKzckLYLdBcFAptyKnVMwpPpfRkmF0SK1Kl0g0eGjQA,9710
@@ -527,7 +527,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
527
527
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
528
528
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
529
529
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
530
- mct_nightly-2.3.0.20250408.522.dist-info/METADATA,sha256=2MJ3qHYwl3E_RVRhIodPb36F7-YSbYHo-rcmjqOYblE,27098
531
- mct_nightly-2.3.0.20250408.522.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
532
- mct_nightly-2.3.0.20250408.522.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
533
- mct_nightly-2.3.0.20250408.522.dist-info/RECORD,,
530
+ mct_nightly-2.3.0.20250410.526.dist-info/METADATA,sha256=lnLhgKNdIybbtKdxLN61inSjhX0CQulfk_9gDUF387o,27148
531
+ mct_nightly-2.3.0.20250410.526.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
532
+ mct_nightly-2.3.0.20250410.526.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
533
+ mct_nightly-2.3.0.20250410.526.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.3.0.20250408.000522"
30
+ __version__ = "2.3.0.20250410.000526"
@@ -93,6 +93,17 @@ class FrameworkImplementation(ABC):
93
93
  raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
94
94
  f'framework\'s to_tensor method.') # pragma: no cover
95
95
 
96
+ @abstractmethod
97
+ def is_tuple_of_tensors(self, obj: Any) -> bool:
98
+ """
99
+ Check if a given object if a tuple of tensors
100
+ :param obj: Object to check its type
101
+ :return: True if obj is a tuple of tensors, False otherwise
102
+ """
103
+ raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
104
+ f'framework\'s is_tuple_of_tensors method.') # pragma: no cover
105
+
106
+
96
107
  @abstractmethod
97
108
  def model_reader(self,
98
109
  model: Any,
@@ -150,7 +150,6 @@ class FusingInfo:
150
150
  """
151
151
  return self.fusing_data
152
152
 
153
-
154
153
  @staticmethod
155
154
  def generate_fused_op_id(nodes: List['BaseNode']) -> str:
156
155
  """
@@ -166,7 +165,7 @@ class FusingInfo:
166
165
  id = FUSED_OP_ID_PREFIX + '_'.join([node.name for node in nodes])
167
166
  return id
168
167
 
169
- def validate(self, graph) -> None:
168
+ def validate(self, graph: 'Graph') -> None:
170
169
  """
171
170
  Validate that the fusing information is consistent with the given graph and generation logic.
172
171
 
@@ -267,7 +266,7 @@ class FusingInfoGenerator:
267
266
  def __init__(self, fusing_patterns):
268
267
  self._fusing_patterns = fusing_patterns
269
268
 
270
- def generate_fusing_info(self, graph) -> FusingInfo:
269
+ def generate_fusing_info(self, graph: 'Graph') -> FusingInfo:
271
270
  """
272
271
  Generate fusing information based on the graph and fusing patterns.
273
272
 
@@ -289,7 +288,7 @@ class FusingInfoGenerator:
289
288
  return FusingInfo(fusing_patterns=self._fusing_patterns)
290
289
 
291
290
  # Find max fusion
292
- max_layers_fusing = 0 if len(self._fusing_patterns) == 0 else max([len(fusing_pattern) for fusing_pattern in self._fusing_patterns])
291
+ max_layers_fusing = max([len(fusing_pattern) for fusing_pattern in self._fusing_patterns])
293
292
 
294
293
  # Travel along the graph to find layers for fusing
295
294
  nodes = graph.get_topo_sorted_nodes()
@@ -331,7 +330,7 @@ def get_valid_fusing_patterns_for_node(fusing_patterns: List[List[Any]],
331
330
  Returns only the fusing patterns where a specific layer (at index idx) matches the given node — either by type or filter params.
332
331
 
333
332
  Args:
334
- fusing_patterns: supported fusings
333
+ fusing_patterns: supported fusing patterns
335
334
  node: node to decide if it can be a part of fusion
336
335
  idx: index of layer in the fusion
337
336
 
@@ -33,6 +33,7 @@ from model_compression_toolkit.core.common.collectors.statistics_collector impor
33
33
  from model_compression_toolkit.core.common.collectors.statistics_collector import scale_statistics, shift_statistics
34
34
  from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection
35
35
  from model_compression_toolkit.core.common.user_info import UserInformation
36
+ from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
36
37
  from model_compression_toolkit.logger import Logger
37
38
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams
38
39
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
@@ -920,7 +921,7 @@ class Graph(nx.MultiDiGraph, GraphSearches):
920
921
  nodes_to_disable = [node for nodes in self.fusing_info.get_all_fused_operations().values() for node in nodes[:-1]]
921
922
  for node in nodes_to_disable:
922
923
  for qc in node.candidates_quantization_cfg:
923
- qc.activation_quantization_cfg.enable_activation_quantization = False
924
+ qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.FLN_QUANT
924
925
 
925
926
  def validate(self):
926
927
  """
@@ -20,7 +20,8 @@ import numpy as np
20
20
 
21
21
  from model_compression_toolkit.constants import WEIGHTS_NBITS_ATTRIBUTE, CORRECTED_BIAS_ATTRIBUTE, \
22
22
  ACTIVATION_N_BITS_ATTRIBUTE, FP32_BYTES_PER_PARAMETER
23
- from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
23
+ from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig, \
24
+ ActivationQuantizationMode
24
25
  from model_compression_toolkit.logger import Logger
25
26
  from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions, \
26
27
  OpQuantizationConfig
@@ -116,33 +117,28 @@ class BaseNode:
116
117
  """
117
118
  return any(isinstance(key, int) for key in self.weights.keys())
118
119
 
120
+ def _is_single_quant_mode(self, q_mode: ActivationQuantizationMode) -> bool:
121
+ """ Check whether all candidates have the same unique quantization mode, and if it is 'q_mode'. """
122
+
123
+ if self.final_activation_quantization_cfg:
124
+ # if we have a final configuration, then we only care to check if it enables activation quantization.
125
+ return self.final_activation_quantization_cfg.quant_mode == q_mode
126
+
127
+ q_modes = {qc.activation_quantization_cfg.quant_mode for qc in self.candidates_quantization_cfg}
128
+ assert len(q_modes) == 1
129
+ return q_modes.pop() == q_mode
130
+
119
131
  def is_activation_quantization_enabled(self) -> bool:
120
132
  """
121
-
122
133
  Returns: Whether node activation quantization is enabled or not.
123
-
124
134
  """
125
- if self.final_activation_quantization_cfg:
126
- # if we have a final configuration, then we only care to check if it enables activation quantization
127
- return self.final_activation_quantization_cfg.enable_activation_quantization
128
-
129
- for qc in self.candidates_quantization_cfg:
130
- assert self.candidates_quantization_cfg[0].activation_quantization_cfg.enable_activation_quantization == \
131
- qc.activation_quantization_cfg.enable_activation_quantization
132
- return self.candidates_quantization_cfg[0].activation_quantization_cfg.enable_activation_quantization
135
+ return self._is_single_quant_mode(ActivationQuantizationMode.QUANT)
133
136
 
134
137
  def is_quantization_preserving(self) -> bool:
135
138
  """
136
139
  Returns: Whether node activation quantization information is preserved from its inputs.
137
140
  """
138
- if self.final_activation_quantization_cfg:
139
- # if we have a final configuration, then we only care to check if it enables activation quantization.
140
- return self.final_activation_quantization_cfg.quantization_preserving
141
-
142
- for qc in self.candidates_quantization_cfg:
143
- assert self.candidates_quantization_cfg[0].activation_quantization_cfg.quantization_preserving == \
144
- qc.activation_quantization_cfg.quantization_preserving
145
- return self.candidates_quantization_cfg[0].activation_quantization_cfg.quantization_preserving
141
+ return self._is_single_quant_mode(ActivationQuantizationMode.PRESERVE_QUANT)
146
142
 
147
143
  def is_weights_quantization_enabled(self, attr_name: str) -> bool:
148
144
  """
@@ -72,8 +72,7 @@ class HessianScoresCalculator(ABC):
72
72
  """
73
73
  raise NotImplemented(f'{self.__class__.__name__} have to implement compute method.') # pragma: no cover
74
74
 
75
- @staticmethod
76
- def unfold_tensors_list(tensors_to_unfold: Any) -> List[Any]:
75
+ def unfold_tensors_list(self, tensors_to_unfold: Any) -> List[Any]:
77
76
  """
78
77
  Unfold (flatten) a nested tensors list.
79
78
  Given a mixed list of single tensors and nested tensor lists,
@@ -85,6 +84,9 @@ class HessianScoresCalculator(ABC):
85
84
  """
86
85
  unfold_tensors = []
87
86
  for tensor in tensors_to_unfold:
87
+ if self.fw_impl.is_tuple_of_tensors(tensor):
88
+ tensor = list(tensor) # converts named tuple to list
89
+
88
90
  if isinstance(tensor, List):
89
91
  unfold_tensors += tensor
90
92
  else:
@@ -31,7 +31,7 @@ from model_compression_toolkit.core.common.graph.virtual_activation_weights_node
31
31
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
32
32
  RUTarget, ResourceUtilization
33
33
  from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeWeightsQuantizationConfig, \
34
- NodeActivationQuantizationConfig, BaseNodeQuantizationConfig
34
+ NodeActivationQuantizationConfig, BaseNodeQuantizationConfig, ActivationQuantizationMode
35
35
  from model_compression_toolkit.core.common.substitutions.virtual_activation_weights_composition import \
36
36
  get_input_activation_if_composable
37
37
 
@@ -710,7 +710,7 @@ class ResourceUtilizationCalculator:
710
710
  """
711
711
  if act_qc:
712
712
  assert bitwidth_mode == BitwidthMode.QCustom
713
- return act_qc.activation_n_bits if act_qc.enable_activation_quantization else FLOAT_BITWIDTH
713
+ return act_qc.activation_n_bits if act_qc.quant_mode == ActivationQuantizationMode.QUANT else FLOAT_BITWIDTH
714
714
 
715
715
  if bitwidth_mode == BitwidthMode.Float or not (n.is_activation_quantization_enabled() or
716
716
  n.is_quantization_preserving()):
@@ -20,6 +20,7 @@ from typing import Callable, Any, List, Tuple
20
20
  from model_compression_toolkit.constants import AXIS
21
21
  from model_compression_toolkit.core import FrameworkInfo, MixedPrecisionQuantizationConfig
22
22
  from model_compression_toolkit.core.common import Graph, BaseNode
23
+ from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
23
24
  from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
24
25
  from model_compression_toolkit.core.common.similarity_analyzer import compute_kl_divergence
25
26
  from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
@@ -207,7 +208,7 @@ class SensitivityEvaluation:
207
208
  if self.disable_activation_for_metric:
208
209
  for n in evaluation_graph.get_topo_sorted_nodes():
209
210
  for c in n.candidates_quantization_cfg:
210
- c.activation_quantization_cfg.enable_activation_quantization = False
211
+ c.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
211
212
 
212
213
  model_mp, _, conf_node2layers = self.fw_impl.model_builder(evaluation_graph,
213
214
  mode=ModelBuilderMode.MIXEDPRECISION,
@@ -29,8 +29,10 @@ class DebugConfig:
29
29
  enabled) or not. Can be used to pinpoint problematic layers in the quantization process.
30
30
  network_editor (List[EditRule]): A list of rules and actions to edit the network for quantization.
31
31
  simulate_scheduler (bool): Simulate scheduler behavior to compute operators' order and cuts.
32
+ bypass (bool): A flag to enable MCT bypass, which skips MCT runner and returns the input model unchanged.
32
33
  """
33
34
 
34
35
  analyze_similarity: bool = False
35
36
  network_editor: List[EditRule] = field(default_factory=list)
36
37
  simulate_scheduler: bool = False
38
+ bypass: bool = False
@@ -15,7 +15,7 @@
15
15
 
16
16
 
17
17
  from typing import Callable, Any, List, Tuple, Union, Dict, TYPE_CHECKING
18
-
18
+ from enum import Enum, auto
19
19
  import numpy as np
20
20
 
21
21
  from model_compression_toolkit.core.common.quantization.quantization_fn_selection import get_weights_quantization_fn
@@ -40,6 +40,14 @@ if TYPE_CHECKING:
40
40
  ##########################################
41
41
 
42
42
 
43
+ class ActivationQuantizationMode(Enum):
44
+ """ An enum defining the output activation quantization mode of a node. """
45
+ QUANT = auto()
46
+ FLN_QUANT = auto()
47
+ PRESERVE_QUANT = auto()
48
+ NO_QUANT = auto()
49
+
50
+
43
51
  class BaseNodeQuantizationConfig(object):
44
52
  """
45
53
  Base class for node quantization configuration
@@ -100,8 +108,14 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
100
108
  self.activation_n_bits = op_cfg.activation_n_bits
101
109
  self.relu_bound_to_power_of_2 = qc.relu_bound_to_power_of_2
102
110
  self.activation_bias_correction_term = None
103
- self.enable_activation_quantization = op_cfg.enable_activation_quantization
104
- self.quantization_preserving = op_cfg.quantization_preserving
111
+ if op_cfg.enable_activation_quantization and op_cfg.quantization_preserving:
112
+ raise ValueError("An OpQuantizationConfig can't have both enable_activation_quantization and quantization_preserving enabled.")
113
+ if op_cfg.enable_activation_quantization:
114
+ self.quant_mode = ActivationQuantizationMode.QUANT
115
+ elif op_cfg.quantization_preserving:
116
+ self.quant_mode = ActivationQuantizationMode.PRESERVE_QUANT
117
+ else:
118
+ self.quant_mode = ActivationQuantizationMode.NO_QUANT
105
119
  self.signedness = op_cfg.signedness
106
120
  self.activation_channel_equalization = qc.activation_channel_equalization
107
121
  self.input_scaling = qc.input_scaling
@@ -113,6 +127,17 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
113
127
  self.shift_negative_threshold_recalculation = qc.shift_negative_threshold_recalculation
114
128
  self.concat_threshold_update = qc.concat_threshold_update
115
129
 
130
+ @property
131
+ def enable_activation_quantization(self):
132
+ return self.quant_mode == ActivationQuantizationMode.QUANT
133
+
134
+ @property
135
+ def quantization_preserving(self):
136
+ return self.quant_mode == ActivationQuantizationMode.PRESERVE_QUANT
137
+
138
+ def fln_quantization(self):
139
+ return self.quant_mode == ActivationQuantizationMode.FLN_QUANT
140
+
116
141
  def quantize_node_output(self,
117
142
  tensors: Any) -> Any:
118
143
  """
@@ -181,7 +206,7 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
181
206
  activation_params: Dictionary that contains weight quantization params.
182
207
 
183
208
  """
184
- assert self.enable_activation_quantization
209
+ assert self.quant_mode == ActivationQuantizationMode.QUANT
185
210
  for param_name, param_value in activation_params.items():
186
211
  self.activation_quantization_params[param_name] = param_value
187
212
 
@@ -203,7 +228,7 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
203
228
  self.activation_error_method == other.activation_error_method and \
204
229
  self.activation_quantization_method == other.activation_quantization_method and \
205
230
  self.activation_n_bits == other.activation_n_bits and \
206
- self.enable_activation_quantization == other.enable_activation_quantization and \
231
+ self.quant_mode == other.quant_mode and \
207
232
  self.activation_channel_equalization == other.activation_channel_equalization and \
208
233
  self.input_scaling == other.input_scaling and \
209
234
  self.min_threshold == other.min_threshold and \
@@ -219,7 +244,7 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
219
244
  self.activation_error_method,
220
245
  self.activation_quantization_method,
221
246
  self.activation_n_bits,
222
- self.enable_activation_quantization,
247
+ self.quant_mode,
223
248
  self.activation_channel_equalization,
224
249
  self.input_scaling,
225
250
  self.min_threshold,
@@ -25,7 +25,8 @@ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
25
25
  from model_compression_toolkit.core.common.graph.base_graph import Graph
26
26
  from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
27
27
  CandidateNodeQuantizationConfig
28
- from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig
28
+ from model_compression_toolkit.core.common.quantization.node_quantization_config import NodeActivationQuantizationConfig, \
29
+ ActivationQuantizationMode
29
30
  from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, \
30
31
  QuantizationErrorMethod
31
32
  from model_compression_toolkit.core.common.quantization.quantization_params_fn_selection import \
@@ -191,8 +192,9 @@ def set_quantization_configs_to_node(node: BaseNode,
191
192
  node.sort_node_candidates(fw_info)
192
193
 
193
194
  for candidate_qc in node.candidates_quantization_cfg:
194
- candidate_qc.activation_quantization_cfg.enable_activation_quantization = \
195
- candidate_qc.activation_quantization_cfg.enable_activation_quantization and node.get_has_activation()
195
+ if candidate_qc.activation_quantization_cfg.quant_mode == ActivationQuantizationMode.QUANT and \
196
+ not node.get_has_activation():
197
+ candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
196
198
 
197
199
 
198
200
  def create_node_activation_qc(qc: QuantizationConfig,
@@ -45,8 +45,7 @@ def get_previous_node_with_activation_quantization(linear_node: BaseNode,
45
45
  activation_quantization_config = prev_node.final_activation_quantization_cfg
46
46
 
47
47
  # Search for node with activation quantization
48
- if (activation_quantization_config.enable_activation_quantization and
49
- not activation_quantization_config.quantization_preserving):
48
+ if activation_quantization_config.enable_activation_quantization:
50
49
  return prev_node
51
50
  else:
52
51
  return get_previous_node_with_activation_quantization(prev_node, graph)
@@ -22,7 +22,8 @@ import numpy as np
22
22
  from model_compression_toolkit.core.common import Graph
23
23
  from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
24
24
  from model_compression_toolkit.core import common
25
- from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
25
+ from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig, \
26
+ ActivationQuantizationMode
26
27
  from model_compression_toolkit.logger import Logger
27
28
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
28
29
  from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
@@ -127,7 +128,7 @@ class BatchNormalizationReconstruction(common.BaseSubstitution):
127
128
  bn_node.candidates_quantization_cfg = copy.deepcopy(source_node.candidates_quantization_cfg)
128
129
 
129
130
  for qc in bn_node.candidates_quantization_cfg:
130
- qc.activation_quantization_cfg.enable_activation_quantization = False
131
+ qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
131
132
  for attr in bn_node.get_node_weights_attributes():
132
133
  if qc.weights_quantization_cfg.has_attribute_config(attr):
133
134
  # we only create a BN layer to collect statistics, so we don't need to quantize anything,
@@ -17,7 +17,8 @@ import numpy as np
17
17
  from typing import List, Tuple, Any, Callable
18
18
 
19
19
  from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
20
- from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
20
+ from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig, \
21
+ ActivationQuantizationMode
21
22
  from model_compression_toolkit.logger import Logger
22
23
  from model_compression_toolkit.core.common import FrameworkInfo, Graph, BaseNode
23
24
  from model_compression_toolkit.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS
@@ -363,7 +364,7 @@ def shift_negative_function(graph: Graph,
363
364
  mixed_precision_enable=core_config.is_mixed_precision_enabled)
364
365
 
365
366
  for candidate_qc in pad_node.candidates_quantization_cfg:
366
- candidate_qc.activation_quantization_cfg.enable_activation_quantization = False
367
+ candidate_qc.activation_quantization_cfg.quant_mode = ActivationQuantizationMode.NO_QUANT
367
368
  for attr in pad_node.get_node_weights_attributes():
368
369
  candidate_qc.weights_quantization_cfg.get_attr_config(attr).enable_weights_quantization = False
369
370
 
@@ -159,6 +159,19 @@ class KerasImplementation(FrameworkImplementation):
159
159
  """
160
160
  return to_tf_tensor(tensor)
161
161
 
162
+ def is_tuple_of_tensors(self, obj: Any) -> bool:
163
+ """
164
+ Check if a given object if a tuple of tensors
165
+ :param obj: Object to check its type
166
+ :return: True if obj is a tuple of tensors, False otherwise
167
+ """
168
+ if not isinstance(obj, tuple):
169
+ return False
170
+ for item in obj:
171
+ if not isinstance(item, tf.Tensor):
172
+ return False
173
+ return True
174
+
162
175
  def model_builder(self,
163
176
  graph: Graph,
164
177
  mode: ModelBuilderMode,
@@ -454,7 +467,7 @@ class KerasImplementation(FrameworkImplementation):
454
467
  return True
455
468
 
456
469
  return any([node.is_match_type(_type) for _type in [Conv2D, DepthwiseConv2D, Conv2DTranspose, Dense,
457
- Concatenate, tf.concat, Add, tf.add]])
470
+ Concatenate, tf.concat, Add, tf.add, tf.stack]])
458
471
 
459
472
  def get_mp_node_distance_fn(self, n: BaseNode,
460
473
  compute_distance_fn: Callable = None,
@@ -64,8 +64,7 @@ class ConfigurableActivationQuantizer(BaseKerasInferableQuantizer):
64
64
  verify_candidates_descending_order(self.node_q_cfg, kernel_attr)
65
65
 
66
66
  for qc in node_q_cfg:
67
- if qc.activation_quantization_cfg.enable_activation_quantization != \
68
- node_q_cfg[0].activation_quantization_cfg.enable_activation_quantization:
67
+ if qc.activation_quantization_cfg.quant_mode != node_q_cfg[0].activation_quantization_cfg.quant_mode:
69
68
  Logger.critical("Unsupported configuration: Mixing candidates with differing activation quantization states (enabled/disabled).") # pragma: no cover
70
69
 
71
70
  self.activation_quantizers = init_activation_quantizers(self.node_q_cfg)
@@ -198,7 +198,6 @@ class ScaledDotProductDecomposition(BaseSubstitution):
198
198
  :param attention_node: the node to replace
199
199
  :return: A graph after the substitution
200
200
  """
201
- print("In scale_dot_product_attention substitution@@@@@@@@")
202
201
  input_nodes = self._get_attention_input_nodes(graph, attention_node)
203
202
  q_node, k_node, v_node = input_nodes["q"], input_nodes["k"], input_nodes["v"]
204
203
  transpose_k_node = self._get_transpose_k_node(attention_node.name, k_node)
@@ -63,8 +63,7 @@ class ConfigurableActivationQuantizer(BasePyTorchInferableQuantizer):
63
63
  verify_candidates_descending_order(self.node_q_cfg, kernel_attr)
64
64
 
65
65
  for qc in self.node_q_cfg:
66
- if qc.activation_quantization_cfg.enable_activation_quantization != \
67
- self.node_q_cfg[0].activation_quantization_cfg.enable_activation_quantization:
66
+ if qc.activation_quantization_cfg.quant_mode != self.node_q_cfg[0].activation_quantization_cfg.quant_mode:
68
67
  Logger.critical("Unsupported configuration: Mixing candidates with differing activation quantization states (enabled/disabled).") # pragma: no cover
69
68
 
70
69
  # Setting layer's activation
@@ -15,12 +15,12 @@
15
15
  import operator
16
16
  from copy import deepcopy
17
17
  from functools import partial
18
- from typing import List, Any, Tuple, Callable, Type, Dict, Generator
18
+ from typing import List, Any, Tuple, Callable, Generator
19
19
 
20
20
  import numpy as np
21
21
  import torch
22
22
  from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder
23
- from torch import sigmoid, softmax, add, cat, argmax, concat, concatenate
23
+ from torch import sigmoid, softmax, add, cat, argmax, concat, concatenate, stack
24
24
  from torch.nn import Conv2d, ConvTranspose2d, Linear
25
25
  from torch.nn import Module, Sigmoid, Softmax
26
26
 
@@ -144,6 +144,19 @@ class PytorchImplementation(FrameworkImplementation):
144
144
  """
145
145
  return to_torch_tensor(tensor)
146
146
 
147
+ def is_tuple_of_tensors(self, obj: Any) -> bool:
148
+ """
149
+ Check if a given object if a tuple of tensors
150
+ :param obj: Object to check its type
151
+ :return: True if obj is a tuple of tensors, False otherwise
152
+ """
153
+ if not isinstance(obj, tuple):
154
+ return False
155
+ for item in obj:
156
+ if not isinstance(item, torch.Tensor):
157
+ return False
158
+ return True
159
+
147
160
  def model_reader(self,
148
161
  module: Module,
149
162
  representative_data_gen: Callable) -> Graph:
@@ -449,7 +462,7 @@ class PytorchImplementation(FrameworkImplementation):
449
462
 
450
463
  return any(node.is_match_type(_type) for _type in [Conv2d, Linear, ConvTranspose2d, Sigmoid, sigmoid, Softmax,
451
464
  softmax, operator.add, add, cat, concat, concatenate,
452
- operator.concat])
465
+ operator.concat, stack])
453
466
 
454
467
  def get_mp_node_distance_fn(self, n: BaseNode,
455
468
  compute_distance_fn: Callable = None,
@@ -13,19 +13,40 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
-
17
- import logging
18
- from typing import Callable, Dict
19
-
20
- import numpy as np
21
16
  import torch
22
- from torch.fx import symbolic_trace
17
+ import logging
18
+ from typing import Callable, Dict, Union, Any
23
19
  from torch.fx.passes.shape_prop import ShapeProp
20
+ from torch.fx import Tracer, GraphModule, symbolic_trace
24
21
 
25
22
  from model_compression_toolkit.logger import Logger
26
23
  from model_compression_toolkit.core.common import Graph
27
24
  from model_compression_toolkit.core.pytorch.reader.graph_builders import edges_builder, nodes_builder
28
25
  from model_compression_toolkit.core.pytorch.utils import set_model
26
+ from sony_custom_layers.pytorch import CustomLayer
27
+
28
+
29
+ def _trace_model(root: Union[torch.nn.Module, Callable[..., Any]]) -> GraphModule:
30
+ """
31
+ Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule``
32
+ constructed by recording operations seen while tracing through ``root``.
33
+ This function replaces torch.fx.symbolic_trace in order to handle custom layers tracing - treating them as graph
34
+ leafs.
35
+ :param root: Module or function to be traced and converted into a Graph representation.
36
+ :return: GraphModule: a Module created from the recorded operations from ``root``.
37
+ """
38
+
39
+ class MCTTracer(Tracer):
40
+ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
41
+ if isinstance(m, CustomLayer):
42
+ return True
43
+ return super().is_leaf_module(m, module_qualified_name)
44
+
45
+ tracer = MCTTracer()
46
+ graph = tracer.trace(root)
47
+ # handling the possibility that the model (root) might be a torch.nn.Module or a function
48
+ model_name = (root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__)
49
+ return GraphModule(tracer.root, graph, model_name)
29
50
 
30
51
 
31
52
  def generate_module_dict(model: torch.nn.Module) -> Dict:
@@ -87,7 +108,7 @@ def fx_graph_module_generation(pytorch_model: torch.nn.Module,
87
108
  set_model(pytorch_model)
88
109
 
89
110
  try:
90
- symbolic_traced = symbolic_trace(pytorch_model)
111
+ symbolic_traced = _trace_model(pytorch_model)
91
112
  except torch.fx.proxy.TraceError as e:
92
113
  Logger.critical(f'Error parsing model with torch.fx\n'
93
114
  f'fx error: {e}')
@@ -15,7 +15,7 @@
15
15
  import torch
16
16
  from torch import Tensor
17
17
  import numpy as np
18
- from typing import Union, Sequence, Optional, List, Tuple
18
+ from typing import Union, Optional, List, Tuple, Any
19
19
 
20
20
  from model_compression_toolkit.core.pytorch.constants import MAX_FLOAT16, MIN_FLOAT16
21
21
  from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device
@@ -112,4 +112,4 @@ def clip_inf_values_float16(tensor: Tensor) -> Tensor:
112
112
  # Replace inf values with max float16 value
113
113
  tensor[inf_mask] = MAX_FLOAT16 * torch.sign(tensor[inf_mask])
114
114
 
115
- return tensor
115
+ return tensor
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  import copy
16
16
 
17
- from typing import Callable, Tuple, Union
17
+ from typing import Callable, Tuple, Union, Optional
18
18
  from packaging import version
19
19
 
20
20
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
@@ -158,7 +158,7 @@ if FOUND_TF:
158
158
  target_resource_utilization: ResourceUtilization = None,
159
159
  core_config: CoreConfig = CoreConfig(),
160
160
  target_platform_capabilities: Union[TargetPlatformCapabilities, str]
161
- = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
161
+ = DEFAULT_KERAS_TPC) -> Tuple[Model, Optional[UserInformation]]:
162
162
  """
163
163
  Quantize a trained Keras model using post-training quantization. The model is quantized using a
164
164
  symmetric constraint quantization thresholds (power of two).
@@ -230,6 +230,10 @@ if FOUND_TF:
230
230
  >>> quantized_model, quantization_info = mct.gptq.keras_gradient_post_training_quantization(model, repr_datagen, gptq_config, target_resource_utilization=ru, core_config=config)
231
231
 
232
232
  """
233
+
234
+ if core_config.debug_config.bypass:
235
+ return in_model, None
236
+
233
237
  KerasModelValidation(model=in_model,
234
238
  fw_info=DEFAULT_KERAS_INFO).validate()
235
239
 
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  import copy
16
- from typing import Callable, Union
16
+ from typing import Callable, Union, Optional, Tuple
17
17
 
18
18
  from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, PYTORCH, GPTQ_HESSIAN_NUM_SAMPLES
19
19
  from model_compression_toolkit.core import CoreConfig
@@ -22,6 +22,7 @@ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quant
22
22
  MixedPrecisionQuantizationConfig
23
23
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
24
24
  ResourceUtilization
25
+ from model_compression_toolkit.core.common.user_info import UserInformation
25
26
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
26
27
  from model_compression_toolkit.core.runner import core_runner
27
28
  from model_compression_toolkit.gptq.common.gptq_config import (
@@ -147,7 +148,8 @@ if FOUND_TORCH:
147
148
  core_config: CoreConfig = CoreConfig(),
148
149
  gptq_config: GradientPTQConfig = None,
149
150
  gptq_representative_data_gen: Callable = None,
150
- target_platform_capabilities: Union[TargetPlatformCapabilities, str] = DEFAULT_PYTORCH_TPC):
151
+ target_platform_capabilities: Union[TargetPlatformCapabilities, str] = DEFAULT_PYTORCH_TPC
152
+ ) -> Tuple[Module, Optional[UserInformation]]:
151
153
  """
152
154
  Quantize a trained Pytorch module using post-training quantization.
153
155
  By default, the module is quantized using a symmetric constraint quantization thresholds
@@ -206,6 +208,9 @@ if FOUND_TORCH:
206
208
 
207
209
  """
208
210
 
211
+ if core_config.debug_config.bypass:
212
+ return model, None
213
+
209
214
  if core_config.is_mixed_precision_enabled: # pragma: no cover
210
215
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
211
216
  Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
@@ -14,11 +14,12 @@
14
14
  # ==============================================================================
15
15
  import copy
16
16
 
17
- from typing import Callable
17
+ from typing import Callable, Tuple, Optional
18
18
 
19
19
  from model_compression_toolkit.core import CoreConfig
20
20
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
21
21
  from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
22
+ from model_compression_toolkit.core.common.user_info import UserInformation
22
23
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
23
24
  from model_compression_toolkit.logger import Logger
24
25
  from model_compression_toolkit.constants import TENSORFLOW
@@ -52,7 +53,8 @@ if FOUND_TF:
52
53
  representative_data_gen: Callable,
53
54
  target_resource_utilization: ResourceUtilization = None,
54
55
  core_config: CoreConfig = CoreConfig(),
55
- target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
56
+ target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC
57
+ ) -> Tuple[Model, Optional[UserInformation]]:
56
58
  """
57
59
  Quantize a trained Keras model using post-training quantization. The model is quantized using a
58
60
  symmetric constraint quantization thresholds (power of two).
@@ -123,6 +125,9 @@ if FOUND_TF:
123
125
 
124
126
  """
125
127
 
128
+ if core_config.debug_config.bypass:
129
+ return in_model, None
130
+
126
131
  fw_info = DEFAULT_KERAS_INFO
127
132
 
128
133
  KerasModelValidation(model=in_model,
@@ -14,8 +14,9 @@
14
14
  # ==============================================================================
15
15
  import copy
16
16
 
17
- from typing import Callable, Union
17
+ from typing import Callable, Union, Tuple, Optional
18
18
 
19
+ from model_compression_toolkit.core.common.user_info import UserInformation
19
20
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
20
21
  from model_compression_toolkit.logger import Logger
21
22
  from model_compression_toolkit.constants import PYTORCH
@@ -49,7 +50,8 @@ if FOUND_TORCH:
49
50
  representative_data_gen: Callable,
50
51
  target_resource_utilization: ResourceUtilization = None,
51
52
  core_config: CoreConfig = CoreConfig(),
52
- target_platform_capabilities: Union[TargetPlatformCapabilities, str] = DEFAULT_PYTORCH_TPC):
53
+ target_platform_capabilities: Union[TargetPlatformCapabilities, str] = DEFAULT_PYTORCH_TPC
54
+ ) -> Tuple[Module, Optional[UserInformation]]:
53
55
  """
54
56
  Quantize a trained Pytorch module using post-training quantization.
55
57
  By default, the module is quantized using a symmetric constraint quantization thresholds
@@ -97,6 +99,9 @@ if FOUND_TORCH:
97
99
 
98
100
  """
99
101
 
102
+ if core_config.debug_config.bypass:
103
+ return in_module, None
104
+
100
105
  fw_info = DEFAULT_PYTORCH_INFO
101
106
 
102
107
  if core_config.is_mixed_precision_enabled:
@@ -19,10 +19,8 @@ from packaging import version
19
19
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams
20
20
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2fw import \
21
21
  AttachTpcToFramework
22
- from model_compression_toolkit.verify_packages import FOUND_SONY_CUSTOM_LAYERS
23
22
 
24
- if FOUND_SONY_CUSTOM_LAYERS:
25
- from sony_custom_layers.keras.object_detection.ssd_post_process import SSDPostProcess
23
+ from sony_custom_layers.keras.object_detection.ssd_post_process import SSDPostProcess
26
24
 
27
25
  if version.parse(tf.__version__) >= version.parse("2.13"):
28
26
  from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \
@@ -102,15 +100,9 @@ class AttachTpcToKeras(AttachTpcToFramework):
102
100
  OperatorSetNames.LOG_SOFTMAX: [tf.nn.log_softmax],
103
101
  OperatorSetNames.ADD_BIAS: [tf.nn.bias_add],
104
102
  OperatorSetNames.L2NORM: [tf.math.l2_normalize],
103
+ OperatorSetNames.SSD_POST_PROCESS: [SSDPostProcess]
105
104
  }
106
105
 
107
- if FOUND_SONY_CUSTOM_LAYERS:
108
- self._opset2layer[OperatorSetNames.SSD_POST_PROCESS] = [SSDPostProcess]
109
- else:
110
- # If Custom layers is not installed then we don't want the user to fail, but just ignore custom layers
111
- # in the initialized framework TPC
112
- self._opset2layer[OperatorSetNames.SSD_POST_PROCESS] = []
113
-
114
106
  self._opset2attr_mapping = {
115
107
  OperatorSetNames.CONV: {
116
108
  KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
@@ -32,6 +32,7 @@ from model_compression_toolkit.target_platform_capabilities.targetplatform2frame
32
32
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2fw import \
33
33
  AttachTpcToFramework
34
34
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attribute_filter import Eq
35
+ from sony_custom_layers.pytorch import MulticlassNMS, MulticlassNMSWithIndices
35
36
 
36
37
 
37
38
  class AttachTpcToPytorch(AttachTpcToFramework):
@@ -97,7 +98,7 @@ class AttachTpcToPytorch(AttachTpcToFramework):
97
98
  OperatorSetNames.L2NORM: [LayerFilterParams(torch.nn.functional.normalize,
98
99
  Eq('p', 2) | Eq('p', None))],
99
100
  OperatorSetNames.SSD_POST_PROCESS: [], # no such operator in pytorch
100
- OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [], # no such operator in pytorch
101
+ OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [MulticlassNMS, MulticlassNMSWithIndices],
101
102
  }
102
103
 
103
104
  pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL),
@@ -30,4 +30,3 @@ FOUND_TORCH = importlib.util.find_spec("torch") is not None
30
30
  FOUND_TORCHVISION = importlib.util.find_spec("torchvision") is not None
31
31
  FOUND_ONNX = importlib.util.find_spec("onnx") is not None
32
32
  FOUND_ONNXRUNTIME = importlib.util.find_spec("onnxruntime") is not None
33
- FOUND_SONY_CUSTOM_LAYERS = importlib.util.find_spec('sony_custom_layers') is not None