mct-nightly 1.11.0.20240313.post405__py3-none-any.whl → 1.11.0.20240315.post349__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 1.11.0.20240313.post405
3
+ Version: 1.11.0.20240315.post349
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -85,11 +85,11 @@ Currently, MCT is being tested on various Python, Pytorch and TensorFlow version
85
85
  | Python 3.11 | | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch20.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch20.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch21.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch21.yml) |
86
86
 
87
87
 
88
- | | TensorFlow 2.12 | TensorFlow 2.13 | TensorFlow 2.14 |
89
- |-------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
90
- | Python 3.9 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml) |
91
- | Python 3.10 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml) |
92
- | Python 3.11 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml) |
88
+ | | TensorFlow 2.12 | TensorFlow 2.13 | TensorFlow 2.14 | TensorFlow 2.15 |
89
+ |-------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
90
+ | Python 3.9 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras215.yml) |
91
+ | Python 3.10 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras215.yml) |
92
+ | Python 3.11 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras215.yml) |
93
93
 
94
94
 
95
95
  ## Supported Features
@@ -6,7 +6,7 @@ model_compression_toolkit/core/__init__.py,sha256=P-7OYR4TFYxVV_ZpIJBogkX8bGvXci
6
6
  model_compression_toolkit/core/analyzer.py,sha256=dbsD61pakp_9JXNyAScLdtJvcXny9jr_cMbET0Bd3Sg,2975
7
7
  model_compression_toolkit/core/exporter.py,sha256=Zo_C5GjIzihtJOyGp-xeCVhY_qohkVz_EGyrSZCbWRM,4115
8
8
  model_compression_toolkit/core/graph_prep_runner.py,sha256=3xp0WYqyeRdlBkf5R6uD2zWubg_JPttOwS7JRhKykBY,10043
9
- model_compression_toolkit/core/quantization_prep_runner.py,sha256=NkCNC6qjzrCnqw0rwPNKjjRbLcXsZRTBhZR_GnZeIM0,6154
9
+ model_compression_toolkit/core/quantization_prep_runner.py,sha256=hFhDkS8GwzXZ7Ho_9qbbb8DAAWs3OONOfMSD5OU_b0o,6153
10
10
  model_compression_toolkit/core/runner.py,sha256=hXnbgP8Q-62Ie4wAq4JXO-2o77uR3le4mHYgFqJOvfc,10928
11
11
  model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
12
12
  model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
@@ -29,13 +29,13 @@ model_compression_toolkit/core/common/collectors/mean_collector.py,sha256=mjr3U_
29
29
  model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py,sha256=5oKsJEKdVmj4C7fKdHhmrFN5k4G2BaFETpmf_xKNs7s,5207
30
30
  model_compression_toolkit/core/common/collectors/statistics_collector.py,sha256=vcf7Pk1v09SJC4fbAWf_8AgTktE6tPizJbQpSmocP2U,7930
31
31
  model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
32
- model_compression_toolkit/core/common/fusion/layer_fusing.py,sha256=tIsWFYc771o59uvq5fxAaBmOCnd_gd-_xMbQI9SupQA,5479
32
+ model_compression_toolkit/core/common/fusion/layer_fusing.py,sha256=lOubqpc18TslhXZijWUJQAa1c3jIB2S-M-5HK78wJPQ,5548
33
33
  model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
34
34
  model_compression_toolkit/core/common/graph/base_graph.py,sha256=Go3x1Sc0V8p9CNaJhJ6ZTjtSGGjUaKjwdLkxX_8ZFmg,38129
35
- model_compression_toolkit/core/common/graph/base_node.py,sha256=m_qh0bgo_vWTTW5JkA8pdGUQfMYt8SdG5XZWBhcH-aI,27999
35
+ model_compression_toolkit/core/common/graph/base_node.py,sha256=6-ZYY_BlGyG6I_XYlB4AtfYs--_vpoF6hGyEwHFZNfc,28483
36
36
  model_compression_toolkit/core/common/graph/edge.py,sha256=K6Wc2hBcIqig5PbbLhbjtTgYtkyZEohfgj4Wn_J5yEA,3733
37
37
  model_compression_toolkit/core/common/graph/functional_node.py,sha256=RgwWAoMX7YV5c2gZdTBSX-ziTh3OLbebZXr3jitkxDs,3173
