mct-nightly 2.0.0.20240407.442__py3-none-any.whl → 2.0.0.20240409.404__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.0.0.20240407.442
3
+ Version: 2.0.0.20240409.404
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,9 +1,9 @@
1
- model_compression_toolkit/__init__.py,sha256=D8v4eaxAbnPL5dY8Em333z-gdWTtL6LPQkUT3wBaQdQ,1573
1
+ model_compression_toolkit/__init__.py,sha256=ALvOQYWLrTHNtxDnpNxy7lyftsvgDpzcoW-wTFtMedY,1573
2
2
  model_compression_toolkit/constants.py,sha256=KW_HUEPmQEYqCvWGyORqkYxpvO7w5LViB5J5D-pm_6o,3648
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
5
5
  model_compression_toolkit/core/__init__.py,sha256=TrRgkWpT1AN2Faw1M_1HXyJkJnbxfn9p-RigDZl7pg0,1982
6
- model_compression_toolkit/core/analyzer.py,sha256=dbsD61pakp_9JXNyAScLdtJvcXny9jr_cMbET0Bd3Sg,2975
6
+ model_compression_toolkit/core/analyzer.py,sha256=X-2ZpkH1xdXnISnw1yJvXnvV-ssoUh-9LkLISSWNqiY,3691
7
7
  model_compression_toolkit/core/graph_prep_runner.py,sha256=Ftqm59hT5TGWmSNkY9bFZkVfCacpGyZfCe-6yZR5WY0,10100
8
8
  model_compression_toolkit/core/quantization_prep_runner.py,sha256=hFhDkS8GwzXZ7Ho_9qbbb8DAAWs3OONOfMSD5OU_b0o,6153
9
9
  model_compression_toolkit/core/runner.py,sha256=NKSC6ujfQPy6dKtJVwxyK2zNDd64eyR5csYy9lBrCPA,11836
@@ -16,7 +16,7 @@ model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3u
16
16
  model_compression_toolkit/core/common/model_collector.py,sha256=ofcepKtxc3j2Ouz6BpAKXTzPgjABnpRP47ndmJCXAkk,8352
17
17
  model_compression_toolkit/core/common/model_validation.py,sha256=LaG8wd6aZl0OJgieE3SeiVDEPxtk8IHq9-3wSnmWhY4,1214
18
18
  model_compression_toolkit/core/common/node_prior_info.py,sha256=WXX_PrGVG9M9I_REG5ZzFBohwmV4yf356sZnrja_FLo,2832
19
- model_compression_toolkit/core/common/similarity_analyzer.py,sha256=2w-q7guEb5bpLY4Vk_TMjR8TzLYEymR3tPFlrVq7K68,8515
19
+ model_compression_toolkit/core/common/similarity_analyzer.py,sha256=98l9ttnXHf6VYxBW4852h2CPJKg3A6nLOovpHn-tnKs,8560
20
20
  model_compression_toolkit/core/common/user_info.py,sha256=dSRMnT-oewmdOziIpEuW-s9K7vTSeyUBxT4z9neXurI,1648
21
21
  model_compression_toolkit/core/common/back2framework/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
22
22
  model_compression_toolkit/core/common/back2framework/base_model_builder.py,sha256=V1oShKzbSkdcTvREn8VnQQBzvm-tTHkWMXqMkYozF2s,2023
@@ -100,11 +100,11 @@ model_compression_toolkit/core/common/quantization/candidate_node_quantization_c
100
100
  model_compression_toolkit/core/common/quantization/core_config.py,sha256=KYdyfSmjSL4ye24nKlC_c4_AxYb14qoqaeMnZj4-8kE,2257
101
101
  model_compression_toolkit/core/common/quantization/debug_config.py,sha256=HtkMmneN-EmAzgZK4Vp4M8Sqm5QKdrvNyyZMpaVqYzY,1482
102
102
  model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=fwF4VILaX-u3ZaFd81xjbJuhg8Ef-JX_KfMXW0TPV-I,7136
103
- model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=lXONxIOSvYMgkN9M1st4tV1V5JSpijUGxF0hZWRvtUI,26737
103
+ model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=h_cgpvT50gYgO8T363-Zw_b2jfqo3uoa7TqnSuig7I4,26947
104
104
  model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=hQMKm55EXS1oV-Upt6IQtsYhpuhMvYeWRJhh6lhv_Ko,6699
105
105
  model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=T1nVWdRJfBQ_iuMQYQSIkjfkR-2n3lAOKGAz_rUZZN0,2190
106
106
  model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=MwIOBZ4BlZSTIOG75PDvlI3JmZ6t8YjPc1VP9Adei60,3847
