mct-nightly 2.1.0.20240608.434__py3-none-any.whl → 2.1.0.20240610.442__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. {mct_nightly-2.1.0.20240608.434.dist-info → mct_nightly-2.1.0.20240610.442.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.1.0.20240608.434.dist-info → mct_nightly-2.1.0.20240610.442.dist-info}/RECORD +26 -18
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/graph/base_node.py +1 -4
  5. model_compression_toolkit/core/common/quantization/node_quantization_config.py +10 -6
  6. model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +15 -7
  7. model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +30 -14
  8. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +8 -7
  9. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +108 -87
  10. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +15 -13
  11. model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +29 -14
  12. model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +40 -14
  13. model_compression_toolkit/core/keras/reader/node_builder.py +3 -3
  14. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +25 -23
  15. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +10 -0
  16. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/__init__.py +16 -0
  17. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +222 -0
  18. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_keras.py +131 -0
  19. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py +111 -0
  20. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/__init__.py +16 -0
  21. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +219 -0
  22. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_keras.py +131 -0
  23. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py +110 -0
  24. {mct_nightly-2.1.0.20240608.434.dist-info → mct_nightly-2.1.0.20240610.442.dist-info}/LICENSE.md +0 -0
  25. {mct_nightly-2.1.0.20240608.434.dist-info → mct_nightly-2.1.0.20240610.442.dist-info}/WHEEL +0 -0
  26. {mct_nightly-2.1.0.20240608.434.dist-info → mct_nightly-2.1.0.20240610.442.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.1.0.20240608.434
3
+ Version: 2.1.0.20240610.442
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=loXKkVRKW11Ehu1o3cKIpDx-z_o1TIydRprxcqjElkA,1573
1
+ model_compression_toolkit/__init__.py,sha256=8uKLxbPGI4bXEsOnz8snYp5aOCbWS0nIiBxD9ic580Y,1573
2
2
  model_compression_toolkit/constants.py,sha256=9pVleMwnhlM4QwIL2HcEq42I1uF4rlSw63RUjkxOF4w,3923
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -31,7 +31,7 @@ model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5
31
31
  model_compression_toolkit/core/common/fusion/layer_fusing.py,sha256=lOubqpc18TslhXZijWUJQAa1c3jIB2S-M-5HK78wJPQ,5548
32
32
  model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
33
33
  model_compression_toolkit/core/common/graph/base_graph.py,sha256=lmIw0srKiwCvz7KWqfwKTxyQHDy3s6rWMIXzFAa1UMo,38326
34
- model_compression_toolkit/core/common/graph/base_node.py,sha256=exvUkLDChl6YaoaQRHgSrettsgOsd18bfq01tPxXr-4,29722
34
+ model_compression_toolkit/core/common/graph/base_node.py,sha256=X_0zqHrKYAsmnj9tAKjVYasbFcZD8OHpjdiMj9ugQs0,29436
35
35
  model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
36
36
  model_compression_toolkit/core/common/graph/functional_node.py,sha256=71_4TrCdqR_r0mtgxmAyqI05iP5YoQQGeSmDgynuzTw,3902
37
37
  model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
@@ -101,7 +101,7 @@ model_compression_toolkit/core/common/quantization/candidate_node_quantization_c
101
101
  model_compression_toolkit/core/common/quantization/core_config.py,sha256=KYdyfSmjSL4ye24nKlC_c4_AxYb14qoqaeMnZj4-8kE,2257
102
102
  model_compression_toolkit/core/common/quantization/debug_config.py,sha256=HtkMmneN-EmAzgZK4Vp4M8Sqm5QKdrvNyyZMpaVqYzY,1482
103
103
  model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=fwF4VILaX-u3ZaFd81xjbJuhg8Ef-JX_KfMXW0TPV-I,7136
104
- model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=0XFJwHbuUjT_C20XB0Omumd6PSQqYj5fnsYHRx78AaU,26733
104
+ model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=u0JkdRqBXG0RvvYyLyvYknEVtB2-gxpqUJnUw3loLmE,26851
105
105
  model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=du0VdsxfkOSYaP1EU9gHA5qbXpfQNZL0jXrjk1wBA0U,7106
106
106
  model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=eyosbVdnCwed7oMQ19tqnh0VoyGZ_UAuD_UnNoXyBpo,2210
107
107
  model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=MwIOBZ4BlZSTIOG75PDvlI3JmZ6t8YjPc1VP9Adei60,3847
@@ -110,15 +110,15 @@ model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=cdzGN
110
110
  model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=O4qFJw3nBYUD4cGbO8haGXZ2-piSqoRpDKDD74iXSxw,12417
111
111
  model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=eCDGwsWYLU6z7qbEVb4TozMW_nd5VEP_iCJ6PcvyEPw,1486
112
112
  model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=w367wmtJ7iWmM4_HlpX-YVUuqtYKrsiPP1oDaICIuK8,23308
113
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=FWyOcjENAK-bFPpVjgczDiGAWZi--OgJ60jZjPUPqzo,8059
113
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=t0XSwjfOxcq2Sj2PGzccntz1GGv2eqVn9oR3OI0t9wo,8533
114
114
  model_compression_toolkit/core/common/quantization/quantization_params_generation/outlier_filter.py,sha256=9gnfJV89jpGwAx8ImJ5E9NjCv3lDtbyulP4OtgWb62M,1772
115
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py,sha256=ejc_obamUndJsv3F1FuOGMrIibS__qDUbAia1H9vwUM,9487
115
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py,sha256=HfnhQ4MxGpb95gOWXD1vnroTxxjFt9VFd4jIdo-rvAQ,10623
116
116
  model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py,sha256=noEdvGiyyW7acgQ2OFWLedCODibTGYJifC9qo8YIU5U,4558
117
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=7ITrOw5ykncpHNghlPNTaDZExFYrPmhRck4oW0GaPe0,6213
118
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py,sha256=7kt0JB8PQE0SW9kg8fCwZ5mBkHNgiRrn0of4ZQYQN2A,41524
119
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py,sha256=kAqVKZYu6FHWlC_PUiytsmXdTX1GzO_S5DWrTXuJBjs,4894
120
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py,sha256=_ULwlPvzVL_UcYVlUPjDIeXz_99eW26l9FwGzaUu-_M,10789
121
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py,sha256=VG0UqFOQk_7ALdJsUl1wwwFLjE38DxN6-NRZx161XiY,8902
117
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=E_XFTpYNUZ3JgOk_2qbUbmJH6qGqBM3TDsY4WptYup0,6478
118
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py,sha256=o2XNY_0pUUyId02TUVQBtkux_i40NCcnzuobSeQLy3E,42863
119
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py,sha256=zSNda0jN8cP41m6g5TOv5WvATwIhV8z6AVM1Es6rq1s,4419
120
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py,sha256=4TP41wPYC0azIzFxUt-lNlKUPIIXQeE4H1SYHkON75k,11875
121
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py,sha256=E83BU4wZEOY-Q-HTo04ABftv22Y6fWEdNYkGA-MZLMU,10494
122
122
  model_compression_toolkit/core/common/quantization/quantizers/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
123
123
  model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py,sha256=P0x_y18LypBxP2tV9OWizheYfILqvaMC8RwHo04sUpQ,2761
124
124
  model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py,sha256=CCFhi5LUIcHCCIzDyORvm0FDZLknrctdNwNlPphOQgI,14245
@@ -199,7 +199,7 @@ model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py,sha256=Up3-sbuA
199
199
  model_compression_toolkit/core/keras/reader/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
200
200
  model_compression_toolkit/core/keras/reader/common.py,sha256=eZWjBcvTDUX7fCWmy1OAH4lYLFTh59_UQ_nP_Gjp4yw,2594
201
201
  model_compression_toolkit/core/keras/reader/connectivity_handler.py,sha256=AgF6qXZOJMeXvc-pBnGY23BJz7wPBx2aTYxHiO8efec,11303
202
- model_compression_toolkit/core/keras/reader/node_builder.py,sha256=URmE3lM9CskS-9a3TuqfReLdHh36Dti08RL8qxzrBjc,10471
202
+ model_compression_toolkit/core/keras/reader/node_builder.py,sha256=SAPkgL8aqJjnB6eCucU2D4m50WACCzWC8wjCVtFnwp8,10424
203
203
  model_compression_toolkit/core/keras/reader/reader.py,sha256=wS9UQ2wJKnkZYe9JHwQp7ygDr6CRlzrxmIyLDv1Qz6U,8109
204
204
  model_compression_toolkit/core/keras/reader/nested_model/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
205
205
  model_compression_toolkit/core/keras/reader/nested_model/edges_merger.py,sha256=K6KAH9o8KSG6baLmhKoCrYK-i-wb6gRKiZmoijFqEYA,7906
@@ -222,7 +222,7 @@ model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,s
222
222
  model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
223
223
  model_compression_toolkit/core/pytorch/back2framework/instance_builder.py,sha256=BBHBfTqeWm7L3iDyPBpk0jxvj-rBg1QWI23imkjfIl0,1467
224
224
  model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py,sha256=D7lU1r9Uq_7fdNuKk2BMF8ho5GrsY-8gyGN6yYoHaVg,15060
225
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=oJdTA9T-qNWY4vEckiYlf3kCQrsl6IVPliXg9S6dqWM,18259
225
+ model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=Zw4gi-wjJNV8-qGv79YBWVAHmy27f7iW0c2JGNWAKD0,18199
226
226
  model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py,sha256=qZNNOlNTTV4ZKPG3q5GDXkIVTPUEr8dvxAS_YiMORmg,3456
227
227
  model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
228
228
  model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py,sha256=q2JDw10NKng50ee2i9faGzWZ-IydnR2aOMGSn9RoZmc,5773
@@ -431,7 +431,7 @@ model_compression_toolkit/target_platform_capabilities/target_platform/targetpla
431
431
  model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
432
432
  model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py,sha256=-jCL-meZWFBF-Dp9wBYTX_14SKmyyUJE-BZ2IQDJIAk,3336
433
433
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
434
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py,sha256=bUeKEjL45oU6J1EXwt1MGhlWs_87zF1GGz6X3ES72ps,3796
434
+ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py,sha256=mjPFr6Z-PLzqQta8mW7dK31mbbBZsJo4MdpJQmxlSt4,4640
435
435
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py,sha256=F5RG4MnuAwKcNXbfVbPFLQu30-lNax-7knqu20B6udQ,1522
436
436
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/__init__.py,sha256=1mMOREEMoNHu_KTMGDp4crN61opKWX6aFn1DrDLvqcc,717
437
437
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py,sha256=S-GwMI-JiuPpbtOdd6TSOEjiUFiIs6M2RAiJNJ3O950,10883
@@ -453,6 +453,14 @@ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_
453
453
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py,sha256=dmi2lCT0dw6RnWVw73tcnqgsVSgINSWaIWfgZhEli4Q,10691
454
454
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py,sha256=6PVKQKGpJpM2B1qvmf6fID_-MACaSQZkaL_9J_fj2SQ,6595
455
455
  model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_pytorch.py,sha256=dFQjzFlLDwoUqKNP1at1fS1N1WJadSSasRyzHl6vaB8,5733
456
+ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/__init__.py,sha256=gAeebYCKyIXH9-Qwze7FwvTihudzAHk_Qsg94fQbkjQ,717
457
+ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py,sha256=edMH4lM7Bq7FaPAFZLU5UMX-bWSWiaaAIXnQE7lZ7rI,11844
458
+ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_keras.py,sha256=T5YMv-RzgYlzBaagnMO7WnKgbZ7PrOvm29Nn4vUhCHI,6587
459
+ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py,sha256=-q6Tnn7diPCCoATmLDzJwWwviQcbMMISqgpLu2n42JY,5726
460
+ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/__init__.py,sha256=C2kwyDE1-rtukkbNSoKRv9q8Nt2GOCaBbl0BdOr3goA,721
461
+ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py,sha256=HoGjDwoSx2Y4dQua5v1qzzlnSl_HfDMK6bGWuZhPOzQ,11577
462
+ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_keras.py,sha256=LvqUkvpJKXBb9QETcHsmp9OGDwl9KWr457deag8GVuM,6595
463
+ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py,sha256=4Y2D14rE0SnWIkBTYsVqCryB-gkHU1ZlbdkWF864mPU,5733
456
464
  model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
457
465
  model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py,sha256=7KVcuz0LfngRKOsfcvBysxGVb9fqgoAO6MVTl1CmB5c,2082
458
466
  model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py,sha256=UUvUCcTots_sehdRnDfgkaE8WPQ7dPbeuhDF4Qy2nzw,1510
@@ -483,8 +491,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
483
491
  model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
484
492
  model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
485
493
  model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=MxylaVFPgN7zBiRBy6WV610EA4scLgRJFbMucKvvNDU,2896
486
- mct_nightly-2.1.0.20240608.434.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
487
- mct_nightly-2.1.0.20240608.434.dist-info/METADATA,sha256=I7XXFZFj5zx7OCRB_ggqsafDnsyODn_1o9vsNbTXT00,19721
488
- mct_nightly-2.1.0.20240608.434.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
489
- mct_nightly-2.1.0.20240608.434.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
490
- mct_nightly-2.1.0.20240608.434.dist-info/RECORD,,
494
+ mct_nightly-2.1.0.20240610.442.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
495
+ mct_nightly-2.1.0.20240610.442.dist-info/METADATA,sha256=Juo23o8F4ndhmb8TksZ99xKWtks0DK59daxJqx_9RmI,19721
496
+ mct_nightly-2.1.0.20240610.442.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
497
+ mct_nightly-2.1.0.20240610.442.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
498
+ mct_nightly-2.1.0.20240610.442.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.1.0.20240608.000434"
30
+ __version__ = "2.1.0.20240610.000442"
@@ -240,10 +240,7 @@ class BaseNode:
240
240
  if isinstance(pos, int)):