38
- model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=kQ14uXW6ecsj7IarjRLAXUzDBmakD_v6Ck7-u24_nxg,4732
38
+ model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
39
39
  model_compression_toolkit/core/common/graph/graph_searches.py,sha256=2oKuW6L8hP-oL0lFO9PhQFt9fEFgVJwpc1u4fHExAtE,5128
40
40
  model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py,sha256=4OojC7tM7QvP7GETEuIS7PVIfCctDfgzMyfHqa741T4,9789
41
41
  model_compression_toolkit/core/common/graph/memory_graph/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
@@ -153,7 +153,7 @@ model_compression_toolkit/core/keras/default_framework_info.py,sha256=Ha4HTHuiw_
153
153
  model_compression_toolkit/core/keras/keras_implementation.py,sha256=R7gtKur2Ubw9HVBI7ImdyNaP6OxBfNfvwqQ2Ey2wW2E,29398
154
154
  model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
155
155
  model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=Aqh31wOPaiZcJIOm-uJwzev0eTMdJyXaOk97rs4z7BU,3879
156
- model_compression_toolkit/core/keras/kpi_data_facade.py,sha256=nZdRhQuIXjoL3sq2HjffKHWfnafVM-j_oqk0Cc5Op3I,4889
156
+ model_compression_toolkit/core/keras/kpi_data_facade.py,sha256=1EY7krgkOdo5zbpeKvPq0kiroU1aHskMVlDnReLdjuM,4501
157
157
  model_compression_toolkit/core/keras/tf_tensor_numpy.py,sha256=1kBs9URqZTfmRXAsCqvnekV5bKUL3MyqGbORewLIwu8,2457
158
158
  model_compression_toolkit/core/keras/back2framework/__init__.py,sha256=rhIiXg_nBgUZ-baE3M6SzCuQbcnq4iebY1jtJBvKHOM,808
159
159
  model_compression_toolkit/core/keras/back2framework/factory_model_builder.py,sha256=GSh1Piz5qpA7IlvHTMqUvPn7WBDa0IHEDZdd_TzY9XA,2226
@@ -178,7 +178,7 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/remove_re
178
178
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/residual_collapsing.py,sha256=6PnPIC5ax7uTzcoslW7ropIu7vVmo70AD4QYcYnQV20,3176
179
179
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py,sha256=ryes9y1ie-vjBGso2TeO4EXxVk69Ew3iSAhshPz1Ou4,5542
180
180
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/separableconv_decomposition.py,sha256=TEaHlIbXj_ZjIdT5TmAICD3WLD3u_7g0fLWQcNzTJuM,7941
181
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py,sha256=6vEakr0jWrccU7dfubRCiNg6TFe6whte_pbTiXMJIvc,11045
181
+ model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py,sha256=JCK--hQMKzbx4MOQZBPZqK015JWZELUO5YdA30IU4bI,11149
182
182
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/softmax_shift.py,sha256=Qk5seDALj_th9dHJehY7ynZjvFjVfCv_mJ1enA5hX0c,1623
183
183
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/virtual_activation_weights_composition.py,sha256=wH9ocMLL725-uUPU-zCxdd8NwT5nyd0ZShmI7iuTwF8,1462
184
184
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/weights_activation_split.py,sha256=rjIheZW7LbSPv9bzMSmC8wl6UUxaTkd4J2IHinObT-Y,1814
@@ -198,7 +198,7 @@ model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py,sha256=Up3-sbuA
198
198
  model_compression_toolkit/core/keras/reader/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
199
199
  model_compression_toolkit/core/keras/reader/common.py,sha256=z0PMoP_HjndN3upYTEQS6yxSGqQsM74V3uRHt_Hx3jw,2490
200
200
  model_compression_toolkit/core/keras/reader/connectivity_handler.py,sha256=AgF6qXZOJMeXvc-pBnGY23BJz7wPBx2aTYxHiO8efec,11303
201
- model_compression_toolkit/core/keras/reader/node_builder.py,sha256=kVP5lK13rGU5pcyHMTXOF7oyG48DyO-SDm1wBFWGzdo,9106
201
+ model_compression_toolkit/core/keras/reader/node_builder.py,sha256=Zpq7aab68s16UJTpeEt-ybtIcHBwZqz5PoA9h774yAE,9657
202
202
  model_compression_toolkit/core/keras/reader/reader.py,sha256=wS9UQ2wJKnkZYe9JHwQp7ygDr6CRlzrxmIyLDv1Qz6U,8109
203
203
  model_compression_toolkit/core/keras/reader/nested_model/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
