mct-nightly 2.3.0.20250409.529__py3-none-any.whl → 2.3.0.20250410.526__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.3.0.20250409.529.dist-info → mct_nightly-2.3.0.20250410.526.dist-info}/METADATA +2 -1
- {mct_nightly-2.3.0.20250409.529.dist-info → mct_nightly-2.3.0.20250410.526.dist-info}/RECORD +21 -21
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +11 -0
- model_compression_toolkit/core/common/hessian/hessian_scores_calculator.py +4 -2
- model_compression_toolkit/core/common/quantization/debug_config.py +2 -0
- model_compression_toolkit/core/keras/keras_implementation.py +14 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py +0 -1
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +16 -3
- model_compression_toolkit/core/pytorch/reader/reader.py +28 -7
- model_compression_toolkit/core/pytorch/utils.py +2 -2
- model_compression_toolkit/gptq/keras/quantization_facade.py +6 -2
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -2
- model_compression_toolkit/ptq/keras/quantization_facade.py +7 -2
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +7 -2
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py +2 -10
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +2 -1
- model_compression_toolkit/verify_packages.py +0 -1
- {mct_nightly-2.3.0.20250409.529.dist-info → mct_nightly-2.3.0.20250410.526.dist-info}/WHEEL +0 -0
- {mct_nightly-2.3.0.20250409.529.dist-info → mct_nightly-2.3.0.20250410.526.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.3.0.20250409.529.dist-info → mct_nightly-2.3.0.20250410.526.dist-info}/top_level.txt +0 -0
{mct_nightly-2.3.0.20250409.529.dist-info → mct_nightly-2.3.0.20250410.526.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: mct-nightly
|
3
|
-
Version: 2.3.0.
|
3
|
+
Version: 2.3.0.20250410.526
|
4
4
|
Summary: A Model Compression Toolkit for neural networks
|
5
5
|
Classifier: Programming Language :: Python :: 3
|
6
6
|
Classifier: License :: OSI Approved :: Apache Software License
|
@@ -22,6 +22,7 @@ Requires-Dist: scipy
|
|
22
22
|
Requires-Dist: protobuf
|
23
23
|
Requires-Dist: mct-quantizers-nightly
|
24
24
|
Requires-Dist: pydantic<2.0
|
25
|
+
Requires-Dist: sony-custom-layers-dev==0.4.0.dev6
|
25
26
|
Dynamic: classifier
|
26
27
|
Dynamic: description
|
27
28
|
Dynamic: description-content-type
|
{mct_nightly-2.3.0.20250409.529.dist-info → mct_nightly-2.3.0.20250410.526.dist-info}/RECORD
RENAMED
@@ -1,10 +1,10 @@
|
|
1
|
-
mct_nightly-2.3.0.
|
2
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
mct_nightly-2.3.0.20250410.526.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
2
|
+
model_compression_toolkit/__init__.py,sha256=hH7K5n9ZDkBMIcujI3umeAh3pUxoyxZu2pWu83zoGgk,1557
|
3
3
|
model_compression_toolkit/constants.py,sha256=2ltuH-gdaLZoZV4CPUgKjC3S9ojz2z4OTVdenyVEypU,3912
|
4
4
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
5
5
|
model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
|
6
6
|
model_compression_toolkit/metadata.py,sha256=x_Bk4VpzILdsFax6--CZ3X18qUTP28sbF_AhoQW8dNc,4003
|
7
|
-
model_compression_toolkit/verify_packages.py,sha256=
|
7
|
+
model_compression_toolkit/verify_packages.py,sha256=l0neIRr8q_QwxmuiTI4vyCMDISDedK0EihjEQUe66tE,1319
|
8
8
|
model_compression_toolkit/core/__init__.py,sha256=8a0wUNBKwTdJGDk_Ho6WQAXjGuCqQZG1FUxxJlAV8L8,2096
|
9
9
|
model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
|
10
10
|
model_compression_toolkit/core/graph_prep_runner.py,sha256=C6eUTd-fcgxk0LUbt51gFZwmyDDDEB8-9Q4kr9ujYvI,11555
|
@@ -12,7 +12,7 @@ model_compression_toolkit/core/quantization_prep_runner.py,sha256=DPevqQ8brkdut8
|
|
12
12
|
model_compression_toolkit/core/runner.py,sha256=_r6cieb7Ur2BeHQK5XxTZHogjyA0utybvIVbH06CBHY,13056
|
13
13
|
model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
|
14
14
|
model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
|
15
|
-
model_compression_toolkit/core/common/framework_implementation.py,sha256=
|
15
|
+
model_compression_toolkit/core/common/framework_implementation.py,sha256=L88uv_sfYM_56FSmxXP--emjv01_lk7IPqOI7QBZEt0,22939
|
16
16
|
model_compression_toolkit/core/common/framework_info.py,sha256=RWeZfQOPiBroU2v4AeZoquVunNtZ4UORjOr2aRAPu8o,6279
|
17
17
|
model_compression_toolkit/core/common/memory_computation.py,sha256=ixoSpV5ZYZGyzhre3kQcvR2sNA8KBsPZ3lgbkDnw9Cs,1205
|
18
18
|
model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3ug7Y9-lLyV99_FoNHxkGZMgcm0Vkpss,1324
|
@@ -51,7 +51,7 @@ model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py,sha256=
|
|
51
51
|
model_compression_toolkit/core/common/hessian/__init__.py,sha256=E7LK3K_1AwMCQokanNc1JODMwUKNOKmwXQiGQ7GO10I,1033
|
52
52
|
model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=8NDC_WLe3ZnY_v3e_Vz_lseF22lrbvhFmArihpeWfuI,14291
|
53
53
|
model_compression_toolkit/core/common/hessian/hessian_info_utils.py,sha256=1axmN0tjJSo_7hUr2d2KMv4y1pBi19cqWSQpi4BbdsA,1458
|
54
|
-
model_compression_toolkit/core/common/hessian/hessian_scores_calculator.py,sha256=
|
54
|
+
model_compression_toolkit/core/common/hessian/hessian_scores_calculator.py,sha256=wqKPfAJgXiV7zD2DufbOU5HcOLi-44Fv9PWdVgFMGaw,4354
|
55
55
|
model_compression_toolkit/core/common/hessian/hessian_scores_request.py,sha256=ZNdwDzW7QF2A-w1Ye4P2xn5erTQnoTXk5z_b17HDGH4,3391
|
56
56
|
model_compression_toolkit/core/common/matchers/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
57
57
|
model_compression_toolkit/core/common/matchers/base_graph_filter.py,sha256=mTk54z0mIbFmPOb4h0xfLtLDookcFyNh8H0pIN5js_M,3091
|
@@ -104,7 +104,7 @@ model_compression_toolkit/core/common/quantization/__init__.py,sha256=sw7LOPN1bM
|
|
104
104
|
model_compression_toolkit/core/common/quantization/bit_width_config.py,sha256=0HA3CIZW-ZrA55ra-yJXRvAYnoR8i1SjpbnMDKcWYNQ,12819
|
105
105
|
model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py,sha256=lyWPvnoX8BmulhLKR20r5gT2_Yan7P40d8EcgDhErPk,4905
|
106
106
|
model_compression_toolkit/core/common/quantization/core_config.py,sha256=yxCzWqldcHoe8GGxrH0tp99bhrc5jDT7SgZftnMUUBE,2374
|
107
|
-
model_compression_toolkit/core/common/quantization/debug_config.py,sha256=
|
107
|
+
model_compression_toolkit/core/common/quantization/debug_config.py,sha256=uH45Uq3Tp9FIyMynex_WY2_y-Kv8LuPw2XXZydnpW5A,1649
|
108
108
|
model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=n2A8pO7_DMMae4o69U0I00iW6mzeRlRfKHDxlQUBBuI,7204
|
109
109
|
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=gL3XUm85FBLvtF60jmWkPxITOBw7cs66scNtC7QHW-M,29471
|
110
110
|
model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=UkSVW7d1OF_Px9gAjsqqK65aYhIBFWaBO-_IH6_AFfg,4403
|
@@ -158,7 +158,7 @@ model_compression_toolkit/core/keras/constants.py,sha256=dh4elQWt6Q6NYRht5k5RiiO
|
|
158
158
|
model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
|
159
159
|
model_compression_toolkit/core/keras/data_util.py,sha256=jm54o-SlI1DJ-sEvRuX9OyLN68tEt0VxcqrdIjR98Ag,8366
|
160
160
|
model_compression_toolkit/core/keras/default_framework_info.py,sha256=IGEHKH3IcmpRfyHuEBJTpEXu2-TDFfqQzpm8kHuj8QY,4974
|
161
|
-
model_compression_toolkit/core/keras/keras_implementation.py,sha256=
|
161
|
+
model_compression_toolkit/core/keras/keras_implementation.py,sha256=_15BrSGTRSSp_8ayuo2x-hdKanew1xuIPSumP46IGSA,32545
|
162
162
|
model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
|
163
163
|
model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=HUmzEXDQ8LGX7uOYSRiLZ2TNbYxLX9J9IeAa6QYlifg,3927
|
164
164
|
model_compression_toolkit/core/keras/resource_utilization_data_facade.py,sha256=XBCmUrHy_fNQCfSjnXCpwuEtc7cda4hXySuiIzhFGqc,5696
|
@@ -224,10 +224,10 @@ model_compression_toolkit/core/pytorch/constants.py,sha256=Sg0hkUaMe88mI2_pd3Kqh
|
|
224
224
|
model_compression_toolkit/core/pytorch/data_util.py,sha256=YYbT135HhlTt0q6XdD2JX7AS_L92f_uV2rWq2hsJOCA,6325
|
225
225
|
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=NLdmiig5a2EBxutJeDHjp8px4g_2EKt3zmntmK-NrT4,4309
|
226
226
|
model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
|
227
|
-
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=
|
227
|
+
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=c_QFo4e7t6b21CDakGhjVpqy5aXFxxqkdJ-s54HEOfs,31207
|
228
228
|
model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=2LDQ7qupglHQ7o1Am7LWdfYVacfQnl-aW2N6l9det1w,3264
|
229
229
|
model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=aIHl-dTAC4ISnWSKLD99c-1W3827vfRGyLjMBib-l3s,5618
|
230
|
-
model_compression_toolkit/core/pytorch/utils.py,sha256=
|
230
|
+
model_compression_toolkit/core/pytorch/utils.py,sha256=xNVE7YMtHupLEimIJcxmfcMGM4XKB9I1v0-K8lDeLB8,3936
|
231
231
|
model_compression_toolkit/core/pytorch/back2framework/__init__.py,sha256=H_WixgN0elVWf3exgGYsi58imPoYDj5eYPeh6x4yfug,813
|
232
232
|
model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,sha256=bwppTPRs6gL96nm7qPiKrNcBj4Krr0yEsOWjRF0aXmQ,2339
|
233
233
|
model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
|
@@ -257,7 +257,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/remove_
|
|
257
257
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py,sha256=hAZXzrEinHa-dJHLj39Hy_9Q-13QyO95rtYVSLrhvT8,4915
|
258
258
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py,sha256=DcJEIkGvBdIMOelNIwaJUZ5UsAHiGnDJPR20I464vWo,2929
|
259
259
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py,sha256=XFtU9yuBmoZlX0f0mS6otMPWMk-RcWs94XdvvTNhW8Y,3303
|
260
|
-
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py,sha256=
|
260
|
+
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py,sha256=D1hxN3pZ5-_FLJSS30ZJUo-v8TqUWFcMjhMijFa9aSo,12407
|
261
261
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py,sha256=3WCLvPyx7tVkM0rwYhYq-gntCzW9R_DcImR1ucKlPac,10772
|
262
262
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/softmax_shift.py,sha256=05lV4pIL3hJkZl4JQPV4wk_EFD0eYLG5b8cdzvZk4P8,1588
|
263
263
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/transform_function_call_method.py,sha256=EC9Dvp-_UlpDWnipnf8ds65wh_Y-T8pXAFIwRScWpiY,2044
|
@@ -278,7 +278,7 @@ model_compression_toolkit/core/pytorch/quantizer/lut_fake_quant.py,sha256=uyeBtN
|
|
278
278
|
model_compression_toolkit/core/pytorch/reader/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
279
279
|
model_compression_toolkit/core/pytorch/reader/graph_builders.py,sha256=ZASzWbYYojFYIx-ynqMTkg6mCpTrJg2oWYT-xXki4Mw,19763
|
280
280
|
model_compression_toolkit/core/pytorch/reader/node_holders.py,sha256=7XNc7-l1MZPJGcOESvtAwfIMxrU6kvt3YjF5B7qOqK4,1048
|
281
|
-
model_compression_toolkit/core/pytorch/reader/reader.py,sha256=
|
281
|
+
model_compression_toolkit/core/pytorch/reader/reader.py,sha256=OKlSkGXI-5fKULPEcBnGM6dxwUlWGQEq7ZWdUIhovMU,7440
|
282
282
|
model_compression_toolkit/core/pytorch/statistics_correction/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
283
283
|
model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py,sha256=VgU24J3jf7QComHH7jonOXSkg6mO4TOch3uFkOthZvM,3261
|
284
284
|
model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py,sha256=N-9QaEaQYUsIoya9Lc0ZDoMZ0fkiT2gFpOd4zXHKP34,3096
|
@@ -366,7 +366,7 @@ model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCS
|
|
366
366
|
model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=2hzWzsbuVd5XcL85NM57YeOyHxRY0qMArKn8NvQ1UWw,7643
|
367
367
|
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=km9tcuugOkRvprGXQZrsq_GPtA3-7Du_-rnbR_Gyups,23228
|
368
368
|
model_compression_toolkit/gptq/keras/graph_info.py,sha256=zwoeHX67nJJ5-zYLjzvMXS9TLsy9BsizARbZiDVjVSA,4473
|
369
|
-
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256
|
369
|
+
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=-goXDz-ACJ4QQH55XTA5n4eGVRXcYAWtqJ4dq6tWq1o,18927
|
370
370
|
model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
|
371
371
|
model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=Rbl9urzkmACvVxICSEyJ02qFOBxWK0UQWtysFJzBVZw,4899
|
372
372
|
model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
|
@@ -382,7 +382,7 @@ model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=_07Zx_43bnNokwR5S8phI
|
|
382
382
|
model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
|
383
383
|
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=WtehnyiYXdUXf8-uNpV0mdsalF7YF7eKnL7tcFrzZoE,19549
|
384
384
|
model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=4mVM-VvnBaA64ACVdOe6wTGHdMSa2UTLIUe7nACLcdo,4008
|
385
|
-
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=
|
385
|
+
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=V_T3EbFiHO3SkN0kvppsEB9IFW8Q68_GMtUf3xjHnXU,17445
|
386
386
|
model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
|
387
387
|
model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=fKg-PNOhGBiL-4eySS9Fyw0GkA76Pq8jT_HbJuJ8iZU,4143
|
388
388
|
model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
|
@@ -401,9 +401,9 @@ model_compression_toolkit/pruning/pytorch/pruning_facade.py,sha256=FmUQvT0T247Xa
|
|
401
401
|
model_compression_toolkit/ptq/__init__.py,sha256=Z_hkmTh7aLFei1DJKV0oNVUbrv_Q_0CTw-qD85Xf8UM,904
|
402
402
|
model_compression_toolkit/ptq/runner.py,sha256=_c1dSjlPPpsx59Vbg1buhG9bZq__OORz1VlPkwjJzoc,2552
|
403
403
|
model_compression_toolkit/ptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
404
|
-
model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=
|
404
|
+
model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=QAQ7Pegk26fARDQg2ZNzcYY8aYKmb2hnUY8FiAdcuy0,11824
|
405
405
|
model_compression_toolkit/ptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
406
|
-
model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=
|
406
|
+
model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=Du3CBhp7HXam-GSkv9VPcBoaIBydKjdXsnhFjsemT3E,10282
|
407
407
|
model_compression_toolkit/qat/__init__.py,sha256=AaC4KBha4jDW_tyg2SOxZaKh_idIz0gZtDK3_zxs64E,1241
|
408
408
|
model_compression_toolkit/qat/common/__init__.py,sha256=6tLZ4R4pYP6QVztLVQC_jik2nES3l4uhML0qUxZrezk,829
|
409
409
|
model_compression_toolkit/qat/common/qat_config.py,sha256=xtfVSoyELGXynHNrw86dB9FU3Inu0zwehc3wLrh7JvY,2918
|
@@ -441,8 +441,8 @@ model_compression_toolkit/target_platform_capabilities/schema/v1.py,sha256=4CGpW
|
|
441
441
|
model_compression_toolkit/target_platform_capabilities/schema/v2.py,sha256=yg0ZrsaqaS69lmDvxRrz636CRARzx_eZbokTMVHNEXc,4555
|
442
442
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/__init__.py,sha256=XjNws3zoiJkeH4ixKqrLA5xBvpv5rq31qX7wYQjNpZM,1447
|
443
443
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2fw.py,sha256=HJ8uc3PFfyxg-WpVXPBg4mGaox8Z9bRqtQNbRfIyAk4,3745
|
444
|
-
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py,sha256=
|
445
|
-
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py,sha256=
|
444
|
+
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py,sha256=Ehwpd_sL6zxmJFpJugOdN9uNxNX05nijvOCilNfHnFs,7162
|
445
|
+
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py,sha256=RY7STxTqYG1umFJEbWFRuGXk32eGi1iYuDFKgyVFo-8,6408
|
446
446
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attribute_filter.py,sha256=jfhszvuD2Fyy6W2KjlLzXBQKFzTqGAaDZeFVr4-ONQw,8776
|
447
447
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/current_tpc.py,sha256=_kFG0USYa6yzvLsi82_Vusv_KR8Hi7J1u680pPXECuo,2192
|
448
448
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py,sha256=UKzckLYLdBcFAptyKnVMwpPpfRkmF0SK1Kl0g0eGjQA,9710
|
@@ -527,7 +527,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
527
527
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
|
528
528
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
529
529
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
|
530
|
-
mct_nightly-2.3.0.
|
531
|
-
mct_nightly-2.3.0.
|
532
|
-
mct_nightly-2.3.0.
|
533
|
-
mct_nightly-2.3.0.
|
530
|
+
mct_nightly-2.3.0.20250410.526.dist-info/METADATA,sha256=lnLhgKNdIybbtKdxLN61inSjhX0CQulfk_9gDUF387o,27148
|
531
|
+
mct_nightly-2.3.0.20250410.526.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
532
|
+
mct_nightly-2.3.0.20250410.526.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
533
|
+
mct_nightly-2.3.0.20250410.526.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.3.0.
|
30
|
+
__version__ = "2.3.0.20250410.000526"
|
@@ -93,6 +93,17 @@ class FrameworkImplementation(ABC):
|
|
93
93
|
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
94
94
|
f'framework\'s to_tensor method.') # pragma: no cover
|
95
95
|
|
96
|
+
@abstractmethod
|
97
|
+
def is_tuple_of_tensors(self, obj: Any) -> bool:
|
98
|
+
"""
|
99
|
+
Check if a given object if a tuple of tensors
|
100
|
+
:param obj: Object to check its type
|
101
|
+
:return: True if obj is a tuple of tensors, False otherwise
|
102
|
+
"""
|
103
|
+
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
|
104
|
+
f'framework\'s is_tuple_of_tensors method.') # pragma: no cover
|
105
|
+
|
106
|
+
|
96
107
|
@abstractmethod
|
97
108
|
def model_reader(self,
|
98
109
|
model: Any,
|
@@ -72,8 +72,7 @@ class HessianScoresCalculator(ABC):
|
|
72
72
|
"""
|
73
73
|
raise NotImplemented(f'{self.__class__.__name__} have to implement compute method.') # pragma: no cover
|
74
74
|
|
75
|
-
|
76
|
-
def unfold_tensors_list(tensors_to_unfold: Any) -> List[Any]:
|
75
|
+
def unfold_tensors_list(self, tensors_to_unfold: Any) -> List[Any]:
|
77
76
|
"""
|
78
77
|
Unfold (flatten) a nested tensors list.
|
79
78
|
Given a mixed list of single tensors and nested tensor lists,
|
@@ -85,6 +84,9 @@ class HessianScoresCalculator(ABC):
|
|
85
84
|
"""
|
86
85
|
unfold_tensors = []
|
87
86
|
for tensor in tensors_to_unfold:
|
87
|
+
if self.fw_impl.is_tuple_of_tensors(tensor):
|
88
|
+
tensor = list(tensor) # converts named tuple to list
|
89
|
+
|
88
90
|
if isinstance(tensor, List):
|
89
91
|
unfold_tensors += tensor
|
90
92
|
else:
|
@@ -29,8 +29,10 @@ class DebugConfig:
|
|
29
29
|
enabled) or not. Can be used to pinpoint problematic layers in the quantization process.
|
30
30
|
network_editor (List[EditRule]): A list of rules and actions to edit the network for quantization.
|
31
31
|
simulate_scheduler (bool): Simulate scheduler behavior to compute operators' order and cuts.
|
32
|
+
bypass (bool): A flag to enable MCT bypass, which skips MCT runner and returns the input model unchanged.
|
32
33
|
"""
|
33
34
|
|
34
35
|
analyze_similarity: bool = False
|
35
36
|
network_editor: List[EditRule] = field(default_factory=list)
|
36
37
|
simulate_scheduler: bool = False
|
38
|
+
bypass: bool = False
|
@@ -159,6 +159,19 @@ class KerasImplementation(FrameworkImplementation):
|
|
159
159
|
"""
|
160
160
|
return to_tf_tensor(tensor)
|
161
161
|
|
162
|
+
def is_tuple_of_tensors(self, obj: Any) -> bool:
|
163
|
+
"""
|
164
|
+
Check if a given object if a tuple of tensors
|
165
|
+
:param obj: Object to check its type
|
166
|
+
:return: True if obj is a tuple of tensors, False otherwise
|
167
|
+
"""
|
168
|
+
if not isinstance(obj, tuple):
|
169
|
+
return False
|
170
|
+
for item in obj:
|
171
|
+
if not isinstance(item, tf.Tensor):
|
172
|
+
return False
|
173
|
+
return True
|
174
|
+
|
162
175
|
def model_builder(self,
|
163
176
|
graph: Graph,
|
164
177
|
mode: ModelBuilderMode,
|
@@ -454,7 +467,7 @@ class KerasImplementation(FrameworkImplementation):
|
|
454
467
|
return True
|
455
468
|
|
456
469
|
return any([node.is_match_type(_type) for _type in [Conv2D, DepthwiseConv2D, Conv2DTranspose, Dense,
|
457
|
-
Concatenate, tf.concat, Add, tf.add]])
|
470
|
+
Concatenate, tf.concat, Add, tf.add, tf.stack]])
|
458
471
|
|
459
472
|
def get_mp_node_distance_fn(self, n: BaseNode,
|
460
473
|
compute_distance_fn: Callable = None,
|
@@ -198,7 +198,6 @@ class ScaledDotProductDecomposition(BaseSubstitution):
|
|
198
198
|
:param attention_node: the node to replace
|
199
199
|
:return: A graph after the substitution
|
200
200
|
"""
|
201
|
-
print("In scale_dot_product_attention substitution@@@@@@@@")
|
202
201
|
input_nodes = self._get_attention_input_nodes(graph, attention_node)
|
203
202
|
q_node, k_node, v_node = input_nodes["q"], input_nodes["k"], input_nodes["v"]
|
204
203
|
transpose_k_node = self._get_transpose_k_node(attention_node.name, k_node)
|
@@ -15,12 +15,12 @@
|
|
15
15
|
import operator
|
16
16
|
from copy import deepcopy
|
17
17
|
from functools import partial
|
18
|
-
from typing import List, Any, Tuple, Callable,
|
18
|
+
from typing import List, Any, Tuple, Callable, Generator
|
19
19
|
|
20
20
|
import numpy as np
|
21
21
|
import torch
|
22
22
|
from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder
|
23
|
-
from torch import sigmoid, softmax, add, cat, argmax, concat, concatenate
|
23
|
+
from torch import sigmoid, softmax, add, cat, argmax, concat, concatenate, stack
|
24
24
|
from torch.nn import Conv2d, ConvTranspose2d, Linear
|
25
25
|
from torch.nn import Module, Sigmoid, Softmax
|
26
26
|
|
@@ -144,6 +144,19 @@ class PytorchImplementation(FrameworkImplementation):
|
|
144
144
|
"""
|
145
145
|
return to_torch_tensor(tensor)
|
146
146
|
|
147
|
+
def is_tuple_of_tensors(self, obj: Any) -> bool:
|
148
|
+
"""
|
149
|
+
Check if a given object if a tuple of tensors
|
150
|
+
:param obj: Object to check its type
|
151
|
+
:return: True if obj is a tuple of tensors, False otherwise
|
152
|
+
"""
|
153
|
+
if not isinstance(obj, tuple):
|
154
|
+
return False
|
155
|
+
for item in obj:
|
156
|
+
if not isinstance(item, torch.Tensor):
|
157
|
+
return False
|
158
|
+
return True
|
159
|
+
|
147
160
|
def model_reader(self,
|
148
161
|
module: Module,
|
149
162
|
representative_data_gen: Callable) -> Graph:
|
@@ -449,7 +462,7 @@ class PytorchImplementation(FrameworkImplementation):
|
|
449
462
|
|
450
463
|
return any(node.is_match_type(_type) for _type in [Conv2d, Linear, ConvTranspose2d, Sigmoid, sigmoid, Softmax,
|
451
464
|
softmax, operator.add, add, cat, concat, concatenate,
|
452
|
-
operator.concat])
|
465
|
+
operator.concat, stack])
|
453
466
|
|
454
467
|
def get_mp_node_distance_fn(self, n: BaseNode,
|
455
468
|
compute_distance_fn: Callable = None,
|
@@ -13,19 +13,40 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
|
17
|
-
import logging
|
18
|
-
from typing import Callable, Dict
|
19
|
-
|
20
|
-
import numpy as np
|
21
16
|
import torch
|
22
|
-
|
17
|
+
import logging
|
18
|
+
from typing import Callable, Dict, Union, Any
|
23
19
|
from torch.fx.passes.shape_prop import ShapeProp
|
20
|
+
from torch.fx import Tracer, GraphModule, symbolic_trace
|
24
21
|
|
25
22
|
from model_compression_toolkit.logger import Logger
|
26
23
|
from model_compression_toolkit.core.common import Graph
|
27
24
|
from model_compression_toolkit.core.pytorch.reader.graph_builders import edges_builder, nodes_builder
|
28
25
|
from model_compression_toolkit.core.pytorch.utils import set_model
|
26
|
+
from sony_custom_layers.pytorch import CustomLayer
|
27
|
+
|
28
|
+
|
29
|
+
def _trace_model(root: Union[torch.nn.Module, Callable[..., Any]]) -> GraphModule:
|
30
|
+
"""
|
31
|
+
Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule``
|
32
|
+
constructed by recording operations seen while tracing through ``root``.
|
33
|
+
This function replaces torch.fx.symbolic_trace in order to handle custom layers tracing - treating them as graph
|
34
|
+
leafs.
|
35
|
+
:param root: Module or function to be traced and converted into a Graph representation.
|
36
|
+
:return: GraphModule: a Module created from the recorded operations from ``root``.
|
37
|
+
"""
|
38
|
+
|
39
|
+
class MCTTracer(Tracer):
|
40
|
+
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
|
41
|
+
if isinstance(m, CustomLayer):
|
42
|
+
return True
|
43
|
+
return super().is_leaf_module(m, module_qualified_name)
|
44
|
+
|
45
|
+
tracer = MCTTracer()
|
46
|
+
graph = tracer.trace(root)
|
47
|
+
# handling the possibility that the model (root) might be a torch.nn.Module or a function
|
48
|
+
model_name = (root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__)
|
49
|
+
return GraphModule(tracer.root, graph, model_name)
|
29
50
|
|
30
51
|
|
31
52
|
def generate_module_dict(model: torch.nn.Module) -> Dict:
|
@@ -87,7 +108,7 @@ def fx_graph_module_generation(pytorch_model: torch.nn.Module,
|
|
87
108
|
set_model(pytorch_model)
|
88
109
|
|
89
110
|
try:
|
90
|
-
symbolic_traced =
|
111
|
+
symbolic_traced = _trace_model(pytorch_model)
|
91
112
|
except torch.fx.proxy.TraceError as e:
|
92
113
|
Logger.critical(f'Error parsing model with torch.fx\n'
|
93
114
|
f'fx error: {e}')
|
@@ -15,7 +15,7 @@
|
|
15
15
|
import torch
|
16
16
|
from torch import Tensor
|
17
17
|
import numpy as np
|
18
|
-
from typing import Union,
|
18
|
+
from typing import Union, Optional, List, Tuple, Any
|
19
19
|
|
20
20
|
from model_compression_toolkit.core.pytorch.constants import MAX_FLOAT16, MIN_FLOAT16
|
21
21
|
from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device
|
@@ -112,4 +112,4 @@ def clip_inf_values_float16(tensor: Tensor) -> Tensor:
|
|
112
112
|
# Replace inf values with max float16 value
|
113
113
|
tensor[inf_mask] = MAX_FLOAT16 * torch.sign(tensor[inf_mask])
|
114
114
|
|
115
|
-
return tensor
|
115
|
+
return tensor
|
@@ -14,7 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
import copy
|
16
16
|
|
17
|
-
from typing import Callable, Tuple, Union
|
17
|
+
from typing import Callable, Tuple, Union, Optional
|
18
18
|
from packaging import version
|
19
19
|
|
20
20
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
@@ -158,7 +158,7 @@ if FOUND_TF:
|
|
158
158
|
target_resource_utilization: ResourceUtilization = None,
|
159
159
|
core_config: CoreConfig = CoreConfig(),
|
160
160
|
target_platform_capabilities: Union[TargetPlatformCapabilities, str]
|
161
|
-
= DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
|
161
|
+
= DEFAULT_KERAS_TPC) -> Tuple[Model, Optional[UserInformation]]:
|
162
162
|
"""
|
163
163
|
Quantize a trained Keras model using post-training quantization. The model is quantized using a
|
164
164
|
symmetric constraint quantization thresholds (power of two).
|
@@ -230,6 +230,10 @@ if FOUND_TF:
|
|
230
230
|
>>> quantized_model, quantization_info = mct.gptq.keras_gradient_post_training_quantization(model, repr_datagen, gptq_config, target_resource_utilization=ru, core_config=config)
|
231
231
|
|
232
232
|
"""
|
233
|
+
|
234
|
+
if core_config.debug_config.bypass:
|
235
|
+
return in_model, None
|
236
|
+
|
233
237
|
KerasModelValidation(model=in_model,
|
234
238
|
fw_info=DEFAULT_KERAS_INFO).validate()
|
235
239
|
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
import copy
|
16
|
-
from typing import Callable, Union
|
16
|
+
from typing import Callable, Union, Optional, Tuple
|
17
17
|
|
18
18
|
from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE, PYTORCH, GPTQ_HESSIAN_NUM_SAMPLES
|
19
19
|
from model_compression_toolkit.core import CoreConfig
|
@@ -22,6 +22,7 @@ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quant
|
|
22
22
|
MixedPrecisionQuantizationConfig
|
23
23
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import \
|
24
24
|
ResourceUtilization
|
25
|
+
from model_compression_toolkit.core.common.user_info import UserInformation
|
25
26
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
26
27
|
from model_compression_toolkit.core.runner import core_runner
|
27
28
|
from model_compression_toolkit.gptq.common.gptq_config import (
|
@@ -147,7 +148,8 @@ if FOUND_TORCH:
|
|
147
148
|
core_config: CoreConfig = CoreConfig(),
|
148
149
|
gptq_config: GradientPTQConfig = None,
|
149
150
|
gptq_representative_data_gen: Callable = None,
|
150
|
-
target_platform_capabilities: Union[TargetPlatformCapabilities, str] = DEFAULT_PYTORCH_TPC
|
151
|
+
target_platform_capabilities: Union[TargetPlatformCapabilities, str] = DEFAULT_PYTORCH_TPC
|
152
|
+
) -> Tuple[Module, Optional[UserInformation]]:
|
151
153
|
"""
|
152
154
|
Quantize a trained Pytorch module using post-training quantization.
|
153
155
|
By default, the module is quantized using a symmetric constraint quantization thresholds
|
@@ -206,6 +208,9 @@ if FOUND_TORCH:
|
|
206
208
|
|
207
209
|
"""
|
208
210
|
|
211
|
+
if core_config.debug_config.bypass:
|
212
|
+
return model, None
|
213
|
+
|
209
214
|
if core_config.is_mixed_precision_enabled: # pragma: no cover
|
210
215
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
211
216
|
Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
|
@@ -14,11 +14,12 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
import copy
|
16
16
|
|
17
|
-
from typing import Callable
|
17
|
+
from typing import Callable, Tuple, Optional
|
18
18
|
|
19
19
|
from model_compression_toolkit.core import CoreConfig
|
20
20
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
21
21
|
from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
|
22
|
+
from model_compression_toolkit.core.common.user_info import UserInformation
|
22
23
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
23
24
|
from model_compression_toolkit.logger import Logger
|
24
25
|
from model_compression_toolkit.constants import TENSORFLOW
|
@@ -52,7 +53,8 @@ if FOUND_TF:
|
|
52
53
|
representative_data_gen: Callable,
|
53
54
|
target_resource_utilization: ResourceUtilization = None,
|
54
55
|
core_config: CoreConfig = CoreConfig(),
|
55
|
-
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC
|
56
|
+
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC
|
57
|
+
) -> Tuple[Model, Optional[UserInformation]]:
|
56
58
|
"""
|
57
59
|
Quantize a trained Keras model using post-training quantization. The model is quantized using a
|
58
60
|
symmetric constraint quantization thresholds (power of two).
|
@@ -123,6 +125,9 @@ if FOUND_TF:
|
|
123
125
|
|
124
126
|
"""
|
125
127
|
|
128
|
+
if core_config.debug_config.bypass:
|
129
|
+
return in_model, None
|
130
|
+
|
126
131
|
fw_info = DEFAULT_KERAS_INFO
|
127
132
|
|
128
133
|
KerasModelValidation(model=in_model,
|
@@ -14,8 +14,9 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
import copy
|
16
16
|
|
17
|
-
from typing import Callable, Union
|
17
|
+
from typing import Callable, Union, Tuple, Optional
|
18
18
|
|
19
|
+
from model_compression_toolkit.core.common.user_info import UserInformation
|
19
20
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
20
21
|
from model_compression_toolkit.logger import Logger
|
21
22
|
from model_compression_toolkit.constants import PYTORCH
|
@@ -49,7 +50,8 @@ if FOUND_TORCH:
|
|
49
50
|
representative_data_gen: Callable,
|
50
51
|
target_resource_utilization: ResourceUtilization = None,
|
51
52
|
core_config: CoreConfig = CoreConfig(),
|
52
|
-
target_platform_capabilities: Union[TargetPlatformCapabilities, str] = DEFAULT_PYTORCH_TPC
|
53
|
+
target_platform_capabilities: Union[TargetPlatformCapabilities, str] = DEFAULT_PYTORCH_TPC
|
54
|
+
) -> Tuple[Module, Optional[UserInformation]]:
|
53
55
|
"""
|
54
56
|
Quantize a trained Pytorch module using post-training quantization.
|
55
57
|
By default, the module is quantized using a symmetric constraint quantization thresholds
|
@@ -97,6 +99,9 @@ if FOUND_TORCH:
|
|
97
99
|
|
98
100
|
"""
|
99
101
|
|
102
|
+
if core_config.debug_config.bypass:
|
103
|
+
return in_module, None
|
104
|
+
|
100
105
|
fw_info = DEFAULT_PYTORCH_INFO
|
101
106
|
|
102
107
|
if core_config.is_mixed_precision_enabled:
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py
CHANGED
@@ -19,10 +19,8 @@ from packaging import version
|
|
19
19
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams
|
20
20
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2fw import \
|
21
21
|
AttachTpcToFramework
|
22
|
-
from model_compression_toolkit.verify_packages import FOUND_SONY_CUSTOM_LAYERS
|
23
22
|
|
24
|
-
|
25
|
-
from sony_custom_layers.keras.object_detection.ssd_post_process import SSDPostProcess
|
23
|
+
from sony_custom_layers.keras.object_detection.ssd_post_process import SSDPostProcess
|
26
24
|
|
27
25
|
if version.parse(tf.__version__) >= version.parse("2.13"):
|
28
26
|
from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \
|
@@ -102,15 +100,9 @@ class AttachTpcToKeras(AttachTpcToFramework):
|
|
102
100
|
OperatorSetNames.LOG_SOFTMAX: [tf.nn.log_softmax],
|
103
101
|
OperatorSetNames.ADD_BIAS: [tf.nn.bias_add],
|
104
102
|
OperatorSetNames.L2NORM: [tf.math.l2_normalize],
|
103
|
+
OperatorSetNames.SSD_POST_PROCESS: [SSDPostProcess]
|
105
104
|
}
|
106
105
|
|
107
|
-
if FOUND_SONY_CUSTOM_LAYERS:
|
108
|
-
self._opset2layer[OperatorSetNames.SSD_POST_PROCESS] = [SSDPostProcess]
|
109
|
-
else:
|
110
|
-
# If Custom layers is not installed then we don't want the user to fail, but just ignore custom layers
|
111
|
-
# in the initialized framework TPC
|
112
|
-
self._opset2layer[OperatorSetNames.SSD_POST_PROCESS] = []
|
113
|
-
|
114
106
|
self._opset2attr_mapping = {
|
115
107
|
OperatorSetNames.CONV: {
|
116
108
|
KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py
CHANGED
@@ -32,6 +32,7 @@ from model_compression_toolkit.target_platform_capabilities.targetplatform2frame
|
|
32
32
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2fw import \
|
33
33
|
AttachTpcToFramework
|
34
34
|
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attribute_filter import Eq
|
35
|
+
from sony_custom_layers.pytorch import MulticlassNMS, MulticlassNMSWithIndices
|
35
36
|
|
36
37
|
|
37
38
|
class AttachTpcToPytorch(AttachTpcToFramework):
|
@@ -97,7 +98,7 @@ class AttachTpcToPytorch(AttachTpcToFramework):
|
|
97
98
|
OperatorSetNames.L2NORM: [LayerFilterParams(torch.nn.functional.normalize,
|
98
99
|
Eq('p', 2) | Eq('p', None))],
|
99
100
|
OperatorSetNames.SSD_POST_PROCESS: [], # no such operator in pytorch
|
100
|
-
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [],
|
101
|
+
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [MulticlassNMS, MulticlassNMSWithIndices],
|
101
102
|
}
|
102
103
|
|
103
104
|
pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL),
|
@@ -30,4 +30,3 @@ FOUND_TORCH = importlib.util.find_spec("torch") is not None
|
|
30
30
|
FOUND_TORCHVISION = importlib.util.find_spec("torchvision") is not None
|
31
31
|
FOUND_ONNX = importlib.util.find_spec("onnx") is not None
|
32
32
|
FOUND_ONNXRUNTIME = importlib.util.find_spec("onnxruntime") is not None
|
33
|
-
FOUND_SONY_CUSTOM_LAYERS = importlib.util.find_spec('sony_custom_layers') is not None
|
File without changes
|
File without changes
|
{mct_nightly-2.3.0.20250409.529.dist-info → mct_nightly-2.3.0.20250410.526.dist-info}/top_level.txt
RENAMED
File without changes
|