107
- model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=xnM9O9LshYw3dprqfsnK9mw7ipOEAkI85o20auyfswg,2626
107
+ model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=N005MSvx8UypVpa7XrxNrB2G732n2wHj3RmLyjTgd3I,2728
108
108
  model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=cdzGNWfT4MRogIU8ehs0tr3lVjnzAI-jeoS9b4TwVBo,2854
109
109
  model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=9BEv2l0z2trDEsr40VB8tO3ToBA_b2sd_jH9uqZ5Wo8,11503
110
110
  model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=eCDGwsWYLU6z7qbEVb4TozMW_nd5VEP_iCJ6PcvyEPw,1486
@@ -142,7 +142,7 @@ model_compression_toolkit/core/common/substitutions/virtual_activation_weights_c
142
142
  model_compression_toolkit/core/common/substitutions/weights_activation_split.py,sha256=h85L2VlDOqbLd-N98wA3SdYWiblBgSsPceNuLanJd70,4737
143
143
  model_compression_toolkit/core/common/visualization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
144
144
  model_compression_toolkit/core/common/visualization/final_config_visualizer.py,sha256=6I10jKLesB-RQKaXA75Xgz2wPvylQUrnPtCcQZIynGo,6371
145
- model_compression_toolkit/core/common/visualization/nn_visualizer.py,sha256=jjsOng-fLxeQFQNshIsOu_w1d5a3fJ359Hcnt85Te-o,5921
145
+ model_compression_toolkit/core/common/visualization/nn_visualizer.py,sha256=HOq7AObkmEZiDSZXUMJDAEJzUY-fSXUT0AMgwiyH7dg,7388
146
146
  model_compression_toolkit/core/common/visualization/tensorboard_writer.py,sha256=4E4ZXZmqusGIJ4XQNH8FFt07htAHgT3gy5E7wPIaVBI,21951
147
147
  model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
148
148
  model_compression_toolkit/core/keras/constants.py,sha256=Uv3c0UdW55pIVQNW_1HQlgl-dHXREkltOLyzp8G1mTQ,3163
@@ -336,9 +336,9 @@ model_compression_toolkit/gptq/common/gptq_training.py,sha256=rLA1xlOO-6gWfmc2dL
336
336
  model_compression_toolkit/gptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
337
337
  model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCSjq5xk-xGymOwSOqjp39It-CVtGcCTRTf0E_4,1248
338
338
  model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
339
- model_compression_toolkit/gptq/keras/gptq_training.py,sha256=6TseqtBzZkLqyc3hiRVdA1dv01us6Y_Su05CBboGjjc,18438
339
+ model_compression_toolkit/gptq/keras/gptq_training.py,sha256=OhYfH6zxRHrRhCde0lbcV9Hu2oeDD9RXh-O8vOPgLbs,18875
340
340
  model_compression_toolkit/gptq/keras/graph_info.py,sha256=5IvgGlJlgOmQYmldjdCBv7tuzAoY0HazatG5Pedrg0Q,4639
341
- model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=yRgiikyeELar4jlsdcf5pO9HQcxiyhKiAXY3lsMixew,13913
341
+ model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=zAkzWpWP9_aobWgMo_BlUm7-4fR5dHvoGx0sDqs2rZg,14299
342
342
  model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
343
343
  model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=2YU-x4-Q5f6hkUJf0tw6vcwdNwRMHdefrFjhhyHYsvA,4782
344
344
  model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
@@ -353,9 +353,9 @@ model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py,sha
353
353
  model_compression_toolkit/gptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
354
354
  model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=kDuWw-6zh17wZpYWh4Xa94rpoodf82DksgjQCnL7nBc,2719
355
355
  model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
356
- model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=ksFXtepAsk-66xk7OPciG05kU9sgAUrWqjOgplsGSnw,15808
356
+ model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=LN4vOwcMuSSFTSnHDACV9hX_Yd2YIXJRl7WkdODuA0k,16245
357
357
  model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=yXJzDd24zfGs2_vfMovxD1WSh1RxXoPxN4GztOf3P5c,3967
358
- model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=bImrTw9rrAVc3VD4nmrXmBo_K4fuf5m5XPPf8ybOThs,12430
358
+ model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=-4USg-tep6EQSArcTxBowhMeAuExrBTNLOWgHFpsIy4,12699
359
359
  model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
360
360
  model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=TCA1hAc7raPnrjl06sjFtVM4XUtLtuwAhCGX4U3KGZo,4137
361
361
  model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