204
204
  model_compression_toolkit/core/keras/reader/nested_model/edges_merger.py,sha256=K6KAH9o8KSG6baLmhKoCrYK-i-wb6gRKiZmoijFqEYA,7906
@@ -211,7 +211,7 @@ model_compression_toolkit/core/keras/visualization/__init__.py,sha256=mjbqLD-KcG
211
211
  model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
212
212
  model_compression_toolkit/core/pytorch/constants.py,sha256=NI-J7REuxn06oEIHsmJ4GqtNC3TbV8xlkJjt5Ar-c4U,2626
213
213
  model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=r1XyzUFvrjGcJHQM5ETLsMZIG2yHCr9HMjqf0ti9inw,4175
214
- model_compression_toolkit/core/pytorch/kpi_data_facade.py,sha256=xQdg8vtkwhx1uzElrh0KtwKdWFr6b2Guzv913iA_VoI,4978
214
+ model_compression_toolkit/core/pytorch/kpi_data_facade.py,sha256=fMUUHOv31FGWy1dUXteWtj6OlVm4QC2mf2H77n7ToLM,4584
215
215
  model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=IoMvTch5awAEPvB6Tg6ANhFGXvfSgv7JLsUBlxpMwk4,4330
216
216
  model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=1uIDT-3wLzQf1FT8fMleyu5w5EYL0n7HoFEG80XDUY8,27082
217
217
  model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=n_B4a6FMwM9D2w8kzy3oenBWZgXNZuIZgTJC6JEuTy0,3250
@@ -341,7 +341,7 @@ model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCS
341
341
  model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
342
342
  model_compression_toolkit/gptq/keras/gptq_training.py,sha256=cASZlTmnth3Vu-7GfmC03FxWSXtpSVhdPKT_twWml68,17949
343
343
  model_compression_toolkit/gptq/keras/graph_info.py,sha256=FIGqzJbG6GkdHenvdMu-tGTjp4j9BewdF_spmWCb4Mo,4627
344
- model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=qJGTpbjc-771L4Qq8xKO4LfV889IItbX3jV_Sx7gNjA,14010
344
+ model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=wRyQrJJ71JwtFoiIdBPDHE0srpUwmL7nqHbXOvjDHFc,13578
345
345
  model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
346
346
  model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=8NrJBftKFbMAF_jYaAbLP6GBwpCv3Ln1NKURaV75zko,4770
347
347
  model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
@@ -385,7 +385,7 @@ model_compression_toolkit/qat/__init__.py,sha256=kj2qsZh_Ca7PncsHKcaL5EVT2H8g4hY
385
385
  model_compression_toolkit/qat/common/__init__.py,sha256=6tLZ4R4pYP6QVztLVQC_jik2nES3l4uhML0qUxZrezk,829
386
386
  model_compression_toolkit/qat/common/qat_config.py,sha256=zoq0Vb74vCY7WlWD8JH_KPrHDoUHSvMc3gcO53u7L2U,3394
387
387
  model_compression_toolkit/qat/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
388
- model_compression_toolkit/qat/keras/quantization_facade.py,sha256=eLDFe6QwdRxXBo2U7s-ohXRpYIc_xetNyh8hWxcu4L4,16902
388
+ model_compression_toolkit/qat/keras/quantization_facade.py,sha256=xH05Ro9aY9HabQo_PztaXw0-D3Cxvl-GYCmDKRjwkuI,16524
389
389
  model_compression_toolkit/qat/keras/quantizer/__init__.py,sha256=zmYyCa25_KLCSUCGUDRslh3RCIjcRMxc_oXa54Aui-4,996
390
390
  model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py,sha256=gPuIgQb8OafvC3SuA8jNsGoy8S8eTsDCEKuh36WDNss,2104
391
391
  model_compression_toolkit/qat/keras/quantizer/quant_utils.py,sha256=cBULOgWUodcBO1lHevZggdTevuDYI6tQceV86U2x6DA,2543
@@ -397,7 +397,7 @@ model_compression_toolkit/qat/keras/quantizer/ste_rounding/__init__.py,sha256=cc
397
397
  model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py,sha256=I4KlaGv17k71IyjuSG9M0OlXlD5P0pfvKa6oCyRQ5FE,13517
398
398
  model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py,sha256=EED6LfqhX_OhDRJ9e4GwbpgNC9vq7hoXyJS2VPvG2qc,10789
399
399
  model_compression_toolkit/qat/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
