mct-nightly 2.4.0.20250617.613__py3-none-any.whl → 2.4.0.20250619.621__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/RECORD +123 -123
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +2 -5
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -3
  6. model_compression_toolkit/core/common/framework_implementation.py +10 -22
  7. model_compression_toolkit/core/common/framework_info.py +105 -68
  8. model_compression_toolkit/core/common/graph/base_graph.py +15 -42
  9. model_compression_toolkit/core/common/graph/base_node.py +103 -42
  10. model_compression_toolkit/core/common/graph/functional_node.py +18 -1
  11. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +7 -13
  12. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +8 -18
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +4 -7
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +2 -3
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -5
  16. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +3 -6
  17. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +5 -10
  18. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -5
  19. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +4 -8
  20. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +2 -7
  21. model_compression_toolkit/core/common/model_collector.py +10 -20
  22. model_compression_toolkit/core/common/model_validation.py +1 -4
  23. model_compression_toolkit/core/common/network_editors/actions.py +14 -38
  24. model_compression_toolkit/core/common/network_editors/edit_network.py +1 -4
  25. model_compression_toolkit/core/common/pruning/channels_grouping.py +1 -5
  26. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +0 -6
  27. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -15
  28. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +3 -7
  29. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -4
  30. model_compression_toolkit/core/common/pruning/memory_calculator.py +5 -13
  31. model_compression_toolkit/core/common/pruning/prune_graph.py +1 -4
  32. model_compression_toolkit/core/common/pruning/pruner.py +1 -6
  33. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +5 -13
  34. model_compression_toolkit/core/common/pruning/pruning_section.py +9 -18
  35. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -1
  36. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +10 -12
  37. model_compression_toolkit/core/common/quantization/node_quantization_config.py +4 -3
  38. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -11
  39. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +8 -22
  40. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -2
  41. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -3
  42. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -13
  43. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +3 -9
  44. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -10
  45. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +1 -6
  46. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -3
  47. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -6
  48. model_compression_toolkit/core/common/substitutions/scale_equalization.py +5 -21
  49. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -19
  50. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -3
  51. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  52. model_compression_toolkit/core/common/visualization/nn_visualizer.py +3 -8
  53. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +6 -8
  54. model_compression_toolkit/core/graph_prep_runner.py +2 -16
  55. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +0 -4
  56. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +0 -5
  57. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +8 -15
  58. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +0 -4
  59. model_compression_toolkit/core/keras/default_framework_info.py +138 -87
  60. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -7
  61. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +0 -1
  62. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -5
  63. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +8 -16
  64. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
  65. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +3 -13
  66. model_compression_toolkit/core/keras/keras_implementation.py +15 -35
  67. model_compression_toolkit/core/keras/keras_model_validation.py +6 -7
  68. model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -13
  69. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +11 -34
  70. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  71. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +0 -3
  72. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +3 -12
  73. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +9 -16
  74. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -5
  75. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +2 -3
  76. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +0 -4
  77. model_compression_toolkit/core/pytorch/default_framework_info.py +100 -74
  78. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +3 -4
  79. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +4 -8
  80. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
  81. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +3 -12
  82. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +16 -41
  83. model_compression_toolkit/core/pytorch/pytorch_implementation.py +12 -32
  84. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -5
  85. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  86. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +0 -3
  87. model_compression_toolkit/core/quantization_prep_runner.py +4 -9
  88. model_compression_toolkit/core/runner.py +5 -15
  89. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  90. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  91. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -0
  92. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +19 -17
  93. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -0
  94. model_compression_toolkit/gptq/common/gptq_graph.py +5 -11
  95. model_compression_toolkit/gptq/common/gptq_training.py +1 -8
  96. model_compression_toolkit/gptq/keras/gptq_training.py +3 -9
  97. model_compression_toolkit/gptq/keras/graph_info.py +4 -6
  98. model_compression_toolkit/gptq/keras/quantization_facade.py +5 -8
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
  100. model_compression_toolkit/gptq/pytorch/gptq_training.py +3 -9
  101. model_compression_toolkit/gptq/pytorch/graph_info.py +1 -3
  102. model_compression_toolkit/gptq/pytorch/quantization_facade.py +5 -7
  103. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
  104. model_compression_toolkit/gptq/runner.py +1 -7
  105. model_compression_toolkit/pruning/keras/pruning_facade.py +2 -3
  106. model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -3
  107. model_compression_toolkit/ptq/keras/quantization_facade.py +5 -10
  108. model_compression_toolkit/ptq/pytorch/quantization_facade.py +4 -8
  109. model_compression_toolkit/ptq/runner.py +1 -4
  110. model_compression_toolkit/qat/common/qat_config.py +2 -6
  111. model_compression_toolkit/qat/keras/quantization_facade.py +7 -10
  112. model_compression_toolkit/qat/pytorch/quantization_facade.py +6 -10
  113. model_compression_toolkit/xquant/common/core_report_generator.py +1 -1
  114. model_compression_toolkit/xquant/common/framework_report_utils.py +0 -3
  115. model_compression_toolkit/xquant/common/model_folding_utils.py +1 -6
  116. model_compression_toolkit/xquant/common/tensorboard_utils.py +1 -4
  117. model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -8
  118. model_compression_toolkit/xquant/keras/tensorboard_utils.py +0 -3
  119. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +5 -8
  120. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +0 -3
  121. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/WHEEL +0 -0
  122. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/licenses/LICENSE.md +0 -0
  123. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/top_level.txt +0 -0