@@ -375,9 +375,9 @@ model_compression_toolkit/pruning/pytorch/pruning_facade.py,sha256=cSuvHHCqgr7k9
375
375
  model_compression_toolkit/ptq/__init__.py,sha256=Z_hkmTh7aLFei1DJKV0oNVUbrv_Q_0CTw-qD85Xf8UM,904
376
376
  model_compression_toolkit/ptq/runner.py,sha256=_c1dSjlPPpsx59Vbg1buhG9bZq__OORz1VlPkwjJzoc,2552
377
377
  model_compression_toolkit/ptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
378
- model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=5Doa5Rwer84DRgJxLa2e6aX9B4yGYdmFGgiv71_wD9o,8992
378
+ model_compression_toolkit/ptq/keras/quantization_facade.py,sha256=T1_UqXmOc4I2a6IHkQAlFhGtcAYjsXSApMIdRlvgDvg,10154
379
379
  model_compression_toolkit/ptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
380
- model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=QW7lMRuuonupLPZ2w2PDIQd7qpDZ_euLInhskTc1Yes,7518
380
+ model_compression_toolkit/ptq/pytorch/quantization_facade.py,sha256=eof9bo-Mv_lLY7fFpiVeT5pIde-MuTWkIAqRKH4j9MI,8646
381
381
  model_compression_toolkit/qat/__init__.py,sha256=kj2qsZh_Ca7PncsHKcaL5EVT2H8g4hYtvaQ3KFxOkwE,1143
382
382
  model_compression_toolkit/qat/common/__init__.py,sha256=6tLZ4R4pYP6QVztLVQC_jik2nES3l4uhML0qUxZrezk,829
383
383
  model_compression_toolkit/qat/common/qat_config.py,sha256=zoq0Vb74vCY7WlWD8JH_KPrHDoUHSvMc3gcO53u7L2U,3394
@@ -469,8 +469,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
469
469
  model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
470
470
  model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
471
471
  model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=7bbzqJN8ZAycVDvZr_5xC-niTAR5df8f03Kooev_pfg,3047
472
- mct_nightly-2.0.0.20240407.442.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
473
- mct_nightly-2.0.0.20240407.442.dist-info/METADATA,sha256=VYYUCAhW6RmEexLHQghz4uGkGFB0JlYg9R4CkDCiEs0,18795
474
- mct_nightly-2.0.0.20240407.442.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
475
- mct_nightly-2.0.0.20240407.442.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
476
- mct_nightly-2.0.0.20240407.442.dist-info/RECORD,,
472
+ mct_nightly-2.0.0.20240409.404.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
473
+ mct_nightly-2.0.0.20240409.404.dist-info/METADATA,sha256=uDkh4Eu7g8uMdBVYp8H_rPGwkuhe_aWWK86DgPSBj94,18795
474
+ mct_nightly-2.0.0.20240409.404.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
475
+ mct_nightly-2.0.0.20240409.404.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
476
+ mct_nightly-2.0.0.20240409.404.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.0.0.20240407.000442"
30
+ __version__ = "2.0.0.20240409.000404"
@@ -30,7 +30,8 @@ from model_compression_toolkit.logger import Logger
30
30
 
31
31
  def analyzer_model_quantization(representative_data_gen: Callable,
32
32
  tb_w: TensorboardWriter,
33
- tg: Graph,
33
+ float_graph: Graph,
34
+ quantized_graph: Graph,
34
35
  fw_impl: FrameworkImplementation,
35
36
  fw_info: FrameworkInfo):
36
37
  """
@@ -41,23 +42,32 @@ def analyzer_model_quantization(representative_data_gen: Callable,
41
42
  Args:
42
43
  representative_data_gen: Dataset used for calibration.
43
44
  tb_w: TensorBoardWriter object to log events.
44
- tg: Graph of quantized model.
45
+ float_graph: Graph of float model.
46
+ quantized_graph: Graph of quantized model.
45
47
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
46
48
  fw_info: Information needed for quantization about the specific framework.
47
49
 
48
50
  """
49
51
  if tb_w is not None:
50
- visual = NNVisualizer(tg,
52
+ visual = NNVisualizer(float_graph,
53
+ quantized_graph,
51
54
  fw_impl=fw_impl,
52
55
  fw_info=fw_info)
53
-
54
- for i, _data in enumerate(representative_data_gen()):
55
- if i >= NUM_SAMPLES_DISTANCE_TENSORBOARD:
56
- break
57
- figure = visual.plot_distance_graph(_data,
58
- distance_fn=compute_cs,
59
- convert_to_range=lambda a: 1 - 2 * a)
60
- tb_w.add_figure(figure, f'similarity_distance_sample_{i}')
56
+ if not visual.has_compare_points():
57
+ Logger.error(f'No comparing points were found to plot analyze similarity.')
61
58
  else:
62
- Logger.warning(f'Not enough batches in representative dataset to generate {NUM_SAMPLES_DISTANCE_TENSORBOARD} figures')
59
+ visualized_samples = 0
60
+ for _data in representative_data_gen():
61
+ batch_size = _data[0].shape[0]
62
+ for sample_index in range(batch_size):
63
+ if visualized_samples >= NUM_SAMPLES_DISTANCE_TENSORBOARD:
64
+ break
65
+ figure = visual.plot_distance_graph(_data,
66
+ sample_index=sample_index,
67
+ distance_fn=compute_cs,
68
+ convert_to_range=lambda a: 1 - 2 * a)
69
+ tb_w.add_figure(figure, f'similarity_distance_sample_{visualized_samples}')
70
+ visualized_samples += 1
71
+ if visualized_samples < NUM_SAMPLES_DISTANCE_TENSORBOARD:
72
+ Logger.error(f'Not enough batches in representative dataset to generate {NUM_SAMPLES_DISTANCE_TENSORBOARD} figures')
63
73
  tb_w.close()
@@ -41,24 +41,24 @@ class BaseNodeQuantizationConfig(object):
41
41
  Base class for node quantization configuration
42
42
  """
43
43
 
44
- def set_quant_config_attr(self, parameter_name: str, parameter_value: Any,
44
+ def set_quant_config_attr(self, config_parameter_name: str, config_parameter_value: Any,
45
45
  *args: List[Any], **kwargs: Dict[str, Any]):
46
46
  """
47
47
  Changes a BaseNodeQuantizationConfig's parameter.
48
48
  Note that arg and kwargs are only to allow clean override in the child classes.
49
49
 
50
50
  Args:
51
- parameter_name: parameter name to change.
52
- parameter_value: parameter value to change.
51
+ config_parameter_name: parameter name to change.
52
+ config_parameter_value: parameter value to change.
53
53
  args: A list of additional arguments.
54
54
  kwargs: A dictionary with additional key arguments.
55
55
 
56
56
  """
57
57
 
58
- if hasattr(self, parameter_name):
59
- setattr(self, parameter_name, parameter_value)
58
+ if hasattr(self, config_parameter_name):
59
+ setattr(self, config_parameter_name, config_parameter_value)
60
60
  else:
61
- Logger.warning(f"Parameter {parameter_name} could not be found in the node quantization config and "
61
+ Logger.warning(f"Parameter {config_parameter_name} could not be found in the node quantization config and "
62
62
  f"was not updated!")
63
63
 
64
64
  def __repr__(self) -> str:
@@ -521,7 +521,7 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
521
521
  f"{list(attrs_with_name.keys())}.")
522
522
  return attrs_with_name
523
523
 
524
- def set_quant_config_attr(self, parameter_name: str, parameter_value: Any, attr_name: str = None,
524
+ def set_quant_config_attr(self, config_parameter_name: str, config_parameter_value: Any, attr_name: str = None,
525
525
  *args: List[Any], **kwargs: Dict[str, Any]):
526
526
  """
527
527
  This method overrides the parent class set_quant_config_attr to enable setting a specific weights
@@ -529,26 +529,27 @@ class NodeWeightsQuantizationConfig(BaseNodeQuantizationConfig):
529
529
 
530
530
  Args:
531
531
  attr_name: attribute name to change.
532
- parameter_name: parameter name to change.
533
- parameter_value: parameter value to change.
532
+ config_parameter_name: parameter name to change.
533
+ config_parameter_value: parameter value to change.
534
534
  args: A list of additional arguments.
535
535
  kwargs: A dictionary with additional key arguments.
536
536
 
537
537
  """
538
538
 
539
539
  if attr_name is None:
540
- super(NodeWeightsQuantizationConfig, self).set_quant_config_attr(parameter_name, parameter_value,
540
+ super(NodeWeightsQuantizationConfig, self).set_quant_config_attr(config_parameter_name,
541
+ config_parameter_value,
541
542
  *args, **kwargs)
542
543
  else:
543
544
  if self.has_attribute_config(attr_name):
544
545
  attr_cfg = self.get_attr_config(attr_name)
545
- if hasattr(attr_cfg, parameter_name):
546
- setattr(attr_cfg, parameter_name, parameter_value)
546
+ if hasattr(attr_cfg, config_parameter_name):
547
+ setattr(attr_cfg, config_parameter_name, config_parameter_value)
547
548
  else:
548
- Logger.warning(f"Parameter {parameter_name} could not be found in the node quantization config of "
549
+ Logger.warning(f"Parameter {config_parameter_name} could not be found in the node quantization config of "
549
550
  f"weights attribute {attr_name} and was not updated!")
550
551
  else:
551
- Logger.error(f"Weights attribute {attr_name} could not be found to set parameter {parameter_name}.")
552
+ Logger.error(f"Weights attribute {attr_name} could not be found to set parameter {config_parameter_name}.")
552
553
 
553
554
  def __eq__(self, other: Any) -> bool:
554
555
  """
@@ -23,7 +23,7 @@ from model_compression_toolkit.core.common.quantization.quantize_node import get
23
23
  from model_compression_toolkit.logger import Logger
24
24
 
25
25
 
26
- def quantize_graph_weights(graph: Graph) -> Graph:
26
+ def quantize_graph_weights(graph_to_quantize: Graph) -> Graph:
27
27
  """
28
28
  Get a graph representing a model, and quantize its nodes' weights.
29
29
  Each node is quantized according to the passed framework info and quantization configuration.
@@ -31,12 +31,13 @@ def quantize_graph_weights(graph: Graph) -> Graph:
31
31
  is calculated and subtracted from the original node's bias. The graph is quantized in-place.
32
32
 
33
33
  Args:
34
- graph: Graph to quantize its nodes.
34
+ graph_to_quantize: Graph to quantize its nodes.
35
35
 
36
36
  """
37
+ _quantized_graph = copy.deepcopy(graph_to_quantize)
37
38
  # Iterate over nodes in the graph and quantize each node's weights and activations
38
39
  # (according to operators groups in framework info).
39
- for n in graph.nodes():
40
+ for n in _quantized_graph.nodes():
40
41
  for attr in n.get_node_weights_attributes():
41
42
  if n.is_weights_quantization_enabled(attr):
42
43
  quantized_attr, io_channels_axes = \
@@ -51,4 +52,4 @@ def quantize_graph_weights(graph: Graph) -> Graph:
51
52
  # Set the attribute to be the quantized attribute.
52
53
  n.set_weights_by_keys(attr, quantized_attr)
53
54
 
54
- return graph
55
+ return _quantized_graph
@@ -146,7 +146,10 @@ def compute_mae(float_tensor: np.ndarray,
146
146
  return error
147
147
 
148
148
 
149
- def compute_cs(float_tensor: np.ndarray, fxp_tensor: np.ndarray, eps: float = 1e-8, batch: bool = False,
149
+ def compute_cs(float_tensor: np.ndarray,
150
+ fxp_tensor: np.ndarray,
151
+ eps: float = 1e-8,
152
+ batch: bool = False,
150
153
  axis: int = None) -> float:
151
154
  """
152
155
  Compute the similarity between two tensor using cosine similarity.
@@ -12,19 +12,20 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+
15
16
  from typing import Tuple, List, Callable
16
17
 
17
18
  import numpy as np
18
19
  from matplotlib import pyplot as plt
19
20
  from matplotlib.figure import Figure
20
21
 
21
- from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
22
22
  from model_compression_toolkit.core.common import Graph
23
23
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
24
24
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
25
25
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
26
26
  from model_compression_toolkit.core.common.model_builder_mode import ModelBuilderMode
27
27
  from model_compression_toolkit.core.common.similarity_analyzer import compute_cs
28
+ from model_compression_toolkit.logger import Logger
28
29
 
29
30
 
30
31
  def _get_compare_points(input_graph: Graph) -> Tuple[List[BaseNode], List[str]]:
@@ -57,17 +58,21 @@ class NNVisualizer:
57
58
 
58
59
  def __init__(self,
59
60
  graph_float: Graph,
61
+ graph_quantized: Graph,
60
62
  fw_impl: FrameworkImplementation,
61
63
  fw_info: FrameworkInfo):
62
64
  """
63
65
  Initialize a NNVisualizer object.
64
66
  Args:
65
67
  graph_float: Float version of the graph.
68
+ graph_quantized: Quantized version of the graph.
69
+ fw_impl: Framework implementation with framework-specific methods implementation.
70
+ fw_info: Framework info with framework-specific information.
66
71
 
67
72
  """
68
73
 
69
74
  self.graph_float = graph_float
70
- self.graph_quantized = quantize_graph_weights(graph_float)
75
+ self.graph_quantized = graph_quantized
71
76
  self.fw_impl = fw_impl
72
77
  self.fw_info = fw_info
73
78
 
@@ -75,6 +80,16 @@ class NNVisualizer:
75
80
  self.compare_points, self.compare_points_name = _get_compare_points(self.graph_quantized)
76
81
  self.compare_points_float, self.compare_points_name_float = _get_compare_points(self.graph_float)
77
82
 
83
+ if len(self.compare_points) != len(self.compare_points_float):
84
+ Logger.critical(f"Number of compare points in float and quantized models must be equal but "
85
+ f"num of quantized compare points: {len(self.compare_points)} and "
86
+ f"num of float compare points: {len(self.compare_points_float)}")
87
+ if len(self.compare_points_name) != len(self.compare_points_name_float):
88
+ Logger.critical(f"Number of compare points in float and quantized models must be equal "
89
+ f"but num of quantized compare points: {len(self.compare_points_name)}"
90
+ f" and num of float compare points: "
91
+ f"{len(self.compare_points_name_float)}")
92
+
78
93
  self.quantized_model, _ = self.fw_impl.model_builder(self.graph_quantized,
79
94
  mode=ModelBuilderMode.QUANTIZED,
80
95
  append2output=self.compare_points,
@@ -85,8 +100,19 @@ class NNVisualizer:
85
100
  append2output=self.compare_points_float,
86
101
  fw_info=self.fw_info)
87
102
 
103
+ def has_compare_points(self) -> bool:
104
+ """
105
+
106
+ Returns: Whether or not compare points were found.
107
+
108
+ """
109
+ return len(self.compare_points_float) > 0 and len(self.compare_points) > 0 and len(
110
+ self.compare_points_name_float) > 0 and len(self.compare_points_name) > 0
111
+
112
+
88
113
  def plot_distance_graph(self,
89
114
  input_image: np.ndarray,
115
+ sample_index: int,
90
116
  distance_fn: Callable = compute_cs,
91
117
  convert_to_range: Callable = lambda a: a) -> Figure:
92
118
  """
@@ -95,6 +121,7 @@ class NNVisualizer:
95
121
 
96
122
  Args:
97
123
  input_image: Image to use as input to the networks.
124
+ sample_index: The index of the sample from input_image to use for comparison.
98
125
  distance_fn: Distance function to calculate the distance between two tensors.
99
126
  convert_to_range: Optional function to move the distance values into a specific range, e.g., when using
100
127
  cosine similarity for distance, use 'lambda a: 1 - 2 * a' to convert the distance values to the range
@@ -108,7 +135,7 @@ class NNVisualizer:
108
135
  # to make the difference more noticeable when exists
109
136
  new_inputs = []
110
137
  for single_input in input_image:
111
- img = single_input[0]
138
+ img = single_input[sample_index]
112
139
  new_inputs.append(np.expand_dims(img, axis=0))
113
140
 
114
141
  # Get outputs
@@ -123,7 +150,7 @@ class NNVisualizer:
123
150
 
124
151
  # Display the result: distance at every layer's output.
125
152
  fig = plt.figure()
126
- plt.plot(distance_array)
153
+ plt.plot(list(range(len(distance_array))), distance_array)
127
154
  eps = 0.5
128
155
  y_limits = (min(distance_array) - eps, max(distance_array) + eps)
129
156
  plt.ylim(y_limits)
@@ -337,12 +337,16 @@ class KerasGPTQTrainer(GPTQTrainer):
337
337
  node = node[0]
338
338
  kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
339
339
  fw_info=self.fw_info)
340
+ # TODO: only kernel attributes are currently trained in GPTQ, so only the kernel weights need to be updated.
341
+ # To enable GPTQ for other attributes, this code needs to be modified.
340
342
  weights, weight_quant_config, activation_quant_config = \
341
343
  layer.weights_quantizers[kernel_attribute].update_layer_quantization_params(layer)
342
344
  for weight_attr, weight in weights.items():
343
345
  node.set_weights_by_keys(weight_attr, weight.numpy())
344
- for config_attr, config_value in weight_quant_config.items():
345
- node.final_weights_quantization_cfg.set_quant_config_attr(config_attr, config_value)
346
+ for config_parameter_name, config_parameter_value in weight_quant_config.items():
347
+ node.final_weights_quantization_cfg.set_quant_config_attr(config_parameter_name,
348
+ config_parameter_value,
349
+ attr_name=kernel_attribute)
346
350
  for config_attr, config_value in activation_quant_config.items():
347
351
  node.final_activation_quantization_cfg.set_quant_config_attr(config_attr, config_value)
348
352
  if self.gptq_config.train_bias:
@@ -12,10 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import copy
15
16
 
16
17
  from typing import Callable, Tuple
17
18
  from packaging import version
18
19
 
20
+ from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
19
21
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
20
22
  from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
21
23
  from model_compression_toolkit.logger import Logger
@@ -210,6 +212,8 @@ if FOUND_TF:
210
212
  target_resource_utilization=target_resource_utilization,
211
213
  tb_w=tb_w)
212
214
 
215
+ float_graph = copy.deepcopy(tg)
216
+
213
217
  tg_gptq = gptq_runner(tg,
214
218
  core_config,
215
219
  gptq_config,
@@ -223,7 +227,12 @@ if FOUND_TF:
223
227
  del hessian_info_service
224
228
 
225
229
  if core_config.debug_config.analyze_similarity:
226
- analyzer_model_quantization(representative_data_gen, tb_w, tg_gptq, fw_impl, fw_info)
230
+ analyzer_model_quantization(representative_data_gen,
231
+ tb_w,
232
+ float_graph,
233
+ tg_gptq,
234
+ fw_impl,
235
+ DEFAULT_KERAS_INFO)
227
236
 
228
237
  return get_exportable_keras_model(tg_gptq)
229
238
 
@@ -284,12 +284,16 @@ class PytorchGPTQTrainer(GPTQTrainer):
284
284
  node = node[0]
285
285
  kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
286
286
  fw_info=self.fw_info)
287
+ # TODO: only kernel attributes are currently trained in GPTQ, so only the kernel weights need to be updated.
288
+ # To enable GPTQ for other attributes, this code needs to be modified.
287
289
  weights, weight_quant_config, activation_quant_config = \
288
290
  layer.weights_quantizers[kernel_attribute].update_layer_quantization_params(layer)
289
291
  for weight_attr, weight in weights.items():
290
292
  node.set_weights_by_keys(weight_attr, self.fw_impl.to_numpy(weight))
291
- for config_attr, config_value in weight_quant_config.items():
292
- node.final_weights_quantization_cfg.set_quant_config_attr(config_attr, config_value)
293
+ for config_parameter_name, config_parameter_value in weight_quant_config.items():
294
+ node.final_weights_quantization_cfg.set_quant_config_attr(config_parameter_name,
295
+ config_parameter_value,
296
+ attr_name=kernel_attribute)
293
297
  for config_attr, config_value in activation_quant_config.items():
294
298
  node.final_activation_quantization_cfg.set_quant_config_attr(config_attr, config_value)
295
299
  if self.gptq_config.train_bias and hasattr(layer.layer, BIAS):
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import copy
16
+
15
17
  from typing import Callable
16
18
  from model_compression_toolkit.core import common
17
19
  from model_compression_toolkit.constants import FOUND_TORCH
@@ -177,6 +179,7 @@ if FOUND_TORCH:
177
179
  tpc=target_platform_capabilities,
178
180
  target_resource_utilization=target_resource_utilization,
179
181
  tb_w=tb_w)