400
- model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=P62u3D9Ilh9dkcxf6qiLKnn_DkahX2Eht4hKmNbPEEc,13370
400
+ model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=TaciVmT0tQhvfpp7ASxPo-feZWlUNLg4IVvx8Qpe5jA,12963
401
401
  model_compression_toolkit/qat/pytorch/quantizer/__init__.py,sha256=xYa4C8pr9cG1f3mQQcBXO_u3IdJN-zl7leZxuXDs86w,1003
402
402
  model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py,sha256=FnhuFCuQoSf78FM1z1UZgXXd3k-mKSM7i9dYOuJUmeA,2213
403
403
  model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py,sha256=e8Yfqbc552iAiP4Zxbd2ht1A3moRFGnV_KRGDm9Gw_g,5709
@@ -472,8 +472,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
472
472
  model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
473
473
  model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
474
474
  model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=SbvRlIdE32PEBsINt1bhSqvrKL_zbM9V-aeSkOn-sw4,3083
475
- mct_nightly-1.11.0.20240313.post405.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
476
- mct_nightly-1.11.0.20240313.post405.dist-info/METADATA,sha256=7ah0K0r6vuzGNTh9Dh1mE2_OF4NnjIBBAR8XPpT_t8I,17444
477
- mct_nightly-1.11.0.20240313.post405.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
478
- mct_nightly-1.11.0.20240313.post405.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
479
- mct_nightly-1.11.0.20240313.post405.dist-info/RECORD,,
475
+ mct_nightly-1.11.0.20240315.post349.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
476
+ mct_nightly-1.11.0.20240315.post349.dist-info/METADATA,sha256=Abw6jCZMqvGv6zOjZJNCIXqCQQlu8_135csHgGOuxHA,18529
477
+ mct_nightly-1.11.0.20240315.post349.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
478
+ mct_nightly-1.11.0.20240315.post349.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
479
+ mct_nightly-1.11.0.20240315.post349.dist-info/RECORD,,
@@ -31,9 +31,10 @@ def filter_fusing_patterns(fusing_patterns: List[List[Any]], node: BaseNode, idx
31
31
  fusing_patterns after filtering non-relevant fusions
32
32
  """
33
33
  valid_fusing_patterns = []
34
- for i,fusing_pattern in enumerate(fusing_patterns):
34
+ for i, fusing_pattern in enumerate(fusing_patterns):
35
35
  if idx < len(fusing_pattern):
36
- if (type(fusing_pattern[idx]) == LayerFilterParams and node.is_match_filter_params(fusing_pattern[idx])) or fusing_pattern[idx] == node.type:
36
+ if (type(fusing_pattern[idx]) == LayerFilterParams and node.is_match_filter_params(fusing_pattern[idx])) or \
37
+ node.is_match_type(fusing_pattern[idx]):
37
38
  valid_fusing_patterns.append(fusing_pattern)
38
39
 
39
40
  # Return only valid patterns for this node
@@ -44,7 +45,7 @@ def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List[BaseNode]) ->
44
45
  """
45
46
  Check if the fusion is valid: exist in fusing_patterns
46
47
  Args:
47
- fusing_patterns: supported fusings
48
+ fusing_patterns: supported fusing patterns
48
49
  nodes: nodes which are participating in fusion
49
50
  Returns:
50
51
  whether the fusion in valid
@@ -56,8 +57,9 @@ def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List[BaseNode]) ->
56
57
  if fusion_depth != len(fusing_pattern):
57
58
  continue
58
59
  counter = 0
59
- for i,layer in enumerate(fusing_pattern):
60
- if (type(layer) == LayerFilterParams and nodes[i].is_match_filter_params(layer)) or layer == nodes[i].type:
60
+ for i, layer in enumerate(fusing_pattern):
61
+ if (type(layer) == LayerFilterParams and nodes[i].is_match_filter_params(layer)) or \
62
+ nodes[i].is_match_type(layer):
61
63
  counter += 1
62
64
  if counter == fusion_depth:
63
65
  return True
@@ -107,7 +109,7 @@ def fusion(graph: Graph, tpc: TargetPlatformCapabilities) -> Graph:
107
109
  if node in fused_nodes:
108
110
  continue
109
111
  # Start fusing search
110
- fusing_nodes = [] # nodes that are candidates for participating in fusing
112
+ fusing_nodes = [] # nodes that are candidates for participating in fusing
111
113
  patterns = copy.deepcopy(fusing_patterns)
112
114
  next_nodes = [node]
