mct-nightly 2.1.0.20240811.503__py3-none-any.whl → 2.1.0.20240813.442__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. {mct_nightly-2.1.0.20240811.503.dist-info → mct_nightly-2.1.0.20240813.442.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.1.0.20240811.503.dist-info → mct_nightly-2.1.0.20240813.442.dist-info}/RECORD +64 -62
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/constants.py +0 -7
  5. model_compression_toolkit/core/__init__.py +1 -0
  6. model_compression_toolkit/core/common/graph/functional_node.py +1 -1
  7. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -0
  8. model_compression_toolkit/core/common/quantization/bit_width_config.py +91 -0
  9. model_compression_toolkit/core/common/quantization/core_config.py +8 -4
  10. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -1
  11. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +88 -22
  12. model_compression_toolkit/core/graph_prep_runner.py +16 -9
  13. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +4 -3
  14. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +1 -1
  15. model_compression_toolkit/core/runner.py +1 -0
  16. model_compression_toolkit/data_generation/__init__.py +1 -1
  17. model_compression_toolkit/data_generation/keras/keras_data_generation.py +7 -3
  18. model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py +1 -1
  19. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +4 -3
  20. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -1
  21. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +1 -1
  22. model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +4 -3
  23. model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py +4 -3
  24. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +1 -1
  25. model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py +1 -2
  26. model_compression_toolkit/gptq/keras/quantization_facade.py +8 -5
  27. model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py +4 -3
  28. model_compression_toolkit/gptq/pytorch/quantization_facade.py +2 -1
  29. model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py +1 -1
  30. model_compression_toolkit/pruning/keras/pruning_facade.py +6 -3
  31. model_compression_toolkit/pruning/pytorch/pruning_facade.py +3 -1
  32. model_compression_toolkit/ptq/keras/quantization_facade.py +5 -3
  33. model_compression_toolkit/ptq/pytorch/quantization_facade.py +2 -1
  34. model_compression_toolkit/qat/keras/quantization_facade.py +7 -5
  35. model_compression_toolkit/qat/keras/quantizer/base_keras_qat_quantizer.py +4 -3
  36. model_compression_toolkit/qat/pytorch/quantization_facade.py +2 -1
  37. model_compression_toolkit/qat/pytorch/quantizer/base_pytorch_qat_quantizer.py +1 -1
  38. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/latest/__init__.py +1 -1
  39. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/target_platform_capabilities.py +2 -1
  40. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_keras.py +1 -1
  41. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_keras.py +1 -1
  42. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_keras.py +1 -1
  43. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_keras.py +1 -1
  44. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py +1 -1
  45. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_keras.py +1 -1
  46. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_keras.py +1 -1
  47. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +1 -0
  48. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py +5 -3
  49. model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py +2 -0
  50. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/latest/__init__.py +1 -1
  51. model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/target_platform_capabilities.py +2 -1
  52. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/latest/__init__.py +1 -1
  53. model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/target_platform_capabilities.py +2 -1
  54. model_compression_toolkit/trainable_infrastructure/keras/base_keras_quantizer.py +4 -3
  55. model_compression_toolkit/trainable_infrastructure/keras/load_model.py +4 -3
  56. model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py +4 -3
  57. model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py +1 -1
  58. model_compression_toolkit/verify_packages.py +33 -0
  59. model_compression_toolkit/xquant/common/model_folding_utils.py +1 -0
  60. model_compression_toolkit/xquant/keras/facade_xquant_report.py +4 -3
  61. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +1 -1
  62. {mct_nightly-2.1.0.20240811.503.dist-info → mct_nightly-2.1.0.20240813.442.dist-info}/LICENSE.md +0 -0
  63. {mct_nightly-2.1.0.20240811.503.dist-info → mct_nightly-2.1.0.20240813.442.dist-info}/WHEEL +0 -0
  64. {mct_nightly-2.1.0.20240811.503.dist-info → mct_nightly-2.1.0.20240813.442.dist-info}/top_level.txt +0 -0
@@ -15,9 +15,11 @@
15
15
 
16
16
 
17
17
  import copy
18
- from typing import List, Tuple
18
+ from typing import List, Tuple, Optional
19
19
 
20
+ from mct_quantizers.common.constants import ACTIVATION_N_BITS
20
21
  from model_compression_toolkit.core.common import BaseNode
22
+ from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
21
23
  from model_compression_toolkit.logger import Logger
22
24
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
23
25
  from model_compression_toolkit.core.common.graph.base_graph import Graph
@@ -37,19 +39,21 @@ from model_compression_toolkit.target_platform_capabilities.target_platform.op_q
37
39
 
38
40
  def set_quantization_configuration_to_graph(graph: Graph,
39
41
  quant_config: QuantizationConfig,
42
+ bit_width_config: BitWidthConfig = None,
40
43
  mixed_precision_enable: bool = False,
41
44
  running_gptq: bool = False) -> Graph:
42
45
  """
43
46
  Add quantization configuration for each graph node.
44
47
 
45
48
  Args:
46
- graph: Graph for which to add quantization info to each node.
47
- quant_config: Quantization configuration containing parameters for how the graph should be quantized.
48
- mixed_precision_enable: is mixed precision enabled.
49
- running_gptq: Whether or not a GPTQ optimization is planned to run after the PTQ process.
49
+ graph (Graph): Graph for which to add quantization info to each node.
50
+ quant_config (QuantizationConfig): Quantization configuration containing parameters for how the graph should be quantized.
51
+ bit_width_config (BitWidthConfig): Configuration for manual bit width selection. Defaults to None.
52
+ mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
53
+ running_gptq (bool): Whether or not a GPTQ optimization is planned to run after the PTQ process. Defaults to False.
50
54
 
51
55
  Returns:
52
- The graph with quantization configurations attached to each node in it.
56
+ Graph: The graph with quantization configurations attached to each node in it.
53
57
  """
54
58
 
55
59
  if quant_config.weights_error_method == QuantizationErrorMethod.HMSE:
@@ -62,13 +66,16 @@ def set_quantization_configuration_to_graph(graph: Graph,
62
66
  Logger.warning("Using the HMSE error method for weights quantization parameters search. "
63
67
  "Note: This method may significantly increase runtime during the parameter search process.")
64
68
 
69
+ nodes_to_manipulate_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_bit_widths(graph)
70
+
65
71
  for n in graph.nodes:
66
72
  set_quantization_configs_to_node(node=n,
67
73
  graph=graph,
68
74
  quant_config=quant_config,
69
75
  fw_info=graph.fw_info,
70
76
  tpc=graph.tpc,
71
- mixed_precision_enable=mixed_precision_enable)
77
+ mixed_precision_enable=mixed_precision_enable,
78
+ manual_bit_width_override=nodes_to_manipulate_bit_widths.get(n))
72
79
  return graph
73
80
 
74
81
 
@@ -77,21 +84,32 @@ def set_quantization_configs_to_node(node: BaseNode,
77
84
  quant_config: QuantizationConfig,
78
85
  fw_info: FrameworkInfo,
79
86
  tpc: TargetPlatformCapabilities,
80
- mixed_precision_enable: bool = False):
87
+ mixed_precision_enable: bool = False,
88
+ manual_bit_width_override: Optional[int] = None):
81
89
  """
82
90
  Create and set quantization configurations to a node (for both weights and activation).
83
91
 
84
92
  Args:
85
- node: Node to set its quantization configurations.
86
- graph: Model's internal representation graph.
87
- quant_config: Quantization configuration to generate the node's configurations from.
88
- fw_info: Information needed for quantization about the specific framework.
89
- tpc: TargetPlatformCapabilities to get default OpQuantizationConfig.
90
- mixed_precision_enable: is mixed precision enabled.
93
+ node (BaseNode): Node to set its quantization configurations.
94
+ graph (Graph): Model's internal representation graph.
95
+ quant_config (QuantizationConfig): Quantization configuration to generate the node's configurations from.
96
+ fw_info (FrameworkInfo): Information needed for quantization about the specific framework.
97
+ tpc (TargetPlatformCapabilities): TargetPlatformCapabilities to get default OpQuantizationConfig.
98
+ mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
99
+ manual_bit_width_override (Optional[int]): Specifies a custom bit-width to override the node's activation bit-width. Defaults to None.
91
100
  """
92
101
  node_qc_options = node.get_qco(tpc)
93
102
  base_config, node_qc_options_list = node.filter_node_qco_by_graph(tpc, graph.get_next_nodes(node), node_qc_options)
94
103
 
104
+ # If a manual_bit_width_override is given, filter node_qc_options_list to retain only the options with activation bits equal to manual_bit_width_override,
105
+ # and update base_config accordingly.
106
+ base_config, node_qc_options_list = filter_qc_options_with_manual_bit_width(
107
+ node=node,
108
+ node_qc_options_list=node_qc_options_list,
109
+ base_config=base_config,
110
+ manual_bit_width_override=manual_bit_width_override,
111
+ mixed_precision_enable=mixed_precision_enable)
112
+
95
113
  # Create QC candidates for weights and activation combined
96
114
  weight_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