@@ -35,11 +35,12 @@ if FOUND_TF:
35
35
  AttachTpcToKeras
36
36
  from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
37
37
  from model_compression_toolkit.core.keras.pruning.pruning_keras_implementation import PruningKerasImplementation
38
- from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
38
+ from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
39
39
  from tensorflow.keras.models import Model
40
40
 
41
41
  DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
42
42
 
43
+ @set_keras_info
43
44
  def keras_pruning_experimental(model: Model,
44
45
  target_resource_utilization: ResourceUtilization,
45
46
  representative_data_gen: Callable,
@@ -123,7 +124,6 @@ if FOUND_TF:
123
124
  float_graph = read_model_to_graph(model,
124
125
  representative_data_gen,
125
126
  target_platform_capabilities,
126
- DEFAULT_KERAS_INFO,
127
127
  fw_impl)
128
128
 
129
129
  # Apply quantization configuration to the graph. This step is necessary even when not quantizing,
@@ -134,7 +134,6 @@ if FOUND_TF:
134
134
 
135
135
  # Create a Pruner object with the graph and configuration.
136
136
  pruner = Pruner(float_graph_with_compression_config,
137
- DEFAULT_KERAS_INFO,
138
137
  fw_impl,
139
138
  target_resource_utilization,
140
139
  representative_data_gen,
@@ -36,7 +36,7 @@ if FOUND_TORCH:
36
36
  from model_compression_toolkit.core.pytorch.back2framework.float_model_builder import FloatPyTorchModelBuilder
37
37
  from model_compression_toolkit.core.pytorch.pruning.pruning_pytorch_implementation import \
38
38
  PruningPytorchImplementation
39
- from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
39
+ from model_compression_toolkit.core.pytorch.default_framework_info import set_pytorch_info
40
40
  from torch.nn import Module
41
41
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2pytorch import \
42
42
  AttachTpcToPytorch
@@ -44,6 +44,7 @@ if FOUND_TORCH:
44
44
  # Set the default Target Platform Capabilities (TPC) for PyTorch.
45
45
  DEFAULT_PYOTRCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
46
46
 
47
+ @set_pytorch_info
47
48
  def pytorch_pruning_experimental(model: Module,
48
49
  target_resource_utilization: ResourceUtilization,
49
50
  representative_data_gen: Callable,
@@ -129,7 +130,6 @@ if FOUND_TORCH:
129
130
  float_graph = read_model_to_graph(model,
130
131
  representative_data_gen,
131
132
  framework_platform_capabilities,
132
- DEFAULT_PYTORCH_INFO,
133
133
  fw_impl)
134
134
 
135
135
  # Apply quantization configuration to the graph. This step is necessary even when not quantizing,
@@ -140,7 +140,6 @@ if FOUND_TORCH:
140
140
 
141
141
  # Create a Pruner object with the graph and configuration.
142
142
  pruner = Pruner(float_graph_with_compression_config,
143
- DEFAULT_PYTORCH_INFO,
144
143
  fw_impl,
145
144
  target_resource_utilization,
146
145
  representative_data_gen,
@@ -36,7 +36,7 @@ from model_compression_toolkit.metadata import create_model_metadata
36
36
  if FOUND_TF:
37
37
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
38
38
  AttachTpcToKeras
39
- from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
39
+ from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
40
40
  from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
41
41
  from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
42
42
  from tensorflow.keras.models import Model
@@ -49,6 +49,7 @@ if FOUND_TF:
49
49
  DEFAULT_KERAS_TPC = get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL)
50
50
 
51
51
 
52
+ @set_keras_info
52
53
  def keras_post_training_quantization(in_model: Model,
53
54
  representative_data_gen: Callable,
54
55
  target_resource_utilization: ResourceUtilization = None,
@@ -128,10 +129,7 @@ if FOUND_TF:
128
129
  if core_config.debug_config.bypass:
129
130
  return in_model, None
130
131
 
131
- fw_info = DEFAULT_KERAS_INFO
132
-
133
- KerasModelValidation(model=in_model,
134
- fw_info=fw_info).validate()
132
+ KerasModelValidation(model=in_model).validate()
135
133
 
136
134
  if core_config.is_mixed_precision_enabled:
137
135
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
@@ -139,7 +137,7 @@ if FOUND_TF:
139
137
  "MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization "
140
138
  "API, or pass a valid mixed precision configuration.") # pragma: no cover
141
139
 
142
- tb_w = init_tensorboard_writer(fw_info)
140
+ tb_w = init_tensorboard_writer()
143
141
 
144
142
  fw_impl = KerasImplementation()
145
143
 
@@ -153,7 +151,6 @@ if FOUND_TF:
153
151
  tg, bit_widths_config, _, scheduling_info = core_runner(in_model=in_model,
154
152
  representative_data_gen=representative_data_gen,
155
153
  core_config=core_config,
156
- fw_info=fw_info,
157
154
  fw_impl=fw_impl,
158
155
  fqc=framework_platform_capabilities,
159
156
  target_resource_utilization=target_resource_utilization,
@@ -169,7 +166,6 @@ if FOUND_TF:
169
166
  graph_with_stats_correction = ptq_runner(tg,
170
167
  representative_data_gen,
171
168
  core_config,
172
- fw_info,
173
169
  fw_impl,
174
170
  tb_w)
175
171
 
@@ -179,8 +175,7 @@ if FOUND_TF:
179
175
  tb_w,
180
176
  similarity_baseline_graph,
181
177
  quantized_graph,
182
- fw_impl,
183
- fw_info)
178
+ fw_impl)
184
179
 
185
180
  exportable_model, user_info = get_exportable_keras_model(graph_with_stats_correction)
186
181
  if framework_platform_capabilities.tpc.add_metadata:
@@ -34,7 +34,7 @@ from model_compression_toolkit.core.common.quantization.quantize_graph_weights i
34
34
  from model_compression_toolkit.metadata import create_model_metadata
35
35
 
36
36
  if FOUND_TORCH:
37
- from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
37
+ from model_compression_toolkit.core.pytorch.default_framework_info import set_pytorch_info
38
38
  from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
39
39
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
40
40
  from torch.nn import Module
@@ -46,6 +46,7 @@ if FOUND_TORCH:
46
46
 
47
47
  DEFAULT_PYTORCH_TPC = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
48
48
 
49
+ @set_pytorch_info
49
50
  def pytorch_post_training_quantization(in_module: Module,
50
51
  representative_data_gen: Callable,
51
52
  target_resource_utilization: ResourceUtilization = None,
@@ -102,8 +103,6 @@ if FOUND_TORCH:
102
103
  if core_config.debug_config.bypass:
103
104
  return in_module, None
104
105
 
105
- fw_info = DEFAULT_PYTORCH_INFO
106
-
107
106
  if core_config.is_mixed_precision_enabled:
108
107
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
109
108
  Logger.critical("Given quantization config to mixed-precision facade is not of type "
@@ -111,7 +110,7 @@ if FOUND_TORCH:
111
110
  "pytorch_post_training_quantization API, or pass a valid mixed precision "
112
111
  "configuration.") # pragma: no cover
113
112
 
114
- tb_w = init_tensorboard_writer(fw_info)
113
+ tb_w = init_tensorboard_writer()
115
114
 
116
115
  fw_impl = PytorchImplementation()
117
116
 
@@ -125,7 +124,6 @@ if FOUND_TORCH:
125
124
  tg, bit_widths_config, _, scheduling_info = core_runner(in_model=in_module,
126
125
  representative_data_gen=representative_data_gen,
127
126
  core_config=core_config,
128
- fw_info=fw_info,
129
127
  fw_impl=fw_impl,
130
128
  fqc=framework_platform_capabilities,
131
129
  target_resource_utilization=target_resource_utilization,
@@ -141,7 +139,6 @@ if FOUND_TORCH:
141
139
  graph_with_stats_correction = ptq_runner(tg,
142
140
  representative_data_gen,
143
141
  core_config,
144
- fw_info,
145
142
  fw_impl,
146
143
  tb_w)
147
144
 
@@ -151,8 +148,7 @@ if FOUND_TORCH:
151
148
  tb_w,
152
149
  similarity_baseline_graph,
153
150
  quantized_graph,
154
- fw_impl,
155
- fw_info)
151
+ fw_impl)
156
152
 
157
153
  exportable_model, user_info = get_exportable_pytorch_model(graph_with_stats_correction)
158
154
  if framework_platform_capabilities.tpc.add_metadata:
@@ -16,7 +16,6 @@
16
16
 
17
17
  from typing import Callable
18
18
 
19
- from model_compression_toolkit.core.common import FrameworkInfo
20
19
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21
20
  from model_compression_toolkit.core.common.graph.base_graph import Graph
22
21
  from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
@@ -28,7 +27,6 @@ from model_compression_toolkit.core.common.visualization.tensorboard_writer impo
28
27
  def ptq_runner(tg: Graph,
29
28
  representative_data_gen: Callable,
30
29
  core_config: CoreConfig,
31
- fw_info: FrameworkInfo,
32
30
  fw_impl: FrameworkImplementation,
33
31
  tb_w: TensorboardWriter) -> Graph:
34
32
  """
@@ -38,7 +36,6 @@ def ptq_runner(tg: Graph,
38
36
  tg: Graph to apply PTQ and to quantize.
39
37
  representative_data_gen (Callable): Dataset used for calibration.
40
38
  core_config: CoreConfig containing parameters of how the model should be quantized.
41
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
42
39
  groups of layers by how they should be quantized, etc.)
43
40
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
44
41
  tb_w: A TensorBoardWriter object initialized with the logger dir path if it was set, or None otherwise.
@@ -50,5 +47,5 @@ def ptq_runner(tg: Graph,
50
47
  #############################################
51
48
  # Statistics Correction
52
49
  #############################################
53
- tg = apply_statistics_correction(tg, representative_data_gen, core_config, fw_info, fw_impl, tb_w)
50
+ tg = apply_statistics_correction(tg, representative_data_gen, core_config, fw_impl, tb_w)
54
51
  return tg
@@ -19,21 +19,17 @@ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
19
19
  from model_compression_toolkit.trainable_infrastructure import TrainingMethod
20
20
 
21
21
 
22
- def is_qat_applicable(node: common.BaseNode,
23
- fw_info: FrameworkInfo) -> bool:
22
+ def is_qat_applicable(node: common.BaseNode) -> bool:
24
23
  """
25
24
  A function for deciding if a layer should be fine-tuned during QAT
26
25
 
27
26
  Args:
28
27
  node (BaseNode): Node for quantization decision
29
- fw_info (FrameworkInfo): Pytorch quantization information
30
28
 
31
29
  Returns:
32
30
  A boolean whether the layer is to be wrapped with a QuantizeWrapper
33
31
  """
34
-
35
- kernel_attr = fw_info.get_kernel_op_attributes(node.type)[0]
36
- return (kernel_attr is not None and node.is_weights_quantization_enabled(kernel_attr)) \
32
+ return (node.kernel_attr is not None and node.is_weights_quantization_enabled(node.kernel_attr)) \
37
33
  or node.is_activation_quantization_enabled()
38
34
 
39
35
 
@@ -37,10 +37,10 @@ if FOUND_TF:
37
37
  from tensorflow.keras.models import Model
38
38
 
39
39
  from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
40
- from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
41
40
  from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
42
41
  from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
43
42
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
43
+ from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
44
44
 
45
45
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
46
46
 
@@ -52,7 +52,6 @@ if FOUND_TF:
52
52
  from model_compression_toolkit.constants import TENSORFLOW
53
53
  from model_compression_toolkit.qat.common.qat_config import is_qat_applicable
54
54
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
55
- from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
56
55
  from model_compression_toolkit.qat.keras.quantizer.quantization_builder import quantization_builder, \
57
56
  get_activation_quantizer_holder
58
57
  from model_compression_toolkit.qat.common.qat_config import QATConfig
@@ -73,11 +72,11 @@ if FOUND_TF:
73
72
  Returns: Wrapped layer
74
73
 
75
74
  """
76
- if is_qat_applicable(n, DEFAULT_KERAS_INFO):
75
+ if is_qat_applicable(n):
77
76
  # If we are here, then the node has a kernel attribute to quantize and training during QAT
78
77
  weights_quantizers, _ = quantization_builder(n,
79
78
  qat_config,
80
- DEFAULT_KERAS_INFO.get_kernel_op_attributes(n.type)[0])
79
+ n.kernel_attr)
81
80
  if len(weights_quantizers) > 0:
82
81
  layer.trainable = True
83
82
  return KerasTrainableQuantizationWrapper(layer, weights_quantizers)
@@ -87,6 +86,7 @@ if FOUND_TF:
87
86
  return layer
88
87
 
89
88
 
89
+ @set_keras_info
90
90
  def keras_quantization_aware_training_init_experimental(in_model: Model,
91
91
  representative_data_gen: Callable,
92
92
  target_resource_utilization: ResourceUtilization = None,
@@ -175,8 +175,7 @@ if FOUND_TF:
175
175
  f"If you encounter an issue, please open an issue in our GitHub "
176
176
  f"project https://github.com/sony/model_optimization")
177
177
 
178
- KerasModelValidation(model=in_model,
179
- fw_info=DEFAULT_KERAS_INFO).validate()
178
+ KerasModelValidation(model=in_model).validate()
180
179
 
181
180
  if core_config.is_mixed_precision_enabled:
182
181
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
@@ -184,7 +183,7 @@ if FOUND_TF:
184
183
  "MixedPrecisionQuantizationConfig. Please use keras_post_training_quantization API,"
185
184
  "or pass a valid mixed precision configuration.")