113
115
  for i in range(max_layers_fusing):
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import copy
17
- from typing import Dict, Any, Tuple, List
17
+ from typing import Dict, Any, Tuple, List, Type
18
18
 
19
19
  import numpy as np
20
20
 
@@ -556,6 +556,19 @@ class BaseNode:
556
556
  return tpc.layer2qco.get(self.type)
557
557
  return tpc.tp_model.default_qco
558
558
 
559
+ def is_match_type(self, _type: Type) -> bool:
560
+ """
561
+ Check if input type matches the node type, either in instance type or in type name. Checking the
562
+ name string is required because of function types changes that occurred in TF 2.15.
563
+
564
+ Args:
565
+ _type: other node type
566
+ Returns:
567
+ Whether _type matches the self node type
568
+
569
+ """
570
+ return _type == self.type or _type.__name__ == self.type.__name__
571
+
559
572
  def is_match_filter_params(self, layer_filter_params: LayerFilterParams) -> bool:
560
573
  """
561
574
  Check if the node matches a LayerFilterParams according to its
@@ -572,7 +585,7 @@ class BaseNode:
572
585
  return False
573
586
 
574
587
  # Check the node has the same type as the layer in LayerFilterParams
575
- if layer_filter_params.layer != self.type:
588
+ if not self.is_match_type(layer_filter_params.layer):
576
589
  return False
577
590
 
578
591
  # Get attributes from node to filter
@@ -35,7 +35,7 @@ class NodeOperationMatcher(node_matcher.BaseNodeMatcher):
35
35
 
36
36
  self.operation = operation
37
37
 
38
- def apply(self, input_node_object: Any) -> bool:
38
+ def apply(self, input_node_object: BaseNode) -> bool:
39
39
  """
40
40
  Check if input_node_object matches the matcher condition.
41
41
 
@@ -47,7 +47,7 @@ class NodeOperationMatcher(node_matcher.BaseNodeMatcher):
47
47
  return nothing.
48
48
  """
49
49
 
50
- if input_node_object.type == self.operation:
50
+ if input_node_object.is_match_type(self.operation):
51
51
  return True
52
52
 
53
53
 
@@ -109,7 +109,8 @@ def create_add_node(add_value: float,
109
109
  quantization_attr={},
110
110
  layer_class=TFOpLambda,
111
111
  op_call_args=[np.array(add_value, dtype=np.float32).reshape([1] * len(input_shape))],
112
- op_call_kwargs={})
112
+ op_call_kwargs={},
113
+ functional_op=tf.add)
113
114
  return add_node
114
115
 
115
116
 
@@ -157,7 +158,8 @@ def create_pad_node(next_node_name: str,
157
158
  layer_class=TFOpLambda,
158
159
  op_call_args=[],
159
160
  op_call_kwargs={'paddings': num_elements_to_pad,
160
- 'constant_values': value_to_pad})
161
+ 'constant_values': value_to_pad},
162
+ functional_op=tf.pad)
161
163
 
162
164
  return pad_node
163
165
 
@@ -36,7 +36,6 @@ if FOUND_TF:
36
36
  def keras_kpi_data(in_model: Model,
37
37
  representative_data_gen: Callable,
38
38
  core_config: CoreConfig = CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig()),
39
- fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
40
39
  target_platform_capabilities: TargetPlatformCapabilities = KERAS_DEFAULT_TPC) -> KPI:
41
40
  """
42
41
  Computes KPI data that can be used to calculate the desired target KPI for mixed-precision quantization.
@@ -46,7 +45,6 @@ if FOUND_TF:
46
45
  in_model (Model): Keras model to quantize.
47
46
  representative_data_gen (Callable): Dataset used for calibration.
48
47
  core_config (CoreConfig): CoreConfig containing parameters for quantization and mixed precision of how the model should be quantized.