97
115
  node.candidates_quantization_cfg = _create_node_candidates_qc(quant_config,
@@ -199,16 +217,16 @@ def _create_node_candidates_qc(qc: QuantizationConfig,
199
217
  Create a list of candidates of weights and activation quantization configurations for a node.
200
218
 
201
219
  Args:
202
- qc: Quantization configuration the quantization process should follow.
203
- fw_info: Framework information (e.g., which layers should have their kernels' quantized).
204
- weight_channel_axis: (Output, Input) channel index of the node's kernel.
205
- node_qc_options_list: List of quantization configs of node.
206
- base_config: Base quantization config for node.
207
- node: A node to set quantization configuration candidates to.
208
- mixed_precision_enable: is mixed precision enabled
220
+ qc (QuantizationConfig): Quantization configuration the quantization process should follow.
221
+ fw_info (FrameworkInfo): Framework information (e.g., which layers should have their kernels quantized).
222
+ weight_channel_axis (Tuple[int, int]): (Output, Input) channel index of the node's kernel.
223
+ node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs of node.
224
+ base_config (OpQuantizationConfig): Base quantization config for node.
225
+ node (BaseNode): A node to set quantization configuration candidates to.
226
+ mixed_precision_enable (bool): Whether mixed precision is enabled. Defaults to False.
209
227
 
210
228
  Returns:
211
- List of candidates of weights quantization configurations to set for a node.
229
+ List[CandidateNodeQuantizationConfig]: List of candidates of weights quantization configurations to set for a node.
212
230
  """
213
231
 
214
232
  candidates = []
@@ -231,3 +249,51 @@ def _create_node_candidates_qc(qc: QuantizationConfig,
231
249
  node_attrs_list))
232
250
 
233
251
  return candidates
252
+
253
+
254
+ def filter_qc_options_with_manual_bit_width(
255
+ node: BaseNode,
256
+ node_qc_options_list: List[OpQuantizationConfig],
257
+ base_config: OpQuantizationConfig,
258
+ manual_bit_width_override: Optional[int],
259
+ mixed_precision_enable: bool) -> Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]:
260
+ """
261
+ Update the quantization configurations for a node, allowing manual bit-width overrides if specified.
262
+
263
+ Args:
264
+ node (BaseNode): A node to set quantization configuration candidates to.
265
+ node_qc_options_list (List[OpQuantizationConfig]): List of quantization configs for the node.
266
+ base_config (OpQuantizationConfig): Base quantization config for the node.
267
+ manual_bit_width_override (Optional[int]): Specifies a custom bit-width to override the node's activation bit-width.
268
+ mixed_precision_enable (bool): Whether mixed precision is enabled.
269
+
270
+ Returns:
271
+ Tuple[OpQuantizationConfig, List[OpQuantizationConfig]]: The updated base configuration and the filtered list of quantization configs.
272
+ """
273
+ if manual_bit_width_override is None:
274
+ return base_config, node_qc_options_list
275
+
276
+ # Filter node_qc_options_list to retain only the options with activation bits equal to manual_bit_width_override.
277
+ node_qc_options_list = [op_cfg for op_cfg in node_qc_options_list if
278
+ manual_bit_width_override == op_cfg.activation_n_bits]
279
+
280
+ if len(node_qc_options_list) == 0:
281
+ Logger.critical(f"Manually selected activation bit-width {manual_bit_width_override} is invalid for node {node}.")
282
+ else:
283
+ # Update the base_config to one of the values from the filtered node_qc_options_list.
284
+ # First, check if a configuration similar to the original base_config but with activation bits equal to manual_bit_width_override exists.
285
+ # If it does, use it as the base_config. If not, choose a different configuration from node_qc_options_list.
286
+ Logger.info(f"Setting node {node} bit-width to manually selected bit-width: {manual_bit_width_override} bits.")
287
+ updated_base_config = base_config.clone_and_edit({ACTIVATION_N_BITS, manual_bit_width_override})
288
+ if updated_base_config in node_qc_options_list:
289
+ # If a base_config with the specified manual_bit_width_override exists in the node_qc_options_list,
290
+ # point the base_config to this option.
291
+ base_config = node_qc_options_list[node_qc_options_list.index(updated_base_config)]
292
+ else:
293
+ # Choose a different configuration from node_qc_options_list. If multiple options exist, issue a warning.
294
+ base_config = node_qc_options_list[0]
295
+ if len(node_qc_options_list) > 0 and not mixed_precision_enable:
296
+ Logger.info(
297
+ f"Request received to select {manual_bit_width_override} activation bits. However, the base configuration for layer type {node.type} is missing in the node_qc_options_list."
298
+ f" Overriding base_config with an option that uses {manual_bit_width_override} bit activations.") # pragma: no cover
299
+ return base_config, node_qc_options_list
@@ -20,6 +20,7 @@ from model_compression_toolkit.core.common import FrameworkInfo
20
20
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
21
21
  from model_compression_toolkit.core.common.fusion.layer_fusing import fusion
22
22
  from model_compression_toolkit.core.common.graph.base_graph import Graph
23
+ from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
23
24
  from model_compression_toolkit.core.common.quantization.filter_nodes_candidates import filter_nodes_candidates
24
25
  from model_compression_toolkit.core.common.quantization.quantization_config import DEFAULTCONFIG
25
26
  from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
@@ -38,6 +39,7 @@ def graph_preparation_runner(in_model: Any,
38
39
  fw_info: FrameworkInfo,
39
40
  fw_impl: FrameworkImplementation,
40
41
  tpc: TargetPlatformCapabilities,
42
+ bit_width_config: BitWidthConfig = None,
41
43
  tb_w: TensorboardWriter = None,
42
44
  mixed_precision_enable: bool = False,
43
45
  running_gptq: bool = False) -> Graph:
@@ -50,17 +52,18 @@ def graph_preparation_runner(in_model: Any,
50
52
  - Apply all necessary substitutions to finalize the graph for quantization.
51
53
 
52
54
  Args:
53
- in_model: Model to quantize.
54
- representative_data_gen: Dataset used for calibration.
55
- quantization_config: QuantizationConfig containing parameters of how the model should be quantized.
56
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices,
55
+ in_model (Any): Model to quantize.
56
+ representative_data_gen (Callable): Dataset used for calibration.
57
+ quantization_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be quantized.
58
+ fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g., kernel channels indices,
57
59
  groups of layers by how they should be quantized, etc.).
58
- fw_impl: FrameworkImplementation object with a specific framework methods implementation.
59
- tpc: TargetPlatformCapabilities object that models the inference target platform and
60
+ fw_impl (FrameworkImplementation): FrameworkImplementation object with a specific framework methods implementation.
61
+ tpc (TargetPlatformCapabilities): TargetPlatformCapabilities object that models the inference target platform and
60
62
  the attached framework operator's information.
61
- tb_w: TensorboardWriter object for logging.
62
- mixed_precision_enable: is mixed precision enabled.
63
- running_gptq: Whether or not a GPTQ optimization is planned to run after the PTQ process.
63
+ bit_width_config (BitWidthConfig): Config for bit-width selection. Defaults to None.
64
+ tb_w (TensorboardWriter): TensorboardWriter object for logging.
65
+ mixed_precision_enable (bool): is mixed precision enabled.
66
+ running_gptq (bool): Whether or not a GPTQ optimization is planned to run after the PTQ process.
64
67
 
65
68
  Returns:
66
69
  An internal graph representation of the input model.
@@ -78,6 +81,7 @@ def graph_preparation_runner(in_model: Any,
78
81
  transformed_graph = get_finalized_graph(graph,
79
82
  tpc,
80
83
  quantization_config,
84
+ bit_width_config,
81
85
  fw_info,
82
86
  tb_w,
83
87
  fw_impl,
@@ -90,6 +94,7 @@ def graph_preparation_runner(in_model: Any,
90
94
  def get_finalized_graph(initial_graph: Graph,
91
95
  tpc: TargetPlatformCapabilities,
92
96
  quant_config: QuantizationConfig = DEFAULTCONFIG,
97
+ bit_width_config: BitWidthConfig = None,
93
98
  fw_info: FrameworkInfo = None,
94
99
  tb_w: TensorboardWriter = None,
95
100
  fw_impl: FrameworkImplementation = None,
@@ -104,6 +109,7 @@ def get_finalized_graph(initial_graph: Graph,
104
109
  tpc (TargetPlatformCapabilities): TargetPlatformCapabilities object that describes the desired inference target platform (includes fusing patterns MCT should handle).
105
110
  quant_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be
106
111
  quantized.
112
+ bit_width_config (BitWidthConfig): Config for bit-width selection. Defaults to None.
107
113
  fw_info (FrameworkInfo): Information needed for quantization about the specific framework (e.g.,
108
114
  kernel channels indices, groups of layers by how they should be quantized, etc.)
109
115
  tb_w (TensorboardWriter): TensorboardWriter object to use for logging events such as graphs, histograms, etc.
@@ -147,6 +153,7 @@ def get_finalized_graph(initial_graph: Graph,
147
153
  ######################################
148
154
  transformed_graph = set_quantization_configuration_to_graph(graph=transformed_graph,
149
155
  quant_config=quant_config,
156
+ bit_width_config=bit_width_config,
150
157
  mixed_precision_enable=mixed_precision_enable,
151
158
  running_gptq=running_gptq)
152
159
 
@@ -20,7 +20,7 @@ from model_compression_toolkit.logger import Logger
20
20
  from model_compression_toolkit.constants import TENSORFLOW
21
21
  from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
22
22
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import compute_resource_utilization_data
23
- from model_compression_toolkit.constants import FOUND_TF
23
+ from model_compression_toolkit.verify_packages import FOUND_TF
24
24
 
25
25
  if FOUND_TF:
26
26
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
@@ -89,5 +89,6 @@ else:
89
89
  # If tensorflow is not installed,
90
90
  # we raise an exception when trying to use this function.
91
91
  def keras_resource_utilization_data(*args, **kwargs):
92
- Logger.critical("Tensorflow must be installed to use keras_resource_utilization_data. "
93
- "The 'tensorflow' package is missing.") # pragma: no cover
92
+ Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use "
93
+ "keras_resource_utilization_data. The 'tensorflow' package is either not installed or is "
94
+ "installed with a version higher than 2.15.") # pragma: no cover
@@ -23,7 +23,7 @@ from model_compression_toolkit.core.common.framework_info import FrameworkInfo
23
23
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_data import compute_resource_utilization_data
24
24
  from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
25
25
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
26
- from model_compression_toolkit.constants import FOUND_TORCH
26
+ from model_compression_toolkit.verify_packages import FOUND_TORCH
27
27
 
28
28
  if FOUND_TORCH:
29
29
  from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
@@ -115,6 +115,7 @@ def core_runner(in_model: Any,
115
115
  fw_info,
116
116
  fw_impl,
117
117
  tpc,
118
+ core_config.bit_width_config,
118
119
  tb_w,
119
120
  mixed_precision_enable=core_config.mixed_precision_enable,
120
121
  running_gptq=running_gptq)
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from model_compression_toolkit.constants import FOUND_TORCH, FOUND_TF, FOUND_TORCHVISION
16
+ from model_compression_toolkit.verify_packages import FOUND_TORCHVISION, FOUND_TORCH, FOUND_TF
17
17
  from model_compression_toolkit.data_generation.common.data_generation_config import DataGenerationConfig
18
18
  from model_compression_toolkit.data_generation.common.enums import ImageGranularity, DataInitType, SchedulerType, BNLayerWeightingType, OutputLossType, BatchNormAlignemntLossType, ImagePipelineType, ImageNormalizationType
19
19
 
@@ -16,7 +16,7 @@ import time
16
16
  from typing import Callable, Tuple, List, Dict, Union
17
17
  from tqdm import tqdm
18
18
 
19
- from model_compression_toolkit.constants import FOUND_TF
19
+ from model_compression_toolkit.verify_packages import FOUND_TF
20
20
  from model_compression_toolkit.data_generation.common.constants import DEFAULT_N_ITER, DEFAULT_DATA_GEN_BS
21
21
  from model_compression_toolkit.data_generation.common.data_generation import get_data_generation_classes
22
22
  from model_compression_toolkit.data_generation.common.image_pipeline import image_normalization_dict
@@ -349,8 +349,12 @@ if FOUND_TF:
349
349
  else:
350
350
  def get_keras_data_generation_config(*args, **kwargs):
351
351
  Logger.critical(
352
- "Tensorflow must be installed to use get_tensorflow_data_generation_config. The 'tensorflow' package is missing.") # pragma: no cover
352
+ "Tensorflow must be installed with a version of 2.15 or lower to use "
353
+ "get_tensorflow_data_generation_config. The 'tensorflow' package is missing or is installed with a "
354
+ "version higher than 2.15.") # pragma: no cover
353
355
 
354
356
 
355
357
  def keras_data_generation_experimental(*args, **kwargs):
356
- Logger.critical("Tensorflow must be installed to use tensorflow_data_generation_experimental. The 'tensorflow' package is missing.") # pragma: no cover
358
+ Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use "
359
+ "tensorflow_data_generation_experimental. The 'tensorflow' package is missing or is installed "
360
+ "with a version higher than 2.15.") # pragma: no cover
@@ -17,7 +17,7 @@ from typing import Callable, Any, Tuple, List, Union
17
17
 
18
18
  from tqdm import tqdm
19
19
 
20
- from model_compression_toolkit.constants import FOUND_TORCH, FOUND_TORCHVISION
20
+ from model_compression_toolkit.verify_packages import FOUND_TORCHVISION, FOUND_TORCH
21
21
  from model_compression_toolkit.core.pytorch.utils import set_model
22
22
  from model_compression_toolkit.data_generation.common.constants import DEFAULT_N_ITER, DEFAULT_DATA_GEN_BS
23
23
  from model_compression_toolkit.data_generation.common.data_generation import get_data_generation_classes
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from typing import Callable, Dict
16
16
 
17
- from model_compression_toolkit.constants import FOUND_TF
17
+ from model_compression_toolkit.verify_packages import FOUND_TF
18
18
  from model_compression_toolkit.exporter.model_exporter.fw_agonstic.quantization_format import QuantizationFormat
19
19
  from model_compression_toolkit.logger import Logger
20
20
 
@@ -101,5 +101,6 @@ if FOUND_TF:
101
101
  return exporter.get_custom_objects()
102
102
  else:
103
103
  def keras_export_model(*args, **kwargs):
104
- Logger.critical("Tensorflow must be installed to use keras_export_model. "
105
- "The 'tensorflow' package is missing.") # pragma: no cover
104
+ Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use keras_export_model."
105
+ "The 'tensorflow' package is missing or is installed "
106
+ "with a version higher than 2.15.") # pragma: no cover
@@ -18,7 +18,7 @@ from io import BytesIO
18
18
  import torch.nn
19
19
 
20
20
  from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper
21
- from model_compression_toolkit.constants import FOUND_ONNX
21
+ from model_compression_toolkit.verify_packages import FOUND_ONNX
22
22
  from model_compression_toolkit.logger import Logger
23
23
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
24
24
  from model_compression_toolkit.exporter.model_exporter.pytorch.base_pytorch_exporter import BasePyTorchExporter
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from typing import Callable
16
16
 
17
- from model_compression_toolkit.constants import FOUND_TORCH
17
+ from model_compression_toolkit.verify_packages import FOUND_TORCH
18
18
  from model_compression_toolkit.exporter.model_exporter.fw_agonstic.quantization_format import QuantizationFormat
19
19
  from model_compression_toolkit.exporter.model_exporter.pytorch.export_serialization_format import \
20
20
  PytorchExportSerializationFormat
@@ -16,7 +16,7 @@
16
16
  from typing import Tuple, Callable
17
17
  from model_compression_toolkit.core import common
18
18
  from model_compression_toolkit.core.common import Graph
19
- from model_compression_toolkit.constants import FOUND_TF
19
+ from model_compression_toolkit.verify_packages import FOUND_TF
20
20
  from model_compression_toolkit.core.common.user_info import UserInformation
21
21
  from model_compression_toolkit.logger import Logger
22
22
  import model_compression_toolkit.core as C
@@ -101,5 +101,6 @@ if FOUND_TF:
101
101
  return exportable_model, user_info
102
102
  else:
103
103
  def get_exportable_keras_model(*args, **kwargs): # pragma: no cover
104
- Logger.critical("Tensorflow must be installed to use get_exportable_keras_model. "
105
- "The 'tensorflow' package is missing.") # pragma: no cover
104
+ Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use "
105
+ "get_exportable_keras_model. The 'tensorflow' package is missing or is installed with a "
106
+ "version higher than 2.15.") # pragma: no cover
@@ -15,7 +15,7 @@
15
15
  from typing import Any
16
16
 
17
17
  from mct_quantizers import BaseInferableQuantizer, KerasActivationQuantizationHolder
18
- from model_compression_toolkit.constants import FOUND_TF
18
+ from model_compression_toolkit.verify_packages import FOUND_TF
19
19
  from model_compression_toolkit.logger import Logger
20
20
 
21
21
  if FOUND_TF:
@@ -76,5 +76,6 @@ if FOUND_TF:
76
76
  return True
77
77
  else:
78
78
  def is_keras_layer_exportable(*args, **kwargs): # pragma: no cover
79
- Logger.critical("Tensorflow must be installed to use is_keras_layer_exportable. "
80
- "The 'tensorflow' package is missing.") # pragma: no cover
79
+ Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use "
80
+ "is_keras_layer_exportable. The 'tensorflow' package is missing or is installed with a "
81
+ "version higher than 2.15.") # pragma: no cover
@@ -16,7 +16,7 @@
16
16
  from typing import Union, Callable
17
17
  from model_compression_toolkit.core import common
18
18
  from model_compression_toolkit.core.common import Graph
19
- from model_compression_toolkit.constants import FOUND_TORCH
19
+ from model_compression_toolkit.verify_packages import FOUND_TORCH
20
20
  from model_compression_toolkit.logger import Logger
21
21
  from model_compression_toolkit.core.common import BaseNode
22
22
  import model_compression_toolkit.core as C
@@ -15,8 +15,7 @@
15
15
  from typing import Any
16
16
 
17
17
  from model_compression_toolkit.logger import Logger
18
- from model_compression_toolkit.constants import FOUND_TORCH
19
-
18
+ from model_compression_toolkit.verify_packages import FOUND_TORCH
20
19
 
21
20
  if FOUND_TORCH:
22
21
  import torch.nn as nn
@@ -21,7 +21,8 @@ from model_compression_toolkit.core.common.quantization.quantize_graph_weights i
21
21
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
22
22
  from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
23
23
  from model_compression_toolkit.logger import Logger
24
- from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF, ACT_HESSIAN_DEFAULT_BATCH_SIZE
24
+ from model_compression_toolkit.constants import TENSORFLOW, ACT_HESSIAN_DEFAULT_BATCH_SIZE
25
+ from model_compression_toolkit.verify_packages import FOUND_TF
25
26
  from model_compression_toolkit.core.common.user_info import UserInformation
26
27
  from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, GPTQHessianScoresConfig
27
28
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
@@ -251,10 +252,12 @@ else:
251
252
  # If tensorflow is not installed,
252
253
  # we raise an exception when trying to use these functions.
253
254
  def get_keras_gptq_config(*args, **kwargs):
254
- Logger.critical("Tensorflow must be installed to use get_keras_gptq_config. "
255
- "The 'tensorflow' package is missing.") # pragma: no cover
255
+ Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use "
256
+ "get_keras_gptq_config. The 'tensorflow' package is missing or is "
257
+ "installed with a version higher than 2.15.") # pragma: no cover
256
258
 
257
259
 
258
260
  def keras_gradient_post_training_quantization(*args, **kwargs):
259
- Logger.critical("Tensorflow must be installed to use keras_gradient_post_training_quantization. "
260
- "The 'tensorflow' package is missing.") # pragma: no cover
261
+ Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use "
262
+ "keras_gradient_post_training_quantization. The 'tensorflow' package is missing or is "
263
+ "installed with a version higher than 2.15.") # pragma: no cover
@@ -16,7 +16,7 @@ from abc import abstractmethod
16
16
  from typing import Union, Dict, List
17
17
 
18
18
  from model_compression_toolkit.logger import Logger
19
- from model_compression_toolkit.constants import FOUND_TF
19
+ from model_compression_toolkit.verify_packages import FOUND_TF
20
20
  from model_compression_toolkit.gptq.common.gptq_constants import WEIGHTS_QUANTIZATION_PARAMS
21
21
 
22
22
  from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig
@@ -105,5 +105,6 @@ if FOUND_TF:
105
105
  else:
106
106
  class BaseKerasGPTQTrainableQuantizer: # pragma: no cover
107
107
  def __init__(self, *args, **kwargs):
108
- Logger.critical("Tensorflow must be installed to use BaseKerasGPTQTrainableQuantizer. "
109
- "The 'tensorflow' package is missing.") # pragma: no cover
108
+ Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use "
109
+ "BaseKerasGPTQTrainableQuantizer. The 'tensorflow' package is missing or is "
110
+ "installed with a version higher than 2.15.") # pragma: no cover
@@ -16,7 +16,8 @@ import copy
16
16
 
17
17
  from typing import Callable
18
18
  from model_compression_toolkit.core import common
19
- from model_compression_toolkit.constants import FOUND_TORCH, ACT_HESSIAN_DEFAULT_BATCH_SIZE
19
+ from model_compression_toolkit.constants import ACT_HESSIAN_DEFAULT_BATCH_SIZE
20
+ from model_compression_toolkit.verify_packages import FOUND_TORCH
20
21
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
21
22
  from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT
22
23
  from model_compression_toolkit.logger import Logger
@@ -16,7 +16,7 @@ from abc import abstractmethod
16
16
  from typing import Union, Dict
17
17
 
18
18
  from model_compression_toolkit.logger import Logger
19
- from model_compression_toolkit.constants import FOUND_TORCH
19
+ from model_compression_toolkit.verify_packages import FOUND_TORCH
20
20
  from model_compression_toolkit.gptq.common.gptq_constants import WEIGHTS_QUANTIZATION_PARAMS
21
21
 
22
22
  from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \
@@ -16,11 +16,13 @@
16
16
  from typing import Callable, Tuple
17
17
 
18
18
  from model_compression_toolkit import get_target_platform_capabilities
19
- from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
19
+ from model_compression_toolkit.constants import TENSORFLOW
20
+ from model_compression_toolkit.verify_packages import FOUND_TF
20
21
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
21
22
  from model_compression_toolkit.core.common.pruning.pruner import Pruner
22
23
  from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
23
24
  from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
25
+ from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
24
26
  from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
25
27
  from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
26
28
  from model_compression_toolkit.logger import Logger
@@ -148,5 +150,6 @@ else:
148
150
  # If tensorflow is not installed,
149
151
  # we raise an exception when trying to use these functions.
150
152
  def keras_pruning_experimental(*args, **kwargs):
151
- Logger.critical("Tensorflow must be installed to use keras_pruning_experimental. "
152
- "The 'tensorflow' package is missing.") # pragma: no cover
153
+ Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use "
154
+ "keras_pruning_experimental. The 'tensorflow' package is missing or is "
155
+ "installed with a version higher than 2.15.") # pragma: no cover
@@ -15,11 +15,13 @@
15
15
 
16
16
  from typing import Callable, Tuple
17
17
  from model_compression_toolkit import get_target_platform_capabilities
18
- from model_compression_toolkit.constants import FOUND_TORCH, PYTORCH
18
+ from model_compression_toolkit.constants import PYTORCH
19
+ from model_compression_toolkit.verify_packages import FOUND_TORCH
19
20
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
20
21
  from model_compression_toolkit.core.common.pruning.pruner import Pruner
21
22
  from model_compression_toolkit.core.common.pruning.pruning_config import PruningConfig
22
23
  from model_compression_toolkit.core.common.pruning.pruning_info import PruningInfo
24
+ from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
23
25
  from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_quantization_configuration_to_graph
24
26
  from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
25
27
  from model_compression_toolkit.logger import Logger
@@ -21,7 +21,8 @@ from model_compression_toolkit.core.analyzer import analyzer_model_quantization
21
21
  from model_compression_toolkit.core.common.quantization.quantize_graph_weights import quantize_graph_weights
22
22
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
23
23
  from model_compression_toolkit.logger import Logger
24
- from model_compression_toolkit.constants import TENSORFLOW, FOUND_TF
24
+ from model_compression_toolkit.constants import TENSORFLOW
25
+ from model_compression_toolkit.verify_packages import FOUND_TF
25
26
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
26
27
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
27
28
  MixedPrecisionQuantizationConfig
@@ -178,5 +179,6 @@ else:
178
179
  # If tensorflow is not installed,
179
180
  # we raise an exception when trying to use these functions.
180
181
  def keras_post_training_quantization(*args, **kwargs):
181
- Logger.critical("Tensorflow must be installed to use keras_post_training_quantization. "
182
- "The 'tensorflow' package is missing.") # pragma: no cover
182
+ Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use "
183
+ "keras_post_training_quantization. The 'tensorflow' package is missing or is "
184
+ "installed with a version higher than 2.15.") # pragma: no cover
@@ -18,7 +18,8 @@ from typing import Callable
18
18
 
19
19
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
20
20
  from model_compression_toolkit.logger import Logger
21
- from model_compression_toolkit.constants import PYTORCH, FOUND_TORCH
21
+ from model_compression_toolkit.constants import PYTORCH
22
+ from model_compression_toolkit.verify_packages import FOUND_TORCH
22
23
  from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
23
24
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
24
25
  from model_compression_toolkit.core import CoreConfig
@@ -19,7 +19,7 @@ from functools import partial
19
19
  from model_compression_toolkit.core import CoreConfig
20
20
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
21
21
  from model_compression_toolkit.logger import Logger
22
- from model_compression_toolkit.constants import FOUND_TF
22
+ from model_compression_toolkit.verify_packages import FOUND_TF
23
23
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import ResourceUtilization
24
24
  from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import \
25
25
  MixedPrecisionQuantizationConfig
@@ -291,10 +291,12 @@ else:
291
291
  # If tensorflow is not installed,
292
292
  # we raise an exception when trying to use these functions.
293
293
  def keras_quantization_aware_training_init_experimental(*args, **kwargs):
294
- Logger.critical("Tensorflow must be installed to use keras_quantization_aware_training_init_experimental. "
295
- "The 'tensorflow' package is missing.") # pragma: no cover
294
+ Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use "
295
+ "keras_quantization_aware_training_init_experimental. The 'tensorflow' package is missing "
296
+ "or is installed with a version higher than 2.15.") # pragma: no cover
296
297
 
297
298
 
298
299
  def keras_quantization_aware_training_finalize_experimental(*args, **kwargs):
299
- Logger.critical("Tensorflow must be installed to use keras_quantization_aware_training_finalize_experimental. "
300
- "The 'tensorflow' package is missing.") # pragma: no cover
300
+ Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use "
301
+ "keras_quantization_aware_training_finalize_experimental. The 'tensorflow' package is missing "
302
+ "or is installed with a version higher than 2.15.") # pragma: no cover
@@ -15,7 +15,7 @@
15
15
  from typing import Union
16
16
 
17
17
  from model_compression_toolkit.logger import Logger
18
- from model_compression_toolkit.constants import FOUND_TF
18
+ from model_compression_toolkit.verify_packages import FOUND_TF
19
19
 
20
20
  from model_compression_toolkit.trainable_infrastructure import TrainableQuantizerWeightsConfig, \
21
21
  TrainableQuantizerActivationConfig, BaseKerasTrainableQuantizer
@@ -44,5 +44,6 @@ else: # pragma: no cover
44
44
  quantization_config: Union[TrainableQuantizerWeightsConfig, TrainableQuantizerActivationConfig]):
45
45
 
46
46
  super().__init__(quantization_config)
47
- Logger.critical("Tensorflow must be installed to use BaseKerasQATTrainableQuantizer. "
48
- "The 'tensorflow' package is missing.") # pragma: no cover
47
+ Logger.critical("Tensorflow must be installed with a version of 2.15 or lower to use "
48
+ "BaseKerasQATTrainableQuantizer. The 'tensorflow' package is missing "
49
+ "or is installed with a version higher than 2.15.") # pragma: no cover