mct-nightly 1.10.0.20231017.post414__py3-none-any.whl → 1.10.0.20231019.post424__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-1.10.0.20231017.post414.dist-info → mct_nightly-1.10.0.20231019.post424.dist-info}/METADATA +1 -1
- {mct_nightly-1.10.0.20231017.post414.dist-info → mct_nightly-1.10.0.20231019.post424.dist-info}/RECORD +22 -15
- model_compression_toolkit/core/common/framework_implementation.py +0 -12
- model_compression_toolkit/core/common/hessian/hessian_info_service.py +17 -1
- model_compression_toolkit/core/keras/constants.py +7 -0
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +127 -0
- model_compression_toolkit/core/keras/keras_implementation.py +3 -17
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +0 -15
- model_compression_toolkit/qat/common/qat_config.py +4 -1
- model_compression_toolkit/qat/keras/quantizer/__init__.py +2 -0
- model_compression_toolkit/qat/keras/quantizer/lsq/__init__.py +14 -0
- model_compression_toolkit/qat/keras/quantizer/lsq/symmetric_lsq.py +254 -0
- model_compression_toolkit/qat/keras/quantizer/lsq/uniform_lsq.py +250 -0
- model_compression_toolkit/qat/keras/quantizer/quant_utils.py +17 -0
- model_compression_toolkit/qat/pytorch/quantizer/__init__.py +3 -1
- model_compression_toolkit/qat/pytorch/quantizer/lsq/__init__.py +14 -0
- model_compression_toolkit/qat/pytorch/quantizer/lsq/symmetric_lsq.py +228 -0
- model_compression_toolkit/qat/pytorch/quantizer/lsq/uniform_lsq.py +223 -0
- model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py +17 -4
- {mct_nightly-1.10.0.20231017.post414.dist-info → mct_nightly-1.10.0.20231019.post424.dist-info}/LICENSE.md +0 -0
- {mct_nightly-1.10.0.20231017.post414.dist-info → mct_nightly-1.10.0.20231019.post424.dist-info}/WHEEL +0 -0
- {mct_nightly-1.10.0.20231017.post414.dist-info → mct_nightly-1.10.0.20231019.post424.dist-info}/top_level.txt +0 -0
|
@@ -10,7 +10,7 @@ model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71E
|
|
|
10
10
|
model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
|
|
11
11
|
model_compression_toolkit/core/common/data_loader.py,sha256=7YF5Mqz64Xb4rVwY3knrdIZ4JEHybXxiQqx0deR_c5k,4017
|
|
12
12
|
model_compression_toolkit/core/common/defaultdict.py,sha256=n-F3dP-VTMnGy9KfCwp7D_WBlvFxe3waX4LpnOX8FH0,2281
|
|
13
|
-
model_compression_toolkit/core/common/framework_implementation.py,sha256=
|
|
13
|
+
model_compression_toolkit/core/common/framework_implementation.py,sha256=yFWOic8eEq1A0_IhgPAAMFPfMbm0XBlsoz8-PYYp2o4,20683
|
|
14
14
|
model_compression_toolkit/core/common/framework_info.py,sha256=hwmstv7IuBRfa6IxDbeG4y-7AxKx4bwCyI_Exi2C7mo,6424
|
|
15
15
|
model_compression_toolkit/core/common/memory_computation.py,sha256=ixoSpV5ZYZGyzhre3kQcvR2sNA8KBsPZ3lgbkDnw9Cs,1205
|
|
16
16
|
model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3ug7Y9-lLyV99_FoNHxkGZMgcm0Vkpss,1324
|
|
@@ -46,7 +46,7 @@ model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py,sha256
|
|
|
46
46
|
model_compression_toolkit/core/common/graph/memory_graph/memory_element.py,sha256=gRmBEFRmyJsNKezQfiwDwQu1cmbGd2wgKCRTH6iw8mw,3961
|
|
47
47
|
model_compression_toolkit/core/common/graph/memory_graph/memory_graph.py,sha256=gw4av_rzn_3oEAPpD3B7PHZDqnxHMjIESevl6ppPnkk,7175
|
|
48
48
|
model_compression_toolkit/core/common/hessian/__init__.py,sha256=bxPVbkIlHFJMiOgTdWMVCqcD9JKV5kb2bVdWUTeLpj8,1021
|
|
49
|
-
model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=
|
|
49
|
+
model_compression_toolkit/core/common/hessian/hessian_info_service.py,sha256=FJlNhYaxI_3X0aG5wXSmdNsKYZm7paiSchi1TJpZQiA,8180
|
|
50
50
|
model_compression_toolkit/core/common/hessian/hessian_info_utils.py,sha256=zHEcql5z4VGuvEx3oI7OOD5FeWgwuN4xpf2H01O8kY4,3324
|
|
51
51
|
model_compression_toolkit/core/common/hessian/trace_hessian_calculator.py,sha256=08PapRDOgfHybDsNW5iBEEvNFKsWraPnBBlsUfcwdRE,3600
|
|
52
52
|
model_compression_toolkit/core/common/hessian/trace_hessian_request.py,sha256=sPOQL7IUTrd7QlCFuDPgL6RhMkBPMl7Fg81-WhouvQY,2705
|
|
@@ -133,10 +133,10 @@ model_compression_toolkit/core/common/visualization/final_config_visualizer.py,s
|
|
|
133
133
|
model_compression_toolkit/core/common/visualization/nn_visualizer.py,sha256=6EjZj_KE1tICTQ0XSKIx5ivsRFpRktFywda7pW7YnNQ,5955
|
|
134
134
|
model_compression_toolkit/core/common/visualization/tensorboard_writer.py,sha256=1Vbr5gKqH1fJl91iTNFlIjyEMh6jm88T4AIWalMrJFw,20099
|
|
135
135
|
model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
|
136
|
-
model_compression_toolkit/core/keras/constants.py,sha256=
|
|
136
|
+
model_compression_toolkit/core/keras/constants.py,sha256=0vQZ3-8-IJ735KCet858OcxlDRFx3GyDks97GBF9gS4,2968
|
|
137
137
|
model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
|
|
138
138
|
model_compression_toolkit/core/keras/default_framework_info.py,sha256=HLLO28tsbb9tHYQ05grUr3oJTRKdS520QnyGUYgzRK0,4994
|
|
139
|
-
model_compression_toolkit/core/keras/keras_implementation.py,sha256
|
|
139
|
+
model_compression_toolkit/core/keras/keras_implementation.py,sha256=icZxFqA-f3afiaIfHYgTrcxS26w7BaJKzHWVx-oYwAU,28370
|
|
140
140
|
model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
|
|
141
141
|
model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=f6o5Fmpw0aDrO704_A-SqBrKSO1iNEOyofP9pm3g8yg,3936
|
|
142
142
|
model_compression_toolkit/core/keras/kpi_data_facade.py,sha256=rArrfMtxWGR1P4nhKKxqh6fo7pauRDzkRsZIh_SXxO4,8502
|
|
@@ -154,6 +154,7 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/activatio
|
|
|
154
154
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py,sha256=5ZMQLGs5Tc11WUIuAnxiOihVpwqT2bijezPGGtaOgjA,8145
|
|
155
155
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_reconstruction.py,sha256=GR1a3mCZpNUu4WxixJXF_aSm57phAdxaRoHecNx3hxw,3168
|
|
156
156
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_refusing.py,sha256=5df_xGfXkqNub4xVRnCWQvSohWqdv12axjJ6edVU2H0,2478
|
|
157
|
+
model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py,sha256=R3U7cjc2E0zheMem16GHygp5jZFGSaomkNOTxTjcAgw,5794
|
|
157
158
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py,sha256=Yj59BMBrITJnXJHH-7de91LJwH_1l1WhY1udSQjdoi4,5598
|
|
158
159
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py,sha256=Yl72Io4-etnsOXfMKAZmC2lDzmskxZu5gey7IBcUukU,5925
|
|
159
160
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/multi_head_attention_decomposition.py,sha256=aAG5wrcnnydn1pPYqvH56LWsQXjSODbsoNbX_jtQGP4,26759
|
|
@@ -194,7 +195,7 @@ model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKW
|
|
|
194
195
|
model_compression_toolkit/core/pytorch/constants.py,sha256=Kt_GDwe3yX9oMS1DI2eXYuUT25_lpjeCkxpstsAiXCI,2472
|
|
195
196
|
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=n0efDuwJUR-BHldZEGQu5bpb3XIR7QYt1HI5tcmz--c,4224
|
|
196
197
|
model_compression_toolkit/core/pytorch/kpi_data_facade.py,sha256=J0IDOtFMVFSFyBXDzNGbwJfHu89iRBJFdid1_wFB-xQ,8482
|
|
197
|
-
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=
|
|
198
|
+
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=bTjqdPu4u7kVpH5u9kK5WgjexJBrUCSUBkCHAUYPv4A,26257
|
|
198
199
|
model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=n_B4a6FMwM9D2w8kzy3oenBWZgXNZuIZgTJC6JEuTy0,3250
|
|
199
200
|
model_compression_toolkit/core/pytorch/utils.py,sha256=rBQMAbWluyIMjVfeghzq6FZv3sR_khszSRpbWvwussw,2959
|
|
200
201
|
model_compression_toolkit/core/pytorch/back2framework/__init__.py,sha256=H_WixgN0elVWf3exgGYsi58imPoYDj5eYPeh6x4yfug,813
|
|
@@ -357,22 +358,28 @@ model_compression_toolkit/ptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws
|
|
|
357
358
|
model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=nsfvvLbgERes9qqNSLi-X5EgCIhzpLBDqK_dCAKAbGk,8670
|
|
358
359
|
model_compression_toolkit/qat/__init__.py,sha256=BYKgH1NwB9fqF1TszULQ5tDfLI-GqgZV5sao-lDN9EM,1091
|
|
359
360
|
model_compression_toolkit/qat/common/__init__.py,sha256=6tLZ4R4pYP6QVztLVQC_jik2nES3l4uhML0qUxZrezk,829
|
|
360
|
-
model_compression_toolkit/qat/common/qat_config.py,sha256
|
|
361
|
+
model_compression_toolkit/qat/common/qat_config.py,sha256=kbSxFL6_u28furq5mW_75STWDmyX4clPt-seJAnX3IQ,3445
|
|
361
362
|
model_compression_toolkit/qat/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
362
363
|
model_compression_toolkit/qat/keras/quantization_facade.py,sha256=bH9pw2RtFD_91pfA5LPOuZ_ot04uX8R1au_n3Vr_7-0,16173
|
|
363
|
-
model_compression_toolkit/qat/keras/quantizer/__init__.py,sha256=
|
|
364
|
+
model_compression_toolkit/qat/keras/quantizer/__init__.py,sha256=zmYyCa25_KLCSUCGUDRslh3RCIjcRMxc_oXa54Aui-4,996
|
|
364
365
|
model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py,sha256=gPuIgQb8OafvC3SuA8jNsGoy8S8eTsDCEKuh36WDNss,2104
|
|
365
|
-
model_compression_toolkit/qat/keras/quantizer/quant_utils.py,sha256=
|
|
366
|
+
model_compression_toolkit/qat/keras/quantizer/quant_utils.py,sha256=cBULOgWUodcBO1lHevZggdTevuDYI6tQceV86U2x6DA,2543
|
|
366
367
|
model_compression_toolkit/qat/keras/quantizer/quantization_builder.py,sha256=mZwghAnKagL7916CJycFHgJdD5aY6At1A_IBkmYqae4,5635
|
|
368
|
+
model_compression_toolkit/qat/keras/quantizer/lsq/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
|
|
369
|
+
model_compression_toolkit/qat/keras/quantizer/lsq/symmetric_lsq.py,sha256=FdtVZepTCTbPHoHRiCV1P8MluaNR8RVx7fEoAzXyA0U,12051
|
|
370
|
+
model_compression_toolkit/qat/keras/quantizer/lsq/uniform_lsq.py,sha256=P_IU8q5ipUR7r-bmmH_wSzoFsAcjWRVgCU68T_EzXVc,11159
|
|
367
371
|
model_compression_toolkit/qat/keras/quantizer/ste_rounding/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
368
372
|
model_compression_toolkit/qat/keras/quantizer/ste_rounding/symmetric_ste.py,sha256=I4KlaGv17k71IyjuSG9M0OlXlD5P0pfvKa6oCyRQ5FE,13517
|
|
369
373
|
model_compression_toolkit/qat/keras/quantizer/ste_rounding/uniform_ste.py,sha256=EED6LfqhX_OhDRJ9e4GwbpgNC9vq7hoXyJS2VPvG2qc,10789
|
|
370
374
|
model_compression_toolkit/qat/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
371
375
|
model_compression_toolkit/qat/pytorch/quantization_facade.py,sha256=LGIk_nbcQEOhSN8fe56u6QGGUtGOVhF7tlCwW2G-Ig4,12549
|
|
372
|
-
model_compression_toolkit/qat/pytorch/quantizer/__init__.py,sha256=
|
|
376
|
+
model_compression_toolkit/qat/pytorch/quantizer/__init__.py,sha256=xYa4C8pr9cG1f3mQQcBXO_u3IdJN-zl7leZxuXDs86w,1003
|
|
373
377
|
model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py,sha256=FnhuFCuQoSf78FM1z1UZgXXd3k-mKSM7i9dYOuJUmeA,2213
|
|
374
378
|
model_compression_toolkit/qat/pytorch/quantizer/quantization_builder.py,sha256=GOYRDXvQSGe_iUFVmvDy5BqC952hu_-rQO06n8QCyw0,5491
|
|
375
|
-
model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py,sha256=
|
|
379
|
+
model_compression_toolkit/qat/pytorch/quantizer/quantizer_utils.py,sha256=nO7IrDRo5b9Asf21WJacE4vf5voD3UzF_oGjBoGusD4,5335
|
|
380
|
+
model_compression_toolkit/qat/pytorch/quantizer/lsq/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
|
381
|
+
model_compression_toolkit/qat/pytorch/quantizer/lsq/symmetric_lsq.py,sha256=HihuaMi0P0OQkNZlZAE-QeYlK_4AqcDKV6N405SdgI0,10712
|
|
382
|
+
model_compression_toolkit/qat/pytorch/quantizer/lsq/uniform_lsq.py,sha256=fqAI151SUo9OYN4UvQJIcdB6p1r7HVeteqPqHGxv-tI,10355
|
|
376
383
|
model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
|
377
384
|
model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/symmetric_ste.py,sha256=4xmLmg7yN2A7iKnifwkWddgJTWMUiIjFilIuorJeK1A,9657
|
|
378
385
|
model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py,sha256=HshW016iVAMx7iMkUwlONN2P3K4XgDIu-2AnJnBVSGo,8778
|
|
@@ -441,8 +448,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
|
|
|
441
448
|
model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
|
|
442
449
|
model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
|
443
450
|
model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=SbvRlIdE32PEBsINt1bhSqvrKL_zbM9V-aeSkOn-sw4,3083
|
|
444
|
-
mct_nightly-1.10.0.
|
|
445
|
-
mct_nightly-1.10.0.
|
|
446
|
-
mct_nightly-1.10.0.
|
|
447
|
-
mct_nightly-1.10.0.
|
|
448
|
-
mct_nightly-1.10.0.
|
|
451
|
+
mct_nightly-1.10.0.20231019.post424.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
|
452
|
+
mct_nightly-1.10.0.20231019.post424.dist-info/METADATA,sha256=7KfKPP_oFOaieSu_3UTKRv1F3HaR5jiwMOP1DiBAiIg,16303
|
|
453
|
+
mct_nightly-1.10.0.20231019.post424.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
|
|
454
|
+
mct_nightly-1.10.0.20231019.post424.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
|
455
|
+
mct_nightly-1.10.0.20231019.post424.dist-info/RECORD,,
|
|
@@ -67,18 +67,6 @@ class FrameworkImplementation(ABC):
|
|
|
67
67
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
68
68
|
f'framework\'s get_trace_hessian_calculator method.') # pragma: no cover
|
|
69
69
|
|
|
70
|
-
@abstractmethod
|
|
71
|
-
def sample_single_representative_dataset(self, representative_dataset: Callable):
|
|
72
|
-
"""
|
|
73
|
-
Get a single sample (namely, batch size of 1) from a representative dataset.
|
|
74
|
-
|
|
75
|
-
Args:
|
|
76
|
-
representative_dataset: Callable which returns the representative dataset at any batch size.
|
|
77
|
-
|
|
78
|
-
Returns: List of inputs from representative_dataset where each sample has a batch size of 1.
|
|
79
|
-
"""
|
|
80
|
-
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
81
|
-
f'framework\'s sample_single_representative_dataset method.') # pragma: no cover
|
|
82
70
|
|
|
83
71
|
@abstractmethod
|
|
84
72
|
def to_numpy(self, tensor: Any) -> np.ndarray:
|
|
@@ -54,7 +54,7 @@ class HessianInfoService:
|
|
|
54
54
|
self.graph = graph
|
|
55
55
|
|
|
56
56
|
# Create a representative_data_gen with batch size of 1
|
|
57
|
-
self.representative_dataset = partial(
|
|
57
|
+
self.representative_dataset = partial(self._sample_single_representative_dataset,
|
|
58
58
|
representative_dataset=representative_dataset)
|
|
59
59
|
|
|
60
60
|
self.fw_impl = fw_impl
|
|
@@ -62,6 +62,22 @@ class HessianInfoService:
|
|
|
62
62
|
|
|
63
63
|
self.trace_hessian_request_to_score_list = {}
|
|
64
64
|
|
|
65
|
+
def _sample_single_representative_dataset(self, representative_dataset: Callable):
|
|
66
|
+
"""
|
|
67
|
+
Get a single sample (namely, batch size of 1) from a representative dataset.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
representative_dataset: Callable which returns the representative dataset at any batch size.
|
|
71
|
+
|
|
72
|
+
Returns: List of inputs from representative_dataset where each sample has a batch size of 1.
|
|
73
|
+
"""
|
|
74
|
+
images = next(representative_dataset())
|
|
75
|
+
if not isinstance(images, list):
|
|
76
|
+
Logger.error(f'Images expected to be a list but is of type {type(images)}')
|
|
77
|
+
|
|
78
|
+
# Ensure each image is a single sample, if not, take the first sample
|
|
79
|
+
return [image[0:1, ...] if image.shape[0] != 1 else image for image in images]
|
|
80
|
+
|
|
65
81
|
def _clear_saved_hessian_info(self):
|
|
66
82
|
"""Clears the saved info approximations."""
|
|
67
83
|
self.trace_hessian_request_to_score_list={}
|
|
@@ -53,6 +53,13 @@ DIMS = 'dims'
|
|
|
53
53
|
TARGET_SHAPE = 'target_shape'
|
|
54
54
|
TRANSPOSE_A = 'transpose_a'
|
|
55
55
|
TRANSPOSE_B = 'transpose_b'
|
|
56
|
+
DEPTH_MULTIPLIER = 'depth_multiplier'
|
|
57
|
+
DEPTHWISE_INITIALIZER = 'depthwise_initializer'
|
|
58
|
+
DEPTHWISE_REGULARIZER = 'depthwise_regularizer'
|
|
59
|
+
DEPTHWISE_CONSTRAINT = 'depthwise_constraint'
|
|
60
|
+
KERNEL_INITIALIZER = 'kernel_initializer'
|
|
61
|
+
KERNEL_REGULARIZER = 'kernel_regularizer'
|
|
62
|
+
KERNEL_CONSTRAINT = 'kernel_constraint'
|
|
56
63
|
|
|
57
64
|
# functional nodes attributes
|
|
58
65
|
FUNCTION = 'function'
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductors Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import tensorflow as tf
|
|
18
|
+
from packaging import version
|
|
19
|
+
if version.parse(tf.__version__) >= version.parse("2.13"):
|
|
20
|
+
from keras.src.layers.core import TFOpLambda
|
|
21
|
+
from keras.src.layers import Dense, Conv2D, Softmax, Concatenate, Reshape, Permute, DepthwiseConv2D
|
|
22
|
+
else:
|
|
23
|
+
from keras.layers.core import TFOpLambda
|
|
24
|
+
from keras.layers import Dense, Conv2D, Softmax, Concatenate, Reshape, Permute, DepthwiseConv2D
|
|
25
|
+
from model_compression_toolkit.core import common
|
|
26
|
+
from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
|
|
27
|
+
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
|
28
|
+
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
|
29
|
+
from model_compression_toolkit.constants import REUSE, REUSE_GROUP
|
|
30
|
+
from model_compression_toolkit.core.keras.reader.node_builder import REUSED_IDENTIFIER
|
|
31
|
+
from model_compression_toolkit.core.keras.constants import KERNEL, BIAS, USE_BIAS, NUM_HEADS, KEY_DIM, VALUE_DIM, \
|
|
32
|
+
QUERY_SHAPE, KEY_SHAPE, VALUE_SHAPE, OUTPUT_SHAPE, ATTENTION_AXES, ACTIVATION, GROUPS, LINEAR, FILTERS, PADDING, \
|
|
33
|
+
FUNCTION, DIMS, TARGET_SHAPE, F_STRIDED_SLICE, F_STACK, Q_KERNEL, Q_BIAS, K_KERNEL, K_BIAS, V_KERNEL, V_BIAS, \
|
|
34
|
+
OUTPUT_KERNEL, OUTPUT_BIAS, F_MATMUL, TRANSPOSE_B, KERNEL_SIZE, AXIS, F_STRIDED_SLICE_BEGIN, F_STRIDED_SLICE_END, \
|
|
35
|
+
DEPTH_MULTIPLIER, DEPTHWISE_INITIALIZER, DEPTHWISE_REGULARIZER, DEPTHWISE_CONSTRAINT, KERNEL_INITIALIZER, \
|
|
36
|
+
KERNEL_REGULARIZER, KERNEL_CONSTRAINT
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class DwconvToConv(common.BaseSubstitution):
|
|
40
|
+
"""
|
|
41
|
+
A substitution class for replacing DepthwiseConv2D layers with Conv2D layers having 'groups' equal to the number of
|
|
42
|
+
input channels.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self):
|
|
46
|
+
"""
|
|
47
|
+
Initializes the DwconvToConv substitution
|
|
48
|
+
"""
|
|
49
|
+
super().__init__(matcher_instance=NodeOperationMatcher(DepthwiseConv2D))
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
def _get_weight_by_name(node, w_str):
|
|
53
|
+
"""
|
|
54
|
+
Retrieve the weight with a given name from the node.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
node: The node containing weights.
|
|
58
|
+
w_str: The name of the weight to retrieve.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
The weight with the specified name or None if not found.
|
|
62
|
+
"""
|
|
63
|
+
w = [k for k in node.weights.keys() if w_str in k]
|
|
64
|
+
return node.weights[w[0]]
|
|
65
|
+
|
|
66
|
+
def substitute(self,
|
|
67
|
+
graph: Graph,
|
|
68
|
+
dwconv_node: BaseNode) -> Graph:
|
|
69
|
+
"""
|
|
70
|
+
Replace a DepthwiseConv2D layer with a Conv2D layer, setting 'groups' parameter to the number of input channels.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
graph: The graph on which the substitution is applied.
|
|
74
|
+
dwconv_node: The DepthwiseConv2D node to be replaced.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
The modified graph after applying the substitution.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
# Skip in case mult depth_multiplier=1
|
|
81
|
+
if dwconv_node.framework_attr[DEPTH_MULTIPLIER] == 1:
|
|
82
|
+
return graph
|
|
83
|
+
|
|
84
|
+
# Build the new node
|
|
85
|
+
k = self._get_weight_by_name(dwconv_node, KERNEL).copy()
|
|
86
|
+
k_shape = k.shape
|
|
87
|
+
filters = k_shape[2] * k_shape[3] # k_shape[2] * k_shape[3] = number of output channels
|
|
88
|
+
|
|
89
|
+
# Transform the DepthwiseConv2D kernel to match the Conv2D kernel, where each input channel is convolved with
|
|
90
|
+
# 'depth_multiplier' filters.
|
|
91
|
+
k = np.reshape(k,[k_shape[0], k_shape[1], 1, filters])
|
|
92
|
+
_reuse_params = {REUSE: dwconv_node.reuse, REUSE_GROUP: dwconv_node.reuse_group}
|
|
93
|
+
|
|
94
|
+
conv_fw_attr = dwconv_node.framework_attr
|
|
95
|
+
conv_fw_attr.update({FILTERS: filters,
|
|
96
|
+
GROUPS: k_shape[2],
|
|
97
|
+
KERNEL_INITIALIZER: dwconv_node.framework_attr[DEPTHWISE_INITIALIZER],
|
|
98
|
+
KERNEL_REGULARIZER: dwconv_node.framework_attr[DEPTHWISE_REGULARIZER],
|
|
99
|
+
KERNEL_CONSTRAINT: dwconv_node.framework_attr[DEPTHWISE_CONSTRAINT]})
|
|
100
|
+
|
|
101
|
+
conv_fw_attr.pop(DEPTH_MULTIPLIER)
|
|
102
|
+
conv_fw_attr.pop(DEPTHWISE_INITIALIZER)
|
|
103
|
+
conv_fw_attr.pop(DEPTHWISE_REGULARIZER)
|
|
104
|
+
conv_fw_attr.pop(DEPTHWISE_CONSTRAINT)
|
|
105
|
+
|
|
106
|
+
conv_weights = {KERNEL: k}
|
|
107
|
+
if conv_fw_attr[USE_BIAS]:
|
|
108
|
+
b = self._get_weight_by_name(dwconv_node, BIAS).copy()
|
|
109
|
+
conv_weights.update({BIAS: b})
|
|
110
|
+
|
|
111
|
+
conv_node = BaseNode(dwconv_node.name, conv_fw_attr, dwconv_node.input_shape, dwconv_node.output_shape,
|
|
112
|
+
conv_weights, Conv2D,
|
|
113
|
+
**_reuse_params)
|
|
114
|
+
|
|
115
|
+
graph.add_node(conv_node)
|
|
116
|
+
|
|
117
|
+
# Replace DWconv node with Conv node
|
|
118
|
+
_in_edge = list(graph.in_edges(dwconv_node))[0]
|
|
119
|
+
_out_edges = graph.out_edges(dwconv_node)
|
|
120
|
+
graph.add_edge(_in_edge[0], conv_node, **graph.get_edge_data(*_in_edge, 0))
|
|
121
|
+
graph.remove_edge(_in_edge[0], dwconv_node)
|
|
122
|
+
graph.reconnect_out_edges(current_node=dwconv_node, new_node=conv_node)
|
|
123
|
+
|
|
124
|
+
# Finally, remove the DWconv node
|
|
125
|
+
graph.remove_node(dwconv_node, new_graph_outputs=[OutTensor(conv_node, 0)])
|
|
126
|
+
|
|
127
|
+
return graph
|
|
@@ -91,6 +91,7 @@ from model_compression_toolkit.core.keras.graph_substitutions.substitutions.sepa
|
|
|
91
91
|
SeparableConvDecomposition, DEPTH_MULTIPLIER
|
|
92
92
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.shift_negative_activation import \
|
|
93
93
|
keras_apply_shift_negative_correction
|
|
94
|
+
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.dwconv_to_conv import DwconvToConv
|
|
94
95
|
from model_compression_toolkit.core.keras.keras_node_prior_info import create_node_prior_info
|
|
95
96
|
from model_compression_toolkit.core.keras.reader.reader import model_reader
|
|
96
97
|
from model_compression_toolkit.core.common.collectors.statistics_collector_generator import \
|
|
@@ -259,7 +260,8 @@ class KerasImplementation(FrameworkImplementation):
|
|
|
259
260
|
"""
|
|
260
261
|
return [SeparableConvDecomposition(),
|
|
261
262
|
MultiHeadAttentionDecomposition(),
|
|
262
|
-
ActivationDecomposition()
|
|
263
|
+
ActivationDecomposition(),
|
|
264
|
+
DwconvToConv()]
|
|
263
265
|
|
|
264
266
|
def get_substitutions_pre_statistics_collection(self, quant_config: QuantizationConfig) -> \
|
|
265
267
|
List[common.BaseSubstitution]:
|
|
@@ -591,19 +593,3 @@ class KerasImplementation(FrameworkImplementation):
|
|
|
591
593
|
"""
|
|
592
594
|
|
|
593
595
|
return model(inputs)
|
|
594
|
-
|
|
595
|
-
def sample_single_representative_dataset(self, representative_dataset: Callable):
|
|
596
|
-
"""
|
|
597
|
-
Get a single sample (namely, batch size of 1) from a representative dataset.
|
|
598
|
-
|
|
599
|
-
Args:
|
|
600
|
-
representative_dataset: Callable which returns the representative dataset at any batch size.
|
|
601
|
-
|
|
602
|
-
Returns: List of inputs from representative_dataset where each sample has a batch size of 1.
|
|
603
|
-
"""
|
|
604
|
-
images = next(representative_dataset())
|
|
605
|
-
if not isinstance(images, list):
|
|
606
|
-
Logger.error(f'Images expected to be a list but is of type {type(images)}')
|
|
607
|
-
|
|
608
|
-
# Ensure each image is a single sample, if not, take the first sample
|
|
609
|
-
return [tf.expand_dims(image[0], 0) if image.shape[0] != 1 else image for image in images]
|
|
@@ -540,18 +540,3 @@ class PytorchImplementation(FrameworkImplementation):
|
|
|
540
540
|
fw_impl=self,
|
|
541
541
|
num_iterations_for_approximation=num_iterations_for_approximation)
|
|
542
542
|
|
|
543
|
-
def sample_single_representative_dataset(self, representative_dataset: Callable):
|
|
544
|
-
"""
|
|
545
|
-
Get a single sample (namely, batch size of 1) from a representative dataset.
|
|
546
|
-
|
|
547
|
-
Args:
|
|
548
|
-
representative_dataset: Callable which returns the representative dataset at any batch size.
|
|
549
|
-
|
|
550
|
-
Returns: List of inputs from representative_dataset where each sample has a batch size of 1.
|
|
551
|
-
"""
|
|
552
|
-
images = next(representative_dataset())
|
|
553
|
-
if not isinstance(images, list):
|
|
554
|
-
Logger.error(f'Images expected to be a list but is of type {type(images)}')
|
|
555
|
-
|
|
556
|
-
# Ensure each image is a single sample, if not, take the first sample
|
|
557
|
-
return [torch.unsqueeze(image[0], 0) if image.shape[0] != 1 else image for image in images]
|
|
@@ -45,9 +45,12 @@ class TrainingMethod(Enum):
|
|
|
45
45
|
|
|
46
46
|
DQA - DNN Quantization with Attention. Includes a smooth quantization introduces by DQA method
|
|
47
47
|
|
|
48
|
+
LSQ - Learned Step size Quantization. Includes PowerOfTwo, symmetric & uniform quantizers: https://arxiv.org/pdf/1902.08153.pdf
|
|
49
|
+
|
|
48
50
|
"""
|
|
49
51
|
STE = "STE",
|
|
50
|
-
DQA = "DQA"
|
|
52
|
+
DQA = "DQA",
|
|
53
|
+
LSQ = "LSQ"
|
|
51
54
|
|
|
52
55
|
|
|
53
56
|
class QATConfig:
|
|
@@ -15,3 +15,5 @@
|
|
|
15
15
|
|
|
16
16
|
import model_compression_toolkit.qat.keras.quantizer.ste_rounding.symmetric_ste
|
|
17
17
|
import model_compression_toolkit.qat.keras.quantizer.ste_rounding.uniform_ste
|
|
18
|
+
import model_compression_toolkit.qat.keras.quantizer.lsq.symmetric_lsq
|
|
19
|
+
import model_compression_toolkit.qat.keras.quantizer.lsq.uniform_lsq
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import Union
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import tensorflow as tf
|
|
20
|
+
from tensorflow.python.framework.tensor_shape import TensorShape
|
|
21
|
+
from model_compression_toolkit.constants import SIGNED
|
|
22
|
+
|
|
23
|
+
from model_compression_toolkit.qat import TrainingMethod
|
|
24
|
+
|
|
25
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
|
|
26
|
+
from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
|
|
27
|
+
from mct_quantizers import QuantizationTarget, mark_quantizer
|
|
28
|
+
from model_compression_toolkit.qat.common import THRESHOLD_TENSOR
|
|
29
|
+
from model_compression_toolkit import constants as C
|
|
30
|
+
|
|
31
|
+
from model_compression_toolkit.qat.keras.quantizer.base_keras_qat_quantizer import BaseKerasQATTrainableQuantizer
|
|
32
|
+
from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \
|
|
33
|
+
TrainableQuantizerActivationConfig
|
|
34
|
+
from mct_quantizers.keras.quantizers import WeightsPOTInferableQuantizer, WeightsSymmetricInferableQuantizer, \
|
|
35
|
+
ActivationPOTInferableQuantizer, ActivationSymmetricInferableQuantizer
|
|
36
|
+
from model_compression_toolkit.trainable_infrastructure.common.base_trainable_quantizer import VariableGroup
|
|
37
|
+
from model_compression_toolkit.qat.keras.quantizer.quant_utils import ste_round, grad_scale
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def symmetric_lsq_quantizer(x: tf.Tensor,
|
|
41
|
+
thresholds: tf.Tensor,
|
|
42
|
+
num_bits: int,
|
|
43
|
+
sign: bool,
|
|
44
|
+
min_int: int,
|
|
45
|
+
max_int:int,
|
|
46
|
+
scale_factor: float) -> tf.Tensor:
|
|
47
|
+
"""
|
|
48
|
+
Symmetric quantizer according to LSQ algorithm: https://arxiv.org/pdf/1902.08153.pdf
|
|
49
|
+
Args:
|
|
50
|
+
x: input to quantize
|
|
51
|
+
thresholds: thresholds of quantization levels
|
|
52
|
+
num_bits: number of bits for quantization
|
|
53
|
+
sign: whether x is signed or not
|
|
54
|
+
min_int: min clipping integer value
|
|
55
|
+
max_int: max clipping integer value
|
|
56
|
+
scale_factor: grad scale of LSQ algorithm
|
|
57
|
+
Returns:
|
|
58
|
+
A quantized tensor
|
|
59
|
+
"""
|
|
60
|
+
delta = thresholds / (2 ** (num_bits - int(sign)))
|
|
61
|
+
delta_scaled = grad_scale(delta, scale_factor)
|
|
62
|
+
rounded = ste_round(x / delta_scaled)
|
|
63
|
+
clipped = tf.math.minimum(tf.math.maximum(rounded, min_int), max_int)
|
|
64
|
+
quantized = delta_scaled * clipped
|
|
65
|
+
return quantized
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@mark_quantizer(quantization_target=QuantizationTarget.Weights,
|
|
69
|
+
quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
|
|
70
|
+
identifier=TrainingMethod.LSQ)
|
|
71
|
+
class LSQWeightQATQuantizer(BaseKerasQATTrainableQuantizer):
|
|
72
|
+
"""
|
|
73
|
+
Trainable constrained quantizer to quantize layer's weights.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def __init__(self, quantization_config: TrainableQuantizerWeightsConfig):
|
|
77
|
+
"""
|
|
78
|
+
Initialize a LSQWeightQATQuantizer object with parameters to use
|
|
79
|
+
for the quantization.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
quantization_config: trainable quantizer config class
|
|
83
|
+
"""
|
|
84
|
+
super().__init__(quantization_config)
|
|
85
|
+
self.power_of_two = quantization_config.weights_quantization_method == QuantizationMethod.POWER_OF_TWO
|
|
86
|
+
self.threshold_values = np.array(quantization_config.weights_quantization_params[C.THRESHOLD])
|
|
87
|
+
self.threshold_shape = self.threshold_values.shape
|
|
88
|
+
self.per_channel = self.quantization_config.weights_per_channel_threshold
|
|
89
|
+
self.channel_axis = self.quantization_config.weights_channels_axis
|
|
90
|
+
self.threshold_values = np.reshape(np.asarray(self.threshold_values), [-1]) if self.per_channel else float(self.threshold_values)
|
|
91
|
+
self.num_bits = self.quantization_config.weights_n_bits
|
|
92
|
+
n_pos_bits = self.num_bits - int(C.WEIGHTS_SIGNED)
|
|
93
|
+
self.min_int = -int(C.WEIGHTS_SIGNED) * (2 ** n_pos_bits)
|
|
94
|
+
self.max_int = 2 **n_pos_bits - 1
|
|
95
|
+
self.scale_factor = 1.0 / np.sqrt(self.max_int * self.threshold_values.size)
|
|
96
|
+
if self.power_of_two:
|
|
97
|
+
self.threshold_values = np.power(2.0, np.ceil(np.log2(np.maximum(self.threshold_values, C.MIN_THRESHOLD))))
|
|
98
|
+
|
|
99
|
+
def initialize_quantization(self,
|
|
100
|
+
tensor_shape: TensorShape,
|
|
101
|
+
name: str,
|
|
102
|
+
layer: KerasTrainableQuantizationWrapper):
|
|
103
|
+
"""
|
|
104
|
+
Add quantizer parameters to the quantizer parameters dictionary
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
tensor_shape: tensor shape of the quantized tensor.
|
|
108
|
+
name: Tensor name.
|
|
109
|
+
layer: Layer to quantize.
|
|
110
|
+
"""
|
|
111
|
+
ptq_threshold_tensor = layer.add_weight(
|
|
112
|
+
name + THRESHOLD_TENSOR,
|
|
113
|
+
shape=len(self.threshold_values) if self.per_channel else (),
|
|
114
|
+
initializer=tf.keras.initializers.Constant(1.0),
|
|
115
|
+
trainable=True)
|
|
116
|
+
ptq_threshold_tensor.assign(self.threshold_values)
|
|
117
|
+
|
|
118
|
+
# save the quantizer added parameters for later calculations
|
|
119
|
+
self.add_quantizer_variable(THRESHOLD_TENSOR, ptq_threshold_tensor, VariableGroup.QPARAMS)
|
|
120
|
+
|
|
121
|
+
def __call__(self,
|
|
122
|
+
inputs: tf.Tensor,
|
|
123
|
+
training: bool):
|
|
124
|
+
"""
|
|
125
|
+
Quantize a tensor.
|
|
126
|
+
Args:
|
|
127
|
+
inputs: Input tensor to quantize.
|
|
128
|
+
training: Whether the graph is in training mode.
|
|
129
|
+
weights: Dictionary of weights the quantizer can use to quantize the tensor.
|
|
130
|
+
**kwargs: Additional variables the quantizer may receive.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
The quantized tensor.
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
thresholds = self.get_quantizer_variable(THRESHOLD_TENSOR)
|
|
137
|
+
q_tensor = symmetric_lsq_quantizer(inputs, thresholds, self.num_bits, C.WEIGHTS_SIGNED, self.min_int, self.max_int, self.scale_factor)
|
|
138
|
+
return q_tensor
|
|
139
|
+
|
|
140
|
+
def convert2inferable(self) -> Union[WeightsPOTInferableQuantizer, WeightsSymmetricInferableQuantizer]:
|
|
141
|
+
"""
|
|
142
|
+
Convert quantizer to inferable quantizer.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
BaseKerasInferableQuantizer object.
|
|
146
|
+
"""
|
|
147
|
+
if self.power_of_two:
|
|
148
|
+
thresholds = 2 ** np.ceil(np.log2(self.get_quantizer_variable(THRESHOLD_TENSOR).numpy()))
|
|
149
|
+
return WeightsPOTInferableQuantizer(num_bits=self.num_bits,
|
|
150
|
+
threshold=list(thresholds.flatten()),
|
|
151
|
+
per_channel=self.per_channel,
|
|
152
|
+
channel_axis=self.channel_axis,
|
|
153
|
+
input_rank=len(self.threshold_shape))
|
|
154
|
+
else:
|
|
155
|
+
thresholds = self.get_quantizer_variable(THRESHOLD_TENSOR).numpy()
|
|
156
|
+
return WeightsSymmetricInferableQuantizer(num_bits=self.num_bits,
|
|
157
|
+
threshold=list(thresholds.flatten()),
|
|
158
|
+
per_channel=self.per_channel,
|
|
159
|
+
channel_axis=self.channel_axis,
|
|
160
|
+
input_rank=len(self.threshold_shape))
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@mark_quantizer(quantization_target=QuantizationTarget.Activation,
|
|
164
|
+
quantization_method=[QuantizationMethod.POWER_OF_TWO, QuantizationMethod.SYMMETRIC],
|
|
165
|
+
identifier=TrainingMethod.LSQ)
|
|
166
|
+
class LSQActivationQATQuantizer(BaseKerasQATTrainableQuantizer):
|
|
167
|
+
"""
|
|
168
|
+
Trainable constrained quantizer to quantize layer activations.
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def __init__(self, quantization_config: TrainableQuantizerActivationConfig):
|
|
172
|
+
"""
|
|
173
|
+
Initialize a LSQActivationQATQuantizer object with parameters to use
|
|
174
|
+
for the quantization.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
quantization_config: trainable quantizer config class
|
|
178
|
+
"""
|
|
179
|
+
super().__init__(quantization_config)
|
|
180
|
+
self.power_of_two = quantization_config.activation_quantization_method == QuantizationMethod.POWER_OF_TWO
|
|
181
|
+
self.threshold_values = float(quantization_config.activation_quantization_params[C.THRESHOLD])
|
|
182
|
+
self.threshold_shape = np.asarray(self.threshold_values).shape
|
|
183
|
+
self.sign = quantization_config.activation_quantization_params[SIGNED]
|
|
184
|
+
self.num_bits = quantization_config.activation_n_bits
|
|
185
|
+
n_pos_bits = self.num_bits - int(self.sign)
|
|
186
|
+
self.min_int = -int(self.sign) * (2 ** n_pos_bits)
|
|
187
|
+
self.max_int = (2 ** n_pos_bits) - 1
|
|
188
|
+
if self.power_of_two:
|
|
189
|
+
self.threshold_values = np.power(2.0, np.ceil(np.log2(np.maximum(self.threshold_values, C.MIN_THRESHOLD))))
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def initialize_quantization(self,
|
|
193
|
+
tensor_shape: TensorShape,
|
|
194
|
+
name: str,
|
|
195
|
+
layer: KerasTrainableQuantizationWrapper):
|
|
196
|
+
"""
|
|
197
|
+
Add quantizer parameters to the quantizer parameters dictionary
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
tensor_shape: tensor shape of the quantized tensor.
|
|
201
|
+
name: Tensor name.
|
|
202
|
+
layer: Layer to quantize.
|
|
203
|
+
"""
|
|
204
|
+
ptq_threshold_tensor = layer.add_weight(
|
|
205
|
+
name + THRESHOLD_TENSOR,
|
|
206
|
+
shape=(),
|
|
207
|
+
initializer=tf.keras.initializers.Constant(1.0),
|
|
208
|
+
trainable=True)
|
|
209
|
+
ptq_threshold_tensor.assign(self.threshold_values)
|
|
210
|
+
|
|
211
|
+
# save the quantizer added parameters for later calculations
|
|
212
|
+
self.add_quantizer_variable(THRESHOLD_TENSOR, ptq_threshold_tensor, VariableGroup.QPARAMS)
|
|
213
|
+
|
|
214
|
+
def __call__(self,
|
|
215
|
+
inputs: tf.Tensor,
|
|
216
|
+
training: bool):
|
|
217
|
+
"""
|
|
218
|
+
Quantize a tensor.
|
|
219
|
+
Args:
|
|
220
|
+
inputs: Input tensor to quantize.
|
|
221
|
+
training: Whether the graph is in training mode.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
The quantized tensor.
|
|
225
|
+
"""
|
|
226
|
+
|
|
227
|
+
thresholds = self.get_quantizer_variable(THRESHOLD_TENSOR)
|
|
228
|
+
n_channels = inputs.shape[-1]
|
|
229
|
+
scale_factor = 1.0 / np.sqrt(self.max_int * n_channels)
|
|
230
|
+
q_tensor = symmetric_lsq_quantizer(inputs, thresholds, self.num_bits, self.sign, self.min_int, self.max_int, scale_factor)
|
|
231
|
+
return q_tensor
|
|
232
|
+
|
|
233
|
+
def convert2inferable(self) -> Union[ActivationPOTInferableQuantizer, ActivationSymmetricInferableQuantizer]:
|
|
234
|
+
"""
|
|
235
|
+
Convert quantizer to inferable quantizer.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
BaseKerasInferableQuantizer object.
|
|
239
|
+
"""
|
|
240
|
+
|
|
241
|
+
if self.power_of_two:
|
|
242
|
+
thresholds = 2 ** np.ceil(np.log2(self.get_quantizer_variable(THRESHOLD_TENSOR).numpy()))
|
|
243
|
+
return ActivationPOTInferableQuantizer(num_bits=self.num_bits,
|
|
244
|
+
# In activation quantization is per-tensor only - thus we pass
|
|
245
|
+
# the threshold as a list with a len of 1
|
|
246
|
+
threshold=[thresholds],
|
|
247
|
+
signed=self.sign)
|
|
248
|
+
else:
|
|
249
|
+
thresholds = self.get_quantizer_variable(THRESHOLD_TENSOR).numpy()
|
|
250
|
+
return ActivationSymmetricInferableQuantizer(num_bits=self.num_bits,
|
|
251
|
+
# In activation quantization is per-tensor only - thus we
|
|
252
|
+
# pass the threshold as a list with a len of 1
|
|
253
|
+
threshold=[thresholds],
|
|
254
|
+
signed=self.sign)
|