182
+ float_graph = copy.deepcopy(graph)
180
183
 
181
184
  # ---------------------- #
182
185
  # GPTQ Runner
@@ -192,7 +195,12 @@ if FOUND_TORCH:
192
195
  hessian_info_service=hessian_info_service)
193
196
 
194
197
  if core_config.debug_config.analyze_similarity:
195
- analyzer_model_quantization(representative_data_gen, tb_w, graph_gptq, fw_impl, DEFAULT_PYTORCH_INFO)
198
+ analyzer_model_quantization(representative_data_gen,
199
+ tb_w,
200
+ float_graph,
201
+ graph_gptq,
202
+ fw_impl,
203
+ DEFAULT_PYTORCH_INFO)
196
204
 
197
205
  return get_exportable_pytorch_model(graph_gptq)
198
206
 
@@ -12,11 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import copy
15
16
 
16
17
  from typing import Callable
17
18
 
18
19
  from model_compression_toolkit.core import CoreConfig
19
20
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
21
+ from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
20
22
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
21
23
  from model_compression_toolkit.logger import Logger
22
24
  from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
@@ -122,8 +124,8 @@ if FOUND_TF:
122
124
  if core_config.mixed_precision_enable:
123
125
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
124
126
  Logger.critical("Given quantization config to mixed-precision facade is not of type "
125
- "MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
126
- "API, or pass a valid mixed precision configuration.") # pragma: no cover
127
+ "MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
128
+ "API, or pass a valid mixed precision configuration.") # pragma: no cover
127
129
 