241
241
  if pos > len(input_tensors):
242
242
  Logger.critical("The positional weight index cannot exceed the number of input tensors to the node.") # pragma: no cover
243
- # Insert only positional weights that are not subject to quantization. If the positional weight is
244
- # subject to quantization, the quantization wrapper inserts the positional weight into the node.
245
- if not self.is_weights_quantization_enabled(pos):
246
- input_tensors.insert(pos, weight)
243
+ input_tensors.insert(pos, weight)
247
244
 
248
245
  return input_tensors
249
246
 
@@ -326,13 +326,17 @@ class WeightsAttrQuantizationConfig:
326
326
 
327
327
  """
328
328
  assert self.enable_weights_quantization
329
+ assert not (self.weights_per_channel_threshold and self.weights_channels_axis is None), \
330
+ "Trying to calculate threshold per channel, channel axis in None."
329
331
  if self.weights_quantization_params_fn is not None:
330
- self.set_weights_quantization_param(self.weights_quantization_params_fn(tensor_data,
331
- p=self.l_p_value,
332
- n_bits=self.weights_n_bits,
333
- per_channel=self.weights_per_channel_threshold and self.weights_channels_axis is not None,
334
- channel_axis=self.weights_channels_axis[0], # output channel axis
335
- min_threshold=min_threshold))
332
+ self.set_weights_quantization_param(
333
+ self.weights_quantization_params_fn(tensor_data,
334
+ p=self.l_p_value,
335
+ n_bits=self.weights_n_bits,
336
+ per_channel=self.weights_per_channel_threshold and self.weights_channels_axis is not None,
337
+ channel_axis=self.weights_channels_axis[0], # output channel axis
338
+ min_threshold=min_threshold)[0] # Take only first output, the q-params, as axis is already chosen.
339
+ )
336
340
  else:
337
341
  self.set_weights_quantization_param({})
338
342
 
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Dict
16
+ from typing import Dict, Tuple
17
17
  import numpy as np
18
18
  from sklearn.cluster import KMeans
19
19
 
@@ -42,7 +42,8 @@ def lut_kmeans_tensor(tensor_data: np.ndarray,
42
42
  is_symmetric: bool = False,
43
43
  node=None,
44
44
  hessian_info_service: HessianInfoService = None,
45
- num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> Dict:
45
+ num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES,
46
+ ) -> Tuple[Dict[str, np.ndarray], int]:
46
47
  """
47
48
  The quantizer first finds the closest max value per channel of tensor_data.
48
49
  Now, we divide tensor_data with the threshold vector per channel. In addition, we scale the result to the range
@@ -70,27 +71,34 @@ def lut_kmeans_tensor(tensor_data: np.ndarray,
70
71
  if n_bits >= LUT_VALUES_BITWIDTH:
71
72
  Logger.critical(f'Look-Up-Table (LUT) bit configuration exceeds maximum: {n_bits} bits provided, must be less than {LUT_VALUES_BITWIDTH} bits.') # pragma: no cover
72
73
  # TODO: need to set this externally
74
+ n_data_points = len(np.unique(tensor_data.flatten()))
73
75
  if len(np.unique(tensor_data.flatten())) < 2 ** n_bits:
74
- n_clusters = len(np.unique(tensor_data.flatten()))
76
+ n_clusters = n_data_points
75
77
  else:
76
78
  n_clusters = 2 ** n_bits
77
79
  kmeans = KMeans(n_clusters=n_clusters, n_init=10)
78
80
 
79
81
  threshold_selection_tensor = symmetric_selection_tensor if is_symmetric else power_of_two_selection_tensor
80
- thresholds_per_channel = threshold_selection_tensor(tensor_data, p, n_bits, per_channel,
81
- channel_axis, n_iter, min_threshold,
82
- qc.QuantizationErrorMethod.NOCLIPPING)[THRESHOLD]
82
+
83
+ _params, channel_axis = threshold_selection_tensor(tensor_data, p, n_bits, per_channel,
84
+ channel_axis, n_iter, min_threshold,
85
+ qc.QuantizationErrorMethod.NOCLIPPING)
86
+ thresholds_per_channel = _params[THRESHOLD]
83
87
 
84
88
  tensor_for_kmeans = int_quantization_with_threshold(tensor_data, thresholds_per_channel, LUT_VALUES_BITWIDTH)
85
89
  kmeans.fit(tensor_for_kmeans.reshape(-1, 1))
86
90
 
87
91
  # Add 0 to the LUT
88
92
  cc = np.round(kmeans.cluster_centers_)
93
+ if n_data_points < 2 ** n_bits and np.all(cc != 0):
94
+ # In case there are fewer data points than potential clusters, we can add the cluster 0.0
95
+ # to the original clusters array to improve quantization (i.e. no need to zero one of the clusters).
96
+ cc = np.concatenate([np.zeros([1, 1], dtype=cc.dtype), cc])
89
97
  closest2zero_idx = (np.abs(cc - 0)).argmin()
90
98
  cc[closest2zero_idx] = 0.0
91
99
 
92
100
  return {LUT_VALUES: cc,
93
- SCALE_PER_CHANNEL: thresholds_per_channel}
101
+ SCALE_PER_CHANNEL: thresholds_per_channel}, channel_axis
94
102
 