49
- fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/keras/default_framework_info.py>`_
50
48
  target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
51
49
 
52
50
  Returns:
@@ -82,7 +80,7 @@ if FOUND_TF:
82
80
  representative_data_gen,
83
81
  core_config,
84
82
  target_platform_capabilities,
85
- fw_info,
83
+ DEFAULT_KERAS_INFO,
86
84
  fw_impl)
87
85
 
88
86
  else:
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Any, Callable, Dict
15
+ from typing import Any, List, Dict
16
16
 
17
17
  import tensorflow as tf
18
18
  from tensorflow.python.util import tf_inspect
@@ -45,19 +45,32 @@ is_const = lambda x: isinstance(x, (tf.Variable, tf.Tensor, np.ndarray))
45
45
  is_tensor = lambda x: isinstance(x, KerasTensor)
46
46
 
47
47
 
48
- def get_kwargs2index(tf_func: Callable) -> Dict[str, int]:
48
+ def get_tf_function_symbols() -> List[str]:
49
+ """
50
+ Create a list of tf function symbols, as they are created in the TFOpLambda layer. The
51
+ symbols are serializations of the function names.
52
+
53
+ Returns:
54
+ A list of TF function symbols,
55
+ """
56
+ return [TFOpLambda(f).symbol for f in [tf.add, tf.multiply, tf.subtract, tf.divide,
57
+ tf.truediv, tf.pow, tf.matmul]]
58
+
59
+
60
+ def get_kwargs2index(tfoplambda_layer: TFOpLambda) -> Dict[str, int]:
49
61
  """
50
62
  Positional weights are saved according to their index in the node's call arguments, so
51
63
  need to know the function arguments' names in case the weights are in the kwargs.
52
64
  Args:
53
- tf_func: functional node function.
65
+ tfoplambda_layer: TFOpLambda layer.
54
66
 
55
67
  Returns:
56
68
  A dictionary with argument number and index: {arg_name: arg_index}.
57
69
  """
58
- if tf_func in [tf.add, tf.subtract, tf.divide, tf.truediv, tf.multiply, tf.pow,
59
- tf.matmul, tf.image.crop_and_resize, tf.image.combined_non_max_suppression]:
60
- return {arg_name: i for i, arg_name in enumerate(tf_inspect.getfullargspec(tf_func).args)}
70
+ if tfoplambda_layer.function in [tf.add, tf.subtract, tf.divide, tf.truediv, tf.multiply, tf.pow,
71
+ tf.matmul, tf.image.crop_and_resize, tf.image.combined_non_max_suppression] or \
72
+ tfoplambda_layer.symbol in ['__operators__.add', 'math.add', 'math.multiply', 'linalg.matmul', 'concat']:
73
+ return {arg_name: i for i, arg_name in enumerate(tf_inspect.getfullargspec(tfoplambda_layer.function).args)}
61
74
  else:
62
75
  return {}
63
76
 
@@ -110,7 +123,7 @@ def build_node(node: KerasNode,
110
123
  # a flag to indicate that.
111
124
  inputs_as_list = __is_functional_inputs_a_list(op_call_args)
112
125
 
113
- kwarg2index = get_kwargs2index(keras_layer.function)
126
+ kwarg2index = get_kwargs2index(keras_layer)
114
127
 
115
128
  # Functional nodes do not have weights, but may have constants in their call_args and\or
116
129
  # call kwargs. Therefore, we extract these constants and save them in the node's weights as
@@ -122,10 +135,10 @@ def build_node(node: KerasNode,
122
135
  Logger.error('Functional nodes are not expected to have weights in framework')
123
136
 
124
137
  # read weights from call args
138
+ tf_function_symbols = get_tf_function_symbols()
125
139
  for i, arg in enumerate(op_call_args[0] if inputs_as_list else op_call_args):
126
140
  if is_const(arg) or (
127
- keras_layer.function in [tf.add, tf.multiply, tf.subtract, tf.divide, tf.truediv, tf.pow,
128
- tf.matmul] and
141
+ keras_layer.symbol in tf_function_symbols and
129
142
  isinstance(arg, (tuple, list))):
130
143
  weights.update({i: to_numpy(arg, is_single_tensor=True)})
131
144
  # remove weights and KerasTensors and weights from op_call_args
@@ -39,7 +39,6 @@ if FOUND_TORCH:
39
39
  def pytorch_kpi_data(in_model: Module,
40
40
  representative_data_gen: Callable,
41
41
  core_config: CoreConfig = CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig()),
42
- fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
43
42
  target_platform_capabilities: TargetPlatformCapabilities = PYTORCH_DEFAULT_TPC) -> KPI:
44
43
  """
45
44
  Computes KPI data that can be used to calculate the desired target KPI for mixed-precision quantization.
@@ -49,7 +48,6 @@ if FOUND_TORCH:
49
48
  in_model (Model): PyTorch model to quantize.
50
49
  representative_data_gen (Callable): Dataset used for calibration.
51
50
  core_config (CoreConfig): CoreConfig containing parameters for quantization and mixed precision
52
- fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default PyTorch info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/pytorch/default_framework_info.py>`_
53
51
  target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