128
130
  tb_w = init_tensorboard_writer(fw_info)
129
131
 
@@ -139,15 +141,30 @@ if FOUND_TF:
139
141
  target_resource_utilization=target_resource_utilization,
140
142
  tb_w=tb_w)
141
143
 
142
- tg = ptq_runner(tg, representative_data_gen, core_config, fw_info, fw_impl, tb_w)
144
+ # At this point, tg is a graph that went through substitutions (such as BN folding) and is
145
+ # ready for quantization (namely, it holds quantization params, etc.) but the weights are
146
+ # not quantized yet. For this reason, we use it to create a graph that acts as a "float" graph
147
+ # for things like similarity analyzer (because the quantized and float graph should have the same
148
+ # architecture to find the appropriate compare points for similarity computation).
149
+ similarity_baseline_graph = copy.deepcopy(tg)
150
+
151
+ graph_with_stats_correction = ptq_runner(tg,
152
+ representative_data_gen,
153
+ core_config,
154
+ fw_info,
155
+ fw_impl,
156
+ tb_w)
143
157
 
144
158
  if core_config.debug_config.analyze_similarity:
159
+ quantized_graph = quantize_graph_weights(graph_with_stats_correction)
145
160
  analyzer_model_quantization(representative_data_gen,
146
- tb_w, tg,
161
+ tb_w,
162
+ similarity_baseline_graph,
163
+ quantized_graph,
147
164
  fw_impl,
148
165
  fw_info)