95
103
 
96
104
  def lut_kmeans_histogram(bins: np.ndarray,
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  import numpy as np
16
+ from typing import Union, Tuple, Dict
16
17
 
17
18
  import model_compression_toolkit.core.common.quantization.quantization_config as qc
18
19
  from model_compression_toolkit.constants import MIN_THRESHOLD, THRESHOLD, NUM_QPARAM_HESSIAN_SAMPLES
@@ -23,20 +24,22 @@ from model_compression_toolkit.core.common.quantization.quantizers.quantizers_he
23
24
  from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \
24
25
  get_threshold_selection_tensor_error_function, get_threshold_selection_histogram_error_function
25
26
  from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
27
+ from model_compression_toolkit.core.common.similarity_analyzer import compute_mse
28
+ from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import quantize_tensor
26
29
 
27
30
 
28
31
  def power_of_two_selection_tensor(tensor_data: np.ndarray,
29
32
  p: int,
30
33
  n_bits: int,
31
34
  per_channel: bool = False,
32
- channel_axis: int = 1,
35
+ channel_axis: Union[int, None] = 1,
33
36
  n_iter: int = 10,
34
37
  min_threshold: float = MIN_THRESHOLD,
35
38
  quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE,
36
39
  node=None,
37
40
  hessian_info_service: HessianInfoService = None,
38
41
  num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES,
39
- ) -> dict:
42
+ ) -> Tuple[Dict[str, np.ndarray], int]:
40
43
  """
41
44
  Compute the power of two threshold based on the provided QuantizationErrorMethod to quantize the tensor.
42
45
  Different search is applied, depends on the value of the selected QuantizationErrorMethod.
@@ -46,7 +49,7 @@ def power_of_two_selection_tensor(tensor_data: np.ndarray,
46
49
  p: p-norm to use for the Lp-norm distance.
47
50
  n_bits: Number of bits to quantize the tensor.
48
51
  per_channel: Whether the quantization should be per-channel or not.
49
- channel_axis: Output channel index.
52
+ channel_axis: Output channel index. if None, search for best axis.
50
53
  n_iter: Number of iterations to search for the optimal threshold (not used for this method).
51
54
  min_threshold: Minimal threshold to use if threshold is too small (not used for this method).
52
55
  quant_error_method: an error function to optimize the parameters' selection accordingly.
@@ -56,11 +59,24 @@ def power_of_two_selection_tensor(tensor_data: np.ndarray,
56
59
 
57
60
  Returns:
58
61
  Power of two threshold to quantize the tensor in a power of 2 manner.
62
+ Selected quantization channel axis.
59
63
  """
60
64
 
61
65
  if quant_error_method == qc.QuantizationErrorMethod.NOCLIPPING:
62
- tensor_max = get_tensor_max(tensor_data, per_channel, channel_axis, n_bits)
63
- threshold = max_power_of_two(tensor_max, min_threshold)
66
+ if channel_axis is None and per_channel:
67
+ total_error_list = []
68
+ th_list = []
69
+ for _axis in range(len(tensor_data.shape)):
70
+ tensor_max = get_tensor_max(tensor_data, per_channel, _axis, n_bits)
71
+ threshold = max_power_of_two(tensor_max, min_threshold)
72
+ q_tensor_data = quantize_tensor(tensor_data, threshold, n_bits, True)
73
+ total_error_list.append(compute_mse(tensor_data, q_tensor_data, norm=True))
74
+ th_list.append(threshold)
75
+ channel_axis = np.argmin(total_error_list)
76
+ threshold = th_list[channel_axis]
77
+ else:
78
+ tensor_max = get_tensor_max(tensor_data, per_channel, channel_axis, n_bits)
79
+ threshold = max_power_of_two(tensor_max, min_threshold)
64
80
  else:
65
81
  signed = True # weights are always signed
66
82
  axis = -1 if per_channel else None
@@ -69,15 +85,15 @@ def power_of_two_selection_tensor(tensor_data: np.ndarray,
69
85
  n_bits=n_bits, signed=signed, node=node,
70
86
  hessian_info_service=hessian_info_service,
71
87
  num_hessian_samples=num_hessian_samples)
72
- threshold = qparams_selection_tensor_search(error_function,
73
- tensor_data,
74
- n_bits,
75
- per_channel=per_channel,
76
- channel_axis=channel_axis,
77
- n_iter=n_iter,
78
- min_threshold=min_threshold,
79
- signed=signed)
80
- return {THRESHOLD: threshold}
88
+ threshold, channel_axis = qparams_selection_tensor_search(error_function,
89
+ tensor_data,
90
+ n_bits,
91
+ per_channel=per_channel,
92
+ channel_axis=channel_axis,
93
+ n_iter=n_iter,
94
+ min_threshold=min_threshold,
95
+ signed=signed)
96
+ return {THRESHOLD: threshold}, channel_axis
81
97
 
82
98
 
83
99
  def power_of_two_selection_histogram(bins: np.ndarray,
@@ -84,13 +84,14 @@ def calculate_quantization_params(graph: Graph,
84
84
  mod_attr_cfg = copy.deepcopy(attr_cfg)
85
85
  mod_attr_cfg.weights_error_method = QuantizationErrorMethod.MSE
86
86
 
87
- weights_params = get_weights_qparams(n.get_weights_by_keys(attr),
88
- candidate_qc.weights_quantization_cfg,
89
- mod_attr_cfg,
90
- output_channels_axis,
91
- node=n,
92
- hessian_info_service=hessian_info_service,
93
- num_hessian_samples=num_hessian_samples)
87
+ weights_params, output_channels_axis = get_weights_qparams(n.get_weights_by_keys(attr),
88
+ candidate_qc.weights_quantization_cfg,
89
+ mod_attr_cfg,
90
+ output_channels_axis,
91
+ node=n,
92
+ hessian_info_service=hessian_info_service,
93
+ num_hessian_samples=num_hessian_samples)
94
+ attr_cfg.weights_channels_axis = (output_channels_axis, attr_cfg.weights_channels_axis[1])
94
95
  attr_cfg.set_weights_quantization_param(weights_params)
95
96
 
96
97
  if n.is_activation_quantization_enabled():
@@ -27,7 +27,7 @@ from model_compression_toolkit.constants import MIN_THRESHOLD, DEFAULT_TOL, DEFA
27
27
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import quantize_tensor, \
28
28
  reshape_tensor_for_per_channel_search, uniform_quantize_tensor, get_output_shape
29
29
  from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import max_power_of_two, \
30
- get_tensor_max
30
+ get_tensor_max, get_tensor_min
31
31
 
32
32
 
33
33
  def qparams_selection_tensor_search(error_function: Callable,
@@ -56,41 +56,49 @@ def qparams_selection_tensor_search(error_function: Callable,
56
56
  signed: a flag whether the tensor is signed.
57
57
 
58
58
  Returns:
59
- Optimal constrained threshold to quantize the tensor.
59
+ Optimal constrained threshold to quantize the tensor, and best channel axis if input channel_axis was None,
60
+ else return the input channel axis.
60
61
 
61
62
  """
62
63
 
63
- output_shape = get_output_shape(tensor_data.shape, channel_axis)
64
-
65
- # First threshold to check is the constrained threshold based on the tensor's maximal value.
66
- tensor_max = get_tensor_max(tensor_data, per_channel, channel_axis, n_bits)
67
- threshold = 2 * max_power_of_two(tensor_max, min_threshold)
64
+ search_axes = range(len(tensor_data.shape)) if channel_axis is None and per_channel else [channel_axis]
65
+ total_error_list = []
66
+ th_list = []
67
+ for _axis in search_axes:
68
+ output_shape = get_output_shape(tensor_data.shape, _axis)
68
69
 
69
- # If the threshold is computed per-channel, we rearrange the tensor such that each sub-tensor
70
- # is flattened, and we iterate over each one of them when searching for the threshold.
71
- if per_channel:
72
- tensor_data_r = reshape_tensor_for_per_channel_search(tensor_data, channel_axis)
70
+ # First threshold to check is the constrained threshold based on the tensor's maximal value.
71
+ tensor_max = get_tensor_max(tensor_data, per_channel, _axis, n_bits)
72
+ threshold = 2 * max_power_of_two(tensor_max, min_threshold)
73
73
 
74
- error_list = [] # init an empty error list
75
- # On each iteration a new constrained threshold which equal to half of the previous tested threshold
76
- # is used for quantizing the tensor and computing the error. The error is appended to an error list, which
77
- # eventually used to select the threshold with the minimal error.
78
- for i in range(n_iter):
74
+ # Rearrange the tensor such that each sub-tensor is flattened, and we iterate over each
75
+ # one of them when searching for the threshold.
79
76
  if per_channel:
80
- threshold_hat = (threshold / (2 ** i)).reshape([-1, 1])
81
- qt = quantize_tensor(tensor_data_r, threshold_hat, n_bits, signed)
82
- per_channel_error = _error_function_wrapper(error_function, tensor_data_r, qt, threshold_hat)
83
-
84
- error_list.append(per_channel_error)
85
- else: # quantize per-tensor
86
- qt = quantize_tensor(tensor_data, threshold / (2 ** i), n_bits, signed)
87
- error = error_function(qt, tensor_data, threshold=threshold / (2 ** i))
88
- error_list.append(error)
89
-
90
- # Take the index of the minimal error, and use it compute the threshold which yielded it.
91
- i = np.argmin(np.stack(error_list, axis=-1), axis=-1)
92
-
93
- return np.maximum(np.reshape(threshold.flatten() / np.power(2, i), output_shape), min_threshold)
77
+ tensor_data_r = reshape_tensor_for_per_channel_search(tensor_data, _axis)
78
+
79
+ error_list = [] # init an empty error list
80
+ # On each iteration a new constrained threshold which equal to half of the previous tested threshold
81
+ # is used for quantizing the tensor and computing the error. The error is appended to an error list, which
82
+ # eventually used to select the threshold with the minimal error.
83
+ for i in range(n_iter):
84
+ if per_channel:
85
+ threshold_hat = (threshold / (2 ** i)).reshape([-1, 1])
86
+ qt = quantize_tensor(tensor_data_r, threshold_hat, n_bits, signed)
87
+ per_channel_error = _error_function_wrapper(error_function, tensor_data_r, qt, threshold_hat)
88
+ error_list.append(per_channel_error)
89
+ else: # quantize per-tensor
90
+ qt = quantize_tensor(tensor_data, threshold / (2 ** i), n_bits, signed)
91
+ error = error_function(qt, tensor_data, threshold=threshold / (2 ** i))
92
+ error_list.append(error)
93
+
94
+ # Take the index of the minimal error, and use it compute the threshold which yielded it.
95
+ err_mat = np.stack(error_list, axis=-1)
96
+ i = np.argmin(err_mat, axis=-1)
97
+ th_list.append(np.maximum(np.reshape(threshold.flatten() / np.power(2, i), output_shape), min_threshold))
98
+ total_error_list.append(err_mat.min(axis=-1).mean())
99
+
100
+ best_axis_index = np.argmin(total_error_list)
101
+ return th_list[best_axis_index], search_axes[best_axis_index]
94
102
 
95
103
 
96
104
  def qparams_selection_histogram_search(error_function: Callable,
@@ -390,13 +398,12 @@ def search_dynamic_range(base_range: np.ndarray, x: np.ndarray, scalers: np.ndar
390
398
 
391
399
  def qparams_symmetric_selection_tensor_search(error_function: Callable,
392
400
  tensor_data: np.ndarray,
393
- tensor_max: np.ndarray,
394
401
  n_bits: int,
395
402
  per_channel: bool = False,
396
403
  channel_axis: int = 1,
397
404
  n_iter: int = SYMMETRIC_TENSOR_PER_CHANNEL_N_ITER,
398
405
  min_threshold=MIN_THRESHOLD,
399
- signed: bool = True) -> Any:
406
+ signed: bool = True) -> Tuple[np.ndarray, int]:
400
407
  """
401
408
  Search for optimal threshold (per-channel or per-tensor) for symmetric quantization of a tensor,
402
409
  using the iterative optimizer method.
@@ -404,7 +411,6 @@ def qparams_symmetric_selection_tensor_search(error_function: Callable,
404
411
  Args:
405
412
  error_function: Function to compute the error between the original and quantized tensors.
406
413
  tensor_data: Numpy array with tensor's content.
407
- tensor_max: The max value of the tensor.
408
414
  n_bits: Number of bits to quantize the tensor.
409
415
  per_channel: Whether the tensor should be quantized per-channel or per-tensor.
410
416
  channel_axis: Index of output channels dimension.
@@ -417,46 +423,55 @@ def qparams_symmetric_selection_tensor_search(error_function: Callable,
417
423
 
418
424
  """
419
425
 
420
- output_shape = get_output_shape(tensor_data.shape, channel_axis)
426
+ search_axes = range(len(tensor_data.shape)) if channel_axis is None and per_channel else [channel_axis]
427
+ total_error_list = []
428
+ th_list = []
429
+ for _axis in search_axes:
430
+ tensor_max = get_tensor_max(tensor_data, per_channel, _axis, n_bits)
431
+ output_shape = get_output_shape(tensor_data.shape, _axis)
421
432
 
422
- # If the threshold is computed per-channel, we rearrange the tensor such that each sub-tensor
423
- # is flattened, and we iterate over each one of them when searching for the threshold.
424
- if per_channel:
425
- tensor_data_r = reshape_tensor_for_per_channel_search(tensor_data, channel_axis)
426
- max_tensor = np.maximum(min_threshold, tensor_max)
427
- res = qparams_symmetric_iterative_minimization(x0=max_tensor,
428
- x=tensor_data_r,
429
- loss_fn=error_function, # gets float_tensor, fxp_tensor, threshold
430
- n_bits=n_bits,
431
- signed=signed,
432
- n_intervals=SYMMETRIC_TENSOR_PER_CHANNEL_N_INTERVALS,
433
- n_iter=SYMMETRIC_TENSOR_PER_CHANNEL_N_ITER,
434
- dec_freq=SYMMETRIC_TENSOR_PER_CHANNEL_DEC_FREQ,
435
- per_channel=True)
436
- return np.reshape(np.maximum(min_threshold, res['param']), output_shape)
437
- else:
438
- # quantize per-tensor
439
- res = qparams_symmetric_iterative_minimization(x0=get_init_threshold(min_threshold, tensor_max),
440
- x=tensor_data,
441
- loss_fn=error_function,
442
- n_bits=n_bits,
443
- signed=signed,
444
- n_intervals=SYMMETRIC_TENSOR_N_INTERVALS,
445
- n_iter=SYMMETRIC_TENSOR_N_ITER,
446
- dec_freq=SYMMETRIC_TENSOR_DEC_FREQ,
447
- per_channel=False)
448
-
449
- return max(min_threshold, res['param'])
433
+ if per_channel:
434
+ # Rearrange the tensor such that each sub-tensor is flattened, and we iterate
435
+ # over each one of them when searching for the threshold.
436
+ tensor_data_r = reshape_tensor_for_per_channel_search(tensor_data, _axis)
437
+ max_tensor = np.maximum(min_threshold, tensor_max)
438
+ res = qparams_symmetric_iterative_minimization(x0=max_tensor,
439
+ x=tensor_data_r,
440
+ loss_fn=error_function, # gets float_tensor, fxp_tensor, threshold
441
+ n_bits=n_bits,
442
+ signed=signed,
443
+ n_intervals=SYMMETRIC_TENSOR_PER_CHANNEL_N_INTERVALS,
444
+ n_iter=SYMMETRIC_TENSOR_PER_CHANNEL_N_ITER,
445
+ dec_freq=SYMMETRIC_TENSOR_PER_CHANNEL_DEC_FREQ,
446
+ per_channel=True)
447
+ th = np.reshape(np.maximum(min_threshold, res['param']), output_shape)
448
+ else:
449
+ # quantize per-tensor
450
+ res = qparams_symmetric_iterative_minimization(x0=get_init_threshold(min_threshold, tensor_max),
451
+ x=tensor_data,
452
+ loss_fn=error_function,
453
+ n_bits=n_bits,
454
+ signed=signed,
455
+ n_intervals=SYMMETRIC_TENSOR_N_INTERVALS,
456
+ n_iter=SYMMETRIC_TENSOR_N_ITER,
457
+ dec_freq=SYMMETRIC_TENSOR_DEC_FREQ,
458
+ per_channel=False)
459
+ th = max(min_threshold, res['param'])
460
+
461
+ total_error_list.append(res['loss'].mean())
462
+ th_list.append(th)
463
+
464
+ best_axis_index = np.argmin(total_error_list)
465
+ return th_list[best_axis_index], search_axes[best_axis_index]
450
466
 
451
467
 
452
468
  def qparams_uniform_selection_tensor_search(error_function: Callable,
453
469
  tensor_data: np.ndarray,
454
- tensor_min: np.ndarray,
455
- tensor_max: np.ndarray,
456
470
  n_bits: int,
457
471
  per_channel: bool = False,
458
472
  channel_axis: int = 1,
459
- n_iter: int = UNIFORM_TENSOR_PER_CHANNEL_N_ITER) -> Any:
473
+ n_iter: int = UNIFORM_TENSOR_PER_CHANNEL_N_ITER,
474
+ ) -> Tuple[Tuple[np.ndarray, np.ndarray], int]:
460
475
  """
461
476
  Search for optimal quantization range (per-channel or per-tensor) for uniform quantization of a tensor,
462
477
  using the iterative optimizer method and built-in scale factors
@@ -465,8 +480,6 @@ def qparams_uniform_selection_tensor_search(error_function: Callable,
465
480
  Args:
466
481
  error_function: Function to compute the error between the original and quantized tensors.
467
482
  tensor_data: Numpy array with tensor's content.
468
- tensor_min: The min value of the tensor.
469
- tensor_max: The max value of the tensor.
470
483
  n_bits: Number of bits to quantize the tensor.
471
484
  per_channel: Whether the tensor should be quantized per-channel or per-tensor.
472
485
  channel_axis: Index of output channels dimension.
@@ -477,17 +490,22 @@ def qparams_uniform_selection_tensor_search(error_function: Callable,
477
490
 
478
491
  """
479
492
 
480
- output_shape = get_output_shape(tensor_data.shape, channel_axis)
493
+ search_axes = range(len(tensor_data.shape)) if channel_axis is None and per_channel else [channel_axis]
494
+ total_error_list = []
495
+ th_list = []
496
+ for _axis in search_axes:
497
+ tensor_min = get_tensor_min(tensor_data, per_channel, _axis)
498
+ tensor_max = get_tensor_max(tensor_data, per_channel, _axis, n_bits, is_uniform_quantization=True)
499
+ output_shape = get_output_shape(tensor_data.shape, _axis)
481
500
 
482
- alpha = np.linspace(BOTTOM_FACTOR, UPPER_FACTOR, UNIFORM_TENSOR_N_SAMPLES)
483
- beta = np.linspace(BOTTOM_FACTOR, UPPER_FACTOR, UNIFORM_TENSOR_N_SAMPLES)
484
- scalers = np.asarray(list(itertools.product(alpha, beta)))
501
+ alpha = np.linspace(BOTTOM_FACTOR, UPPER_FACTOR, UNIFORM_TENSOR_N_SAMPLES)
502
+ beta = np.linspace(BOTTOM_FACTOR, UPPER_FACTOR, UNIFORM_TENSOR_N_SAMPLES)
503
+ scalers = np.asarray(list(itertools.product(alpha, beta)))
485
504
 
486
- # If the threshold is computed per-channel, we rearrange the tensor such that each sub-tensor
487
- # is flattened, and we iterate over each one of them when searching for the threshold.
488
- if per_channel:
505
+ # Rearrange the tensor such that each sub-tensor is flattened, and we iterate over
506
+ # each one of them when searching for the threshold.
489
507
  if per_channel:
490
- tensor_data_r = reshape_tensor_for_per_channel_search(tensor_data, channel_axis)
508
+ tensor_data_r = reshape_tensor_for_per_channel_search(tensor_data, _axis)
491
509
  tensor_min_max = np.column_stack([tensor_min.flatten(), tensor_max.flatten()])
492
510
  res = iterative_uniform_dynamic_range_search(x0=tensor_min_max,
493
511
  x=tensor_data_r,
@@ -496,18 +514,21 @@ def qparams_uniform_selection_tensor_search(error_function: Callable,
496
514
  n_bits=n_bits,
497
515
  n_iter=UNIFORM_TENSOR_PER_CHANNEL_N_ITER,
498
516
  per_channel=True)
499
- return np.reshape(res['param'][:, 0], output_shape), np.reshape(res['param'][:, 1], output_shape)
500
- else:
501
- # quantize per-tensor
502
- pass
503
- res = iterative_uniform_dynamic_range_search(x0=np.array([tensor_min, tensor_max]),
504
- x=tensor_data,
505
- scalers=scalers,
506
- loss_fn=error_function,
507
- n_bits=n_bits,
508
- n_iter=UNIFORM_TENSOR_N_ITER,
509
- per_channel=False)
510
- return res['param']
517
+ th_list.append((np.reshape(res['param'][:, 0], output_shape), np.reshape(res['param'][:, 1], output_shape)))
518
+ else:
519
+ # quantize per-tensor
520
+ res = iterative_uniform_dynamic_range_search(x0=np.array([tensor_min, tensor_max]),
521
+ x=tensor_data,
522
+ scalers=scalers,
523
+ loss_fn=error_function,
524
+ n_bits=n_bits,
525
+ n_iter=UNIFORM_TENSOR_N_ITER,
526
+ per_channel=False)
527
+ th_list.append(tuple(np.split(res['param'], 2)))
528
+ total_error_list.append(res['loss'].mean())
529
+
530
+ best_axis_index = np.argmin(total_error_list)
531
+ return th_list[best_axis_index], search_axes[best_axis_index]
511
532
 
512
533
 
513
534
  def qparams_symmetric_selection_histogram_search(error_function: Callable,