186
185
 
187
- tb_w = init_tensorboard_writer(DEFAULT_KERAS_INFO)
186
+ tb_w = init_tensorboard_writer()
188
187
 
189
188
  fw_impl = KerasImplementation()
190
189
 
@@ -198,17 +197,15 @@ if FOUND_TF:
198
197
  tg, bit_widths_config, _, _ = core_runner(in_model=in_model,
199
198
  representative_data_gen=representative_data_gen,
200
199
  core_config=core_config,
201
- fw_info=DEFAULT_KERAS_INFO,
202
200
  fw_impl=fw_impl,
203
201
  fqc=target_platform_capabilities,
204
202
  target_resource_utilization=target_resource_utilization,
205
203
  tb_w=tb_w)
206
204
 
207
- tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_KERAS_INFO, fw_impl, tb_w)
205
+ tg = ptq_runner(tg, representative_data_gen, core_config, fw_impl, tb_w)
208
206
 
209
207
  _qat_wrapper = partial(qat_wrapper, qat_config=qat_config)
210
208
  qat_model, user_info = KerasModelBuilder(graph=tg,
211
- fw_info=DEFAULT_KERAS_INFO,
212
209
  wrapper=_qat_wrapper,
213
210
  get_activation_quantizer_holder_fn=partial(get_activation_quantizer_holder,
214
211
  qat_config=qat_config)).build_model()