149
166
 
150
- return get_exportable_keras_model(tg)
167
+ return get_exportable_keras_model(graph_with_stats_correction)
151
168
 
152
169
 
153
170
 
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import copy
16
+
15
17
  from typing import Callable
16
18
 
17
19
  from model_compression_toolkit.core import common
@@ -26,6 +28,7 @@ from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quant
26
28
  from model_compression_toolkit.core.runner import core_runner
27
29
  from model_compression_toolkit.ptq.runner import ptq_runner
28
30
  from model_compression_toolkit.core.analyzer import analyzer_model_quantization
31
+ from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
29
32
 
30
33
 
31
34
  if FOUND_TORCH:
@@ -90,14 +93,16 @@ if FOUND_TORCH:
90
93
 
91
94
  """
92
95
 
96
+ fw_info = DEFAULT_PYTORCH_INFO
97
+
93
98
  if core_config.mixed_precision_enable:
94
99
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
95
100
  Logger.critical("Given quantization config to mixed-precision facade is not of type "
96
- "MixedPrecisionQuantizationConfig. Please use "
97
- "pytorch_post_training_quantization API, or pass a valid mixed precision "
98
- "configuration.") # pragma: no cover
101
+ "MixedPrecisionQuantizationConfig. Please use "
102
+ "pytorch_post_training_quantization API, or pass a valid mixed precision "
103
+ "configuration.") # pragma: no cover
99
104
 
100
- tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
105
+ tb_w = init_tensorboard_writer(fw_info)
101
106
 
102
107
  fw_impl = PytorchImplementation()
103
108
 
@@ -105,22 +110,36 @@ if FOUND_TORCH:
105
110
  tg, bit_widths_config, _ = core_runner(in_model=in_module,
106
111
  representative_data_gen=representative_data_gen,
107
112
  core_config=core_config,
108
- fw_info=DEFAULT_PYTORCH_INFO,
113
+ fw_info=fw_info,
109
114
  fw_impl=fw_impl,
110
115
  tpc=target_platform_capabilities,
111
116
  target_resource_utilization=target_resource_utilization,
112
117
  tb_w=tb_w)
113
118
 
114
- tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_PYTORCH_INFO, fw_impl, tb_w)
119
+ # At this point, tg is a graph that went through substitutions (such as BN folding) and is
120
+ # ready for quantization (namely, it holds quantization params, etc.) but the weights are
121
+ # not quantized yet. For this reason, we use it to create a graph that acts as a "float" graph
122
+ # for things like similarity analyzer (because the quantized and float graph should have the same
123
+ # architecture to find the appropriate compare points for similarity computation).
124
+ similarity_baseline_graph = copy.deepcopy(tg)
125
+
126
+ graph_with_stats_correction = ptq_runner(tg,
127
+ representative_data_gen,
128
+ core_config,
129
+ fw_info,
130
+ fw_impl,
131
+ tb_w)
115
132
 
116
133
  if core_config.debug_config.analyze_similarity:
134
+ quantized_graph = quantize_graph_weights(graph_with_stats_correction)
117
135
  analyzer_model_quantization(representative_data_gen,
118
136
  tb_w,
119
- tg,
137
+ similarity_baseline_graph,
138
+ quantized_graph,
120
139
  fw_impl,
121
- DEFAULT_PYTORCH_INFO)
140
+ fw_info)
122
141
 
123
- return get_exportable_pytorch_model(tg)
142
+ return get_exportable_pytorch_model(graph_with_stats_correction)
124
143
 
125
144
 
126
145
  else: