mct-nightly 1.11.0.20240313.post405__py3-none-any.whl → 1.11.0.20240315.post349__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.11.0.20240313.post405.dist-info → mct_nightly-1.11.0.20240315.post349.dist-info}/METADATA +6 -6
- {mct_nightly-1.11.0.20240313.post405.dist-info → mct_nightly-1.11.0.20240315.post349.dist-info}/RECORD +16 -16
- model_compression_toolkit/core/common/fusion/layer_fusing.py +8 -6
- model_compression_toolkit/core/common/graph/base_node.py +15 -2
- model_compression_toolkit/core/common/graph/graph_matchers.py +2 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +4 -2
- model_compression_toolkit/core/keras/kpi_data_facade.py +1 -3
- model_compression_toolkit/core/keras/reader/node_builder.py +22 -9
- model_compression_toolkit/core/pytorch/kpi_data_facade.py +1 -3
- model_compression_toolkit/core/quantization_prep_runner.py +1 -1
- model_compression_toolkit/gptq/keras/quantization_facade.py +5 -8
- model_compression_toolkit/qat/keras/quantization_facade.py +6 -8
- model_compression_toolkit/qat/pytorch/quantization_facade.py +3 -5
- {mct_nightly-1.11.0.20240313.post405.dist-info → mct_nightly-1.11.0.20240315.post349.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.11.0.20240313.post405.dist-info → mct_nightly-1.11.0.20240315.post349.dist-info}/WHEEL +0 -0
- {mct_nightly-1.11.0.20240313.post405.dist-info → mct_nightly-1.11.0.20240315.post349.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: mct-nightly
|
|
3
|
-
Version: 1.11.0.
|
|
3
|
+
Version: 1.11.0.20240315.post349
|
|
4
4
|
Summary: A Model Compression Toolkit for neural networks
|
|
5
5
|
Home-page: UNKNOWN
|
|
6
6
|
License: UNKNOWN
|
|
@@ -85,11 +85,11 @@ Currently, MCT is being tested on various Python, Pytorch and TensorFlow version
|
|
|
85
85
|
| Python 3.11 | | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch20.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch21.yml) |
|
|
86
86
|
|
|
87
87
|
|
|
88
|
-
| | TensorFlow 2.12 | TensorFlow 2.13 | TensorFlow 2.14 |
|
|
89
|
-
|
|
90
|
-
| Python 3.9 | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras212.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras213.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml) |
|
|
91
|
-
| Python 3.10 | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras212.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras213.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml) |
|
|
92
|
-
| Python 3.11 | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras212.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras213.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml) |
|
|
88
|
+
| | TensorFlow 2.12 | TensorFlow 2.13 | TensorFlow 2.14 | TensorFlow 2.15 |
|
|
89
|
+
|-------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
|
90
|
+
| Python 3.9 | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras212.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras213.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras215.yml) |
|
|
91
|
+
| Python 3.10 | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras212.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras213.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras215.yml) |
|
|
92
|
+
| Python 3.11 | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras212.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras213.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras215.yml) |
|
|
93
93
|
|
|
94
94
|
|
|
95
95
|
## Supported Features
|
|
@@ -6,7 +6,7 @@ model_compression_toolkit/core/__init__.py,sha256=P-7OYR4TFYxVV_ZpIJBogkX8bGvXci
|
|
|
6
6
|
model_compression_toolkit/core/analyzer.py,sha256=dbsD61pakp_9JXNyAScLdtJvcXny9jr_cMbET0Bd3Sg,2975
|
|
7
7
|
model_compression_toolkit/core/exporter.py,sha256=Zo_C5GjIzihtJOyGp-xeCVhY_qohkVz_EGyrSZCbWRM,4115
|
|
8
8
|
model_compression_toolkit/core/graph_prep_runner.py,sha256=3xp0WYqyeRdlBkf5R6uD2zWubg_JPttOwS7JRhKykBY,10043
|
|
9
|
-
model_compression_toolkit/core/quantization_prep_runner.py,sha256=
|
|
9
|
+
model_compression_toolkit/core/quantization_prep_runner.py,sha256=hFhDkS8GwzXZ7Ho_9qbbb8DAAWs3OONOfMSD5OU_b0o,6153
|
|
10
10
|
model_compression_toolkit/core/runner.py,sha256=hXnbgP8Q-62Ie4wAq4JXO-2o77uR3le4mHYgFqJOvfc,10928
|
|
11
11
|
model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
|
|
12
12
|
model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
|
|
@@ -29,13 +29,13 @@ model_compression_toolkit/core/common/collectors/mean_collector.py,sha256=mjr3U_
|
|
|
29
29
|
model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py,sha256=5oKsJEKdVmj4C7fKdHhmrFN5k4G2BaFETpmf_xKNs7s,5207
|
|
30
30
|
model_compression_toolkit/core/common/collectors/statistics_collector.py,sha256=vcf7Pk1v09SJC4fbAWf_8AgTktE6tPizJbQpSmocP2U,7930
|
|
31
31
|
model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
|
32
|
-
model_compression_toolkit/core/common/fusion/layer_fusing.py,sha256=
|
|
32
|
+
model_compression_toolkit/core/common/fusion/layer_fusing.py,sha256=lOubqpc18TslhXZijWUJQAa1c3jIB2S-M-5HK78wJPQ,5548
|
|
33
33
|
model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
|
|
34
34
|
model_compression_toolkit/core/common/graph/base_graph.py,sha256=Go3x1Sc0V8p9CNaJhJ6ZTjtSGGjUaKjwdLkxX_8ZFmg,38129
|
|
35
|
-
model_compression_toolkit/core/common/graph/base_node.py,sha256=
|
|
35
|
+
model_compression_toolkit/core/common/graph/base_node.py,sha256=6-ZYY_BlGyG6I_XYlB4AtfYs--_vpoF6hGyEwHFZNfc,28483
|
|
36
36
|
model_compression_toolkit/core/common/graph/edge.py,sha256=K6Wc2hBcIqig5PbbLhbjtTgYtkyZEohfgj4Wn_J5yEA,3733
|
|
37
37
|
model_compression_toolkit/core/common/graph/functional_node.py,sha256=RgwWAoMX7YV5c2gZdTBSX-ziTh3OLbebZXr3jitkxDs,3173
|
|
38
|
-
model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=
|
|
38
|
+
model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
|
|
39
39
|
model_compression_toolkit/core/common/graph/graph_searches.py,sha256=2oKuW6L8hP-oL0lFO9PhQFt9fEFgVJwpc1u4fHExAtE,5128
|
|
40
40
|
model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py,sha256=4OojC7tM7QvP7GETEuIS7PVIfCctDfgzMyfHqa741T4,9789
|
|
41
41
|
model_compression_toolkit/core/common/graph/memory_graph/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
@@ -153,7 +153,7 @@ model_compression_toolkit/core/keras/default_framework_info.py,sha256=Ha4HTHuiw_
|
|
|
153
153
|
model_compression_toolkit/core/keras/keras_implementation.py,sha256=R7gtKur2Ubw9HVBI7ImdyNaP6OxBfNfvwqQ2Ey2wW2E,29398
|
|
154
154
|
model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
|
|
155
155
|
model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=Aqh31wOPaiZcJIOm-uJwzev0eTMdJyXaOk97rs4z7BU,3879
|
|
156
|
-
model_compression_toolkit/core/keras/kpi_data_facade.py,sha256=
|
|
156
|
+
model_compression_toolkit/core/keras/kpi_data_facade.py,sha256=1EY7krgkOdo5zbpeKvPq0kiroU1aHskMVlDnReLdjuM,4501
|
|
157
157
|
model_compression_toolkit/core/keras/tf_tensor_numpy.py,sha256=1kBs9URqZTfmRXAsCqvnekV5bKUL3MyqGbORewLIwu8,2457
|
|
158
158
|
model_compression_toolkit/core/keras/back2framework/__init__.py,sha256=rhIiXg_nBgUZ-baE3M6SzCuQbcnq4iebY1jtJBvKHOM,808
|
|
159
159
|
model_compression_toolkit/core/keras/back2framework/factory_model_builder.py,sha256=GSh1Piz5qpA7IlvHTMqUvPn7WBDa0IHEDZdd_TzY9XA,2226
|
|
@@ -178,7 +178,7 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_re
|
|
|
178
178
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py,sha256=6PnPIC5ax7uTzcoslW7ropIu7vVmo70AD4QYcYnQV20,3176
|
|
179
179
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py,sha256=ryes9y1ie-vjBGso2TeO4EXxVk69Ew3iSAhshPz1Ou4,5542
|
|
180
180
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/separableconv_decomposition.py,sha256=TEaHlIbXj_ZjIdT5TmAICD3WLD3u_7g0fLWQcNzTJuM,7941
|
|
181
|
-
model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py,sha256=
|
|
181
|
+
model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py,sha256=JCK--hQMKzbx4MOQZBPZqK015JWZELUO5YdA30IU4bI,11149
|
|
182
182
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/softmax_shift.py,sha256=Qk5seDALj_th9dHJehY7ynZjvFjVfCv_mJ1enA5hX0c,1623
|
|
183
183
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/virtual_activation_weights_composition.py,sha256=wH9ocMLL725-uUPU-zCxdd8NwT5nyd0ZShmI7iuTwF8,1462
|
|
184
184
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/weights_activation_split.py,sha256=rjIheZW7LbSPv9bzMSmC8wl6UUxaTkd4J2IHinObT-Y,1814
|
|
@@ -198,7 +198,7 @@ model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py,sha256=Up3-sbuA
|
|
|
198
198
|
model_compression_toolkit/core/keras/reader/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
|
199
199
|
model_compression_toolkit/core/keras/reader/common.py,sha256=z0PMoP_HjndN3upYTEQS6yxSGqQsM74V3uRHt_Hx3jw,2490
|
|
200
200
|
model_compression_toolkit/core/keras/reader/connectivity_handler.py,sha256=AgF6qXZOJMeXvc-pBnGY23BJz7wPBx2aTYxHiO8efec,11303
|
|
201
|
-
model_compression_toolkit/core/keras/reader/node_builder.py,sha256=
|
|
201
|
+
model_compression_toolkit/core/keras/reader/node_builder.py,sha256=Zpq7aab68s16UJTpeEt-ybtIcHBwZqz5PoA9h774yAE,9657
|
|
202
202
|
model_compression_toolkit/core/keras/reader/reader.py,sha256=wS9UQ2wJKnkZYe9JHwQp7ygDr6CRlzrxmIyLDv1Qz6U,8109
|
|
203
203
|
model_compression_toolkit/core/keras/reader/nested_model/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
|
204
204
|
model_compression_toolkit/core/keras/reader/nested_model/edges_merger.py,sha256=K6KAH9o8KSG6baLmhKoCrYK-i-wb6gRKiZmoijFqEYA,7906
|
|
@@ -211,7 +211,7 @@ model_compression_toolkit/core/keras/visualization/__init__.py,sha256=mjbqLD-KcG
|
|
|
211
211
|
model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
|
212
212
|
model_compression_toolkit/core/pytorch/constants.py,sha256=NI-J7REuxn06oEIHsmJ4GqtNC3TbV8xlkJjt5Ar-c4U,2626
|
|
213
213
|
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=r1XyzUFvrjGcJHQM5ETLsMZIG2yHCr9HMjqf0ti9inw,4175
|
|
214
|
-
model_compression_toolkit/core/pytorch/kpi_data_facade.py,sha256=
|
|
214
|
+
model_compression_toolkit/core/pytorch/kpi_data_facade.py,sha256=fMUUHOv31FGWy1dUXteWtj6OlVm4QC2mf2H77n7ToLM,4584
|
|
215
215
|
model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=IoMvTch5awAEPvB6Tg6ANhFGXvfSgv7JLsUBlxpMwk4,4330
|
|
216
216
|
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=1uIDT-3wLzQf1FT8fMleyu5w5EYL0n7HoFEG80XDUY8,27082
|
|
217
217
|
model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=n_B4a6FMwM9D2w8kzy3oenBWZgXNZuIZgTJC6JEuTy0,3250
|
|
@@ -341,7 +341,7 @@ model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCS
|
|
|
341
341
|
model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
|
|
342
342
|
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=cASZlTmnth3Vu-7GfmC03FxWSXtpSVhdPKT_twWml68,17949
|
|
343
343
|
model_compression_toolkit/gptq/keras/graph_info.py,sha256=FIGqzJbG6GkdHenvdMu-tGTjp4j9BewdF_spmWCb4Mo,4627
|
|
344
|
-
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=
|
|
344
|
+
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=wRyQrJJ71JwtFoiIdBPDHE0srpUwmL7nqHbXOvjDHFc,13578
|
|
345
345
|
model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
|
|
346
346
|
model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=8NrJBftKFbMAF_jYaAbLP6GBwpCv3Ln1NKURaV75zko,4770
|
|
347
347
|
model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
|
|
@@ -385,7 +385,7 @@ model_compression_toolkit/qat/__init__.py,sha256=kj2qsZh_Ca7PncsHKcaL5EVT2H8g4hY
|
|
|
385
385
|
model_compression_toolkit/qat/common/__init__.py,sha256=6tLZ4R4pYP6QVztLVQC_jik2nES3l4uhML0qUxZrezk,829
|
|
386
386
|
model_compression_toolkit/qat/common/qat_config.py,sha256=zoq0Vb74vCY7WlWD8JH_KPrHDoUHSvMc3gcO53u7L2U,3394
|
|
387
387
|
model_compression_toolkit/qat/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
388
|
-
model_compression_toolkit/qat/keras/quantization_facade.py,sha256=
|
|
388
|
+
model_compression_toolkit/qat/keras/quantization_facade.py,sha256=xH05Ro9aY9HabQo_PztaXw0-D3Cxvl-GYCmDKRjwkuI,16524
|
|
389
389
|
model_compression_toolkit/qat/keras/quantizer/__init__.py,sha256=zmYyCa25_KLCSUCGUDRslh3RCIjcRMxc_oXa54Aui-4,996
|
|
390
390
|
model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py,sha256=gPuIgQb8OafvC3SuA8jNsGoy8S8eTsDCEKuh36WDNss,2104
|
|
391
391
|
model_compression_toolkit/qat/keras/quantizer/quant_utils.py,sha256=cBULOgWUodcBO1lHevZggdTevuDYI6tQceV86U2x6DA,2543
|
|
@@ -397,7 +397,7 @@ model_compression_toolkit/qat/keras/quantizer/ste_rounding/__init__.py,sha256=cc
|
|
|
397
397
|
model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py,sha256=I4KlaGv17k71IyjuSG9M0OlXlD5P0pfvKa6oCyRQ5FE,13517
|
|
398
398
|
model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py,sha256=EED6LfqhX_OhDRJ9e4GwbpgNC9vq7hoXyJS2VPvG2qc,10789
|
|
399
399
|
model_compression_toolkit/qat/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
400
|
-
model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=
|
|
400
|
+
model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=TaciVmT0tQhvfpp7ASxPo-feZWlUNLg4IVvx8Qpe5jA,12963
|
|
401
401
|
model_compression_toolkit/qat/pytorch/quantizer/__init__.py,sha256=xYa4C8pr9cG1f3mQQcBXO_u3IdJN-zl7leZxuXDs86w,1003
|
|
402
402
|
model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py,sha256=FnhuFCuQoSf78FM1z1UZgXXd3k-mKSM7i9dYOuJUmeA,2213
|
|
403
403
|
model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py,sha256=e8Yfqbc552iAiP4Zxbd2ht1A3moRFGnV_KRGDm9Gw_g,5709
|
|
@@ -472,8 +472,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
|
|
|
472
472
|
model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
|
|
473
473
|
model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
|
474
474
|
model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=SbvRlIdE32PEBsINt1bhSqvrKL_zbM9V-aeSkOn-sw4,3083
|
|
475
|
-
mct_nightly-1.11.0.
|
|
476
|
-
mct_nightly-1.11.0.
|
|
477
|
-
mct_nightly-1.11.0.
|
|
478
|
-
mct_nightly-1.11.0.
|
|
479
|
-
mct_nightly-1.11.0.
|
|
475
|
+
mct_nightly-1.11.0.20240315.post349.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
|
476
|
+
mct_nightly-1.11.0.20240315.post349.dist-info/METADATA,sha256=Abw6jCZMqvGv6zOjZJNCIXqCQQlu8_135csHgGOuxHA,18529
|
|
477
|
+
mct_nightly-1.11.0.20240315.post349.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
478
|
+
mct_nightly-1.11.0.20240315.post349.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
|
479
|
+
mct_nightly-1.11.0.20240315.post349.dist-info/RECORD,,
|
|
@@ -31,9 +31,10 @@ def filter_fusing_patterns(fusing_patterns: List[List[Any]], node: BaseNode, idx
|
|
|
31
31
|
fusing_patterns after filtering non-relevant fusions
|
|
32
32
|
"""
|
|
33
33
|
valid_fusing_patterns = []
|
|
34
|
-
for i,fusing_pattern in enumerate(fusing_patterns):
|
|
34
|
+
for i, fusing_pattern in enumerate(fusing_patterns):
|
|
35
35
|
if idx < len(fusing_pattern):
|
|
36
|
-
if (type(fusing_pattern[idx]) == LayerFilterParams and node.is_match_filter_params(fusing_pattern[idx])) or
|
|
36
|
+
if (type(fusing_pattern[idx]) == LayerFilterParams and node.is_match_filter_params(fusing_pattern[idx])) or \
|
|
37
|
+
node.is_match_type(fusing_pattern[idx]):
|
|
37
38
|
valid_fusing_patterns.append(fusing_pattern)
|
|
38
39
|
|
|
39
40
|
# Return only valid patterns for this node
|
|
@@ -44,7 +45,7 @@ def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List[BaseNode]) ->
|
|
|
44
45
|
"""
|
|
45
46
|
Check if the fusion is valid: exist in fusing_patterns
|
|
46
47
|
Args:
|
|
47
|
-
fusing_patterns: supported
|
|
48
|
+
fusing_patterns: supported fusing patterns
|
|
48
49
|
nodes: nodes which are participating in fusion
|
|
49
50
|
Returns:
|
|
50
51
|
whether the fusion in valid
|
|
@@ -56,8 +57,9 @@ def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List[BaseNode]) ->
|
|
|
56
57
|
if fusion_depth != len(fusing_pattern):
|
|
57
58
|
continue
|
|
58
59
|
counter = 0
|
|
59
|
-
for i,layer in enumerate(fusing_pattern):
|
|
60
|
-
if (type(layer) == LayerFilterParams and nodes[i].is_match_filter_params(layer)) or
|
|
60
|
+
for i, layer in enumerate(fusing_pattern):
|
|
61
|
+
if (type(layer) == LayerFilterParams and nodes[i].is_match_filter_params(layer)) or \
|
|
62
|
+
nodes[i].is_match_type(layer):
|
|
61
63
|
counter += 1
|
|
62
64
|
if counter == fusion_depth:
|
|
63
65
|
return True
|
|
@@ -107,7 +109,7 @@ def fusion(graph: Graph, tpc: TargetPlatformCapabilities) -> Graph:
|
|
|
107
109
|
if node in fused_nodes:
|
|
108
110
|
continue
|
|
109
111
|
# Start fusing search
|
|
110
|
-
fusing_nodes = []
|
|
112
|
+
fusing_nodes = [] # nodes that are candidates for participating in fusing
|
|
111
113
|
patterns = copy.deepcopy(fusing_patterns)
|
|
112
114
|
next_nodes = [node]
|
|
113
115
|
for i in range(max_layers_fusing):
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
import copy
|
|
17
|
-
from typing import Dict, Any, Tuple, List
|
|
17
|
+
from typing import Dict, Any, Tuple, List, Type
|
|
18
18
|
|
|
19
19
|
import numpy as np
|
|
20
20
|
|
|
@@ -556,6 +556,19 @@ class BaseNode:
|
|
|
556
556
|
return tpc.layer2qco.get(self.type)
|
|
557
557
|
return tpc.tp_model.default_qco
|
|
558
558
|
|
|
559
|
+
def is_match_type(self, _type: Type) -> bool:
|
|
560
|
+
"""
|
|
561
|
+
Check if input type matches the node type, either in instance type or in type name. Checking the
|
|
562
|
+
name string is required because of function types changes that occurred in TF 2.15.
|
|
563
|
+
|
|
564
|
+
Args:
|
|
565
|
+
_type: other node type
|
|
566
|
+
Returns:
|
|
567
|
+
Whether _type matches the self node type
|
|
568
|
+
|
|
569
|
+
"""
|
|
570
|
+
return _type == self.type or _type.__name__ == self.type.__name__
|
|
571
|
+
|
|
559
572
|
def is_match_filter_params(self, layer_filter_params: LayerFilterParams) -> bool:
|
|
560
573
|
"""
|
|
561
574
|
Check if the node matches a LayerFilterParams according to its
|
|
@@ -572,7 +585,7 @@ class BaseNode:
|
|
|
572
585
|
return False
|
|
573
586
|
|
|
574
587
|
# Check the node has the same type as the layer in LayerFilterParams
|
|
575
|
-
if layer_filter_params.layer
|
|
588
|
+
if not self.is_match_type(layer_filter_params.layer):
|
|
576
589
|
return False
|
|
577
590
|
|
|
578
591
|
# Get attributes from node to filter
|
|
@@ -35,7 +35,7 @@ class NodeOperationMatcher(node_matcher.BaseNodeMatcher):
|
|
|
35
35
|
|
|
36
36
|
self.operation = operation
|
|
37
37
|
|
|
38
|
-
def apply(self, input_node_object:
|
|
38
|
+
def apply(self, input_node_object: BaseNode) -> bool:
|
|
39
39
|
"""
|
|
40
40
|
Check if input_node_object matches the matcher condition.
|
|
41
41
|
|
|
@@ -47,7 +47,7 @@ class NodeOperationMatcher(node_matcher.BaseNodeMatcher):
|
|
|
47
47
|
return nothing.
|
|
48
48
|
"""
|
|
49
49
|
|
|
50
|
-
if input_node_object.
|
|
50
|
+
if input_node_object.is_match_type(self.operation):
|
|
51
51
|
return True
|
|
52
52
|
|
|
53
53
|
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py
CHANGED
|
@@ -109,7 +109,8 @@ def create_add_node(add_value: float,
|
|
|
109
109
|
quantization_attr={},
|
|
110
110
|
layer_class=TFOpLambda,
|
|
111
111
|
op_call_args=[np.array(add_value, dtype=np.float32).reshape([1] * len(input_shape))],
|
|
112
|
-
op_call_kwargs={}
|
|
112
|
+
op_call_kwargs={},
|
|
113
|
+
functional_op=tf.add)
|
|
113
114
|
return add_node
|
|
114
115
|
|
|
115
116
|
|
|
@@ -157,7 +158,8 @@ def create_pad_node(next_node_name: str,
|
|
|
157
158
|
layer_class=TFOpLambda,
|
|
158
159
|
op_call_args=[],
|
|
159
160
|
op_call_kwargs={'paddings': num_elements_to_pad,
|
|
160
|
-
'constant_values': value_to_pad}
|
|
161
|
+
'constant_values': value_to_pad},
|
|
162
|
+
functional_op=tf.pad)
|
|
161
163
|
|
|
162
164
|
return pad_node
|
|
163
165
|
|
|
@@ -36,7 +36,6 @@ if FOUND_TF:
|
|
|
36
36
|
def keras_kpi_data(in_model: Model,
|
|
37
37
|
representative_data_gen: Callable,
|
|
38
38
|
core_config: CoreConfig = CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig()),
|
|
39
|
-
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
|
|
40
39
|
target_platform_capabilities: TargetPlatformCapabilities = KERAS_DEFAULT_TPC) -> KPI:
|
|
41
40
|
"""
|
|
42
41
|
Computes KPI data that can be used to calculate the desired target KPI for mixed-precision quantization.
|
|
@@ -46,7 +45,6 @@ if FOUND_TF:
|
|
|
46
45
|
in_model (Model): Keras model to quantize.
|
|
47
46
|
representative_data_gen (Callable): Dataset used for calibration.
|
|
48
47
|
core_config (CoreConfig): CoreConfig containing parameters for quantization and mixed precision of how the model should be quantized.
|
|
49
|
-
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/keras/default_framework_info.py>`_
|
|
50
48
|
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
|
|
51
49
|
|
|
52
50
|
Returns:
|
|
@@ -82,7 +80,7 @@ if FOUND_TF:
|
|
|
82
80
|
representative_data_gen,
|
|
83
81
|
core_config,
|
|
84
82
|
target_platform_capabilities,
|
|
85
|
-
|
|
83
|
+
DEFAULT_KERAS_INFO,
|
|
86
84
|
fw_impl)
|
|
87
85
|
|
|
88
86
|
else:
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, List, Dict
|
|
16
16
|
|
|
17
17
|
import tensorflow as tf
|
|
18
18
|
from tensorflow.python.util import tf_inspect
|
|
@@ -45,19 +45,32 @@ is_const = lambda x: isinstance(x, (tf.Variable, tf.Tensor, np.ndarray))
|
|
|
45
45
|
is_tensor = lambda x: isinstance(x, KerasTensor)
|
|
46
46
|
|
|
47
47
|
|
|
48
|
-
def
|
|
48
|
+
def get_tf_function_symbols() -> List[str]:
|
|
49
|
+
"""
|
|
50
|
+
Create a list of tf function symbols, as they are created in the TFOpLambda layer. The
|
|
51
|
+
symbols are serializations of the function names.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
A list of TF function symbols,
|
|
55
|
+
"""
|
|
56
|
+
return [TFOpLambda(f).symbol for f in [tf.add, tf.multiply, tf.subtract, tf.divide,
|
|
57
|
+
tf.truediv, tf.pow, tf.matmul]]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def get_kwargs2index(tfoplambda_layer: TFOpLambda) -> Dict[str, int]:
|
|
49
61
|
"""
|
|
50
62
|
Positional weights are saved according to their index in the node's call arguments, so
|
|
51
63
|
need to know the function arguments' names in case the weights are in the kwargs.
|
|
52
64
|
Args:
|
|
53
|
-
|
|
65
|
+
tfoplambda_layer: TFOpLambda layer.
|
|
54
66
|
|
|
55
67
|
Returns:
|
|
56
68
|
A dictionary with argument number and index: {arg_name: arg_index}.
|
|
57
69
|
"""
|
|
58
|
-
if
|
|
59
|
-
|
|
60
|
-
|
|
70
|
+
if tfoplambda_layer.function in [tf.add, tf.subtract, tf.divide, tf.truediv, tf.multiply, tf.pow,
|
|
71
|
+
tf.matmul, tf.image.crop_and_resize, tf.image.combined_non_max_suppression] or \
|
|
72
|
+
tfoplambda_layer.symbol in ['__operators__.add', 'math.add', 'math.multiply', 'linalg.matmul', 'concat']:
|
|
73
|
+
return {arg_name: i for i, arg_name in enumerate(tf_inspect.getfullargspec(tfoplambda_layer.function).args)}
|
|
61
74
|
else:
|
|
62
75
|
return {}
|
|
63
76
|
|
|
@@ -110,7 +123,7 @@ def build_node(node: KerasNode,
|
|
|
110
123
|
# a flag to indicate that.
|
|
111
124
|
inputs_as_list = __is_functional_inputs_a_list(op_call_args)
|
|
112
125
|
|
|
113
|
-
kwarg2index = get_kwargs2index(keras_layer
|
|
126
|
+
kwarg2index = get_kwargs2index(keras_layer)
|
|
114
127
|
|
|
115
128
|
# Functional nodes do not have weights, but may have constants in their call_args and\or
|
|
116
129
|
# call kwargs. Therefore, we extract these constants and save them in the node's weights as
|
|
@@ -122,10 +135,10 @@ def build_node(node: KerasNode,
|
|
|
122
135
|
Logger.error('Functional nodes are not expected to have weights in framework')
|
|
123
136
|
|
|
124
137
|
# read weights from call args
|
|
138
|
+
tf_function_symbols = get_tf_function_symbols()
|
|
125
139
|
for i, arg in enumerate(op_call_args[0] if inputs_as_list else op_call_args):
|
|
126
140
|
if is_const(arg) or (
|
|
127
|
-
keras_layer.
|
|
128
|
-
tf.matmul] and
|
|
141
|
+
keras_layer.symbol in tf_function_symbols and
|
|
129
142
|
isinstance(arg, (tuple, list))):
|
|
130
143
|
weights.update({i: to_numpy(arg, is_single_tensor=True)})
|
|
131
144
|
# remove weights and KerasTensors and weights from op_call_args
|
|
@@ -39,7 +39,6 @@ if FOUND_TORCH:
|
|
|
39
39
|
def pytorch_kpi_data(in_model: Module,
|
|
40
40
|
representative_data_gen: Callable,
|
|
41
41
|
core_config: CoreConfig = CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig()),
|
|
42
|
-
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
|
|
43
42
|
target_platform_capabilities: TargetPlatformCapabilities = PYTORCH_DEFAULT_TPC) -> KPI:
|
|
44
43
|
"""
|
|
45
44
|
Computes KPI data that can be used to calculate the desired target KPI for mixed-precision quantization.
|
|
@@ -49,7 +48,6 @@ if FOUND_TORCH:
|
|
|
49
48
|
in_model (Model): PyTorch model to quantize.
|
|
50
49
|
representative_data_gen (Callable): Dataset used for calibration.
|
|
51
50
|
core_config (CoreConfig): CoreConfig containing parameters for quantization and mixed precision
|
|
52
|
-
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default PyTorch info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/pytorch/default_framework_info.py>`_
|
|
53
51
|
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
|
|
54
52
|
|
|
55
53
|
Returns:
|
|
@@ -85,7 +83,7 @@ if FOUND_TORCH:
|
|
|
85
83
|
representative_data_gen,
|
|
86
84
|
core_config,
|
|
87
85
|
target_platform_capabilities,
|
|
88
|
-
|
|
86
|
+
DEFAULT_PYTORCH_INFO,
|
|
89
87
|
fw_impl)
|
|
90
88
|
|
|
91
89
|
else:
|
|
@@ -68,7 +68,7 @@ def quantization_preparation_runner(graph: Graph,
|
|
|
68
68
|
fw_info,
|
|
69
69
|
core_config.quantization_config) # Mark points for statistics collection
|
|
70
70
|
|
|
71
|
-
for _data in tqdm(representative_data_gen(), "Statistics Collection
|
|
71
|
+
for _data in tqdm(representative_data_gen(), "Statistics Collection"):
|
|
72
72
|
mi.infer(_data)
|
|
73
73
|
|
|
74
74
|
if tb_w is not None:
|
|
@@ -113,12 +113,10 @@ if FOUND_TF:
|
|
|
113
113
|
regularization_factor=regularization_factor)
|
|
114
114
|
|
|
115
115
|
|
|
116
|
-
def keras_gradient_post_training_quantization(in_model: Model,
|
|
117
|
-
representative_data_gen: Callable,
|
|
116
|
+
def keras_gradient_post_training_quantization(in_model: Model, representative_data_gen: Callable,
|
|
118
117
|
gptq_config: GradientPTQConfig,
|
|
119
118
|
gptq_representative_data_gen: Callable = None,
|
|
120
119
|
core_config: CoreConfig = CoreConfig(),
|
|
121
|
-
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
|
|
122
120
|
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
|
|
123
121
|
"""
|
|
124
122
|
Quantize a trained Keras model using post-training quantization. The model is quantized using a
|
|
@@ -142,7 +140,6 @@ if FOUND_TF:
|
|
|
142
140
|
gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
|
|
143
141
|
gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
|
|
144
142
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
|
|
145
|
-
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/keras/default_framework_info.py>`_
|
|
146
143
|
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
|
|
147
144
|
|
|
148
145
|
Returns:
|
|
@@ -192,7 +189,7 @@ if FOUND_TF:
|
|
|
192
189
|
|
|
193
190
|
"""
|
|
194
191
|
KerasModelValidation(model=in_model,
|
|
195
|
-
fw_info=
|
|
192
|
+
fw_info=DEFAULT_KERAS_INFO).validate()
|
|
196
193
|
|
|
197
194
|
if core_config.mixed_precision_enable:
|
|
198
195
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
@@ -200,14 +197,14 @@ if FOUND_TF:
|
|
|
200
197
|
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
|
|
201
198
|
"API, or pass a valid mixed precision configuration.") # pragma: no cover
|
|
202
199
|
|
|
203
|
-
tb_w = init_tensorboard_writer(
|
|
200
|
+
tb_w = init_tensorboard_writer(DEFAULT_KERAS_INFO)
|
|
204
201
|
|
|
205
202
|
fw_impl = GPTQKerasImplemantation()
|
|
206
203
|
|
|
207
204
|
tg, bit_widths_config, hessian_info_service = core_runner(in_model=in_model,
|
|
208
205
|
representative_data_gen=representative_data_gen,
|
|
209
206
|
core_config=core_config,
|
|
210
|
-
fw_info=
|
|
207
|
+
fw_info=DEFAULT_KERAS_INFO,
|
|
211
208
|
fw_impl=fw_impl,
|
|
212
209
|
tpc=target_platform_capabilities,
|
|
213
210
|
tb_w=tb_w)
|
|
@@ -217,7 +214,7 @@ if FOUND_TF:
|
|
|
217
214
|
gptq_config,
|
|
218
215
|
representative_data_gen,
|
|
219
216
|
gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen,
|
|
220
|
-
|
|
217
|
+
DEFAULT_KERAS_INFO,
|
|
221
218
|
fw_impl,
|
|
222
219
|
tb_w,
|
|
223
220
|
hessian_info_service=hessian_info_service)
|
|
@@ -89,7 +89,6 @@ if FOUND_TF:
|
|
|
89
89
|
representative_data_gen: Callable,
|
|
90
90
|
core_config: CoreConfig = CoreConfig(),
|
|
91
91
|
qat_config: QATConfig = QATConfig(),
|
|
92
|
-
fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
|
|
93
92
|
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
|
|
94
93
|
"""
|
|
95
94
|
Prepare a trained Keras model for quantization aware training. First the model quantization is optimized
|
|
@@ -111,7 +110,6 @@ if FOUND_TF:
|
|
|
111
110
|
representative_data_gen (Callable): Dataset used for initial calibration.
|
|
112
111
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
|
|
113
112
|
qat_config (QATConfig): QAT configuration
|
|
114
|
-
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/keras/default_framework_info.py>`_
|
|
115
113
|
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
|
|
116
114
|
|
|
117
115
|
Returns:
|
|
@@ -159,7 +157,7 @@ if FOUND_TF:
|
|
|
159
157
|
Pass the model, the representative dataset generator, the configuration and the target KPI to get a
|
|
160
158
|
quantized model:
|
|
161
159
|
|
|
162
|
-
>>> quantized_model, quantization_info, custom_objects = mct.qat.keras_quantization_aware_training_init_experimental(model, repr_datagen, kpi, core_config=
|
|
160
|
+
>>> quantized_model, quantization_info, custom_objects = mct.qat.keras_quantization_aware_training_init_experimental(model, repr_datagen, kpi, core_config=core_config)
|
|
163
161
|
|
|
164
162
|
Use the quantized model for fine-tuning. For loading the model from file, use the custom_objects dictionary:
|
|
165
163
|
|
|
@@ -174,7 +172,7 @@ if FOUND_TF:
|
|
|
174
172
|
f"project https://github.com/sony/model_optimization")
|
|
175
173
|
|
|
176
174
|
KerasModelValidation(model=in_model,
|
|
177
|
-
fw_info=
|
|
175
|
+
fw_info=DEFAULT_KERAS_INFO).validate()
|
|
178
176
|
|
|
179
177
|
if core_config.mixed_precision_enable:
|
|
180
178
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
@@ -182,7 +180,7 @@ if FOUND_TF:
|
|
|
182
180
|
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization API,"
|
|
183
181
|
"or pass a valid mixed precision configuration.")
|
|
184
182
|
|
|
185
|
-
tb_w = init_tensorboard_writer(
|
|
183
|
+
tb_w = init_tensorboard_writer(DEFAULT_KERAS_INFO)
|
|
186
184
|
|
|
187
185
|
fw_impl = KerasImplementation()
|
|
188
186
|
|
|
@@ -190,16 +188,16 @@ if FOUND_TF:
|
|
|
190
188
|
tg, bit_widths_config, _ = core_runner(in_model=in_model,
|
|
191
189
|
representative_data_gen=representative_data_gen,
|
|
192
190
|
core_config=core_config,
|
|
193
|
-
fw_info=
|
|
191
|
+
fw_info=DEFAULT_KERAS_INFO,
|
|
194
192
|
fw_impl=fw_impl,
|
|
195
193
|
tpc=target_platform_capabilities,
|
|
196
194
|
tb_w=tb_w)
|
|
197
195
|
|
|
198
|
-
tg = ptq_runner(tg, representative_data_gen, core_config,
|
|
196
|
+
tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_KERAS_INFO, fw_impl, tb_w)
|
|
199
197
|
|
|
200
198
|
_qat_wrapper = partial(qat_wrapper, qat_config=qat_config)
|
|
201
199
|
qat_model, user_info = KerasModelBuilder(graph=tg,
|
|
202
|
-
fw_info=
|
|
200
|
+
fw_info=DEFAULT_KERAS_INFO,
|
|
203
201
|
wrapper=_qat_wrapper,
|
|
204
202
|
get_activation_quantizer_holder_fn=partial(get_activation_quantizer_holder,
|
|
205
203
|
qat_config=qat_config)).build_model()
|
|
@@ -77,7 +77,6 @@ if FOUND_TORCH:
|
|
|
77
77
|
representative_data_gen: Callable,
|
|
78
78
|
core_config: CoreConfig = CoreConfig(),
|
|
79
79
|
qat_config: QATConfig = QATConfig(),
|
|
80
|
-
fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
|
|
81
80
|
target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
|
|
82
81
|
"""
|
|
83
82
|
Prepare a trained Pytorch model for quantization aware training. First the model quantization is optimized
|
|
@@ -99,7 +98,6 @@ if FOUND_TORCH:
|
|
|
99
98
|
representative_data_gen (Callable): Dataset used for initial calibration.
|
|
100
99
|
core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
|
|
101
100
|
qat_config (QATConfig): QAT configuration
|
|
102
|
-
fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Pytorch info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/pytorch/default_framework_info.py>`_
|
|
103
101
|
target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Pytorch model according to.
|
|
104
102
|
|
|
105
103
|
Returns:
|
|
@@ -150,7 +148,7 @@ if FOUND_TORCH:
|
|
|
150
148
|
"MixedPrecisionQuantizationConfig. Please use pytorch_post_training_quantization API,"
|
|
151
149
|
"or pass a valid mixed precision configuration.")
|
|
152
150
|
|
|
153
|
-
tb_w = init_tensorboard_writer(
|
|
151
|
+
tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
|
|
154
152
|
fw_impl = PytorchImplementation()
|
|
155
153
|
|
|
156
154
|
# Ignore trace hessian service as we do not use it here
|
|
@@ -162,12 +160,12 @@ if FOUND_TORCH:
|
|
|
162
160
|
tpc=target_platform_capabilities,
|
|
163
161
|
tb_w=tb_w)
|
|
164
162
|
|
|
165
|
-
tg = ptq_runner(tg, representative_data_gen, core_config,
|
|
163
|
+
tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_PYTORCH_INFO, fw_impl, tb_w)
|
|
166
164
|
|
|
167
165
|
_qat_wrapper = partial(qat_wrapper, qat_config=qat_config)
|
|
168
166
|
|
|
169
167
|
qat_model, user_info = PyTorchModelBuilder(graph=tg,
|
|
170
|
-
fw_info=
|
|
168
|
+
fw_info=DEFAULT_PYTORCH_INFO,
|
|
171
169
|
wrapper=_qat_wrapper,
|
|
172
170
|
get_activation_quantizer_holder_fn=partial(
|
|
173
171
|
get_activation_quantizer_holder,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|