mct-nightly 2.0.0.20240407.442__py3-none-any.whl → 2.0.0.20240409.404__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.0.0.20240407.442.dist-info → mct_nightly-2.0.0.20240409.404.dist-info}/METADATA +1 -1
- {mct_nightly-2.0.0.20240407.442.dist-info → mct_nightly-2.0.0.20240409.404.dist-info}/RECORD +17 -17
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +22 -12
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +15 -14
- model_compression_toolkit/core/common/quantization/quantize_graph_weights.py +5 -4
- model_compression_toolkit/core/common/similarity_analyzer.py +4 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +31 -4
- model_compression_toolkit/gptq/keras/gptq_training.py +6 -2
- model_compression_toolkit/gptq/keras/quantization_facade.py +10 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +6 -2
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +9 -1
- model_compression_toolkit/ptq/keras/quantization_facade.py +22 -5
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +28 -9
- {mct_nightly-2.0.0.20240407.442.dist-info → mct_nightly-2.0.0.20240409.404.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.0.0.20240407.442.dist-info → mct_nightly-2.0.0.20240409.404.dist-info}/WHEEL +0 -0
- {mct_nightly-2.0.0.20240407.442.dist-info → mct_nightly-2.0.0.20240409.404.dist-info}/top_level.txt +0 -0
{mct_nightly-2.0.0.20240407.442.dist-info → mct_nightly-2.0.0.20240409.404.dist-info}/RECORD
RENAMED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
|
1
|
+
model_compression_toolkit/__init__.py,sha256=ALvOQYWLrTHNtxDnpNxy7lyftsvgDpzcoW-wTFtMedY,1573
|
|
2
2
|
model_compression_toolkit/constants.py,sha256=KW_HUEPmQEYqCvWGyORqkYxpvO7w5LViB5J5D-pm_6o,3648
|
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
|
4
4
|
model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
|
|
5
5
|
model_compression_toolkit/core/__init__.py,sha256=TrRgkWpT1AN2Faw1M_1HXyJkJnbxfn9p-RigDZl7pg0,1982
|
|
6
|
-
model_compression_toolkit/core/analyzer.py,sha256=
|
|
6
|
+
model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
|
|
7
7
|
model_compression_toolkit/core/graph_prep_runner.py,sha256=Ftqm59hT5TGWmSNkY9bFZkVfCacpGyZfCe-6yZR5WY0,10100
|
|
8
8
|
model_compression_toolkit/core/quantization_prep_runner.py,sha256=hFhDkS8GwzXZ7Ho_9qbbb8DAAWs3OONOfMSD5OU_b0o,6153
|
|
9
9
|
model_compression_toolkit/core/runner.py,sha256=NKSC6ujfQPy6dKtJVwxyK2zNDd64eyR5csYy9lBrCPA,11836
|
|
@@ -16,7 +16,7 @@ model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3u
|
|
|
16
16
|
model_compression_toolkit/core/common/model_collector.py,sha256=ofcepKtxc3j2Ouz6BpAKXTzPgjABnpRP47ndmJCXAkk,8352
|
|
17
17
|
model_compression_toolkit/core/common/model_validation.py,sha256=LaG8wd6aZl0OJgieE3SeiVDEPxtk8IHq9-3wSnmWhY4,1214
|
|
18
18
|
model_compression_toolkit/core/common/node_prior_info.py,sha256=WXX_PrGVG9M9I_REG5ZzFBohwmV4yf356sZnrja_FLo,2832
|
|
19
|
-
model_compression_toolkit/core/common/similarity_analyzer.py,sha256=
|
|
19
|
+
model_compression_toolkit/core/common/similarity_analyzer.py,sha256=98l9ttnXHf6VYxBW4852h2CPJKg3A6nLOovpHn-tnKs,8560
|
|
20
20
|
model_compression_toolkit/core/common/user_info.py,sha256=dSRMnT-oewmdOziIpEuW-s9K7vTSeyUBxT4z9neXurI,1648
|
|
21
21
|
model_compression_toolkit/core/common/back2framework/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
22
22
|
model_compression_toolkit/core/common/back2framework/base_model_builder.py,sha256=V1oShKzbSkdcTvREn8VnQQBzvm-tTHkWMXqMkYozF2s,2023
|
|
@@ -100,11 +100,11 @@ model_compression_toolkit/core/common/quantization/candidate_node_quantization_c
|
|
|
100
100
|
model_compression_toolkit/core/common/quantization/core_config.py,sha256=KYdyfSmjSL4ye24nKlC_c4_AxYb14qoqaeMnZj4-8kE,2257
|
|
101
101
|
model_compression_toolkit/core/common/quantization/debug_config.py,sha256=HtkMmneN-EmAzgZK4Vp4M8Sqm5QKdrvNyyZMpaVqYzY,1482
|
|
102
102
|
model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=fwF4VILaX-u3ZaFd81xjbJuhg8Ef-JX_KfMXW0TPV-I,7136
|
|
103
|
-
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=
|
|
103
|
+
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=h_cgpvT50gYgO8T363-Zw_b2jfqo3uoa7TqnSuig7I4,26947
|
|
104
104
|
model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=hQMKm55EXS1oV-Upt6IQtsYhpuhMvYeWRJhh6lhv_Ko,6699
|
|
105
105
|
model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=T1nVWdRJfBQ_iuMQYQSIkjfkR-2n3lAOKGAz_rUZZN0,2190
|
|
106
106
|
model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=MwIOBZ4BlZSTIOG75PDvlI3JmZ6t8YjPc1VP9Adei60,3847
|
|
107
|
-
model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=
|
|
107
|
+
model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=N005MSvx8UypVpa7XrxNrB2G732n2wHj3RmLyjTgd3I,2728
|
|
108
108
|
model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=cdzGNWfT4MRogIU8ehs0tr3lVjnzAI-jeoS9b4TwVBo,2854
|
|
109
109
|
model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=9BEv2l0z2trDEsr40VB8tO3ToBA_b2sd_jH9uqZ5Wo8,11503
|
|
110
110
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=eCDGwsWYLU6z7qbEVb4TozMW_nd5VEP_iCJ6PcvyEPw,1486
|
|
@@ -142,7 +142,7 @@ model_compression_toolkit/core/common/substitutions/virtual_activation_weights_c
|
|
|
142
142
|
model_compression_toolkit/core/common/substitutions/weights_activation_split.py,sha256=h85L2VlDOqbLd-N98wA3SdYWiblBgSsPceNuLanJd70,4737
|
|
143
143
|
model_compression_toolkit/core/common/visualization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
|
144
144
|
model_compression_toolkit/core/common/visualization/final_config_visualizer.py,sha256=6I10jKLesB-RQKaXA75Xgz2wPvylQUrnPtCcQZIynGo,6371
|
|
145
|
-
model_compression_toolkit/core/common/visualization/nn_visualizer.py,sha256=
|
|
145
|
+
model_compression_toolkit/core/common/visualization/nn_visualizer.py,sha256=HOq7AObkmEZiDSZXUMJDAEJzUY-fSXUT0AMgwiyH7dg,7388
|
|
146
146
|
model_compression_toolkit/core/common/visualization/tensorboard_writer.py,sha256=4E4ZXZmqusGIJ4XQNH8FFt07htAHgT3gy5E7wPIaVBI,21951
|
|
147
147
|
model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
|
148
148
|
model_compression_toolkit/core/keras/constants.py,sha256=Uv3c0UdW55pIVQNW_1HQlgl-dHXREkltOLyzp8G1mTQ,3163
|
|
@@ -336,9 +336,9 @@ model_compression_toolkit/gptq/common/gptq_training.py,sha256=rLA1xlOO-6gWfmc2dL
|
|
|
336
336
|
model_compression_toolkit/gptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
337
337
|
model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCSjq5xk-xGymOwSOqjp39It-CVtGcCTRTf0E_4,1248
|
|
338
338
|
model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
|
|
339
|
-
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=
|
|
339
|
+
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=OhYfH6zxRHrRhCde0lbcV9Hu2oeDD9RXh-O8vOPgLbs,18875
|
|
340
340
|
model_compression_toolkit/gptq/keras/graph_info.py,sha256=5IvgGlJlgOmQYmldjdCBv7tuzAoY0HazatG5Pedrg0Q,4639
|
|
341
|
-
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=
|
|
341
|
+
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=zAkzWpWP9_aobWgMo_BlUm7-4fR5dHvoGx0sDqs2rZg,14299
|
|
342
342
|
model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
|
|
343
343
|
model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=2YU-x4-Q5f6hkUJf0tw6vcwdNwRMHdefrFjhhyHYsvA,4782
|
|
344
344
|
model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
|
|
@@ -353,9 +353,9 @@ model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py,sha
|
|
|
353
353
|
model_compression_toolkit/gptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
354
354
|
model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=kDuWw-6zh17wZpYWh4Xa94rpoodf82DksgjQCnL7nBc,2719
|
|
355
355
|
model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
|
|
356
|
-
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=
|
|
356
|
+
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=LN4vOwcMuSSFTSnHDACV9hX_Yd2YIXJRl7WkdODuA0k,16245
|
|
357
357
|
model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=yXJzDd24zfGs2_vfMovxD1WSh1RxXoPxN4GztOf3P5c,3967
|
|
358
|
-
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256
|
|
358
|
+
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=-4USg-tep6EQSArcTxBowhMeAuExrBTNLOWgHFpsIy4,12699
|
|
359
359
|
model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
|
|
360
360
|
model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=TCA1hAc7raPnrjl06sjFtVM4XUtLtuwAhCGX4U3KGZo,4137
|
|
361
361
|
model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
|
|
@@ -375,9 +375,9 @@ model_compression_toolkit/pruning/pytorch/pruning_facade.py,sha256=cSuvHHCqgr7k9
|
|
|
375
375
|
model_compression_toolkit/ptq/__init__.py,sha256=Z_hkmTh7aLFei1DJKV0oNVUbrv_Q_0CTw-qD85Xf8UM,904
|
|
376
376
|
model_compression_toolkit/ptq/runner.py,sha256=_c1dSjlPPpsx59Vbg1buhG9bZq__OORz1VlPkwjJzoc,2552
|
|
377
377
|
model_compression_toolkit/ptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
378
|
-
model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=
|
|
378
|
+
model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=T1_UqXmOc4I2a6IHkQAlFhGtcAYjsXSApMIdRlvgDvg,10154
|
|
379
379
|
model_compression_toolkit/ptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
380
|
-
model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=
|
|
380
|
+
model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=eof9bo-Mv_lLY7fFpiVeT5pIde-MuTWkIAqRKH4j9MI,8646
|
|
381
381
|
model_compression_toolkit/qat/__init__.py,sha256=kj2qsZh_Ca7PncsHKcaL5EVT2H8g4hYtvaQ3KFxOkwE,1143
|
|
382
382
|
model_compression_toolkit/qat/common/__init__.py,sha256=6tLZ4R4pYP6QVztLVQC_jik2nES3l4uhML0qUxZrezk,829
|
|
383
383
|
model_compression_toolkit/qat/common/qat_config.py,sha256=zoq0Vb74vCY7WlWD8JH_KPrHDoUHSvMc3gcO53u7L2U,3394
|
|
@@ -469,8 +469,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
|
|
|
469
469
|
model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
|
|
470
470
|
model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
|
471
471
|
model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=7bbzqJN8ZAycVDvZr_5xC-niTAR5df8f03Kooev_pfg,3047
|
|
472
|
-
mct_nightly-2.0.0.
|
|
473
|
-
mct_nightly-2.0.0.
|
|
474
|
-
mct_nightly-2.0.0.
|
|
475
|
-
mct_nightly-2.0.0.
|
|
476
|
-
mct_nightly-2.0.0.
|
|
472
|
+
mct_nightly-2.0.0.20240409.404.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
|
473
|
+
mct_nightly-2.0.0.20240409.404.dist-info/METADATA,sha256=uDkh4Eu7g8uMdBVYp8H_rPGwkuhe_aWWK86DgPSBj94,18795
|
|
474
|
+
mct_nightly-2.0.0.20240409.404.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
475
|
+
mct_nightly-2.0.0.20240409.404.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
|
476
|
+
mct_nightly-2.0.0.20240409.404.dist-info/RECORD,,
|
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
|
27
27
|
from model_compression_toolkit import pruning
|
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
|
29
29
|
|
|
30
|
-
__version__ = "2.0.0.
|
|
30
|
+
__version__ = "2.0.0.20240409.000404"
|
|
@@ -30,7 +30,8 @@ from model_compression_toolkit.logger import Logger
|
|
|
30
30
|
|
|
31
31
|
def analyzer_model_quantization(representative_data_gen: Callable,
|
|
32
32
|
tb_w: TensorboardWriter,
|
|
33
|
-
|
|
33
|
+
float_graph: Graph,
|
|
34
|
+
quantized_graph: Graph,
|
|
34
35
|
fw_impl: FrameworkImplementation,
|
|
35
36
|
fw_info: FrameworkInfo):
|
|
36
37
|
"""
|
|
@@ -41,23 +42,32 @@ def analyzer_model_quantization(representative_data_gen: Callable,
|
|
|
41
42
|
Args:
|
|
42
43
|
representative_data_gen: Dataset used for calibration.
|
|
43
44
|
tb_w: TensorBoardWriter object to log events.
|
|
44
|
-
|
|
45
|
+
float_graph: Graph of float model.
|
|
46
|
+
quantized_graph: Graph of quantized model.
|
|
45
47
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
|
46
48
|
fw_info: Information needed for quantization about the specific framework.
|
|
47
49
|
|
|
48
50
|
"""
|
|
49
51
|
if tb_w is not None:
|
|
50
|
-
visual = NNVisualizer(
|
|
52
|
+
visual = NNVisualizer(float_graph,
|
|
53
|
+
quantized_graph,
|
|
51
54
|
fw_impl=fw_impl,
|
|
52
55
|
fw_info=fw_info)
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
if i >= NUM_SAMPLES_DISTANCE_TENSORBOARD:
|
|
56
|
-
break
|
|
57
|
-
figure = visual.plot_distance_graph(_data,
|
|
58
|
-
distance_fn=compute_cs,
|
|
59
|
-
convert_to_range=lambda a: 1 - 2 * a)
|
|
60
|
-
tb_w.add_figure(figure, f'similarity_distance_sample_{i}')
|
|
56
|
+
if not visual.has_compare_points():
|
|
57
|
+
Logger.error(f'No comparing points were found to plot analyze similarity.')
|
|
61
58
|
else:
|
|
62
|
-
|
|
59
|
+
visualized_samples = 0
|
|
60
|
+
for _data in representative_data_gen():
|
|
61
|
+
batch_size = _data[0].shape[0]
|
|
62
|
+
for sample_index in range(batch_size):
|
|
63
|
+
if visualized_samples >= NUM_SAMPLES_DISTANCE_TENSORBOARD:
|
|
64
|
+
break
|
|
65
|
+
figure = visual.plot_distance_graph(_data,
|
|
66
|
+
sample_index=sample_index,
|
|
67
|
+
distance_fn=compute_cs,
|
|
68
|
+
convert_to_range=lambda a: 1 - 2 * a)
|
|
69
|
+
tb_w.add_figure(figure, f'similarity_distance_sample_{visualized_samples}')
|
|
70
|
+
visualized_samples += 1
|
|
71
|
+
if visualized_samples < NUM_SAMPLES_DISTANCE_TENSORBOARD:
|
|
72
|
+
Logger.error(f'Not enough batches in representative dataset to generate {NUM_SAMPLES_DISTANCE_TENSORBOARD} figures')
|
|
63
73
|
tb_w.close()
|
|
@@ -41,24 +41,24 @@ class BaseNodeQuantizationConfig(object):
|
|
|
41
41
|
Base class for node quantization configuration
|
|
42
42
|
"""
|
|
43
43
|
|
|
44
|
-
def set_quant_config_attr(self,
|
|
44
|
+
def set_quant_config_attr(self, config_parameter_name: str, config_parameter_value: Any,
|
|
45
45
|
*args: List[Any], **kwargs: Dict[str, Any]):
|
|
46
46
|
"""
|
|
47
47
|
Changes a BaseNodeQuantizationConfig's parameter.
|
|
48
48
|
Note that arg and kwargs are only to allow clean override in the child classes.
|
|
49
49
|
|
|
50
50
|
Args:
|
|
51
|
-
|
|
52
|
-
|
|
51
|
+
config_parameter_name: parameter name to change.
|
|
52
|
+
config_parameter_value: parameter value to change.
|
|
53
53
|
args: A list of additional arguments.
|
|
54
54
|
kwargs: A dictionary with additional key arguments.
|
|
55
55
|
|
|
56
56
|
"""
|
|
57
57
|
|
|
58
|
-
if hasattr(self,
|
|
59
|
-
setattr(self,
|
|
58
|
+
if hasattr(self, config_parameter_name):
|
|
59
|
+
setattr(self, config_parameter_name, config_parameter_value)
|
|
60
60
|
else:
|
|
61
|
-
Logger.warning(f"Parameter {
|
|
61
|
+
Logger.warning(f"Parameter {config_parameter_name} could not be found in the node quantization config and "
|
|
62
62
|
f"was not updated!")
|
|
63
63
|
|
|
64
64
|
def __repr__(self) -> str:
|
|
@@ -521,7 +521,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
521
521
|
f"{list(attrs_with_name.keys())}.")
|
|
522
522
|
return attrs_with_name
|
|
523
523
|
|
|
524
|
-
def set_quant_config_attr(self,
|
|
524
|
+
def set_quant_config_attr(self, config_parameter_name: str, config_parameter_value: Any, attr_name: str = None,
|
|
525
525
|
*args: List[Any], **kwargs: Dict[str, Any]):
|
|
526
526
|
"""
|
|
527
527
|
This method overrides the parent class set_quant_config_attr to enable setting a specific weights
|
|
@@ -529,26 +529,27 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
529
529
|
|
|
530
530
|
Args:
|
|
531
531
|
attr_name: attribute name to change.
|
|
532
|
-
|
|
533
|
-
|
|
532
|
+
config_parameter_name: parameter name to change.
|
|
533
|
+
config_parameter_value: parameter value to change.
|
|
534
534
|
args: A list of additional arguments.
|
|
535
535
|
kwargs: A dictionary with additional key arguments.
|
|
536
536
|
|
|
537
537
|
"""
|
|
538
538
|
|
|
539
539
|
if attr_name is None:
|
|
540
|
-
super(NodeWeightsQuantizationConfig, self).set_quant_config_attr(
|
|
540
|
+
super(NodeWeightsQuantizationConfig, self).set_quant_config_attr(config_parameter_name,
|
|
541
|
+
config_parameter_value,
|
|
541
542
|
*args, **kwargs)
|
|
542
543
|
else:
|
|
543
544
|
if self.has_attribute_config(attr_name):
|
|
544
545
|
attr_cfg = self.get_attr_config(attr_name)
|
|
545
|
-
if hasattr(attr_cfg,
|
|
546
|
-
setattr(attr_cfg,
|
|
546
|
+
if hasattr(attr_cfg, config_parameter_name):
|
|
547
|
+
setattr(attr_cfg, config_parameter_name, config_parameter_value)
|
|
547
548
|
else:
|
|
548
|
-
Logger.warning(f"Parameter {
|
|
549
|
+
Logger.warning(f"Parameter {config_parameter_name} could not be found in the node quantization config of "
|
|
549
550
|
f"weights attribute {attr_name} and was not updated!")
|
|
550
551
|
else:
|
|
551
|
-
Logger.error(f"Weights attribute {attr_name} could not be found to set parameter {
|
|
552
|
+
Logger.error(f"Weights attribute {attr_name} could not be found to set parameter {config_parameter_name}.")
|
|
552
553
|
|
|
553
554
|
def __eq__(self, other: Any) -> bool:
|
|
554
555
|
"""
|
|
@@ -23,7 +23,7 @@ from model_compression_toolkit.core.common.quantization.quantize_node import get
|
|
|
23
23
|
from model_compression_toolkit.logger import Logger
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def quantize_graph_weights(
|
|
26
|
+
def quantize_graph_weights(graph_to_quantize: Graph) -> Graph:
|
|
27
27
|
"""
|
|
28
28
|
Get a graph representing a model, and quantize its nodes' weights.
|
|
29
29
|
Each node is quantized according to the passed framework info and quantization configuration.
|
|
@@ -31,12 +31,13 @@ def quantize_graph_weights(graph: Graph) -> Graph:
|
|
|
31
31
|
is calculated and subtracted from the original node's bias. The graph is quantized in-place.
|
|
32
32
|
|
|
33
33
|
Args:
|
|
34
|
-
|
|
34
|
+
graph_to_quantize: Graph to quantize its nodes.
|
|
35
35
|
|
|
36
36
|
"""
|
|
37
|
+
_quantized_graph = copy.deepcopy(graph_to_quantize)
|
|
37
38
|
# Iterate over nodes in the graph and quantize each node's weights and activations
|
|
38
39
|
# (according to operators groups in framework info).
|
|
39
|
-
for n in
|
|
40
|
+
for n in _quantized_graph.nodes():
|
|
40
41
|
for attr in n.get_node_weights_attributes():
|
|
41
42
|
if n.is_weights_quantization_enabled(attr):
|
|
42
43
|
quantized_attr, io_channels_axes = \
|
|
@@ -51,4 +52,4 @@ def quantize_graph_weights(graph: Graph) -> Graph:
|
|
|
51
52
|
# Set the attribute to be the quantized attribute.
|
|
52
53
|
n.set_weights_by_keys(attr, quantized_attr)
|
|
53
54
|
|
|
54
|
-
return
|
|
55
|
+
return _quantized_graph
|
|
@@ -146,7 +146,10 @@ def compute_mae(float_tensor: np.ndarray,
|
|
|
146
146
|
return error
|
|
147
147
|
|
|
148
148
|
|
|
149
|
-
def compute_cs(float_tensor: np.ndarray,
|
|
149
|
+
def compute_cs(float_tensor: np.ndarray,
|
|
150
|
+
fxp_tensor: np.ndarray,
|
|
151
|
+
eps: float = 1e-8,
|
|
152
|
+
batch: bool = False,
|
|
150
153
|
axis: int = None) -> float:
|
|
151
154
|
"""
|
|
152
155
|
Compute the similarity between two tensor using cosine similarity.
|
|
@@ -12,19 +12,20 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
|
|
15
16
|
from typing import Tuple, List, Callable
|
|
16
17
|
|
|
17
18
|
import numpy as np
|
|
18
19
|
from matplotlib import pyplot as plt
|
|
19
20
|
from matplotlib.figure import Figure
|
|
20
21
|
|
|
21
|
-
from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
|
|
22
22
|
from model_compression_toolkit.core.common import Graph
|
|
23
23
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
|
24
24
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
|
25
25
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
|
26
26
|
from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
|
|
27
27
|
from model_compression_toolkit.core.common.similarity_analyzer import compute_cs
|
|
28
|
+
from model_compression_toolkit.logger import Logger
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
def _get_compare_points(input_graph: Graph) -> Tuple[List[BaseNode], List[str]]:
|
|
@@ -57,17 +58,21 @@ class NNVisualizer:
|
|
|
57
58
|
|
|
58
59
|
def __init__(self,
|
|
59
60
|
graph_float: Graph,
|
|
61
|
+
graph_quantized: Graph,
|
|
60
62
|
fw_impl: FrameworkImplementation,
|
|
61
63
|
fw_info: FrameworkInfo):
|
|
62
64
|
"""
|
|
63
65
|
Initialize a NNVisualizer object.
|
|
64
66
|
Args:
|
|
65
67
|
graph_float: Float version of the graph.
|
|
68
|
+
graph_quantized: Quantized version of the graph.
|
|
69
|
+
fw_impl: Framework implementation with framework-specific methods implementation.
|
|
70
|
+
fw_info: Framework info with framework-specific information.
|
|
66
71
|
|
|
67
72
|
"""
|
|
68
73
|
|
|
69
74
|
self.graph_float = graph_float
|
|
70
|
-
self.graph_quantized =
|
|
75
|
+
self.graph_quantized = graph_quantized
|
|
71
76
|
self.fw_impl = fw_impl
|
|
72
77
|
self.fw_info = fw_info
|
|
73
78
|
|
|
@@ -75,6 +80,16 @@ class NNVisualizer:
|
|
|
75
80
|
self.compare_points, self.compare_points_name = _get_compare_points(self.graph_quantized)
|
|
76
81
|
self.compare_points_float, self.compare_points_name_float = _get_compare_points(self.graph_float)
|
|
77
82
|
|
|
83
|
+
if len(self.compare_points) != len(self.compare_points_float):
|
|
84
|
+
Logger.critical(f"Number of compare points in float and quantized models must be equal but "
|
|
85
|
+
f"num of quantized compare points: {len(self.compare_points)} and "
|
|
86
|
+
f"num of float compare points: {len(self.compare_points_float)}")
|
|
87
|
+
if len(self.compare_points_name) != len(self.compare_points_name_float):
|
|
88
|
+
Logger.critical(f"Number of compare points in float and quantized models must be equal "
|
|
89
|
+
f"but num of quantized compare points: {len(self.compare_points_name)}"
|
|
90
|
+
f" and num of float compare points: "
|
|
91
|
+
f"{len(self.compare_points_name_float)}")
|
|
92
|
+
|
|
78
93
|
self.quantized_model, _ = self.fw_impl.model_builder(self.graph_quantized,
|
|
79
94
|
mode=ModelBuilderMode.QUANTIZED,
|
|
80
95
|
append2output=self.compare_points,
|
|
@@ -85,8 +100,19 @@ class NNVisualizer:
|
|
|
85
100
|
append2output=self.compare_points_float,
|
|
86
101
|
fw_info=self.fw_info)
|
|
87
102
|
|
|
103
|
+
def has_compare_points(self) -> bool:
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
Returns: Whether or not compare points were found.
|
|
107
|
+
|
|
108
|
+
"""
|
|
109
|
+
return len(self.compare_points_float) > 0 and len(self.compare_points) > 0 and len(
|
|
110
|
+
self.compare_points_name_float) > 0 and len(self.compare_points_name) > 0
|
|
111
|
+
|
|
112
|
+
|
|
88
113
|
def plot_distance_graph(self,
|
|
89
114
|
input_image: np.ndarray,
|
|
115
|
+
sample_index: int,
|
|
90
116
|
distance_fn: Callable = compute_cs,
|
|
91
117
|
convert_to_range: Callable = lambda a: a) -> Figure:
|
|
92
118
|
"""
|
|
@@ -95,6 +121,7 @@ class NNVisualizer:
|
|
|
95
121
|
|
|
96
122
|
Args:
|
|
97
123
|
input_image: Image to use as input to the networks.
|
|
124
|
+
sample_index: The index of the sample from input_image to use for comparison.
|
|
98
125
|
distance_fn: Distance function to calculate the distance between two tensors.
|
|
99
126
|
convert_to_range: Optional function to move the distance values into a specific range, e.g., when using
|
|
100
127
|
cosine similarity for distance, use 'lambda a: 1 - 2 * a' to convert the distance values to the range
|
|
@@ -108,7 +135,7 @@ class NNVisualizer:
|
|
|
108
135
|
# to make the difference more noticeable when exists
|
|
109
136
|
new_inputs = []
|
|
110
137
|
for single_input in input_image:
|
|
111
|
-
img = single_input[
|
|
138
|
+
img = single_input[sample_index]
|
|
112
139
|
new_inputs.append(np.expand_dims(img, axis=0))
|
|
113
140
|
|
|
114
141
|
# Get outputs
|
|
@@ -123,7 +150,7 @@ class NNVisualizer:
|
|
|
123
150
|
|
|
124
151
|
# Display the result: distance at every layer's output.
|
|
125
152
|
fig = plt.figure()
|
|
126
|
-
plt.plot(distance_array)
|
|
153
|
+
plt.plot(list(range(len(distance_array))), distance_array)
|
|
127
154
|
eps = 0.5
|
|
128
155
|
y_limits = (min(distance_array) - eps, max(distance_array) + eps)
|
|
129
156
|
plt.ylim(y_limits)
|
|
@@ -337,12 +337,16 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
337
337
|
node = node[0]
|
|
338
338
|
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
|
|
339
339
|
fw_info=self.fw_info)
|
|
340
|
+
# TODO: only kernel attributes are currently trained in GPTQ, so only the kernel weights need to be updated.
|
|
341
|
+
# To enable GPTQ for other attributes, this code needs to be modified.
|
|
340
342
|
weights, weight_quant_config, activation_quant_config = \
|
|
341
343
|
layer.weights_quantizers[kernel_attribute].update_layer_quantization_params(layer)
|
|
342
344
|
for weight_attr, weight in weights.items():
|
|
343
345
|
node.set_weights_by_keys(weight_attr, weight.numpy())
|
|
344
|
-
for
|
|
345
|
-
node.final_weights_quantization_cfg.set_quant_config_attr(
|
|
346
|
+
for config_parameter_name, config_parameter_value in weight_quant_config.items():
|
|
347
|
+
node.final_weights_quantization_cfg.set_quant_config_attr(config_parameter_name,
|
|
348
|
+
config_parameter_value,
|
|
349
|
+
attr_name=kernel_attribute)
|
|
346
350
|
for config_attr, config_value in activation_quant_config.items():
|
|
347
351
|
node.final_activation_quantization_cfg.set_quant_config_attr(config_attr, config_value)
|
|
348
352
|
if self.gptq_config.train_bias:
|
|
@@ -12,10 +12,12 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
import copy
|
|
15
16
|
|
|
16
17
|
from typing import Callable, Tuple
|
|
17
18
|
from packaging import version
|
|
18
19
|
|
|
20
|
+
from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
|
|
19
21
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
|
20
22
|
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
|
|
21
23
|
from model_compression_toolkit.logger import Logger
|
|
@@ -210,6 +212,8 @@ if FOUND_TF:
|
|
|
210
212
|
target_resource_utilization=target_resource_utilization,
|
|
211
213
|
tb_w=tb_w)
|
|
212
214
|
|
|
215
|
+
float_graph = copy.deepcopy(tg)
|
|
216
|
+
|
|
213
217
|
tg_gptq = gptq_runner(tg,
|
|
214
218
|
core_config,
|
|
215
219
|
gptq_config,
|
|
@@ -223,7 +227,12 @@ if FOUND_TF:
|
|
|
223
227
|
del hessian_info_service
|
|
224
228
|
|
|
225
229
|
if core_config.debug_config.analyze_similarity:
|
|
226
|
-
analyzer_model_quantization(representative_data_gen,
|
|
230
|
+
analyzer_model_quantization(representative_data_gen,
|
|
231
|
+
tb_w,
|
|
232
|
+
float_graph,
|
|
233
|
+
tg_gptq,
|
|
234
|
+
fw_impl,
|
|
235
|
+
DEFAULT_KERAS_INFO)
|
|
227
236
|
|
|
228
237
|
return get_exportable_keras_model(tg_gptq)
|
|
229
238
|
|
|
@@ -284,12 +284,16 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
284
284
|
node = node[0]
|
|
285
285
|
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
|
|
286
286
|
fw_info=self.fw_info)
|
|
287
|
+
# TODO: only kernel attributes are currently trained in GPTQ, so only the kernel weights need to be updated.
|
|
288
|
+
# To enable GPTQ for other attributes, this code needs to be modified.
|
|
287
289
|
weights, weight_quant_config, activation_quant_config = \
|
|
288
290
|
layer.weights_quantizers[kernel_attribute].update_layer_quantization_params(layer)
|
|
289
291
|
for weight_attr, weight in weights.items():
|
|
290
292
|
node.set_weights_by_keys(weight_attr, self.fw_impl.to_numpy(weight))
|
|
291
|
-
for
|
|
292
|
-
node.final_weights_quantization_cfg.set_quant_config_attr(
|
|
293
|
+
for config_parameter_name, config_parameter_value in weight_quant_config.items():
|
|
294
|
+
node.final_weights_quantization_cfg.set_quant_config_attr(config_parameter_name,
|
|
295
|
+
config_parameter_value,
|
|
296
|
+
attr_name=kernel_attribute)
|
|
293
297
|
for config_attr, config_value in activation_quant_config.items():
|
|
294
298
|
node.final_activation_quantization_cfg.set_quant_config_attr(config_attr, config_value)
|
|
295
299
|
if self.gptq_config.train_bias and hasattr(layer.layer, BIAS):
|
|
@@ -12,6 +12,8 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
import copy
|
|
16
|
+
|
|
15
17
|
from typing import Callable
|
|
16
18
|
from model_compression_toolkit.core import common
|
|
17
19
|
from model_compression_toolkit.constants import FOUND_TORCH
|
|
@@ -177,6 +179,7 @@ if FOUND_TORCH:
|
|
|
177
179
|
tpc=target_platform_capabilities,
|
|
178
180
|
target_resource_utilization=target_resource_utilization,
|
|
179
181
|
tb_w=tb_w)
|
|
182
|
+
float_graph = copy.deepcopy(graph)
|
|
180
183
|
|
|
181
184
|
# ---------------------- #
|
|
182
185
|
# GPTQ Runner
|
|
@@ -192,7 +195,12 @@ if FOUND_TORCH:
|
|
|
192
195
|
hessian_info_service=hessian_info_service)
|
|
193
196
|
|
|
194
197
|
if core_config.debug_config.analyze_similarity:
|
|
195
|
-
analyzer_model_quantization(representative_data_gen,
|
|
198
|
+
analyzer_model_quantization(representative_data_gen,
|
|
199
|
+
tb_w,
|
|
200
|
+
float_graph,
|
|
201
|
+
graph_gptq,
|
|
202
|
+
fw_impl,
|
|
203
|
+
DEFAULT_PYTORCH_INFO)
|
|
196
204
|
|
|
197
205
|
return get_exportable_pytorch_model(graph_gptq)
|
|
198
206
|
|
|
@@ -12,11 +12,13 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
import copy
|
|
15
16
|
|
|
16
17
|
from typing import Callable
|
|
17
18
|
|
|
18
19
|
from model_compression_toolkit.core import CoreConfig
|
|
19
20
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
|
21
|
+
from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
|
|
20
22
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
|
21
23
|
from model_compression_toolkit.logger import Logger
|
|
22
24
|
from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
|
|
@@ -122,8 +124,8 @@ if FOUND_TF:
|
|
|
122
124
|
if core_config.mixed_precision_enable:
|
|
123
125
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
124
126
|
Logger.critical("Given quantization config to mixed-precision facade is not of type "
|
|
125
|
-
|
|
126
|
-
|
|
127
|
+
"MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
|
|
128
|
+
"API, or pass a valid mixed precision configuration.") # pragma: no cover
|
|
127
129
|
|
|
128
130
|
tb_w = init_tensorboard_writer(fw_info)
|
|
129
131
|
|
|
@@ -139,15 +141,30 @@ if FOUND_TF:
|
|
|
139
141
|
target_resource_utilization=target_resource_utilization,
|
|
140
142
|
tb_w=tb_w)
|
|
141
143
|
|
|
142
|
-
tg
|
|
144
|
+
# At this point, tg is a graph that went through substitutions (such as BN folding) and is
|
|
145
|
+
# ready for quantization (namely, it holds quantization params, etc.) but the weights are
|
|
146
|
+
# not quantized yet. For this reason, we use it to create a graph that acts as a "float" graph
|
|
147
|
+
# for things like similarity analyzer (because the quantized and float graph should have the same
|
|
148
|
+
# architecture to find the appropriate compare points for similarity computation).
|
|
149
|
+
similarity_baseline_graph = copy.deepcopy(tg)
|
|
150
|
+
|
|
151
|
+
graph_with_stats_correction = ptq_runner(tg,
|
|
152
|
+
representative_data_gen,
|
|
153
|
+
core_config,
|
|
154
|
+
fw_info,
|
|
155
|
+
fw_impl,
|
|
156
|
+
tb_w)
|
|
143
157
|
|
|
144
158
|
if core_config.debug_config.analyze_similarity:
|
|
159
|
+
quantized_graph = quantize_graph_weights(graph_with_stats_correction)
|
|
145
160
|
analyzer_model_quantization(representative_data_gen,
|
|
146
|
-
tb_w,
|
|
161
|
+
tb_w,
|
|
162
|
+
similarity_baseline_graph,
|
|
163
|
+
quantized_graph,
|
|
147
164
|
fw_impl,
|
|
148
165
|
fw_info)
|
|
149
166
|
|
|
150
|
-
return get_exportable_keras_model(
|
|
167
|
+
return get_exportable_keras_model(graph_with_stats_correction)
|
|
151
168
|
|
|
152
169
|
|
|
153
170
|
|
|
@@ -12,6 +12,8 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
import copy
|
|
16
|
+
|
|
15
17
|
from typing import Callable
|
|
16
18
|
|
|
17
19
|
from model_compression_toolkit.core import common
|
|
@@ -26,6 +28,7 @@ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quant
|
|
|
26
28
|
from model_compression_toolkit.core.runner import core_runner
|
|
27
29
|
from model_compression_toolkit.ptq.runner import ptq_runner
|
|
28
30
|
from model_compression_toolkit.core.analyzer import analyzer_model_quantization
|
|
31
|
+
from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
|
|
29
32
|
|
|
30
33
|
|
|
31
34
|
if FOUND_TORCH:
|
|
@@ -90,14 +93,16 @@ if FOUND_TORCH:
|
|
|
90
93
|
|
|
91
94
|
"""
|
|
92
95
|
|
|
96
|
+
fw_info = DEFAULT_PYTORCH_INFO
|
|
97
|
+
|
|
93
98
|
if core_config.mixed_precision_enable:
|
|
94
99
|
if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
|
|
95
100
|
Logger.critical("Given quantization config to mixed-precision facade is not of type "
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
101
|
+
"MixedPrecisionQuantizationConfig. Please use "
|
|
102
|
+
"pytorch_post_training_quantization API, or pass a valid mixed precision "
|
|
103
|
+
"configuration.") # pragma: no cover
|
|
99
104
|
|
|
100
|
-
tb_w = init_tensorboard_writer(
|
|
105
|
+
tb_w = init_tensorboard_writer(fw_info)
|
|
101
106
|
|
|
102
107
|
fw_impl = PytorchImplementation()
|
|
103
108
|
|
|
@@ -105,22 +110,36 @@ if FOUND_TORCH:
|
|
|
105
110
|
tg, bit_widths_config, _ = core_runner(in_model=in_module,
|
|
106
111
|
representative_data_gen=representative_data_gen,
|
|
107
112
|
core_config=core_config,
|
|
108
|
-
fw_info=
|
|
113
|
+
fw_info=fw_info,
|
|
109
114
|
fw_impl=fw_impl,
|
|
110
115
|
tpc=target_platform_capabilities,
|
|
111
116
|
target_resource_utilization=target_resource_utilization,
|
|
112
117
|
tb_w=tb_w)
|
|
113
118
|
|
|
114
|
-
tg
|
|
119
|
+
# At this point, tg is a graph that went through substitutions (such as BN folding) and is
|
|
120
|
+
# ready for quantization (namely, it holds quantization params, etc.) but the weights are
|
|
121
|
+
# not quantized yet. For this reason, we use it to create a graph that acts as a "float" graph
|
|
122
|
+
# for things like similarity analyzer (because the quantized and float graph should have the same
|
|
123
|
+
# architecture to find the appropriate compare points for similarity computation).
|
|
124
|
+
similarity_baseline_graph = copy.deepcopy(tg)
|
|
125
|
+
|
|
126
|
+
graph_with_stats_correction = ptq_runner(tg,
|
|
127
|
+
representative_data_gen,
|
|
128
|
+
core_config,
|
|
129
|
+
fw_info,
|
|
130
|
+
fw_impl,
|
|
131
|
+
tb_w)
|
|
115
132
|
|
|
116
133
|
if core_config.debug_config.analyze_similarity:
|
|
134
|
+
quantized_graph = quantize_graph_weights(graph_with_stats_correction)
|
|
117
135
|
analyzer_model_quantization(representative_data_gen,
|
|
118
136
|
tb_w,
|
|
119
|
-
|
|
137
|
+
similarity_baseline_graph,
|
|
138
|
+
quantized_graph,
|
|
120
139
|
fw_impl,
|
|
121
|
-
|
|
140
|
+
fw_info)
|
|
122
141
|
|
|
123
|
-
return get_exportable_pytorch_model(
|
|
142
|
+
return get_exportable_pytorch_model(graph_with_stats_correction)
|
|
124
143
|
|
|
125
144
|
|
|
126
145
|
else:
|
{mct_nightly-2.0.0.20240407.442.dist-info → mct_nightly-2.0.0.20240409.404.dist-info}/LICENSE.md
RENAMED
|
File without changes
|
|
File without changes
|
{mct_nightly-2.0.0.20240407.442.dist-info → mct_nightly-2.0.0.20240409.404.dist-info}/top_level.txt
RENAMED
|
File without changes
|