mct-nightly 2.1.0.20240608.434__py3-none-any.whl → 2.1.0.20240610.442__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.1.0.20240608.434.dist-info → mct_nightly-2.1.0.20240610.442.dist-info}/METADATA +1 -1
- {mct_nightly-2.1.0.20240608.434.dist-info → mct_nightly-2.1.0.20240610.442.dist-info}/RECORD +26 -18
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/base_node.py +1 -4
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +10 -6
- model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py +15 -7
- model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py +30 -14
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +8 -7
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py +108 -87
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +15 -13
- model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py +29 -14
- model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py +40 -14
- model_compression_toolkit/core/keras/reader/node_builder.py +3 -3
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +25 -23
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +10 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/__init__.py +16 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +222 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_keras.py +131 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py +111 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/__init__.py +16 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +219 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_keras.py +131 -0
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py +110 -0
- {mct_nightly-2.1.0.20240608.434.dist-info → mct_nightly-2.1.0.20240610.442.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.1.0.20240608.434.dist-info → mct_nightly-2.1.0.20240610.442.dist-info}/WHEEL +0 -0
- {mct_nightly-2.1.0.20240608.434.dist-info → mct_nightly-2.1.0.20240610.442.dist-info}/top_level.txt +0 -0
{mct_nightly-2.1.0.20240608.434.dist-info → mct_nightly-2.1.0.20240610.442.dist-info}/RECORD
RENAMED
@@ -1,4 +1,4 @@
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
model_compression_toolkit/__init__.py,sha256=8uKLxbPGI4bXEsOnz8snYp5aOCbWS0nIiBxD9ic580Y,1573
|
2
2
|
model_compression_toolkit/constants.py,sha256=9pVleMwnhlM4QwIL2HcEq42I1uF4rlSw63RUjkxOF4w,3923
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
4
4
|
model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
|
@@ -31,7 +31,7 @@ model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5
|
|
31
31
|
model_compression_toolkit/core/common/fusion/layer_fusing.py,sha256=lOubqpc18TslhXZijWUJQAa1c3jIB2S-M-5HK78wJPQ,5548
|
32
32
|
model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
|
33
33
|
model_compression_toolkit/core/common/graph/base_graph.py,sha256=lmIw0srKiwCvz7KWqfwKTxyQHDy3s6rWMIXzFAa1UMo,38326
|
34
|
-
model_compression_toolkit/core/common/graph/base_node.py,sha256=
|
34
|
+
model_compression_toolkit/core/common/graph/base_node.py,sha256=X_0zqHrKYAsmnj9tAKjVYasbFcZD8OHpjdiMj9ugQs0,29436
|
35
35
|
model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
|
36
36
|
model_compression_toolkit/core/common/graph/functional_node.py,sha256=71_4TrCdqR_r0mtgxmAyqI05iP5YoQQGeSmDgynuzTw,3902
|
37
37
|
model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
|
@@ -101,7 +101,7 @@ model_compression_toolkit/core/common/quantization/candidate_node_quantization_c
|
|
101
101
|
model_compression_toolkit/core/common/quantization/core_config.py,sha256=KYdyfSmjSL4ye24nKlC_c4_AxYb14qoqaeMnZj4-8kE,2257
|
102
102
|
model_compression_toolkit/core/common/quantization/debug_config.py,sha256=HtkMmneN-EmAzgZK4Vp4M8Sqm5QKdrvNyyZMpaVqYzY,1482
|
103
103
|
model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=fwF4VILaX-u3ZaFd81xjbJuhg8Ef-JX_KfMXW0TPV-I,7136
|
104
|
-
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=
|
104
|
+
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=u0JkdRqBXG0RvvYyLyvYknEVtB2-gxpqUJnUw3loLmE,26851
|
105
105
|
model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=du0VdsxfkOSYaP1EU9gHA5qbXpfQNZL0jXrjk1wBA0U,7106
|
106
106
|
model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=eyosbVdnCwed7oMQ19tqnh0VoyGZ_UAuD_UnNoXyBpo,2210
|
107
107
|
model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=MwIOBZ4BlZSTIOG75PDvlI3JmZ6t8YjPc1VP9Adei60,3847
|
@@ -110,15 +110,15 @@ model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=cdzGN
|
|
110
110
|
model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=O4qFJw3nBYUD4cGbO8haGXZ2-piSqoRpDKDD74iXSxw,12417
|
111
111
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=eCDGwsWYLU6z7qbEVb4TozMW_nd5VEP_iCJ6PcvyEPw,1486
|
112
112
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=w367wmtJ7iWmM4_HlpX-YVUuqtYKrsiPP1oDaICIuK8,23308
|
113
|
-
model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=
|
113
|
+
model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=t0XSwjfOxcq2Sj2PGzccntz1GGv2eqVn9oR3OI0t9wo,8533
|
114
114
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/outlier_filter.py,sha256=9gnfJV89jpGwAx8ImJ5E9NjCv3lDtbyulP4OtgWb62M,1772
|
115
|
-
model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py,sha256=
|
115
|
+
model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py,sha256=HfnhQ4MxGpb95gOWXD1vnroTxxjFt9VFd4jIdo-rvAQ,10623
|
116
116
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py,sha256=noEdvGiyyW7acgQ2OFWLedCODibTGYJifC9qo8YIU5U,4558
|
117
|
-
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=
|
118
|
-
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py,sha256=
|
119
|
-
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py,sha256=
|
120
|
-
model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py,sha256=
|
121
|
-
model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py,sha256=
|
117
|
+
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=E_XFTpYNUZ3JgOk_2qbUbmJH6qGqBM3TDsY4WptYup0,6478
|
118
|
+
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py,sha256=o2XNY_0pUUyId02TUVQBtkux_i40NCcnzuobSeQLy3E,42863
|
119
|
+
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py,sha256=zSNda0jN8cP41m6g5TOv5WvATwIhV8z6AVM1Es6rq1s,4419
|
120
|
+
model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py,sha256=4TP41wPYC0azIzFxUt-lNlKUPIIXQeE4H1SYHkON75k,11875
|
121
|
+
model_compression_toolkit/core/common/quantization/quantization_params_generation/uniform_selection.py,sha256=E83BU4wZEOY-Q-HTo04ABftv22Y6fWEdNYkGA-MZLMU,10494
|
122
122
|
model_compression_toolkit/core/common/quantization/quantizers/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
123
123
|
model_compression_toolkit/core/common/quantization/quantizers/lut_kmeans_quantizer.py,sha256=P0x_y18LypBxP2tV9OWizheYfILqvaMC8RwHo04sUpQ,2761
|
124
124
|
model_compression_toolkit/core/common/quantization/quantizers/quantizers_helpers.py,sha256=CCFhi5LUIcHCCIzDyORvm0FDZLknrctdNwNlPphOQgI,14245
|
@@ -199,7 +199,7 @@ model_compression_toolkit/core/keras/quantizer/lut_fake_quant.py,sha256=Up3-sbuA
|
|
199
199
|
model_compression_toolkit/core/keras/reader/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
200
200
|
model_compression_toolkit/core/keras/reader/common.py,sha256=eZWjBcvTDUX7fCWmy1OAH4lYLFTh59_UQ_nP_Gjp4yw,2594
|
201
201
|
model_compression_toolkit/core/keras/reader/connectivity_handler.py,sha256=AgF6qXZOJMeXvc-pBnGY23BJz7wPBx2aTYxHiO8efec,11303
|
202
|
-
model_compression_toolkit/core/keras/reader/node_builder.py,sha256=
|
202
|
+
model_compression_toolkit/core/keras/reader/node_builder.py,sha256=SAPkgL8aqJjnB6eCucU2D4m50WACCzWC8wjCVtFnwp8,10424
|
203
203
|
model_compression_toolkit/core/keras/reader/reader.py,sha256=wS9UQ2wJKnkZYe9JHwQp7ygDr6CRlzrxmIyLDv1Qz6U,8109
|
204
204
|
model_compression_toolkit/core/keras/reader/nested_model/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
205
205
|
model_compression_toolkit/core/keras/reader/nested_model/edges_merger.py,sha256=K6KAH9o8KSG6baLmhKoCrYK-i-wb6gRKiZmoijFqEYA,7906
|
@@ -222,7 +222,7 @@ model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,s
|
|
222
222
|
model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
|
223
223
|
model_compression_toolkit/core/pytorch/back2framework/instance_builder.py,sha256=BBHBfTqeWm7L3iDyPBpk0jxvj-rBg1QWI23imkjfIl0,1467
|
224
224
|
model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py,sha256=D7lU1r9Uq_7fdNuKk2BMF8ho5GrsY-8gyGN6yYoHaVg,15060
|
225
|
-
model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=
|
225
|
+
model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=Zw4gi-wjJNV8-qGv79YBWVAHmy27f7iW0c2JGNWAKD0,18199
|
226
226
|
model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py,sha256=qZNNOlNTTV4ZKPG3q5GDXkIVTPUEr8dvxAS_YiMORmg,3456
|
227
227
|
model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
228
228
|
model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py,sha256=q2JDw10NKng50ee2i9faGzWZ-IydnR2aOMGSn9RoZmc,5773
|
@@ -431,7 +431,7 @@ model_compression_toolkit/target_platform_capabilities/target_platform/targetpla
|
|
431
431
|
model_compression_toolkit/target_platform_capabilities/tpc_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
432
432
|
model_compression_toolkit/target_platform_capabilities/tpc_models/get_target_platform_capabilities.py,sha256=-jCL-meZWFBF-Dp9wBYTX_14SKmyyUJE-BZ2IQDJIAk,3336
|
433
433
|
model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
|
434
|
-
model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py,sha256=
|
434
|
+
model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py,sha256=mjPFr6Z-PLzqQta8mW7dK31mbbBZsJo4MdpJQmxlSt4,4640
|
435
435
|
model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py,sha256=F5RG4MnuAwKcNXbfVbPFLQu30-lNax-7knqu20B6udQ,1522
|
436
436
|
model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/__init__.py,sha256=1mMOREEMoNHu_KTMGDp4crN61opKWX6aFn1DrDLvqcc,717
|
437
437
|
model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py,sha256=S-GwMI-JiuPpbtOdd6TSOEjiUFiIs6M2RAiJNJ3O950,10883
|
@@ -453,6 +453,14 @@ model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_
|
|
453
453
|
model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py,sha256=dmi2lCT0dw6RnWVw73tcnqgsVSgINSWaIWfgZhEli4Q,10691
|
454
454
|
model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py,sha256=6PVKQKGpJpM2B1qvmf6fID_-MACaSQZkaL_9J_fj2SQ,6595
|
455
455
|
model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_pytorch.py,sha256=dFQjzFlLDwoUqKNP1at1fS1N1WJadSSasRyzHl6vaB8,5733
|
456
|
+
model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/__init__.py,sha256=gAeebYCKyIXH9-Qwze7FwvTihudzAHk_Qsg94fQbkjQ,717
|
457
|
+
model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py,sha256=edMH4lM7Bq7FaPAFZLU5UMX-bWSWiaaAIXnQE7lZ7rI,11844
|
458
|
+
model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_keras.py,sha256=T5YMv-RzgYlzBaagnMO7WnKgbZ7PrOvm29Nn4vUhCHI,6587
|
459
|
+
model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py,sha256=-q6Tnn7diPCCoATmLDzJwWwviQcbMMISqgpLu2n42JY,5726
|
460
|
+
model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/__init__.py,sha256=C2kwyDE1-rtukkbNSoKRv9q8Nt2GOCaBbl0BdOr3goA,721
|
461
|
+
model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py,sha256=HoGjDwoSx2Y4dQua5v1qzzlnSl_HfDMK6bGWuZhPOzQ,11577
|
462
|
+
model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_keras.py,sha256=LvqUkvpJKXBb9QETcHsmp9OGDwl9KWr457deag8GVuM,6595
|
463
|
+
model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py,sha256=4Y2D14rE0SnWIkBTYsVqCryB-gkHU1ZlbdkWF864mPU,5733
|
456
464
|
model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
457
465
|
model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py,sha256=7KVcuz0LfngRKOsfcvBysxGVb9fqgoAO6MVTl1CmB5c,2082
|
458
466
|
model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py,sha256=UUvUCcTots_sehdRnDfgkaE8WPQ7dPbeuhDF4Qy2nzw,1510
|
@@ -483,8 +491,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
|
|
483
491
|
model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
|
484
492
|
model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
485
493
|
model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=MxylaVFPgN7zBiRBy6WV610EA4scLgRJFbMucKvvNDU,2896
|
486
|
-
mct_nightly-2.1.0.
|
487
|
-
mct_nightly-2.1.0.
|
488
|
-
mct_nightly-2.1.0.
|
489
|
-
mct_nightly-2.1.0.
|
490
|
-
mct_nightly-2.1.0.
|
494
|
+
mct_nightly-2.1.0.20240610.442.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
495
|
+
mct_nightly-2.1.0.20240610.442.dist-info/METADATA,sha256=Juo23o8F4ndhmb8TksZ99xKWtks0DK59daxJqx_9RmI,19721
|
496
|
+
mct_nightly-2.1.0.20240610.442.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
497
|
+
mct_nightly-2.1.0.20240610.442.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
498
|
+
mct_nightly-2.1.0.20240610.442.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.1.0.
|
30
|
+
__version__ = "2.1.0.20240610.000442"
|
@@ -240,10 +240,7 @@ class BaseNode:
|
|
240
240
|
if isinstance(pos, int)):
|
241
241
|
if pos > len(input_tensors):
|
242
242
|
Logger.critical("The positional weight index cannot exceed the number of input tensors to the node.") # pragma: no cover
|
243
|
-
|
244
|
-
# subject to quantization, the quantization wrapper inserts the positional weight into the node.
|
245
|
-
if not self.is_weights_quantization_enabled(pos):
|
246
|
-
input_tensors.insert(pos, weight)
|
243
|
+
input_tensors.insert(pos, weight)
|
247
244
|
|
248
245
|
return input_tensors
|
249
246
|
|
@@ -326,13 +326,17 @@ class WeightsAttrQuantizationConfig:
|
|
326
326
|
|
327
327
|
"""
|
328
328
|
assert self.enable_weights_quantization
|
329
|
+
assert not (self.weights_per_channel_threshold and self.weights_channels_axis is None), \
|
330
|
+
"Trying to calculate threshold per channel, channel axis in None."
|
329
331
|
if self.weights_quantization_params_fn is not None:
|
330
|
-
self.set_weights_quantization_param(
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
332
|
+
self.set_weights_quantization_param(
|
333
|
+
self.weights_quantization_params_fn(tensor_data,
|
334
|
+
p=self.l_p_value,
|
335
|
+
n_bits=self.weights_n_bits,
|
336
|
+
per_channel=self.weights_per_channel_threshold and self.weights_channels_axis is not None,
|
337
|
+
channel_axis=self.weights_channels_axis[0], # output channel axis
|
338
|
+
min_threshold=min_threshold)[0] # Take only first output, the q-params, as axis is already chosen.
|
339
|
+
)
|
336
340
|
else:
|
337
341
|
self.set_weights_quantization_param({})
|
338
342
|
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from typing import Dict
|
16
|
+
from typing import Dict, Tuple
|
17
17
|
import numpy as np
|
18
18
|
from sklearn.cluster import KMeans
|
19
19
|
|
@@ -42,7 +42,8 @@ def lut_kmeans_tensor(tensor_data: np.ndarray,
|
|
42
42
|
is_symmetric: bool = False,
|
43
43
|
node=None,
|
44
44
|
hessian_info_service: HessianInfoService = None,
|
45
|
-
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES
|
45
|
+
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES,
|
46
|
+
) -> Tuple[Dict[str, np.ndarray], int]:
|
46
47
|
"""
|
47
48
|
The quantizer first finds the closest max value per channel of tensor_data.
|
48
49
|
Now, we divide tensor_data with the threshold vector per channel. In addition, we scale the result to the range
|
@@ -70,27 +71,34 @@ def lut_kmeans_tensor(tensor_data: np.ndarray,
|
|
70
71
|
if n_bits >= LUT_VALUES_BITWIDTH:
|
71
72
|
Logger.critical(f'Look-Up-Table (LUT) bit configuration exceeds maximum: {n_bits} bits provided, must be less than {LUT_VALUES_BITWIDTH} bits.') # pragma: no cover
|
72
73
|
# TODO: need to set this externally
|
74
|
+
n_data_points = len(np.unique(tensor_data.flatten()))
|
73
75
|
if len(np.unique(tensor_data.flatten())) < 2 ** n_bits:
|
74
|
-
n_clusters =
|
76
|
+
n_clusters = n_data_points
|
75
77
|
else:
|
76
78
|
n_clusters = 2 ** n_bits
|
77
79
|
kmeans = KMeans(n_clusters=n_clusters, n_init=10)
|
78
80
|
|
79
81
|
threshold_selection_tensor = symmetric_selection_tensor if is_symmetric else power_of_two_selection_tensor
|
80
|
-
|
81
|
-
|
82
|
-
|
82
|
+
|
83
|
+
_params, channel_axis = threshold_selection_tensor(tensor_data, p, n_bits, per_channel,
|
84
|
+
channel_axis, n_iter, min_threshold,
|
85
|
+
qc.QuantizationErrorMethod.NOCLIPPING)
|
86
|
+
thresholds_per_channel = _params[THRESHOLD]
|
83
87
|
|
84
88
|
tensor_for_kmeans = int_quantization_with_threshold(tensor_data, thresholds_per_channel, LUT_VALUES_BITWIDTH)
|
85
89
|
kmeans.fit(tensor_for_kmeans.reshape(-1, 1))
|
86
90
|
|
87
91
|
# Add 0 to the LUT
|
88
92
|
cc = np.round(kmeans.cluster_centers_)
|
93
|
+
if n_data_points < 2 ** n_bits and np.all(cc != 0):
|
94
|
+
# In case there are fewer data points than potential clusters, we can add the cluster 0.0
|
95
|
+
# to the original clusters array to improve quantization (i.e. no need to zero one of the clusters).
|
96
|
+
cc = np.concatenate([np.zeros([1, 1], dtype=cc.dtype), cc])
|
89
97
|
closest2zero_idx = (np.abs(cc - 0)).argmin()
|
90
98
|
cc[closest2zero_idx] = 0.0
|
91
99
|
|
92
100
|
return {LUT_VALUES: cc,
|
93
|
-
SCALE_PER_CHANNEL: thresholds_per_channel}
|
101
|
+
SCALE_PER_CHANNEL: thresholds_per_channel}, channel_axis
|
94
102
|
|
95
103
|
|
96
104
|
def lut_kmeans_histogram(bins: np.ndarray,
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
import numpy as np
|
16
|
+
from typing import Union, Tuple, Dict
|
16
17
|
|
17
18
|
import model_compression_toolkit.core.common.quantization.quantization_config as qc
|
18
19
|
from model_compression_toolkit.constants import MIN_THRESHOLD, THRESHOLD, NUM_QPARAM_HESSIAN_SAMPLES
|
@@ -23,20 +24,22 @@ from model_compression_toolkit.core.common.quantization.quantizers.quantizers_he
|
|
23
24
|
from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \
|
24
25
|
get_threshold_selection_tensor_error_function, get_threshold_selection_histogram_error_function
|
25
26
|
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
27
|
+
from model_compression_toolkit.core.common.similarity_analyzer import compute_mse
|
28
|
+
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import quantize_tensor
|
26
29
|
|
27
30
|
|
28
31
|
def power_of_two_selection_tensor(tensor_data: np.ndarray,
|
29
32
|
p: int,
|
30
33
|
n_bits: int,
|
31
34
|
per_channel: bool = False,
|
32
|
-
channel_axis: int = 1,
|
35
|
+
channel_axis: Union[int, None] = 1,
|
33
36
|
n_iter: int = 10,
|
34
37
|
min_threshold: float = MIN_THRESHOLD,
|
35
38
|
quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE,
|
36
39
|
node=None,
|
37
40
|
hessian_info_service: HessianInfoService = None,
|
38
41
|
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES,
|
39
|
-
) ->
|
42
|
+
) -> Tuple[Dict[str, np.ndarray], int]:
|
40
43
|
"""
|
41
44
|
Compute the power of two threshold based on the provided QuantizationErrorMethod to quantize the tensor.
|
42
45
|
Different search is applied, depends on the value of the selected QuantizationErrorMethod.
|
@@ -46,7 +49,7 @@ def power_of_two_selection_tensor(tensor_data: np.ndarray,
|
|
46
49
|
p: p-norm to use for the Lp-norm distance.
|
47
50
|
n_bits: Number of bits to quantize the tensor.
|
48
51
|
per_channel: Whether the quantization should be per-channel or not.
|
49
|
-
channel_axis: Output channel index.
|
52
|
+
channel_axis: Output channel index. if None, search for best axis.
|
50
53
|
n_iter: Number of iterations to search for the optimal threshold (not used for this method).
|
51
54
|
min_threshold: Minimal threshold to use if threshold is too small (not used for this method).
|
52
55
|
quant_error_method: an error function to optimize the parameters' selection accordingly.
|
@@ -56,11 +59,24 @@ def power_of_two_selection_tensor(tensor_data: np.ndarray,
|
|
56
59
|
|
57
60
|
Returns:
|
58
61
|
Power of two threshold to quantize the tensor in a power of 2 manner.
|
62
|
+
Selected quantization channel axis.
|
59
63
|
"""
|
60
64
|
|
61
65
|
if quant_error_method == qc.QuantizationErrorMethod.NOCLIPPING:
|
62
|
-
|
63
|
-
|
66
|
+
if channel_axis is None and per_channel:
|
67
|
+
total_error_list = []
|
68
|
+
th_list = []
|
69
|
+
for _axis in range(len(tensor_data.shape)):
|
70
|
+
tensor_max = get_tensor_max(tensor_data, per_channel, _axis, n_bits)
|
71
|
+
threshold = max_power_of_two(tensor_max, min_threshold)
|
72
|
+
q_tensor_data = quantize_tensor(tensor_data, threshold, n_bits, True)
|
73
|
+
total_error_list.append(compute_mse(tensor_data, q_tensor_data, norm=True))
|
74
|
+
th_list.append(threshold)
|
75
|
+
channel_axis = np.argmin(total_error_list)
|
76
|
+
threshold = th_list[channel_axis]
|
77
|
+
else:
|
78
|
+
tensor_max = get_tensor_max(tensor_data, per_channel, channel_axis, n_bits)
|
79
|
+
threshold = max_power_of_two(tensor_max, min_threshold)
|
64
80
|
else:
|
65
81
|
signed = True # weights are always signed
|
66
82
|
axis = -1 if per_channel else None
|
@@ -69,15 +85,15 @@ def power_of_two_selection_tensor(tensor_data: np.ndarray,
|
|
69
85
|
n_bits=n_bits, signed=signed, node=node,
|
70
86
|
hessian_info_service=hessian_info_service,
|
71
87
|
num_hessian_samples=num_hessian_samples)
|
72
|
-
threshold = qparams_selection_tensor_search(error_function,
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
return {THRESHOLD: threshold}
|
88
|
+
threshold, channel_axis = qparams_selection_tensor_search(error_function,
|
89
|
+
tensor_data,
|
90
|
+
n_bits,
|
91
|
+
per_channel=per_channel,
|
92
|
+
channel_axis=channel_axis,
|
93
|
+
n_iter=n_iter,
|
94
|
+
min_threshold=min_threshold,
|
95
|
+
signed=signed)
|
96
|
+
return {THRESHOLD: threshold}, channel_axis
|
81
97
|
|
82
98
|
|
83
99
|
def power_of_two_selection_histogram(bins: np.ndarray,
|
@@ -84,13 +84,14 @@ def calculate_quantization_params(graph: Graph,
|
|
84
84
|
mod_attr_cfg = copy.deepcopy(attr_cfg)
|
85
85
|
mod_attr_cfg.weights_error_method = QuantizationErrorMethod.MSE
|
86
86
|
|
87
|
-
weights_params = get_weights_qparams(n.get_weights_by_keys(attr),
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
87
|
+
weights_params, output_channels_axis = get_weights_qparams(n.get_weights_by_keys(attr),
|
88
|
+
candidate_qc.weights_quantization_cfg,
|
89
|
+
mod_attr_cfg,
|
90
|
+
output_channels_axis,
|
91
|
+
node=n,
|
92
|
+
hessian_info_service=hessian_info_service,
|
93
|
+
num_hessian_samples=num_hessian_samples)
|
94
|
+
attr_cfg.weights_channels_axis = (output_channels_axis, attr_cfg.weights_channels_axis[1])
|
94
95
|
attr_cfg.set_weights_quantization_param(weights_params)
|
95
96
|
|
96
97
|
if n.is_activation_quantization_enabled():
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py
CHANGED
@@ -27,7 +27,7 @@ from model_compression_toolkit.constants import MIN_THRESHOLD, DEFAULT_TOL, DEFA
|
|
27
27
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import quantize_tensor, \
|
28
28
|
reshape_tensor_for_per_channel_search, uniform_quantize_tensor, get_output_shape
|
29
29
|
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import max_power_of_two, \
|
30
|
-
get_tensor_max
|
30
|
+
get_tensor_max, get_tensor_min
|
31
31
|
|
32
32
|
|
33
33
|
def qparams_selection_tensor_search(error_function: Callable,
|
@@ -56,41 +56,49 @@ def qparams_selection_tensor_search(error_function: Callable,
|
|
56
56
|
signed: a flag whether the tensor is signed.
|
57
57
|
|
58
58
|
Returns:
|
59
|
-
Optimal constrained threshold to quantize the tensor
|
59
|
+
Optimal constrained threshold to quantize the tensor, and best channel axis if input channel_axis was None,
|
60
|
+
else return the input channel axis.
|
60
61
|
|
61
62
|
"""
|
62
63
|
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
64
|
+
search_axes = range(len(tensor_data.shape)) if channel_axis is None and per_channel else [channel_axis]
|
65
|
+
total_error_list = []
|
66
|
+
th_list = []
|
67
|
+
for _axis in search_axes:
|
68
|
+
output_shape = get_output_shape(tensor_data.shape, _axis)
|
68
69
|
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
tensor_data_r = reshape_tensor_for_per_channel_search(tensor_data, channel_axis)
|
70
|
+
# First threshold to check is the constrained threshold based on the tensor's maximal value.
|
71
|
+
tensor_max = get_tensor_max(tensor_data, per_channel, _axis, n_bits)
|
72
|
+
threshold = 2 * max_power_of_two(tensor_max, min_threshold)
|
73
73
|
|
74
|
-
|
75
|
-
|
76
|
-
# is used for quantizing the tensor and computing the error. The error is appended to an error list, which
|
77
|
-
# eventually used to select the threshold with the minimal error.
|
78
|
-
for i in range(n_iter):
|
74
|
+
# Rearrange the tensor such that each sub-tensor is flattened, and we iterate over each
|
75
|
+
# one of them when searching for the threshold.
|
79
76
|
if per_channel:
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
77
|
+
tensor_data_r = reshape_tensor_for_per_channel_search(tensor_data, _axis)
|
78
|
+
|
79
|
+
error_list = [] # init an empty error list
|
80
|
+
# On each iteration a new constrained threshold which equal to half of the previous tested threshold
|
81
|
+
# is used for quantizing the tensor and computing the error. The error is appended to an error list, which
|
82
|
+
# eventually used to select the threshold with the minimal error.
|
83
|
+
for i in range(n_iter):
|
84
|
+
if per_channel:
|
85
|
+
threshold_hat = (threshold / (2 ** i)).reshape([-1, 1])
|
86
|
+
qt = quantize_tensor(tensor_data_r, threshold_hat, n_bits, signed)
|
87
|
+
per_channel_error = _error_function_wrapper(error_function, tensor_data_r, qt, threshold_hat)
|
88
|
+
error_list.append(per_channel_error)
|
89
|
+
else: # quantize per-tensor
|
90
|
+
qt = quantize_tensor(tensor_data, threshold / (2 ** i), n_bits, signed)
|
91
|
+
error = error_function(qt, tensor_data, threshold=threshold / (2 ** i))
|
92
|
+
error_list.append(error)
|
93
|
+
|
94
|
+
# Take the index of the minimal error, and use it compute the threshold which yielded it.
|
95
|
+
err_mat = np.stack(error_list, axis=-1)
|
96
|
+
i = np.argmin(err_mat, axis=-1)
|
97
|
+
th_list.append(np.maximum(np.reshape(threshold.flatten() / np.power(2, i), output_shape), min_threshold))
|
98
|
+
total_error_list.append(err_mat.min(axis=-1).mean())
|
99
|
+
|
100
|
+
best_axis_index = np.argmin(total_error_list)
|
101
|
+
return th_list[best_axis_index], search_axes[best_axis_index]
|
94
102
|
|
95
103
|
|
96
104
|
def qparams_selection_histogram_search(error_function: Callable,
|
@@ -390,13 +398,12 @@ def search_dynamic_range(base_range: np.ndarray, x: np.ndarray, scalers: np.ndar
|
|
390
398
|
|
391
399
|
def qparams_symmetric_selection_tensor_search(error_function: Callable,
|
392
400
|
tensor_data: np.ndarray,
|
393
|
-
tensor_max: np.ndarray,
|
394
401
|
n_bits: int,
|
395
402
|
per_channel: bool = False,
|
396
403
|
channel_axis: int = 1,
|
397
404
|
n_iter: int = SYMMETRIC_TENSOR_PER_CHANNEL_N_ITER,
|
398
405
|
min_threshold=MIN_THRESHOLD,
|
399
|
-
signed: bool = True) ->
|
406
|
+
signed: bool = True) -> Tuple[np.ndarray, int]:
|
400
407
|
"""
|
401
408
|
Search for optimal threshold (per-channel or per-tensor) for symmetric quantization of a tensor,
|
402
409
|
using the iterative optimizer method.
|
@@ -404,7 +411,6 @@ def qparams_symmetric_selection_tensor_search(error_function: Callable,
|
|
404
411
|
Args:
|
405
412
|
error_function: Function to compute the error between the original and quantized tensors.
|
406
413
|
tensor_data: Numpy array with tensor's content.
|
407
|
-
tensor_max: The max value of the tensor.
|
408
414
|
n_bits: Number of bits to quantize the tensor.
|
409
415
|
per_channel: Whether the tensor should be quantized per-channel or per-tensor.
|
410
416
|
channel_axis: Index of output channels dimension.
|
@@ -417,46 +423,55 @@ def qparams_symmetric_selection_tensor_search(error_function: Callable,
|
|
417
423
|
|
418
424
|
"""
|
419
425
|
|
420
|
-
|
426
|
+
search_axes = range(len(tensor_data.shape)) if channel_axis is None and per_channel else [channel_axis]
|
427
|
+
total_error_list = []
|
428
|
+
th_list = []
|
429
|
+
for _axis in search_axes:
|
430
|
+
tensor_max = get_tensor_max(tensor_data, per_channel, _axis, n_bits)
|
431
|
+
output_shape = get_output_shape(tensor_data.shape, _axis)
|
421
432
|
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
433
|
+
if per_channel:
|
434
|
+
# Rearrange the tensor such that each sub-tensor is flattened, and we iterate
|
435
|
+
# over each one of them when searching for the threshold.
|
436
|
+
tensor_data_r = reshape_tensor_for_per_channel_search(tensor_data, _axis)
|
437
|
+
max_tensor = np.maximum(min_threshold, tensor_max)
|
438
|
+
res = qparams_symmetric_iterative_minimization(x0=max_tensor,
|
439
|
+
x=tensor_data_r,
|
440
|
+
loss_fn=error_function, # gets float_tensor, fxp_tensor, threshold
|
441
|
+
n_bits=n_bits,
|
442
|
+
signed=signed,
|
443
|
+
n_intervals=SYMMETRIC_TENSOR_PER_CHANNEL_N_INTERVALS,
|
444
|
+
n_iter=SYMMETRIC_TENSOR_PER_CHANNEL_N_ITER,
|
445
|
+
dec_freq=SYMMETRIC_TENSOR_PER_CHANNEL_DEC_FREQ,
|
446
|
+
per_channel=True)
|
447
|
+
th = np.reshape(np.maximum(min_threshold, res['param']), output_shape)
|
448
|
+
else:
|
449
|
+
# quantize per-tensor
|
450
|
+
res = qparams_symmetric_iterative_minimization(x0=get_init_threshold(min_threshold, tensor_max),
|
451
|
+
x=tensor_data,
|
452
|
+
loss_fn=error_function,
|
453
|
+
n_bits=n_bits,
|
454
|
+
signed=signed,
|
455
|
+
n_intervals=SYMMETRIC_TENSOR_N_INTERVALS,
|
456
|
+
n_iter=SYMMETRIC_TENSOR_N_ITER,
|
457
|
+
dec_freq=SYMMETRIC_TENSOR_DEC_FREQ,
|
458
|
+
per_channel=False)
|
459
|
+
th = max(min_threshold, res['param'])
|
460
|
+
|
461
|
+
total_error_list.append(res['loss'].mean())
|
462
|
+
th_list.append(th)
|
463
|
+
|
464
|
+
best_axis_index = np.argmin(total_error_list)
|
465
|
+
return th_list[best_axis_index], search_axes[best_axis_index]
|
450
466
|
|
451
467
|
|
452
468
|
def qparams_uniform_selection_tensor_search(error_function: Callable,
|
453
469
|
tensor_data: np.ndarray,
|
454
|
-
tensor_min: np.ndarray,
|
455
|
-
tensor_max: np.ndarray,
|
456
470
|
n_bits: int,
|
457
471
|
per_channel: bool = False,
|
458
472
|
channel_axis: int = 1,
|
459
|
-
n_iter: int = UNIFORM_TENSOR_PER_CHANNEL_N_ITER
|
473
|
+
n_iter: int = UNIFORM_TENSOR_PER_CHANNEL_N_ITER,
|
474
|
+
) -> Tuple[Tuple[np.ndarray, np.ndarray], int]:
|
460
475
|
"""
|
461
476
|
Search for optimal quantization range (per-channel or per-tensor) for uniform quantization of a tensor,
|
462
477
|
using the iterative optimizer method and built-in scale factors
|
@@ -465,8 +480,6 @@ def qparams_uniform_selection_tensor_search(error_function: Callable,
|
|
465
480
|
Args:
|
466
481
|
error_function: Function to compute the error between the original and quantized tensors.
|
467
482
|
tensor_data: Numpy array with tensor's content.
|
468
|
-
tensor_min: The min value of the tensor.
|
469
|
-
tensor_max: The max value of the tensor.
|
470
483
|
n_bits: Number of bits to quantize the tensor.
|
471
484
|
per_channel: Whether the tensor should be quantized per-channel or per-tensor.
|
472
485
|
channel_axis: Index of output channels dimension.
|
@@ -477,17 +490,22 @@ def qparams_uniform_selection_tensor_search(error_function: Callable,
|
|
477
490
|
|
478
491
|
"""
|
479
492
|
|
480
|
-
|
493
|
+
search_axes = range(len(tensor_data.shape)) if channel_axis is None and per_channel else [channel_axis]
|
494
|
+
total_error_list = []
|
495
|
+
th_list = []
|
496
|
+
for _axis in search_axes:
|
497
|
+
tensor_min = get_tensor_min(tensor_data, per_channel, _axis)
|
498
|
+
tensor_max = get_tensor_max(tensor_data, per_channel, _axis, n_bits, is_uniform_quantization=True)
|
499
|
+
output_shape = get_output_shape(tensor_data.shape, _axis)
|
481
500
|
|
482
|
-
|
483
|
-
|
484
|
-
|
501
|
+
alpha = np.linspace(BOTTOM_FACTOR, UPPER_FACTOR, UNIFORM_TENSOR_N_SAMPLES)
|
502
|
+
beta = np.linspace(BOTTOM_FACTOR, UPPER_FACTOR, UNIFORM_TENSOR_N_SAMPLES)
|
503
|
+
scalers = np.asarray(list(itertools.product(alpha, beta)))
|
485
504
|
|
486
|
-
|
487
|
-
|
488
|
-
if per_channel:
|
505
|
+
# Rearrange the tensor such that each sub-tensor is flattened, and we iterate over
|
506
|
+
# each one of them when searching for the threshold.
|
489
507
|
if per_channel:
|
490
|
-
tensor_data_r = reshape_tensor_for_per_channel_search(tensor_data,
|
508
|
+
tensor_data_r = reshape_tensor_for_per_channel_search(tensor_data, _axis)
|
491
509
|
tensor_min_max = np.column_stack([tensor_min.flatten(), tensor_max.flatten()])
|
492
510
|
res = iterative_uniform_dynamic_range_search(x0=tensor_min_max,
|
493
511
|
x=tensor_data_r,
|
@@ -496,18 +514,21 @@ def qparams_uniform_selection_tensor_search(error_function: Callable,
|
|
496
514
|
n_bits=n_bits,
|
497
515
|
n_iter=UNIFORM_TENSOR_PER_CHANNEL_N_ITER,
|
498
516
|
per_channel=True)
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
517
|
+
th_list.append((np.reshape(res['param'][:, 0], output_shape), np.reshape(res['param'][:, 1], output_shape)))
|
518
|
+
else:
|
519
|
+
# quantize per-tensor
|
520
|
+
res = iterative_uniform_dynamic_range_search(x0=np.array([tensor_min, tensor_max]),
|
521
|
+
x=tensor_data,
|
522
|
+
scalers=scalers,
|
523
|
+
loss_fn=error_function,
|
524
|
+
n_bits=n_bits,
|
525
|
+
n_iter=UNIFORM_TENSOR_N_ITER,
|
526
|
+
per_channel=False)
|
527
|
+
th_list.append(tuple(np.split(res['param'], 2)))
|
528
|
+
total_error_list.append(res['loss'].mean())
|
529
|
+
|
530
|
+
best_axis_index = np.argmin(total_error_list)
|
531
|
+
return th_list[best_axis_index], search_axes[best_axis_index]
|
511
532
|
|
512
533
|
|
513
534
|
def qparams_symmetric_selection_histogram_search(error_function: Callable,
|