mct-nightly 1.11.0.20240317.91316__py3-none-any.whl → 1.11.0.20240319.407__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {mct_nightly-1.11.0.20240317.91316.dist-info → mct_nightly-1.11.0.20240319.407.dist-info}/METADATA +1 -1
  2. {mct_nightly-1.11.0.20240317.91316.dist-info → mct_nightly-1.11.0.20240319.407.dist-info}/RECORD +32 -32
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/__init__.py +2 -0
  5. model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py +3 -1
  6. model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py +13 -9
  7. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -1
  8. model_compression_toolkit/core/common/pruning/pruning_config.py +8 -2
  9. model_compression_toolkit/core/common/pruning/pruning_info.py +3 -10
  10. model_compression_toolkit/core/common/quantization/core_config.py +8 -3
  11. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +1 -1
  12. model_compression_toolkit/core/graph_prep_runner.py +3 -2
  13. model_compression_toolkit/core/pytorch/kpi_data_facade.py +1 -1
  14. model_compression_toolkit/core/runner.py +10 -2
  15. model_compression_toolkit/data_generation/__init__.py +3 -0
  16. model_compression_toolkit/data_generation/common/enums.py +56 -27
  17. model_compression_toolkit/data_generation/keras/keras_data_generation.py +29 -0
  18. model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py +29 -2
  19. model_compression_toolkit/exporter/model_exporter/fw_agonstic/quantization_format.py +10 -0
  20. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +5 -9
  21. model_compression_toolkit/gptq/common/gptq_config.py +6 -3
  22. model_compression_toolkit/gptq/keras/quantization_facade.py +9 -7
  23. model_compression_toolkit/gptq/pytorch/quantization_facade.py +3 -0
  24. model_compression_toolkit/pruning/keras/pruning_facade.py +5 -4
  25. model_compression_toolkit/ptq/keras/quantization_facade.py +3 -0
  26. model_compression_toolkit/ptq/pytorch/quantization_facade.py +3 -0
  27. model_compression_toolkit/qat/keras/quantization_facade.py +4 -1
  28. model_compression_toolkit/qat/pytorch/quantization_facade.py +3 -0
  29. model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +3 -3
  30. {mct_nightly-1.11.0.20240317.91316.dist-info → mct_nightly-1.11.0.20240319.407.dist-info}/LICENSE.md +0 -0
  31. {mct_nightly-1.11.0.20240317.91316.dist-info → mct_nightly-1.11.0.20240319.407.dist-info}/WHEEL +0 -0
  32. {mct_nightly-1.11.0.20240317.91316.dist-info → mct_nightly-1.11.0.20240319.407.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 1.11.0.20240317.91316
3
+ Version: 1.11.0.20240319.407
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,13 +1,13 @@
1
- model_compression_toolkit/__init__.py,sha256=iJs169mTUA7Yi3p0vgbTNRoyzRYnpn7W-AhLFWeJwUE,1574
1
+ model_compression_toolkit/__init__.py,sha256=xX48bBE90QORHpae8fqX4wNK9NTEi5qT_JQoOG7-tPc,1574
2
2
  model_compression_toolkit/constants.py,sha256=_OW_bUeQmf08Bb4oVZ0KfUt-rcCeNOmdBv3aP7NF5fM,3631
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=b9DVktZ-LymFcRxv2aL_sdiE6S2sSrFGWltx6dgEuUY,4863
5
- model_compression_toolkit/core/__init__.py,sha256=P-7OYR4TFYxVV_ZpIJBogkX8bGvXcijlF65Ez3ivjhc,1838
5
+ model_compression_toolkit/core/__init__.py,sha256=DRw7VF7jsqHxGtoxf8F0YcXPRRnQIw6sn6Q925MmWC8,1944
6
6
  model_compression_toolkit/core/analyzer.py,sha256=dbsD61pakp_9JXNyAScLdtJvcXny9jr_cMbET0Bd3Sg,2975
7
7
  model_compression_toolkit/core/exporter.py,sha256=Zo_C5GjIzihtJOyGp-xeCVhY_qohkVz_EGyrSZCbWRM,4115
8
- model_compression_toolkit/core/graph_prep_runner.py,sha256=3xp0WYqyeRdlBkf5R6uD2zWubg_JPttOwS7JRhKykBY,10043
8
+ model_compression_toolkit/core/graph_prep_runner.py,sha256=Ftqm59hT5TGWmSNkY9bFZkVfCacpGyZfCe-6yZR5WY0,10100
9
9
  model_compression_toolkit/core/quantization_prep_runner.py,sha256=hFhDkS8GwzXZ7Ho_9qbbb8DAAWs3OONOfMSD5OU_b0o,6153
10
- model_compression_toolkit/core/runner.py,sha256=hXnbgP8Q-62Ie4wAq4JXO-2o77uR3le4mHYgFqJOvfc,10928
10
+ model_compression_toolkit/core/runner.py,sha256=FJ_TG-OZDtDBM_BNHTdcAX5NKAWPog-0Gh3uDgnXUxU,11383
11
11
  model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
12
12
  model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
13
13
  model_compression_toolkit/core/common/data_loader.py,sha256=jCoVIb4yeOWyCrCNRB1W-mgLSyqNVGEepFXrIqufVc4,4119
@@ -62,8 +62,8 @@ model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py,sha256
62
62
  model_compression_toolkit/core/common/mixed_precision/configurable_quant_id.py,sha256=LLDguK7afsbN742ucLpmJr5TUfTyFpK1vbf2bpVr1v0,882
63
63
  model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py,sha256=7dKMi5S0zQZ16m8NWn1XIuoXsKuZUg64G4-uK8-j1PQ,5177
64
64
  model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha256=H8qYkJsk88OszUJo-Zde7vTmWiypLTg9KbbzIZ-hhvM,2812
65
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=4skskrRuKOoMZIX9XB9Os3WmQiFq8rEe05RmL6xjrxo,4553
66
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=sXyhmyO0FtGO7vHOWqUign88Kh7MCzN8Ohk6wIXq0GQ,6992
65
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=DP5tcxPtiVbSWAeoFbEp7iTwpxDBU1g7V5w7ehDG6jI,4573
66
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=mN-QeabIu_Mz1IzPeQjqgqprCTdwGm4ThYX0gZAek-E,7103
67
67
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=f5j7R7A_bSVqeBY4WuDN8n0YWlR8jhK_n9eKInQ8anY,36763
68
68
  model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=vt829yxXlfbQHPDUHJebda7jfzpGf1N3b6L4XJ4zbSI,28534
69
69
  model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=P8QtKgFXtt5b2RoubzI5OGlCfbEfZsAirjyrkFzK26A,2846
@@ -71,7 +71,7 @@ model_compression_toolkit/core/common/mixed_precision/solution_refinement_proced
71
71
  model_compression_toolkit/core/common/mixed_precision/kpi_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
72
72
  model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi.py,sha256=gcwwuzLKpa2WvsyAr6MXb4cXhOxCM0dvVHKLL-FWGoA,4297
73
73
  model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_aggregation_methods.py,sha256=X0PbF3UHVy3JRRIgcogKpTNm26AJOJ7blajAWsDf7R4,3920
74
- model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py,sha256=vrgUYeL6MPVO_tBNIGf6tuOXsjl60JcIW8y_CPPk6Tk,7464
74
+ model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_data.py,sha256=7wzhz7tPMVq-JmcHvTGiOdDQPXNw3i7HXAvbymJOXY8,7618
75
75
  model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_functions_mapping.py,sha256=cjLf_g4n1INlT1TE1z-I41hDXUTTy8krUSvhRB57rv0,1602
76
76
  model_compression_toolkit/core/common/mixed_precision/kpi_tools/kpi_methods.py,sha256=yg8Pg9kMsjQzi03tcoQlp0iqnnqKvxjohOQNzPdRPzs,20840
77
77
  model_compression_toolkit/core/common/mixed_precision/search_methods/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
@@ -86,9 +86,9 @@ model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py,sha256=G
86
86
  model_compression_toolkit/core/common/pruning/memory_calculator.py,sha256=RnMmgNDHekKFOj-b-ad5rhjuKUvbVawy1A31nxuCRTg,19217
87
87
  model_compression_toolkit/core/common/pruning/prune_graph.py,sha256=ddbZLuWvlNoj5so_5NRbIuG5qDFxD9ApG2gPirbov8o,3317
88
88
  model_compression_toolkit/core/common/pruning/pruner.py,sha256=vXxzBXQ-oAEnw6PAD1SUiNXX7Xix4JJ0LAmV04sjFz0,7313
89
- model_compression_toolkit/core/common/pruning/pruning_config.py,sha256=IfF824hNttyw2i4Tuf3g8CUfelJR3eZuOLzf2aEZNAM,3442
89
+ model_compression_toolkit/core/common/pruning/pruning_config.py,sha256=PO4C1C1_hhAX_B05kqpC-TTx1S1O6Dj9DrtZrxpi1aE,3670
90
90
  model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py,sha256=H2gnCv-lyRLXapDy71QHA3JkLYTQT1ni23nGTYErsZo,6734
91
- model_compression_toolkit/core/common/pruning/pruning_info.py,sha256=DK_ofX-73tfpdmkHNLYlO6_SBifZDsRWmGHsCbrUFN8,4083
91
+ model_compression_toolkit/core/common/pruning/pruning_info.py,sha256=gSCh_qXmLATChb5Nh16wvR1ffI9SERstEroc_hFrVQo,3781
92
92
  model_compression_toolkit/core/common/pruning/pruning_section.py,sha256=I4vxh5iPKWs8yji-q4TVmaa6bcfLwT3ZhB2x8F8cJkU,5721
93
93
  model_compression_toolkit/core/common/pruning/importance_metrics/__init__.py,sha256=3Lkr37Exk9u8811hw8hVqkGcbTQGcLjd3LLuLC3fa_E,698
94
94
  model_compression_toolkit/core/common/pruning/importance_metrics/base_importance_metric.py,sha256=qMAtLWs5fjbSco8nhbig5TkuacdhnDW7cy3avMHRGX4,1988
@@ -99,7 +99,7 @@ model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py,sha256=hk
99
99
  model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py,sha256=gmzD32xsfJH8vkkqaspS7vYa6VWayk1GJe-NfoAEugQ,5901
100
100
  model_compression_toolkit/core/common/quantization/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
101
101
  model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py,sha256=D7lgCc0drQZ3yXNctTBg-FnqHX7e32zp0-ocGYGJbEE,4553
102
- model_compression_toolkit/core/common/quantization/core_config.py,sha256=IkD4Jl9PWdPucfUMq0TtyUl5DBJvha7Dd2xSW7_7dz8,2015
102
+ model_compression_toolkit/core/common/quantization/core_config.py,sha256=KYdyfSmjSL4ye24nKlC_c4_AxYb14qoqaeMnZj4-8kE,2257
103
103
  model_compression_toolkit/core/common/quantization/debug_config.py,sha256=HtkMmneN-EmAzgZK4Vp4M8Sqm5QKdrvNyyZMpaVqYzY,1482
104
104
  model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=fwF4VILaX-u3ZaFd81xjbJuhg8Ef-JX_KfMXW0TPV-I,7136
105
105
  model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=HWBBF--cbzsiMx3BG2kQ3JHkfalVnGO3N-rAXMwNqp4,26707
@@ -108,7 +108,7 @@ model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,
108
108
  model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=mrgVzZszWjxnjT8zm77UVLWKTOwd2thGBo6WNqAS4X8,3867
109
109
  model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=xnM9O9LshYw3dprqfsnK9mw7ipOEAkI85o20auyfswg,2626
110
110
  model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=cdzGNWfT4MRogIU8ehs0tr3lVjnzAI-jeoS9b4TwVBo,2854
111
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=onzV581FPw19WN00EUjrXkvPbA6msHk-VSJkBVduV-s,11490
111
+ model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=ntfdEK39SAuegHtGa_v9H-_IC9WagRvwokRL3wEnGso,11491
112
112
  model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=eCDGwsWYLU6z7qbEVb4TozMW_nd5VEP_iCJ6PcvyEPw,1486
113
113
  model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=eH3nSXPFn94ATF3dZn2HxNAGVJUWotirN6o8wwDfkLg,18165
114
114
  model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=h8Zmpq3KdcsdUUy7K1fvWOVSki0mxT8wtKZXGmgFl74,7405
@@ -211,7 +211,7 @@ model_compression_toolkit/core/keras/visualization/__init__.py,sha256=mjbqLD-KcG
211
211
  model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
212
212
  model_compression_toolkit/core/pytorch/constants.py,sha256=NI-J7REuxn06oEIHsmJ4GqtNC3TbV8xlkJjt5Ar-c4U,2626
213
213
  model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=r1XyzUFvrjGcJHQM5ETLsMZIG2yHCr9HMjqf0ti9inw,4175
214
- model_compression_toolkit/core/pytorch/kpi_data_facade.py,sha256=fMUUHOv31FGWy1dUXteWtj6OlVm4QC2mf2H77n7ToLM,4584
214
+ model_compression_toolkit/core/pytorch/kpi_data_facade.py,sha256=eKFq0gO2DrlS_wN4plMGZTabQKmb0pylIVGZ44HqSnw,4527
215
215
  model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=IoMvTch5awAEPvB6Tg6ANhFGXvfSgv7JLsUBlxpMwk4,4330
216
216
  model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=1uIDT-3wLzQf1FT8fMleyu5w5EYL0n7HoFEG80XDUY8,27082
217
217
  model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=n_B4a6FMwM9D2w8kzy3oenBWZgXNZuIZgTJC6JEuTy0,3250
@@ -263,19 +263,19 @@ model_compression_toolkit/core/pytorch/reader/node_holders.py,sha256=TaolORuwBZE
263
263
  model_compression_toolkit/core/pytorch/reader/reader.py,sha256=Co3-AHZCEOw5w-jtgf9oAKsgtjQoG0MeeSeBVnQ0xOA,5801
264
264
  model_compression_toolkit/core/pytorch/statistics_correction/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
265
265
  model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment_correction.py,sha256=VgU24J3jf7QComHH7jonOXSkg6mO4TOch3uFkOthZvM,3261
266
- model_compression_toolkit/data_generation/__init__.py,sha256=zp3nQ7NhDncuGdHBwCXkRJh6JnGoTYhZZlAOrDE8omc,1138
266
+ model_compression_toolkit/data_generation/__init__.py,sha256=R_RnB8Evj4uq0WKiPWvBWfeePrbake7Z03ugJgK7jLo,1466
267
267
  model_compression_toolkit/data_generation/common/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
268
268
  model_compression_toolkit/data_generation/common/constants.py,sha256=21e3ZX9WVYojexG2acTgklrBk8ZO9DjJnKpP4KHZC44,1018
269
269
  model_compression_toolkit/data_generation/common/data_generation.py,sha256=PnKkWCBf4yla0E4LhvOqT8htWiGW4F98bygExQnpwqI,6397
270
270
  model_compression_toolkit/data_generation/common/data_generation_config.py,sha256=ynyNaT2x2d23bYSrO2sRItM2ZsjGD0K0fM71FlibiJQ,4564
271
- model_compression_toolkit/data_generation/common/enums.py,sha256=UJhndTsE7q7Bm6CgCYQKWOuuD-9lj6j_QQ28KWZK8uU,3522
271
+ model_compression_toolkit/data_generation/common/enums.py,sha256=OGnvtEGFbP5l4V3-1l32zzVQwTb1vGJhTVF0kOkYZK4,3584
272
272
  model_compression_toolkit/data_generation/common/image_pipeline.py,sha256=WwyeoIvgmcxKnuOX-_Hl_0APET4M26f5x-prhUB3qvU,2149
273
273
  model_compression_toolkit/data_generation/common/model_info_exctractors.py,sha256=9zYlyuc7K1s2neHWF3wqL5EVOVaoz_QkCYTktSXrSXI,6047
274
274
  model_compression_toolkit/data_generation/common/optimization_utils.py,sha256=8wCU-bCLabDIUayG3eyapdD8pTE6x0RYG5o3rfha7XE,19572
275
275
  model_compression_toolkit/data_generation/keras/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
276
276
  model_compression_toolkit/data_generation/keras/constants.py,sha256=uy3eU24ykygIrjIvwOMj3j5euBeN2PwWiEFPOkJJ7ss,1088
277
277
  model_compression_toolkit/data_generation/keras/image_pipeline.py,sha256=_Qezq67huKmmNsxdFBBrTY-VaGR-paFzDH80dDuRnug,7623
278
- model_compression_toolkit/data_generation/keras/keras_data_generation.py,sha256=MYFdMPqGxy9tRaTIstJMkcYOk0tMXirke5fxdIJvBjU,19720
278
+ model_compression_toolkit/data_generation/keras/keras_data_generation.py,sha256=6UQXrpwghxhHgMOGP8he84y4ZUQj2v0UGFRNMIPhBI8,21587
279
279
  model_compression_toolkit/data_generation/keras/model_info_exctractors.py,sha256=b3BaOGiMAlCCzPICww722l2H_RucoHgpGUK6xYe8xTA,8552
280
280
  model_compression_toolkit/data_generation/keras/optimization_utils.py,sha256=uQAJpJPpnLDTTLDQGyTS0ZYp2T38TTZLOOElcJPBKHA,21146
281
281
  model_compression_toolkit/data_generation/keras/optimization_functions/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
@@ -289,7 +289,7 @@ model_compression_toolkit/data_generation/pytorch/constants.py,sha256=QWyreMImcf
289
289
  model_compression_toolkit/data_generation/pytorch/image_pipeline.py,sha256=6g7OpOuO3cU4TIuelaRjBKpCPgiMbe1a3iy9bZtdZUo,6617
290
290
  model_compression_toolkit/data_generation/pytorch/model_info_exctractors.py,sha256=wxtaQad4aP8D0SgA8qEPORZM3qBD22G6zO1gjwTNIVU,9632
291
291
  model_compression_toolkit/data_generation/pytorch/optimization_utils.py,sha256=AjYsO-lm06JOUMoKkS6VbyF4O_l_ffWXrgamqJm1ofE,19085
292
- model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py,sha256=BCJ6PVncBBm6sa4IWCYvC-U0-XPs7LV-deao0lq_D20,19192
292
+ model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py,sha256=Fj0pZEdbQUyydbPRrUWAK9X3hJRtzeFkQ1kCsxHjW84,21012
293
293
  model_compression_toolkit/data_generation/pytorch/optimization_functions/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
294
294
  model_compression_toolkit/data_generation/pytorch/optimization_functions/batchnorm_alignment_functions.py,sha256=dMc4zz9XfYfAT4Cxns57VgvGZWPAMfaGlWLFyCyl8TA,1968
295
295
  model_compression_toolkit/data_generation/pytorch/optimization_functions/bn_layer_weighting_functions.py,sha256=i3ePEI8xDE3xZEtmzT5lCkLn9wpObUi_OgqnVDf7nj8,2597
@@ -300,14 +300,14 @@ model_compression_toolkit/exporter/__init__.py,sha256=Eg3c4EAjW3g6h13A-Utgf9ncHr
300
300
  model_compression_toolkit/exporter/model_exporter/__init__.py,sha256=9HIBmj8ROdCA-yvkpA8EcN6RHJe_2vEpLLW_gxOJtak,698
301
301
  model_compression_toolkit/exporter/model_exporter/fw_agonstic/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
302
302
  model_compression_toolkit/exporter/model_exporter/fw_agonstic/exporter.py,sha256=eSC6gEMc9KY5EwVRam9pJCBpCm0ksUeobKV_JAOap9M,2017
303
- model_compression_toolkit/exporter/model_exporter/fw_agonstic/quantization_format.py,sha256=oRlw_LG0UNi8Vl77E1U6WDBsEp1eNaphUbG9rSZkSl4,800
303
+ model_compression_toolkit/exporter/model_exporter/fw_agonstic/quantization_format.py,sha256=otuyY3N2h6NmZKjptRvHEnwJRkPVJ2Ty20J1Mwbkjqc,1165
304
304
  model_compression_toolkit/exporter/model_exporter/keras/__init__.py,sha256=uZ2RigbY9O2PJ0Il8wPpS_s7frgg9WUGd_SHeKGyl1A,699
305
305
  model_compression_toolkit/exporter/model_exporter/keras/base_keras_exporter.py,sha256=-wr2n0yRlmFixXBeZuxg6Rzlvz-ZFUX-PJgSXhgMrEo,1593
306
306
  model_compression_toolkit/exporter/model_exporter/keras/export_serialization_format.py,sha256=v_-rOsWDFI-3k8CoJIr-XzT7ny8WXpAMteWRWtTzaeg,963
307
307
  model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_keras_exporter.py,sha256=E_1IqFYAGUMOrt3U_JK1k--8D0WzWPbjZH_IRLGw_wY,11478
308
308
  model_compression_toolkit/exporter/model_exporter/keras/fakely_quant_tflite_exporter.py,sha256=sqzqQ8US24WgDbg_FoP1NQBgqCbSVwrVTWrxcyY0nPA,3514
309
309
  model_compression_toolkit/exporter/model_exporter/keras/int8_tflite_exporter.py,sha256=nGtpDTeH5Tdp7sjyuXsy_9TPpijDYp4nkz366DUUJ0Q,8048
310
- model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py,sha256=O-GApieS7_zLkpygnN0YvDK-HkCChwA4bSExbI5jvQ8,5998
310
+ model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py,sha256=EByj03xnyJTIwopuKjKne9Nwdr0VqoNdQOCEFwSeTNw,5792
311
311
  model_compression_toolkit/exporter/model_exporter/keras/mctq_keras_exporter.py,sha256=qXXkv3X_wb7t622EOHwXIxfGLGaDqh0T0y4UxREi4Bo,1976
312
312
  model_compression_toolkit/exporter/model_exporter/pytorch/__init__.py,sha256=uZ2RigbY9O2PJ0Il8wPpS_s7frgg9WUGd_SHeKGyl1A,699
313
313
  model_compression_toolkit/exporter/model_exporter/pytorch/base_pytorch_exporter.py,sha256=UPVkEUQCMZ4Lld6CRnEOPEmlfe5vcQZG0Q3FwRBodD4,4021
@@ -331,7 +331,7 @@ model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantiz
331
331
  model_compression_toolkit/gptq/__init__.py,sha256=YKg-tMj9D4Yd0xW9VRD5EN1J5JrmlRbNEF2fOSgodqA,1228
332
332
  model_compression_toolkit/gptq/runner.py,sha256=MIg-oBtR1nbHkexySdCJD_XfjRoHSknLotmGBMuD5qM,5924
333
333
  model_compression_toolkit/gptq/common/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
334
- model_compression_toolkit/gptq/common/gptq_config.py,sha256=U33sLIPB0pI4h_zhr4X_S9K0cEJWTbWFxkj8z9IGlxg,5268
334
+ model_compression_toolkit/gptq/common/gptq_config.py,sha256=6xP99B-lK1bwGv3AdqxnW1V51z2VdzQcjvoSgJOmygA,5288
335
335
  model_compression_toolkit/gptq/common/gptq_constants.py,sha256=QSm6laLkIV0LYmU0BLtmKp3Fi3SqDfbncFQWOGA1cGU,611
336
336
  model_compression_toolkit/gptq/common/gptq_framework_implementation.py,sha256=n3mSf4J92kFjekzyGyrJULylI-8Jf5OVWJ5AFoVnEx0,1266
337
337
  model_compression_toolkit/gptq/common/gptq_graph.py,sha256=8qmty-2MzV6USRoHgShCA13HqxDI3PDGJaFKCQPFo5E,3026
@@ -341,7 +341,7 @@ model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCS
341
341
  model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
342
342
  model_compression_toolkit/gptq/keras/gptq_training.py,sha256=cASZlTmnth3Vu-7GfmC03FxWSXtpSVhdPKT_twWml68,17949
343
343
  model_compression_toolkit/gptq/keras/graph_info.py,sha256=FIGqzJbG6GkdHenvdMu-tGTjp4j9BewdF_spmWCb4Mo,4627
344
- model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=wRyQrJJ71JwtFoiIdBPDHE0srpUwmL7nqHbXOvjDHFc,13578
344
+ model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=NyVMvDsgxMsaAtYxYwaqeQX3VD5GmfftXKHV5wUBLZg,13762
345
345
  model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
346
346
  model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=8NrJBftKFbMAF_jYaAbLP6GBwpCv3Ln1NKURaV75zko,4770
347
347
  model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
@@ -358,7 +358,7 @@ model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=kDuWw-6zh17wZpYWh4Xa9
358
358
  model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
359
359
  model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=9zQC42RfAj4ak-XOzF8xEXS3IkHKhKlOClIfaUA0bGI,15396
360
360
  model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=-0GDC2cr-XXS7cTFTnDflJivGN7VaPnzVPsxCE-vZNU,3955
361
- model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=MMp97lTatr0moe0r4cycqNw-1qVo_ixvissH6n_wjnE,12091
361
+ model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=ER5VPSkZZjqYj7PJ-3B5RX33YjHz3tJ4Er9SF6M-93c,12369
362
362
  model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
363
363
  model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=Zb-P0yRyZHHBlDvUBdRwxDpdduEJyJp6OT9pfKFF5ks,4171
364
364
  model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
@@ -372,20 +372,20 @@ model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/__init__.py,sha256
372
372
  model_compression_toolkit/gptq/pytorch/quantizer/ste_rounding/symmetric_ste.py,sha256=6uxq_w62jn8DDOt9T7VtA6jZ8jTAPcbTufKFOYpVUm4,8768
373
373
  model_compression_toolkit/pruning/__init__.py,sha256=lQMZS8G0pvR1LVi53nnJHNXgLNTan_MWMdwsVxhjrow,1106
374
374
  model_compression_toolkit/pruning/keras/__init__.py,sha256=3Lkr37Exk9u8811hw8hVqkGcbTQGcLjd3LLuLC3fa_E,698
375
- model_compression_toolkit/pruning/keras/pruning_facade.py,sha256=B2mkCh3_AKc1O3IBOdo03PuIyjAoK3IBmgBdmIfUkDI,8296
375
+ model_compression_toolkit/pruning/keras/pruning_facade.py,sha256=2vMmI9QaH9nReyDqZKiWOZPQC3HUQ2ZCahHIMFyveMQ,8396
376
376
  model_compression_toolkit/pruning/pytorch/__init__.py,sha256=pKAdbTCFM_2BrZXUtTIw0ouKotrWwUDF_hP3rPwCM2k,696
377
377
  model_compression_toolkit/pruning/pytorch/pruning_facade.py,sha256=ZLmMhwAEnbXNRwMwgoGEGNmHpZx_KWYu7yi5K3aICWI,9184
378
378
  model_compression_toolkit/ptq/__init__.py,sha256=Z_hkmTh7aLFei1DJKV0oNVUbrv_Q_0CTw-qD85Xf8UM,904
379
379
  model_compression_toolkit/ptq/runner.py,sha256=_c1dSjlPPpsx59Vbg1buhG9bZq__OORz1VlPkwjJzoc,2552
380
380
  model_compression_toolkit/ptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
381
- model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=6kZ54SY_Slw2DGzALm7X2TzZEej9-FEoaVkjxOdFxp8,8598
381
+ model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=ergUI8RDA2h4_SHU05x2pYJatt-U-fZUrShdHJDLo_o,8844
382
382
  model_compression_toolkit/ptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
383
- model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=5QH-khLMPFLdCUPQCxOCYY5v4p_M67TZcfCZGnsWqVs,7191
383
+ model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=WKzokgg_gGcEHipVH26shneiAiTdSa7d_UUQKoS8ALY,7438
384
384
  model_compression_toolkit/qat/__init__.py,sha256=kj2qsZh_Ca7PncsHKcaL5EVT2H8g4hYtvaQ3KFxOkwE,1143
385
385
  model_compression_toolkit/qat/common/__init__.py,sha256=6tLZ4R4pYP6QVztLVQC_jik2nES3l4uhML0qUxZrezk,829
386
386
  model_compression_toolkit/qat/common/qat_config.py,sha256=zoq0Vb74vCY7WlWD8JH_KPrHDoUHSvMc3gcO53u7L2U,3394
387
387
  model_compression_toolkit/qat/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
388
- model_compression_toolkit/qat/keras/quantization_facade.py,sha256=xH05Ro9aY9HabQo_PztaXw0-D3Cxvl-GYCmDKRjwkuI,16524
388
+ model_compression_toolkit/qat/keras/quantization_facade.py,sha256=9qWdNIIx2hKmjGCpSGGEAv7HXg91Y9ZuyKE-avHn46c,16784
389
389
  model_compression_toolkit/qat/keras/quantizer/__init__.py,sha256=zmYyCa25_KLCSUCGUDRslh3RCIjcRMxc_oXa54Aui-4,996
390
390
  model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py,sha256=gPuIgQb8OafvC3SuA8jNsGoy8S8eTsDCEKuh36WDNss,2104
391
391
  model_compression_toolkit/qat/keras/quantizer/quant_utils.py,sha256=cBULOgWUodcBO1lHevZggdTevuDYI6tQceV86U2x6DA,2543
@@ -397,7 +397,7 @@ model_compression_toolkit/qat/keras/quantizer/ste_rounding/__init__.py,sha256=cc
397
397
  model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py,sha256=I4KlaGv17k71IyjuSG9M0OlXlD5P0pfvKa6oCyRQ5FE,13517
398
398
  model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py,sha256=EED6LfqhX_OhDRJ9e4GwbpgNC9vq7hoXyJS2VPvG2qc,10789
399
399
  model_compression_toolkit/qat/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
400
- model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=TaciVmT0tQhvfpp7ASxPo-feZWlUNLg4IVvx8Qpe5jA,12963
400
+ model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=cj1AieM-v1HZcIeNfNDX7AQzQOwUw4ZuGWw2pfuY6Ig,13230
401
401
  model_compression_toolkit/qat/pytorch/quantizer/__init__.py,sha256=xYa4C8pr9cG1f3mQQcBXO_u3IdJN-zl7leZxuXDs86w,1003
402
402
  model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py,sha256=FnhuFCuQoSf78FM1z1UZgXXd3k-mKSM7i9dYOuJUmeA,2213
403
403
  model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py,sha256=e8Yfqbc552iAiP4Zxbd2ht1A3moRFGnV_KRGDm9Gw_g,5709
@@ -422,7 +422,7 @@ model_compression_toolkit/target_platform_capabilities/target_platform/targetpla
422
422
  model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attribute_filter.py,sha256=-riVk2KPy94nYuviaZzZPc6j5vObhD9-6fGryuSLZ9c,8759
423
423
  model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/current_tpc.py,sha256=GcLSXZLxtcE9SxSKdlvo10ba9mqVk_MBiwrvvjSH8H0,2046
424
424
  model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/layer_filter_params.py,sha256=Cl6-mACpje2jM8RJkibbqE3hvTkFR3r26-lW021mIiA,4019
425
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py,sha256=HpJ_zzYHpSMbJ5K-IDhmP-8mwCYconaK17NSIJ3R6iI,6743
425
+ model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py,sha256=Va9-f7M2OK3kOai5AwT-wI2zuezC9is9lwq5OOMhT_4,6733
426
426
  model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py,sha256=m6p9pO_xqGcp-0jAVRaOJww67oSQ6gChCD45_W833Gw,9819
427
427
  model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities_component.py,sha256=FvrYI0Qy7DCmDp2gyUYyCZq5pY84JgLtJqSIiVTJ8Ss,1030
428
428
  model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -472,8 +472,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
472
472
  model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
473
473
  model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
474
474
  model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=SbvRlIdE32PEBsINt1bhSqvrKL_zbM9V-aeSkOn-sw4,3083
475
- mct_nightly-1.11.0.20240317.91316.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
476
- mct_nightly-1.11.0.20240317.91316.dist-info/METADATA,sha256=PKhfKhLQsDpsv2WN0bqt3yB2heGoLsRU7Eu6wf4JFrs,18527
477
- mct_nightly-1.11.0.20240317.91316.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
478
- mct_nightly-1.11.0.20240317.91316.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
479
- mct_nightly-1.11.0.20240317.91316.dist-info/RECORD,,
475
+ mct_nightly-1.11.0.20240319.407.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
476
+ mct_nightly-1.11.0.20240319.407.dist-info/METADATA,sha256=NODMa-MyTyAJxebf2FD53plT_HC0BdlshyXRbOi9vio,18525
477
+ mct_nightly-1.11.0.20240319.407.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
478
+ mct_nightly-1.11.0.20240319.407.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
479
+ mct_nightly-1.11.0.20240319.407.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "1.11.0.20240317.091316"
30
+ __version__ = "1.11.0.20240319.000407"
@@ -25,3 +25,5 @@ from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import
25
25
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
26
26
  from model_compression_toolkit.core.keras.kpi_data_facade import keras_kpi_data
27
27
  from model_compression_toolkit.core.pytorch.kpi_data_facade import pytorch_kpi_data
28
+ from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting
29
+
@@ -47,13 +47,15 @@ def compute_kpi_data(in_model: Any,
47
47
 
48
48
  """
49
49
 
50
+ # We assume that the kpi_data API is used to compute the model KPI for mixed precision scenario,
51
+ # so we run graph preparation under the assumption of enabled mixed precision.
50
52
  transformed_graph = graph_preparation_runner(in_model,
51
53
  representative_data_gen,
52
54
  core_config.quantization_config,
53
55
  fw_info,
54
56
  fw_impl,
55
57
  tpc,
56
- mixed_precision_enable=core_config.mixed_precision_enable)
58
+ mixed_precision_enable=True)
57
59
 
58
60
  # Compute parameters sum
59
61
  weights_params = compute_nodes_weights_params(graph=transformed_graph, fw_info=fw_info)
@@ -16,13 +16,11 @@
16
16
  from typing import List, Callable
17
17
 
18
18
  from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting
19
- from model_compression_toolkit.core.common.mixed_precision.kpi_tools.kpi import KPI
20
19
 
21
20
 
22
21
  class MixedPrecisionQuantizationConfig:
23
22
 
24
23
  def __init__(self,
25
- target_kpi: KPI = None,
26
24
  compute_distance_fn: Callable = None,
27
25
  distance_weighting_method: MpDistanceWeighting = MpDistanceWeighting.AVG,
28
26
  num_of_images: int = 32,
@@ -36,7 +34,6 @@ class MixedPrecisionQuantizationConfig:
36
34
  Class with mixed precision parameters to quantize the input model.
37
35
 
38
36
  Args:
39
- target_kpi (KPI): KPI to constraint the search of the mixed-precision configuration for the model.
40
37
  compute_distance_fn (Callable): Function to compute a distance between two tensors. If None, using pre-defined distance methods based on the layer type for each layer.
41
38
  distance_weighting_method (MpDistanceWeighting): MpDistanceWeighting enum value that provides a function to use when weighting the distances among different layers when computing the sensitivity metric.
42
39
  num_of_images (int): Number of images to use to evaluate the sensitivity of a mixed-precision model comparing to the float model.
@@ -49,7 +46,6 @@ class MixedPrecisionQuantizationConfig:
49
46
 
50
47
  """
51
48
 
52
- self.target_kpi = target_kpi
53
49
  self.compute_distance_fn = compute_distance_fn
54
50
  self.distance_weighting_method = distance_weighting_method
55
51
  self.num_of_images = num_of_images
@@ -67,13 +63,21 @@ class MixedPrecisionQuantizationConfig:
67
63
 
68
64
  self.metric_normalization_threshold = metric_normalization_threshold
69
65
 
70
- def set_target_kpi(self, target_kpi: KPI):
66
+ self._mixed_precision_enable = False
67
+
68
+ def set_mixed_precision_enable(self):
69
+ """
70
+ Set a flag in mixed precision config indicating that mixed precision is enabled.
71
71
  """
72
- Setting target KPI in mixed precision config.
73
72
 
74
- Args:
75
- target_kpi: A target KPI to set.
73
+ self._mixed_precision_enable = True
76
74
 
75
+ @property
76
+ def mixed_precision_enable(self):
77
77
  """
78
+ A property that indicates whether mixed precision quantization is enabled.
78
79
 
79
- self.target_kpi = target_kpi
80
+ Returns: True if mixed precision quantization is enabled
81
+
82
+ """
83
+ return self._mixed_precision_enable
@@ -47,6 +47,7 @@ search_methods = {
47
47
  def search_bit_width(graph_to_search_cfg: Graph,
48
48
  fw_info: FrameworkInfo,
49
49
  fw_impl: FrameworkImplementation,
50
+ target_kpi: KPI,
50
51
  mp_config: MixedPrecisionQuantizationConfig,
51
52
  representative_data_gen: Callable,
52
53
  search_method: BitWidthSearchMethod = BitWidthSearchMethod.INTEGER_PROGRAMMING,
@@ -63,6 +64,7 @@ def search_bit_width(graph_to_search_cfg: Graph,
63
64
  graph_to_search_cfg: Graph to search a MP configuration for.
64
65
  fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
65
66
  fw_impl: FrameworkImplementation object with specific framework methods implementation.
67
+ target_kpi: Target KPI to bound our feasible solution space s.t the configuration does not violate it.
66
68
  mp_config: Mixed-precision quantization configuration.
67
69
  representative_data_gen: Dataset to use for retrieving images for the models inputs.
68
70
  search_method: BitWidthSearchMethod to define which searching method to use.
@@ -74,7 +76,6 @@ def search_bit_width(graph_to_search_cfg: Graph,
74
76
  bit-width index on the node).
75
77
 
76
78
  """
77
- target_kpi = mp_config.target_kpi
78
79
 
79
80
  # target_kpi have to be passed. If it was not passed, the facade is not supposed to get here by now.
80
81
  if target_kpi is None:
@@ -20,14 +20,20 @@ from model_compression_toolkit.constants import PRUNING_NUM_SCORE_APPROXIMATIONS
20
20
 
21
21
  class ImportanceMetric(Enum):
22
22
  """
23
- Enum for specifying the metric used to determine the importance of channels when pruning.
23
+ Enum for specifying the metric used to determine the importance of channels when pruning:
24
+
25
+ LFH - Label-Free Hessian uses hessian info for measuring each channel's sensitivity.
26
+
24
27
  """
25
28
  LFH = 0 # Score based on the Hessian matrix w.r.t. layers weights, to determine channel importance without labels.
26
29
 
27
30
 
28
31
  class ChannelsFilteringStrategy(Enum):
29
32
  """
30
- Enum for specifying the strategy used for filtering (pruning) channels.
33
+ Enum for specifying the strategy used for filtering (pruning) channels:
34
+
35
+ GREEDY - Prune the least important channel groups up to allowed resources in the KPI (for now, only weights_memory is considered).
36
+
31
37
  """
32
38
  GREEDY = 0 # Greedy strategy for pruning channels based on importance metrics.
33
39
 
@@ -26,23 +26,16 @@ class PruningInfo:
26
26
  and importance scores for each layer. This class acts as a container for accessing
27
27
  pruning-related metadata.
28
28
 
29
- Attributes:
30
- pruning_masks (Dict[BaseNode, np.ndarray]): Stores the pruning masks for each layer.
31
- A pruning mask is an array where each element indicates whether the corresponding
32
- channel or neuron has been pruned (0) or kept (1).
33
- importance_scores (Dict[BaseNode, np.ndarray]): Stores the importance scores for each layer.
34
- Importance scores quantify the significance of each channel in the layer.
35
29
  """
36
30
 
37
31
  def __init__(self,
38
32
  pruning_masks: Dict[BaseNode, np.ndarray],
39
33
  importance_scores: Dict[BaseNode, np.ndarray]):
40
34
  """
41
- Initializes the PruningInfo with pruning masks and importance scores.
42
-
43
35
  Args:
44
- pruning_masks (Dict[BaseNode, np.ndarray]): Pruning masks for each layer.
45
- importance_scores (Dict[BaseNode, np.ndarray]): Importance scores for each layer.
36
+ pruning_masks (Dict[BaseNode, np.ndarray]): Stores the pruning masks for each layer. A pruning mask is an array where each element indicates whether the corresponding channel or neuron has been pruned (0) or kept (1).
37
+ importance_scores (Dict[BaseNode, np.ndarray]): Stores the importance scores for each layer. Importance scores quantify the significance of each channel in the layer.
38
+
46
39
  """
47
40
  self._pruning_masks = pruning_masks
48
41
  self._importance_scores = importance_scores
@@ -30,14 +30,19 @@ class CoreConfig:
30
30
 
31
31
  Args:
32
32
  quantization_config (QuantizationConfig): Config for quantization.
33
- mixed_precision_config (MixedPrecisionQuantizationConfig): Config for mixed precision quantization (optional, default=None).
33
+ mixed_precision_config (MixedPrecisionQuantizationConfig): Config for mixed precision quantization.
34
+ If None, a default MixedPrecisionQuantizationConfig is used.
34
35
  debug_config (DebugConfig): Config for debugging and editing the network quantization process.
35
36
  """
36
37
  self.quantization_config = quantization_config
37
- self.mixed_precision_config = mixed_precision_config
38
38
  self.debug_config = debug_config
39
39
 
40
+ if mixed_precision_config is None:
41
+ self.mixed_precision_config = MixedPrecisionQuantizationConfig()
42
+ else:
43
+ self.mixed_precision_config = mixed_precision_config
44
+
40
45
  @property
41
46
  def mixed_precision_enable(self):
42
- return self.mixed_precision_config is not None
47
+ return self.mixed_precision_config is not None and self.mixed_precision_config.mixed_precision_enable
43
48
 
@@ -71,7 +71,7 @@ def set_quantization_configs_to_node(node: BaseNode,
71
71
  quant_config: Quantization configuration to generate the node's configurations from.
72
72
  fw_info: Information needed for quantization about the specific framework.
73
73
  tpc: TargetPlatformCapabilities to get default OpQuantizationConfig.
74
- mixed_precision_enable: is mixed precision enabled
74
+ mixed_precision_enable: is mixed precision enabled.
75
75
  """
76
76
  node_qc_options = node.get_qco(tpc)
77
77
 
@@ -57,7 +57,8 @@ def graph_preparation_runner(in_model: Any,
57
57
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
58
58
  tpc: TargetPlatformCapabilities object that models the inference target platform and
59
59
  the attached framework operator's information.
60
- tb_w: TensorboardWriter object for logging
60
+ tb_w: TensorboardWriter object for logging.
61
+ mixed_precision_enable: is mixed precision enabled.
61
62
 
62
63
  Returns:
63
64
  An internal graph representation of the input model.
@@ -103,7 +104,7 @@ def get_finalized_graph(initial_graph: Graph,
103
104
  kernel channels indices, groups of layers by how they should be quantized, etc.)
104
105
  tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
105
106
  fw_impl (FrameworkImplementation): FrameworkImplementation object with a specific framework methods implementation.
106
- mixed_precision_enable: is mixed precision enabled.
107
+ mixed_precision_enable: is mixed precision enabled.
107
108
 
108
109
  Returns: Graph object that represents the model, after applying all required modifications to it.
109
110
  """
@@ -38,7 +38,7 @@ if FOUND_TORCH:
38
38
 
39
39
  def pytorch_kpi_data(in_model: Module,
40
40
  representative_data_gen: Callable,
41
- core_config: CoreConfig = CoreConfig(mixed_precision_config=MixedPrecisionQuantizationConfig()),
41
+ core_config: CoreConfig = CoreConfig(),
42
42
  target_platform_capabilities: TargetPlatformCapabilities = PYTORCH_DEFAULT_TPC) -> KPI:
43
43
  """
44
44
  Computes KPI data that can be used to calculate the desired target KPI for mixed-precision quantization.
@@ -47,6 +47,7 @@ def core_runner(in_model: Any,
47
47
  fw_info: FrameworkInfo,
48
48
  fw_impl: FrameworkImplementation,
49
49
  tpc: TargetPlatformCapabilities,
50
+ target_kpi: KPI = None,
50
51
  tb_w: TensorboardWriter = None):
51
52
  """
52
53
  Quantize a trained model using post-training quantization.
@@ -66,6 +67,7 @@ def core_runner(in_model: Any,
66
67
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
67
68
  tpc: TargetPlatformCapabilities object that models the inference target platform and
68
69
  the attached framework operator's information.
70
+ target_kpi: KPI to constraint the search of the mixed-precision configuration for the model.
69
71
  tb_w: TensorboardWriter object for logging
70
72
 
71
73
  Returns:
@@ -81,6 +83,13 @@ def core_runner(in_model: Any,
81
83
  Logger.warning('representative_data_gen generates a batch size of 1 which can be slow for optimization:'
82
84
  ' consider increasing the batch size')
83
85
 
86
+ # Checking whether to run mixed precision quantization
87
+ if target_kpi is not None:
88
+ if core_config.mixed_precision_config is None:
89
+ Logger.critical("Provided an initialized target_kpi, that means that mixed precision quantization is "
90
+ "enabled, but the provided MixedPrecisionQuantizationConfig is None.")
91
+ core_config.mixed_precision_config.set_mixed_precision_enable()
92
+
84
93
  graph = graph_preparation_runner(in_model,
85
94
  representative_data_gen,
86
95
  core_config.quantization_config,
@@ -105,13 +114,12 @@ def core_runner(in_model: Any,
105
114
  # Finalize bit widths
106
115
  ######################################
107
116
  if core_config.mixed_precision_enable:
108
- if core_config.mixed_precision_config.target_kpi is None:
109
- Logger.critical(f"Trying to run Mixed Precision quantization without providing a valid target KPI.")
110
117
  if core_config.mixed_precision_config.configuration_overwrite is None:
111
118
 
112
119
  bit_widths_config = search_bit_width(tg,
113
120
  fw_info,
114
121
  fw_impl,
122
+ target_kpi,
115
123
  core_config.mixed_precision_config,
116
124
  representative_data_gen,
117
125
  hessian_info_service=hessian_info_service)
@@ -12,7 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+
15
16
  from model_compression_toolkit.constants import FOUND_TORCH, FOUND_TF
17
+ from model_compression_toolkit.data_generation.common.data_generation_config import DataGenerationConfig
18
+ from model_compression_toolkit.data_generation.common.enums import ImageGranularity, DataInitType, SchedulerType, BNLayerWeightingType, OutputLossType, BatchNormAlignemntLossType, ImagePipelineType, ImageNormalizationType
16
19
 
17
20
  if FOUND_TF:
18
21
  from model_compression_toolkit.data_generation.keras.keras_data_generation import (
@@ -29,10 +29,14 @@ class EnumBaseClass(Enum):
29
29
 
30
30
  class ImageGranularity(EnumBaseClass):
31
31
  """
32
- An enum for choosing the image dependence granularity when generating images.
33
- 0. ImageWise
34
- 1. BatchWise
35
- 2. AllImages
32
+ An enum for choosing the image dependence granularity when generating images:
33
+
34
+ ImageWise
35
+
36
+ BatchWise
37
+
38
+ AllImages
39
+
36
40
  """
37
41
 
38
42
  ImageWise = 0
@@ -42,9 +46,12 @@ class ImageGranularity(EnumBaseClass):
42
46
 
43
47
  class DataInitType(EnumBaseClass):
44
48
  """
45
- An enum for choosing the image dependence granularity when generating images.
46
- 0. Gaussian
47
- 1. Diverse
49
+ An enum for choosing the image dependence granularity when generating images:
50
+
51
+ Gaussian
52
+
53
+ Diverse
54
+
48
55
  """
49
56
  Gaussian = 0
50
57
  Diverse = 1
@@ -52,9 +59,14 @@ class DataInitType(EnumBaseClass):
52
59
 
53
60
  class ImagePipelineType(EnumBaseClass):
54
61
  """
55
- An enum for choosing the image pipeline type for image manipulation.
56
- RANDOM_CROP_FLIP: Crop and flip the images.
57
- IDENTITY: Do not apply any manipulation (identity transformation).
62
+ An enum for choosing the image pipeline type for image manipulation:
63
+
64
+ RANDOM_CROP - Crop the images.
65
+
66
+ RANDOM_CROP_FLIP - Crop and flip the images.
67
+
68
+ IDENTITY - Do not apply any manipulation (identity transformation).
69
+
58
70
  """
59
71
  RANDOM_CROP = 'random_crop'
60
72
  RANDOM_CROP_FLIP = 'random_crop_flip'
@@ -63,10 +75,14 @@ class ImagePipelineType(EnumBaseClass):
63
75
 
64
76
  class ImageNormalizationType(EnumBaseClass):
65
77
  """
66
- An enum for choosing the image normalization type.
67
- TORCHVISION: Normalize the images using torchvision normalization.
68
- KERAS_APPLICATIONS: Normalize the images using keras_applications imagenet normalization.
69
- NO_NORMALIZATION: Do not apply any normalization.
78
+ An enum for choosing the image normalization type:
79
+
80
+ TORCHVISION - Normalize the images using torchvision normalization.
81
+
82
+ KERAS_APPLICATIONS - Normalize the images using keras_applications imagenet normalization.
83
+
84
+ NO_NORMALIZATION - Do not apply any normalization.
85
+
70
86
  """
71
87
  TORCHVISION = 'torchvision'
72
88
  KERAS_APPLICATIONS = 'keras_applications'
@@ -75,10 +91,14 @@ class ImageNormalizationType(EnumBaseClass):
75
91
 
76
92
  class BNLayerWeightingType(EnumBaseClass):
77
93
  """
78
- An enum for choosing the layer weighting type.
79
- AVERAGE: Use the same weight per layer.
80
- FIRST_LAYER_MULTIPLIER: Use a multiplier for the first layer, all other layers with the same weight.
81
- GRAD: Use gradient-based layer weighting.
94
+ An enum for choosing the layer weighting type:
95
+
96
+ AVERAGE - Use the same weight per layer.
97
+
98
+ FIRST_LAYER_MULTIPLIER - Use a multiplier for the first layer, all other layers with the same weight.
99
+
100
+ GRAD - Use gradient-based layer weighting.
101
+
82
102
  """
83
103
  AVERAGE = 'average'
84
104
  FIRST_LAYER_MULTIPLIER = 'first_layer_multiplier'
@@ -87,18 +107,24 @@ class BNLayerWeightingType(EnumBaseClass):
87
107
 
88
108
  class BatchNormAlignemntLossType(EnumBaseClass):
89
109
  """
90
- An enum for choosing the BatchNorm alignment loss type.
91
- L2_SQUARE: Use L2 square loss for BatchNorm alignment.
110
+ An enum for choosing the BatchNorm alignment loss type:
111
+
112
+ L2_SQUARE - Use L2 square loss for BatchNorm alignment.
113
+
92
114
  """
93
115
  L2_SQUARE = 'l2_square'
94
116
 
95
117
 
96
118
  class OutputLossType(EnumBaseClass):
97
119
  """
98
- An enum for choosing the output loss type.
99
- NONE: No output loss is applied.
100
- MIN_MAX_DIFF: Use min-max difference as the output loss.
101
- REGULARIZED_MIN_MAX_DIFF: Use regularized min-max difference as the output loss.
120
+ An enum for choosing the output loss type:
121
+
122
+ NONE - No output loss is applied.
123
+
124
+ MIN_MAX_DIFF - Use min-max difference as the output loss.
125
+
126
+ REGULARIZED_MIN_MAX_DIFF - Use regularized min-max difference as the output loss.
127
+
102
128
  """
103
129
  NONE = 'none'
104
130
  MIN_MAX_DIFF = 'min_max_diff'
@@ -107,9 +133,12 @@ class OutputLossType(EnumBaseClass):
107
133
 
108
134
  class SchedulerType(EnumBaseClass):
109
135
  """
110
- An enum for choosing the scheduler type for the optimizer.
111
- REDUCE_ON_PLATEAU: Use the ReduceOnPlateau scheduler.
112
- STEP: Use the Step scheduler.
136
+ An enum for choosing the scheduler type for the optimizer:
137
+
138
+ REDUCE_ON_PLATEAU - Use the ReduceOnPlateau scheduler.
139
+
140
+ STEP - Use the Step scheduler.
141
+
113
142
  """
114
143
  REDUCE_ON_PLATEAU = 'reduce_on_plateau'
115
144
  STEP = 'step'
@@ -131,7 +131,36 @@ if FOUND_TF:
131
131
 
132
132
  Returns:
133
133
  List[tf.Tensor]: Finalized list containing generated images.
134
+
135
+ Examples:
136
+
137
+ In this example, we'll walk through generating images using a simple Keras model and a data generation configuration. The process involves creating a model, setting up a data generation configuration, and finally generating images with specified parameters.
138
+
139
+ Start by importing the Model Compression Toolkit (MCT), TensorFlow, and some layers from `tensorflow.keras`:
140
+
141
+ >>> import model_compression_toolkit as mct
142
+ >>> from tensorflow.keras.models import Sequential
143
+ >>> from tensorflow.keras.layers import Conv2D, BatchNormalization, Flatten, Dense, Reshape
144
+
145
+ Next, define a simple Keras model:
146
+
147
+ >>> model = Sequential([Conv2D(2, 3, input_shape=(8,8,3)), BatchNormalization(), Flatten(), Dense(10)])
148
+
149
+ Configure the data generation process using `get_keras_data_generation_config`. This function allows customization of the data generation process. For simplicity, this example sets the number of iterations (`n_iter`) to 1 and the batch size (`data_gen_batch_size`) to 2.
150
+
151
+ >>> config = mct.data_generation.get_keras_data_generation_config(n_iter=1, data_gen_batch_size=2)
152
+
153
+ Finally, use the `keras_data_generation_experimental` function to generate images based on the model and data generation configuration.
154
+ Notice that this function is experimental and may change in future versions of MCT.
155
+ The `n_images` parameter specifies the number of images to generate, and `output_image_size` sets the size of the generated images.
156
+
157
+ >>> generated_images = mct.data_generation.keras_data_generation_experimental(model=model, n_images=4, output_image_size=(8, 8), data_generation_config=config)
158
+
159
+ The generated images can then be used for various purposes, such as data-free quantization.
160
+
161
+
134
162
  """
163
+
135
164
  Logger.warning(f"keras_data_generation_experimental is experimental "
136
165
  f"and is subject to future changes."
137
166
  f"If you encounter an issue, please open an issue in our GitHub "
@@ -129,7 +129,7 @@ if FOUND_TORCH:
129
129
  def pytorch_data_generation_experimental(
130
130
  model: Module,
131
131
  n_images: int,
132
- output_image_size: Tuple,
132
+ output_image_size: int,
133
133
  data_generation_config: DataGenerationConfig) -> List[Tensor]:
134
134
  """
135
135
  Function to perform data generation using the provided model and data generation configuration.
@@ -137,11 +137,38 @@ if FOUND_TORCH:
137
137
  Args:
138
138
  model (Module): PyTorch model to generate data for.
139
139
  n_images (int): Number of images to generate.
140
- output_image_size (Tuple): Size of the output images.
140
+ output_image_size (int): The hight and width size of the output images.
141
141
  data_generation_config (DataGenerationConfig): Configuration for data generation.
142
142
 
143
143
  Returns:
144
144
  List[Tensor]: Finalized list containing generated images.
145
+
146
+ Examples:
147
+
148
+ In this example, we'll walk through generating images using a simple PyTorch model and a data generation configuration. The process involves creating a model, setting up a data generation configuration, and finally generating images with specified parameters.
149
+
150
+ Start by importing the Model Compression Toolkit (MCT), PyTorch, and some modules from `torch.nn`:
151
+
152
+ >>> import model_compression_toolkit as mct
153
+ >>> import torch.nn as nn
154
+ >>> from torch.nn import Conv2d, BatchNorm2d, Flatten, Linear
155
+
156
+ Next, define a simple PyTorch model:
157
+
158
+ >>> model = nn.Sequential(nn.Conv2d(3, 2, 3), nn.BatchNorm2d(2), nn.Flatten(), nn.Linear(2*6*6, 10))
159
+
160
+ Configure the data generation process using `get_pytorch_data_generation_config`. This function allows customization of the data generation process. For simplicity, this example sets the number of iterations (`n_iter`) to 1 and the batch size (`data_gen_batch_size`) to 2.
161
+
162
+ >>> config = mct.data_generation.get_pytorch_data_generation_config(n_iter=1, data_gen_batch_size=2)
163
+
164
+ Finally, use the `pytorch_data_generation_experimental` function to generate images based on the model and data generation configuration.
165
+ Notice that this function is experimental and may change in future versions of MCT.
166
+ The `n_images` parameter specifies the number of images to generate, and `output_image_size` sets the size of the generated images.
167
+
168
+ >>> generated_images = mct.data_generation.pytorch_data_generation_experimental(model=model, n_images=4, output_image_size=8, data_generation_config=config)
169
+
170
+ The generated images can then be used for various purposes, such as data-free quantization.
171
+
145
172
  """
146
173
 
147
174
  Logger.warning(f"pytorch_data_generation_experimental is experimental "
@@ -16,6 +16,16 @@ from enum import Enum
16
16
 
17
17
 
18
18
  class QuantizationFormat(Enum):
19
+ """
20
+ Specify which quantization format to use for exporting a quantized model.
21
+
22
+ FAKELY_QUANT - Weights and activations are quantized but represented using float data type.
23
+
24
+ INT8 - Weights and activations are represented using 8-bit integer data type.
25
+
26
+ MCTQ - Weights and activations are quantized using mct_quantizers custom quantizers.
27
+
28
+ """
19
29
  FAKELY_QUANT = 0
20
30
  INT8 = 1
21
31
  MCTQ = 2
@@ -42,21 +42,17 @@ if FOUND_TF:
42
42
  serialization_format: KerasExportSerializationFormat = KerasExportSerializationFormat.KERAS,
43
43
  quantization_format : QuantizationFormat = QuantizationFormat.MCTQ) -> Dict[str, type]:
44
44
  """
45
- Export a Keras quantized model to a h5 or tflite model.
45
+ Export a Keras quantized model to a .keras or .tflite format model (according to serialization_format).
46
46
  The model will be saved to the path in save_model_path.
47
- keras_export_model supports the combination of QuantizationFormat.FAKELY_QUANT (where weights
48
- and activations are float fakely-quantized values) and KerasExportSerializationFormat.KERAS_H5 (where the model
49
- will be saved to h5 model) or the combination of KerasExportSerializationFormat.TFLITE (where the model will be
50
- saved to tflite model) with QuantizationFormat.FAKELY_QUANT or QuantizationFormat.INT8 (where weights and
51
- activations are represented using 8bits integers).
47
+ Models that are exported to .keras format can use quantization_format of QuantizationFormat.MCTQ or QuantizationFormat.FAKELY_QUANT.
48
+ Models that are exported to .tflite format can use quantization_format of QuantizationFormat.INT8 or QuantizationFormat.FAKELY_QUANT.
52
49
 
53
50
  Args:
54
51
  model: Model to export.
55
52
  save_model_path: Path to save the model.
56
53
  is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
57
- serialization_format: Format to export the model according to (by default
58
- KerasExportSerializationFormat.KERAS_H5).
59
- quantization_format: Format of how quantizers are exported (fakely-quant, int8, MCTQ quantizers).
54
+ serialization_format: Format to export the model according to (KerasExportSerializationFormat.KERAS, by default).
55
+ quantization_format: Format of how quantizers are exported (MCTQ quantizers, by default).
60
56
 
61
57
  Returns:
62
58
  Custom objects dictionary needed to load the model.
@@ -19,9 +19,12 @@ from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
19
19
 
20
20
  class RoundingType(Enum):
21
21
  """
22
- An enum for choosing the GPTQ rounding methods
23
- 0. STRAIGHT-THROUGH ESTIMATOR
24
- 1. SoftQuantizer
22
+ An enum for choosing the GPTQ rounding methods:
23
+
24
+ STE - STRAIGHT-THROUGH ESTIMATOR
25
+
26
+ SoftQuantizer - SoftQuantizer
27
+
25
28
  """
26
29
  STE = 0
27
30
  SoftQuantizer = 1
@@ -116,6 +116,7 @@ if FOUND_TF:
116
116
  def keras_gradient_post_training_quantization(in_model: Model, representative_data_gen: Callable,
117
117
  gptq_config: GradientPTQConfig,
118
118
  gptq_representative_data_gen: Callable = None,
119
+ target_kpi: KPI = None,
119
120
  core_config: CoreConfig = CoreConfig(),
120
121
  target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, UserInformation]:
121
122
  """
@@ -139,6 +140,7 @@ if FOUND_TF:
139
140
  representative_data_gen (Callable): Dataset used for calibration.
140
141
  gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
141
142
  gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
143
+ target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
142
144
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
143
145
  target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
144
146
 
@@ -166,6 +168,12 @@ if FOUND_TF:
166
168
 
167
169
  >>> config = mct.core.CoreConfig()
168
170
 
171
+ If mixed precision is desired, create an MCT core config with a mixed-precision configuration, to quantize a model
172
+ with different bitwidths for different layers.
173
+ The candidates bitwidth for quantization should be defined in the target platform model:
174
+
175
+ >>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=1))
176
+
169
177
  For mixed-precision set a target KPI object:
170
178
  Create a KPI object to limit our returned model's size. Note that this value affects only coefficients
171
179
  that should be quantized (for example, the kernel of Conv2D in Keras will be affected by this value,
@@ -173,19 +181,13 @@ if FOUND_TF:
173
181
 
174
182
  >>> kpi = mct.core.KPI(model.count_params() * 0.75) # About 0.75 of the model size when quantized with 8 bits.
175
183
 
176
- If mixed precision is desired, create an MCT core config with a mixed-precision configuration, to quantize a model
177
- with different bitwidths for different layers.
178
- The candidates bitwidth for quantization should be defined in the target platform model:
179
-
180
- >>> config = mct.core.CoreConfig(mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=1, target_kpi=kpi))
181
-
182
184
  Create GPTQ config:
183
185
 
184
186
  >>> gptq_config = mct.gptq.get_keras_gptq_config(n_epochs=1)
185
187
 
186
188
  Pass the model with the representative dataset generator to get a quantized model:
187
189
 
188
- >>> quantized_model, quantization_info = mct.gptq.keras_gradient_post_training_quantization(model, repr_datagen, gptq_config, core_config=config)
190
+ >>> quantized_model, quantization_info = mct.gptq.keras_gradient_post_training_quantization(model, repr_datagen, gptq_config, target_kpi=kpi, core_config=config)
189
191
 
190
192
  """
191
193
  KerasModelValidation(model=in_model,
@@ -94,6 +94,7 @@ if FOUND_TORCH:
94
94
 
95
95
  def pytorch_gradient_post_training_quantization(model: Module,
96
96
  representative_data_gen: Callable,
97
+ target_kpi: KPI = None,
97
98
  core_config: CoreConfig = CoreConfig(),
98
99
  gptq_config: GradientPTQConfig = None,
99
100
  gptq_representative_data_gen: Callable = None,
@@ -117,6 +118,7 @@ if FOUND_TORCH:
117
118
  Args:
118
119
  model (Module): Pytorch model to quantize.
119
120
  representative_data_gen (Callable): Dataset used for calibration.
121
+ target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
120
122
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
121
123
  gptq_config (GradientPTQConfig): Configuration for using gptq (e.g. optimizer).
122
124
  gptq_representative_data_gen (Callable): Dataset used for GPTQ training. If None defaults to representative_data_gen
@@ -174,6 +176,7 @@ if FOUND_TORCH:
174
176
  fw_info=DEFAULT_PYTORCH_INFO,
175
177
  fw_impl=fw_impl,
176
178
  tpc=target_platform_capabilities,
179
+ target_kpi=target_kpi,
177
180
  tb_w=tb_w)
178
181
 
179
182
  # ---------------------- #
@@ -40,8 +40,7 @@ if FOUND_TF:
40
40
  target_kpi: KPI,
41
41
  representative_data_gen: Callable,
42
42
  pruning_config: PruningConfig = PruningConfig(),
43
- target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> \
44
- Tuple[Model, PruningInfo]:
43
+ target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC) -> Tuple[Model, PruningInfo]:
45
44
  """
46
45
  Perform structured pruning on a Keras model to meet a specified target KPI.
47
46
  This function prunes the provided model according to the target KPI by grouping and pruning
@@ -59,12 +58,14 @@ if FOUND_TF:
59
58
  target_kpi (KPI): The target Key Performance Indicators to be achieved through pruning.
60
59
  representative_data_gen (Callable): A function to generate representative data for pruning analysis.
61
60
  pruning_config (PruningConfig): Configuration settings for the pruning process. Defaults to standard config.
62
- target_platform_capabilities (TargetPlatformCapabilities): Platform-specific constraints and capabilities.
63
- Defaults to DEFAULT_KERAS_TPC.
61
+ target_platform_capabilities (TargetPlatformCapabilities): Platform-specific constraints and capabilities. Defaults to DEFAULT_KERAS_TPC.
64
62
 
65
63
  Returns:
66
64
  Tuple[Model, PruningInfo]: A tuple containing the pruned Keras model and associated pruning information.
67
65
 
66
+ Note:
67
+ The pruned model should be fine-tuned or retrained to recover or improve its performance post-pruning.
68
+
68
69
  Examples:
69
70
 
70
71
  Import MCT:
@@ -42,6 +42,7 @@ if FOUND_TF:
42
42
 
43
43
  def keras_post_training_quantization(in_model: Model,
44
44
  representative_data_gen: Callable,
45
+ target_kpi: KPI = None,
45
46
  core_config: CoreConfig = CoreConfig(),
46
47
  target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
47
48
  """
@@ -60,6 +61,7 @@ if FOUND_TF:
60
61
  Args:
61
62
  in_model (Model): Keras model to quantize.
62
63
  representative_data_gen (Callable): Dataset used for calibration.
64
+ target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
63
65
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
64
66
  target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
65
67
 
@@ -135,6 +137,7 @@ if FOUND_TF:
135
137
  fw_info=fw_info,
136
138
  fw_impl=fw_impl,
137
139
  tpc=target_platform_capabilities,
140
+ target_kpi=target_kpi,
138
141
  tb_w=tb_w)
139
142
 
140
143
  tg = ptq_runner(tg, representative_data_gen, core_config, fw_info, fw_impl, tb_w)
@@ -41,6 +41,7 @@ if FOUND_TORCH:
41
41
 
42
42
  def pytorch_post_training_quantization(in_module: Module,
43
43
  representative_data_gen: Callable,
44
+ target_kpi: KPI = None,
44
45
  core_config: CoreConfig = CoreConfig(),
45
46
  target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
46
47
  """
@@ -59,6 +60,7 @@ if FOUND_TORCH:
59
60
  Args:
60
61
  in_module (Module): Pytorch module to quantize.
61
62
  representative_data_gen (Callable): Dataset used for calibration.
63
+ target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
62
64
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
63
65
  target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the PyTorch model according to.
64
66
 
@@ -107,6 +109,7 @@ if FOUND_TORCH:
107
109
  fw_info=DEFAULT_PYTORCH_INFO,
108
110
  fw_impl=fw_impl,
109
111
  tpc=target_platform_capabilities,
112
+ target_kpi=target_kpi,
110
113
  tb_w=tb_w)
111
114
 
112
115
  tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_PYTORCH_INFO, fw_impl, tb_w)
@@ -87,6 +87,7 @@ if FOUND_TF:
87
87
 
88
88
  def keras_quantization_aware_training_init_experimental(in_model: Model,
89
89
  representative_data_gen: Callable,
90
+ target_kpi: KPI = None,
90
91
  core_config: CoreConfig = CoreConfig(),
91
92
  qat_config: QATConfig = QATConfig(),
92
93
  target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_KERAS_TPC):
@@ -108,6 +109,7 @@ if FOUND_TF:
108
109
  Args:
109
110
  in_model (Model): Keras model to quantize.
110
111
  representative_data_gen (Callable): Dataset used for initial calibration.
112
+ target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
111
113
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
112
114
  qat_config (QATConfig): QAT configuration
113
115
  target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Keras model according to.
@@ -157,7 +159,7 @@ if FOUND_TF:
157
159
  Pass the model, the representative dataset generator, the configuration and the target KPI to get a
158
160
  quantized model:
159
161
 
160
- >>> quantized_model, quantization_info, custom_objects = mct.qat.keras_quantization_aware_training_init_experimental(model, repr_datagen, kpi, core_config=core_config)
162
+ >>> quantized_model, quantization_info, custom_objects = mct.qat.keras_quantization_aware_training_init_experimental(model, repr_datagen, kpi, core_config=config)
161
163
 
162
164
  Use the quantized model for fine-tuning. For loading the model from file, use the custom_objects dictionary:
163
165
 
@@ -191,6 +193,7 @@ if FOUND_TF:
191
193
  fw_info=DEFAULT_KERAS_INFO,
192
194
  fw_impl=fw_impl,
193
195
  tpc=target_platform_capabilities,
196
+ target_kpi=target_kpi,
194
197
  tb_w=tb_w)
195
198
 
196
199
  tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_KERAS_INFO, fw_impl, tb_w)
@@ -75,6 +75,7 @@ if FOUND_TORCH:
75
75
 
76
76
  def pytorch_quantization_aware_training_init_experimental(in_model: Module,
77
77
  representative_data_gen: Callable,
78
+ target_kpi: KPI = None,
78
79
  core_config: CoreConfig = CoreConfig(),
79
80
  qat_config: QATConfig = QATConfig(),
80
81
  target_platform_capabilities: TargetPlatformCapabilities = DEFAULT_PYTORCH_TPC):
@@ -96,6 +97,7 @@ if FOUND_TORCH:
96
97
  Args:
97
98
  in_model (Model): Pytorch model to quantize.
98
99
  representative_data_gen (Callable): Dataset used for initial calibration.
100
+ target_kpi (KPI): KPI object to limit the search of the mixed-precision configuration as desired.
99
101
  core_config (CoreConfig): Configuration object containing parameters of how the model should be quantized, including mixed precision parameters.
100
102
  qat_config (QATConfig): QAT configuration
101
103
  target_platform_capabilities (TargetPlatformCapabilities): TargetPlatformCapabilities to optimize the Pytorch model according to.
@@ -158,6 +160,7 @@ if FOUND_TORCH:
158
160
  fw_info=DEFAULT_PYTORCH_INFO,
159
161
  fw_impl=fw_impl,
160
162
  tpc=target_platform_capabilities,
163
+ target_kpi=target_kpi,
161
164
  tb_w=tb_w)
162
165
 
163
166
  tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_PYTORCH_INFO, fw_impl, tb_w)
@@ -35,9 +35,9 @@ class OperationsSetToLayers(TargetPlatformCapabilitiesComponent):
35
35
  Args:
36
36
  op_set_name (str): Name of OperatorsSet to associate with layers.
37
37
  layers (List[Any]): List of layers/FilterLayerParams to associate with OperatorsSet.
38
- attr_mapping (dict): A mapping between a general attribute name to a DefaultDict that maps a layer
39
- + type to the layer's framework name of this attribute (the dictionary type is not specified to
40
- + handle circular dependency).
38
+ attr_mapping (dict): A mapping between a general attribute name to a DefaultDict that maps a layer type
39
+ to the layer's framework name of this attribute (the dictionary type is not specified to handle circular
40
+ dependency).
41
41
  """
42
42
  self.layers = layers
43
43
  self.attr_mapping = attr_mapping