@@ -36,7 +36,7 @@ if FOUND_TORCH:
36
36
  import torch.nn as nn
37
37
  from torch.nn import Module
38
38
  from mct_quantizers import PytorchActivationQuantizationHolder
39
- from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
39
+ from model_compression_toolkit.core.pytorch.default_framework_info import set_pytorch_info
40
40
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
41
41
  from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
42
42
  from model_compression_toolkit.qat.common.qat_config import is_qat_applicable
@@ -62,10 +62,10 @@ if FOUND_TORCH:
62
62
  Returns: Wrapped layer
63
63
 
64
64
  """
65
- if is_qat_applicable(n, DEFAULT_PYTORCH_INFO):
65
+ if is_qat_applicable(n):
66
66
  # If we are here, then the node has a kernel attribute to quantize and training during QAT
67
67
  weights_quantizers, _ = quantization_builder(n, qat_config,
68
- DEFAULT_PYTORCH_INFO.get_kernel_op_attributes(n.type)[0])
68
+ n.kernel_attr)
69
69
  if len(weights_quantizers) > 0:
70
70
  return PytorchQuantizationWrapper(module, weights_quantizers)
71
71
 
@@ -74,6 +74,7 @@ if FOUND_TORCH:
74
74
  return module
75
75
 
76
76
 
77
+ @set_pytorch_info
77
78
  def pytorch_quantization_aware_training_init_experimental(in_model: Module,
78
79
  representative_data_gen: Callable,
79
80
  target_resource_utilization: ResourceUtilization = None,
@@ -149,7 +150,7 @@ if FOUND_TORCH:
149
150
  "MixedPrecisionQuantizationConfig. Please use pytorch_post_training_quantization API,"
150
151
  "or pass a valid mixed precision configuration.")
151
152
 
152
- tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
153
+ tb_w = init_tensorboard_writer()
153
154
  fw_impl = PytorchImplementation()
154
155
 
155
156
  target_platform_capabilities = load_target_platform_capabilities(target_platform_capabilities)
@@ -162,18 +163,16 @@ if FOUND_TORCH:
162
163
  tg, bit_widths_config, _, _ = core_runner(in_model=in_model,
163
164
  representative_data_gen=representative_data_gen,
164
165
  core_config=core_config,
165
- fw_info=DEFAULT_PYTORCH_INFO,
166
166
  fw_impl=fw_impl,
167
167
  fqc=framework_platform_capabilities,
168
168
  target_resource_utilization=target_resource_utilization,
169
169
  tb_w=tb_w)
170
170
 
171
- tg = ptq_runner(tg, representative_data_gen, core_config, DEFAULT_PYTORCH_INFO, fw_impl, tb_w)
171
+ tg = ptq_runner(tg, representative_data_gen, core_config, fw_impl, tb_w)
172
172
 
173
173
  _qat_wrapper = partial(qat_wrapper, qat_config=qat_config)
174
174
 
175
175
  qat_model, user_info = PyTorchModelBuilder(graph=tg,
176
- fw_info=DEFAULT_PYTORCH_INFO,
177
176
  wrapper=_qat_wrapper,
178
177
  get_activation_quantizer_holder_fn=partial(
179
178
  get_activation_quantizer_holder,
@@ -181,9 +180,6 @@ if FOUND_TORCH:
181
180
 
182
181
  user_info.mixed_precision_cfg = bit_widths_config
183
182
 
184
- # Remove fw_info from graph to enable saving the pytorch model (fw_info can not be pickled)
185
- delattr(qat_model.graph, 'fw_info')
186
-
187
183
  return qat_model, user_info
188
184
 
189
185
 
@@ -50,7 +50,7 @@ def core_report_generator(float_model: Any,
50
50
 
51
51
  # Collect histograms on the float model.
52
52
  float_graph = fw_report_utils.model_folding_utils.create_float_folded_graph(float_model, repr_dataset)
53
- mi = ModelCollector(float_graph, fw_report_utils.fw_impl, fw_report_utils.fw_info)
53
+ mi = ModelCollector(float_graph, fw_report_utils.fw_impl)
54
54
  for _data in tqdm(repr_dataset(), desc="Collecting Histograms"):
55
55
  mi.infer(_data)
56
56
 
@@ -34,7 +34,6 @@ class FrameworkReportUtils:
34
34
  """
35
35
 
36
36
  def __init__(self,
37
- fw_info: FrameworkInfo,
38
37
  fw_impl: FrameworkImplementation,
39
38
  similarity_calculator: SimilarityCalculator,
40
39
  dataset_utils: DatasetUtils,
@@ -45,7 +44,6 @@ class FrameworkReportUtils:
45
44
  Initializes the FrameworkReportUtils class with various utility components required for generating the report.
46
45
 
47
46
  Args:
48
- fw_info (FrameworkInfo): Information about the framework being used.
49
47
  fw_impl (FrameworkImplementation): The implemented functions of the framework.
50
48
  similarity_calculator (SimilarityCalculator): A utility for calculating similarity metrics.
51
49
  dataset_utils (DatasetUtils): Utilities for handling datasets.
@@ -53,7 +51,6 @@ class FrameworkReportUtils:
53
51
  tb_utils (TensorboardUtils): Utilities for TensorBoard operations.
54
52
  get_metadata_fn (Callable): Function to retrieve the metadata from the quantized model.
55
53
  """
56
- self.fw_info = fw_info
57
54
  self.fw_impl = fw_impl
58
55
  self.similarity_calculator = similarity_calculator
59
56
  self.dataset_utils = dataset_utils
@@ -34,7 +34,6 @@ class ModelFoldingUtils:
34
34
  """
35
35
 
36
36
  def __init__(self,
37
- fw_info: FrameworkInfo,
38
37
  fw_impl: FrameworkImplementation,
39
38
  fw_default_fqc: FrameworkQuantizationCapabilities):
40
39
  """
@@ -42,11 +41,9 @@ class ModelFoldingUtils:
42
41
  and default FQC.
43
42
 
44
43
  Args:
45
- fw_info: Framework-specific information.
46
44
  fw_impl: Implementation functions for the framework.
47
45
  fw_default_fqc: Default target platform capabilities for the handled framework.
48
46
  """
49
- self.fw_info = fw_info
50
47
  self.fw_impl = fw_impl
51
48
  self.fw_default_fqc = fw_default_fqc
52
49
 
@@ -69,8 +66,7 @@ class ModelFoldingUtils:
69
66
  float_folded_model, _ = self.fw_impl.model_builder(
70
67
  float_graph,
71
68
  mode=ModelBuilderMode.FLOAT,
72
- append2output=None,
73
- fw_info=self.fw_info
69
+ append2output=None
74
70
  )
75
71
  return float_folded_model
76
72
 
@@ -100,7 +96,6 @@ class ModelFoldingUtils:
100
96
  graph = graph_preparation_runner(in_model=model,
101
97
  representative_data_gen=repr_dataset,
102
98
  fw_impl=self.fw_impl,
103
- fw_info=self.fw_info,
104
99
  quantization_config=DEFAULTCONFIG,
105
100
  fqc=self.fw_default_fqc)
106
101
  return graph
@@ -36,19 +36,16 @@ class TensorboardUtils:
36
36
 
37
37
  def __init__(self,
38
38
  report_dir: str,
39
- fw_info: FrameworkInfo,
40
39
  fw_impl: FrameworkImplementation):
41
40
  """
42
41
  Initialize the TensorboardUtils.
43
42
 
44
43
  Args:
45
44
  report_dir (str): Directory where Tensorboard logs will be stored.
46
- fw_info (FrameworkInfo): Framework-specific information.
47
45
  fw_impl (FrameworkImplementation): Framework-specific implementation.
48
46
  """
49
47
  self.fw_impl = fw_impl
50
- self.fw_info = fw_info
51
- self.tb_writer = TensorboardWriter(report_dir, fw_info)
48
+ self.tb_writer = TensorboardWriter(report_dir)
52
49
  Logger.info(f"Please run: tensorboard --logdir {self.tb_writer.dir_path}")
53
50
 
54
51
  def get_graph_for_tensorboard_display(self,
@@ -15,7 +15,6 @@
15
15
 
16
16
  from model_compression_toolkit import get_target_platform_capabilities
17
17
  from model_compression_toolkit.constants import TENSORFLOW
18
- from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
19
18
  from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
20
19
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2keras import \
21
20
  AttachTpcToKeras
@@ -40,7 +39,6 @@ class KerasReportUtils(FrameworkReportUtils):
40
39
  Args:
41
40
  report_dir: Logging dir path.
42
41
  """
43
- fw_info = DEFAULT_KERAS_INFO
44
42
  fw_impl = KerasImplementation()
45
43
 
46
44
  # Set the default Target Platform Capabilities (TPC) for Keras.
@@ -49,8 +47,7 @@ class KerasReportUtils(FrameworkReportUtils):
49
47
  framework_platform_capabilities = attach2pytorch.attach(default_tpc)
50
48
 
51
49
  dataset_utils = KerasDatasetUtils()
52
- model_folding = ModelFoldingUtils(fw_info=fw_info,
53
- fw_impl=fw_impl,
50
+ model_folding = ModelFoldingUtils(fw_impl=fw_impl,
54
51
  fw_default_fqc=framework_platform_capabilities)
55
52
 
56
53
  similarity_calculator = SimilarityCalculator(dataset_utils=dataset_utils,
@@ -59,10 +56,8 @@ class KerasReportUtils(FrameworkReportUtils):
59
56
  model_analyzer_utils=KerasModelAnalyzer())
60
57
 
61
58
  tb_utils = KerasTensorboardUtils(report_dir=report_dir,
62
- fw_impl=fw_impl,
63
- fw_info=fw_info)
64
- super().__init__(fw_info,
65
- fw_impl,
59
+ fw_impl=fw_impl)
60
+ super().__init__(fw_impl,
66
61
  similarity_calculator,
67
62
  dataset_utils,
68
63
  model_folding,
@@ -40,18 +40,15 @@ class KerasTensorboardUtils(TensorboardUtils):
40
40
  """
41
41
 
42
42
  def __init__(self, report_dir: str,
43
- fw_info: FrameworkInfo,
44
43
  fw_impl: FrameworkImplementation):
45
44
  """
46
45
  Initialize the KerasTensorboardUtils class with the given parameters.
47
46
 
48
47
  Args:
49
48
  report_dir (str): Directory where the TensorBoard files will be stored.
50
- fw_info (FrameworkInfo): Information about the framework being used.
51
49
  fw_impl (FrameworkImplementation): Implementation functions for the framework.
52
50
  """
53
51
  super().__init__(report_dir,
54
- fw_info,
55
52
  fw_impl)
56
53
 
57
54
  def get_graph_for_tensorboard_display(self,
@@ -20,7 +20,7 @@ from model_compression_toolkit.target_platform_capabilities.targetplatform2frame
20
20
  AttachTpcToPytorch
21
21
 
22
22
  from model_compression_toolkit.xquant.common.framework_report_utils import FrameworkReportUtils
23
- from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
23
+ from model_compression_toolkit.core.pytorch.default_framework_info import PyTorchInfo
24
24
  from model_compression_toolkit.core.pytorch.pytorch_implementation import PytorchImplementation
25
25
  from model_compression_toolkit.xquant.common.model_folding_utils import ModelFoldingUtils
26
26
  from model_compression_toolkit.xquant.common.similarity_calculator import SimilarityCalculator
@@ -30,6 +30,7 @@ from model_compression_toolkit.xquant.pytorch.similarity_functions import Pytorc
30
30
  from model_compression_toolkit.xquant.pytorch.tensorboard_utils import PytorchTensorboardUtils
31
31
  from mct_quantizers.pytorch.metadata import get_metadata
32
32
 
33
+
33
34
  class PytorchReportUtils(FrameworkReportUtils):
34
35
  """
35
36
  Class with various utility components required for generating the report for a Pytorch model.
@@ -39,7 +40,6 @@ class PytorchReportUtils(FrameworkReportUtils):
39
40
  Args:
40
41
  report_dir: Logging dir path.
41
42
  """
42
- fw_info = DEFAULT_PYTORCH_INFO
43
43
  fw_impl = PytorchImplementation()
44
44
  # Set the default Target Platform Capabilities (TPC) for PyTorch.
45
45
  default_tpc = get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL)
@@ -47,8 +47,7 @@ class PytorchReportUtils(FrameworkReportUtils):
47
47
  framework_quantization_capabilities = attach2pytorch.attach(default_tpc)
48
48
 
49
49
  dataset_utils = PytorchDatasetUtils()
50
- model_folding = ModelFoldingUtils(fw_info=fw_info,
51
- fw_impl=fw_impl,
50
+ model_folding = ModelFoldingUtils(fw_impl=fw_impl,
52
51
  fw_default_fqc=framework_quantization_capabilities)
53
52
 
54
53
  similarity_calculator = SimilarityCalculator(dataset_utils=dataset_utils,
@@ -58,11 +57,9 @@ class PytorchReportUtils(FrameworkReportUtils):
58
57
  device=get_working_device())
59
58
 
60
59
  tb_utils = PytorchTensorboardUtils(report_dir=report_dir,
61
- fw_impl=fw_impl,
62
- fw_info=fw_info)
60
+ fw_impl=fw_impl)
63
61
 
64
- super().__init__(fw_info=fw_info,
65
- fw_impl=fw_impl,
62
+ super().__init__(fw_impl=fw_impl,
66
63
  tb_utils=tb_utils,
67
64
  dataset_utils=dataset_utils,
68
65
  similarity_calculator=similarity_calculator,
@@ -41,18 +41,15 @@ class PytorchTensorboardUtils(TensorboardUtils):
41
41
 
42
42
  def __init__(self,
43
43
  report_dir: str,
44
- fw_info: FrameworkInfo,
45
44
  fw_impl: FrameworkImplementation):
46
45
  """
47
46
  Initialize the PytorchTensorboardUtils instance.
48
47
 
49
48
  Args:
50
49
  report_dir: Directory where the reports are stored.
51
- fw_info: Information about the framework being used.
52
50
  fw_impl: Implementation methods for the framework.
53
51
  """
54
52
  super().__init__(report_dir,
55
- fw_info,
56
53
  fw_impl)
57
54
 
58
55
  def get_graph_for_tensorboard_display(self,