54
52
 
55
53
  Returns:
@@ -85,7 +83,7 @@ if FOUND_TORCH:
85
83
  representative_data_gen,
86
84
  core_config,
87
85
  target_platform_capabilities,
88
- fw_info,
86
+ DEFAULT_PYTORCH_INFO,
89
87
  fw_impl)
90
88
 
91
89
  else:
@@ -68,7 +68,7 @@ def quantization_preparation_runner(graph: Graph,
68
68
  fw_info,
69
69
  core_config.quantization_config) # Mark points for statistics collection
70
70
 
71
- for _data in tqdm(representative_data_gen(), "Statistics Collection:"):
71
+ for _data in tqdm(representative_data_gen(), "Statistics Collection"):
72
72
  mi.infer(_data)
73
73
 
74
74
  if tb_w is not None:
@@ -113,12 +113,10 @@ if FOUND_TF:
113
113
  regularization_factor=regularization_factor)
114
114
 
115
115
 
116
- def keras_gradient_post_training_quantization(in_model: Model,
117
- representative_data_gen: Callable,
116
+ def keras_gradient_post_training_quantization(in_model: Model, representative_data_gen: Callable,
118
117
  gptq_config: GradientPTQConfig,
119
118
  gptq_representative_data_gen: Callable = None,
120
119
  core_config: CoreConfig = CoreConfig(),
121
- fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
122
120
  target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
123
121
  """
124
122
  Quantize a trained Keras model using post-training quantization. The model is quantized using a
@@ -142,7 +140,6 @@ if FOUND_TF:
142
140
  gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
143
141
  gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
144
142
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
145
- fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/keras/default_framework_info.py>`_
146
143
  target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
147
144
 
148
145
  Returns:
@@ -192,7 +189,7 @@ if FOUND_TF:
192
189
 
193
190
  """
194
191
  KerasModelValidation(model=in_model,
195
- fw_info=fw_info).validate()
192
+ fw_info=DEFAULT_KERAS_INFO).validate()
196
193
 
197
194
  if core_config.mixed_precision_enable:
198
195
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
@@ -200,14 +197,14 @@ if FOUND_TF:
200
197
  "MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
201
198
  "API, or pass a valid mixed precision configuration.") # pragma: no cover
202
199
 
203
- tb_w = init_tensorboard_writer(fw_info)
200
+ tb_w = init_tensorboard_writer(DEFAULT_KERAS_INFO)
204
201
 
205
202
  fw_impl = GPTQKerasImplemantation()
206
203
 
207
204
  tg, bit_widths_config, hessian_info_service = core_runner(in_model=in_model,
208
205
  representative_data_gen=representative_data_gen,
209
206
  core_config=core_config,
210
- fw_info=fw_info,
207
+ fw_info=DEFAULT_KERAS_INFO,
211
208
  fw_impl=fw_impl,
212
209
  tpc=target_platform_capabilities,
213
210
  tb_w=tb_w)
@@ -217,7 +214,7 @@ if FOUND_TF:
217
214
  gptq_config,
218
215
  representative_data_gen,
219
216
  gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen,
220
- fw_info,
217
+ DEFAULT_KERAS_INFO,
221
218
  fw_impl,
222
219
  tb_w,
223
220
  hessian_info_service=hessian_info_service)
@@ -89,7 +89,6 @@ if FOUND_TF:
89
89
  representative_data_gen: Callable,
90
90
  core_config: CoreConfig = CoreConfig(),
91
91
  qat_config: QATConfig = QATConfig(),
92
- fw_info: FrameworkInfo = DEFAULT_KERAS_INFO,
93
92
  target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
94
93
  """
95
94
  Prepare a trained Keras model for quantization aware training. First the model quantization is optimized
@@ -111,7 +110,6 @@ if FOUND_TF:
111
110
  representative_data_gen (Callable): Dataset used for initial calibration.
112
111
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
113
112
  qat_config (QATConfig): QAT configuration
114
- fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Keras info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/keras/default_framework_info.py>`_
115
113
  target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
116
114
 
117
115
  Returns:
@@ -159,7 +157,7 @@ if FOUND_TF:
159
157
  Pass the model, the representative dataset generator, the configuration and the target KPI to get a
160
158
  quantized model:
161
159
 
162
- >>> quantized_model, quantization_info, custom_objects = mct.qat.keras_quantization_aware_training_init_experimental(model, repr_datagen, kpi, core_config=config)
160
+ >>> quantized_model, quantization_info, custom_objects = mct.qat.keras_quantization_aware_training_init_experimental(model, repr_datagen, kpi, core_config=core_config)
163
161
 
164
162
  Use the quantized model for fine-tuning. For loading the model from file, use the custom_objects dictionary:
165
163
 
@@ -174,7 +172,7 @@ if FOUND_TF:
174
172
  f"project https://github.com/sony/model_optimization")
175
173
 
176
174
  KerasModelValidation(model=in_model,
177
- fw_info=fw_info).validate()
175
+ fw_info=DEFAULT_KERAS_INFO).validate()
178
176
 
179
177
  if core_config.mixed_precision_enable:
180
178
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
@@ -182,7 +180,7 @@ if FOUND_TF:
182
180
  "MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization API,"
183
181
  "or pass a valid mixed precision configuration.")
184
182
 
185
- tb_w = init_tensorboard_writer(fw_info)
183
+ tb_w = init_tensorboard_writer(DEFAULT_KERAS_INFO)
186
184
 
187
185
  fw_impl = KerasImplementation()
188
186
 
@@ -190,16 +188,16 @@ if FOUND_TF:
190
188
  tg, bit_widths_config, _ = core_runner(in_model=in_model,
191
189
  representative_data_gen=representative_data_gen,
192
190
  core_config=core_config,
193
- fw_info=fw_info,
191
+ fw_info=DEFAULT_KERAS_INFO,
194
192
  fw_impl=fw_impl,
195
193
  tpc=target_platform_capabilities,
196
194
  tb_w=tb_w)
197
195
 
198
- tg = ptq_runner(tg, representative_data_gen, core_config, fw_info, fw_impl, tb_w)
196
+ tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_KERAS_INFO, fw_impl, tb_w)
199
197
 
200
198
  _qat_wrapper = partial(qat_wrapper, qat_config=qat_config)
201
199
  qat_model, user_info = KerasModelBuilder(graph=tg,
202
- fw_info=fw_info,
200
+ fw_info=DEFAULT_KERAS_INFO,
203
201
  wrapper=_qat_wrapper,
204
202
  get_activation_quantizer_holder_fn=partial(get_activation_quantizer_holder,
205
203
  qat_config=qat_config)).build_model()
@@ -77,7 +77,6 @@ if FOUND_TORCH:
77
77
  representative_data_gen: Callable,
78
78
  core_config: CoreConfig = CoreConfig(),
79
79
  qat_config: QATConfig = QATConfig(),
80
- fw_info: FrameworkInfo = DEFAULT_PYTORCH_INFO,
81
80
  target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
82
81
  """
83
82
  Prepare a trained Pytorch model for quantization aware training. First the model quantization is optimized
@@ -99,7 +98,6 @@ if FOUND_TORCH:
99
98
  representative_data_gen (Callable): Dataset used for initial calibration.
100
99
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
101
100
  qat_config (QATConfig): QAT configuration
102
- fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.). `Default Pytorch info <https://github.com/sony/model_optimization/blob/main/model_compression_toolkit/core/pytorch/default_framework_info.py>`_
103
101
  target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Pytorch model according to.
104
102
 
105
103
  Returns:
@@ -150,7 +148,7 @@ if FOUND_TORCH:
150
148
  "MixedPrecisionQuantizationConfig. Please use pytorch_post_training_quantization API,"
151
149
  "or pass a valid mixed precision configuration.")
152
150
 
153
- tb_w = init_tensorboard_writer(fw_info)
151
+ tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
154
152
  fw_impl = PytorchImplementation()
155
153
 
156
154
  # Ignore trace hessian service as we do not use it here
@@ -162,12 +160,12 @@ if FOUND_TORCH:
162
160
  tpc=target_platform_capabilities,
163
161
  tb_w=tb_w)
164
162
 
165
- tg = ptq_runner(tg, representative_data_gen, core_config, fw_info, fw_impl, tb_w)
163
+ tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_PYTORCH_INFO, fw_impl, tb_w)
166
164
 
167
165
  _qat_wrapper = partial(qat_wrapper, qat_config=qat_config)
168
166
 
169
167
  qat_model, user_info = PyTorchModelBuilder(graph=tg,
170
- fw_info=fw_info,
168
+ fw_info=DEFAULT_PYTORCH_INFO,
171
169
  wrapper=_qat_wrapper,
172
170
  get_activation_quantizer_holder_fn=partial(
173
171
  get_activation_quantizer_holder,