mct-nightly 2.4.0.20250616.616__py3-none-any.whl → 2.4.0.20250618.606__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/METADATA +1 -1
- {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/RECORD +120 -120
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/analyzer.py +2 -5
- model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -3
- model_compression_toolkit/core/common/framework_implementation.py +10 -22
- model_compression_toolkit/core/common/framework_info.py +105 -68
- model_compression_toolkit/core/common/graph/base_graph.py +15 -42
- model_compression_toolkit/core/common/graph/base_node.py +103 -42
- model_compression_toolkit/core/common/graph/functional_node.py +18 -1
- model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +7 -13
- model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +8 -18
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +4 -7
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +2 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -5
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +3 -6
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +5 -10
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -5
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +4 -8
- model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +2 -7
- model_compression_toolkit/core/common/model_collector.py +10 -20
- model_compression_toolkit/core/common/model_validation.py +1 -4
- model_compression_toolkit/core/common/network_editors/actions.py +14 -38
- model_compression_toolkit/core/common/network_editors/edit_network.py +1 -4
- model_compression_toolkit/core/common/pruning/channels_grouping.py +1 -5
- model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +0 -6
- model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -15
- model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +3 -7
- model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -4
- model_compression_toolkit/core/common/pruning/memory_calculator.py +5 -13
- model_compression_toolkit/core/common/pruning/prune_graph.py +1 -4
- model_compression_toolkit/core/common/pruning/pruner.py +1 -6
- model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +5 -13
- model_compression_toolkit/core/common/pruning/pruning_section.py +9 -18
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -1
- model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +10 -12
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +4 -3
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -11
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +8 -22
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -2
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -3
- model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -13
- model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +3 -9
- model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -10
- model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +1 -6
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -3
- model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -6
- model_compression_toolkit/core/common/substitutions/scale_equalization.py +5 -21
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -19
- model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -3
- model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
- model_compression_toolkit/core/common/visualization/nn_visualizer.py +3 -8
- model_compression_toolkit/core/common/visualization/tensorboard_writer.py +6 -8
- model_compression_toolkit/core/graph_prep_runner.py +2 -16
- model_compression_toolkit/core/keras/back2framework/float_model_builder.py +0 -4
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +0 -5
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +8 -15
- model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +0 -4
- model_compression_toolkit/core/keras/default_framework_info.py +138 -87
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -7
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +0 -1
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -5
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +8 -16
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +3 -13
- model_compression_toolkit/core/keras/keras_implementation.py +15 -35
- model_compression_toolkit/core/keras/keras_model_validation.py +6 -7
- model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -13
- model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +11 -34
- model_compression_toolkit/core/keras/resource_utilization_data_facade.py +0 -2
- model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +0 -3
- model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +3 -12
- model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +9 -16
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -5
- model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +2 -3
- model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +0 -4
- model_compression_toolkit/core/pytorch/default_framework_info.py +100 -74
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +3 -4
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +4 -8
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
- model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +3 -12
- model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +16 -41
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +12 -32
- model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -5
- model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
- model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +0 -3
- model_compression_toolkit/core/quantization_prep_runner.py +4 -9
- model_compression_toolkit/core/runner.py +5 -15
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
- model_compression_toolkit/gptq/common/gptq_graph.py +5 -11
- model_compression_toolkit/gptq/common/gptq_training.py +1 -8
- model_compression_toolkit/gptq/keras/gptq_training.py +3 -9
- model_compression_toolkit/gptq/keras/graph_info.py +4 -6
- model_compression_toolkit/gptq/keras/quantization_facade.py +5 -8
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
- model_compression_toolkit/gptq/pytorch/gptq_training.py +3 -9
- model_compression_toolkit/gptq/pytorch/graph_info.py +1 -3
- model_compression_toolkit/gptq/pytorch/quantization_facade.py +5 -7
- model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
- model_compression_toolkit/gptq/runner.py +1 -7
- model_compression_toolkit/pruning/keras/pruning_facade.py +2 -3
- model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -3
- model_compression_toolkit/ptq/keras/quantization_facade.py +5 -10
- model_compression_toolkit/ptq/pytorch/quantization_facade.py +4 -8
- model_compression_toolkit/ptq/runner.py +1 -4
- model_compression_toolkit/qat/common/qat_config.py +2 -6
- model_compression_toolkit/qat/keras/quantization_facade.py +7 -10
- model_compression_toolkit/qat/pytorch/quantization_facade.py +6 -10
- model_compression_toolkit/xquant/common/core_report_generator.py +1 -1
- model_compression_toolkit/xquant/common/framework_report_utils.py +0 -3
- model_compression_toolkit/xquant/common/model_folding_utils.py +1 -6
- model_compression_toolkit/xquant/common/tensorboard_utils.py +1 -4
- model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -8
- model_compression_toolkit/xquant/keras/tensorboard_utils.py +0 -3
- model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +5 -8
- model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +0 -3
- {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/WHEEL +0 -0
- {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.4.0.20250616.616.dist-info → mct_nightly-2.4.0.20250618.606.dist-info}/top_level.txt +0 -0
@@ -14,10 +14,11 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
import copy
|
17
|
-
from typing import Dict, Any, Tuple, List, Type, Union
|
17
|
+
from typing import Dict, Any, Tuple, List, Type, Union, NamedTuple
|
18
18
|
|
19
19
|
import numpy as np
|
20
20
|
|
21
|
+
from model_compression_toolkit.core.common.framework_info import get_fw_info, ChannelAxisMapping
|
21
22
|
from model_compression_toolkit.constants import WEIGHTS_NBITS_ATTRIBUTE, CORRECTED_BIAS_ATTRIBUTE, \
|
22
23
|
ACTIVATION_N_BITS_ATTRIBUTE, FP32_BYTES_PER_PARAMETER
|
23
24
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig, \
|
@@ -34,11 +35,21 @@ from model_compression_toolkit.target_platform_capabilities.targetplatform2frame
|
|
34
35
|
WeightAttrT = Union[str, int]
|
35
36
|
|
36
37
|
|
38
|
+
class NodeFrameworkInfo(NamedTuple):
|
39
|
+
"""
|
40
|
+
Node's specific framework information.
|
41
|
+
"""
|
42
|
+
channel_axis: ChannelAxisMapping
|
43
|
+
out_channel_axis: int
|
44
|
+
minmax: Tuple[float, float]
|
45
|
+
kernel_attr: str
|
46
|
+
is_kernel_op: bool
|
47
|
+
|
48
|
+
|
37
49
|
class BaseNode:
|
38
50
|
"""
|
39
51
|
Class to represent a node in a graph that represents the model.
|
40
52
|
"""
|
41
|
-
|
42
53
|
def __init__(self,
|
43
54
|
name: str,
|
44
55
|
framework_attr: Dict[str, Any],
|
@@ -88,6 +99,78 @@ class BaseNode:
|
|
88
99
|
self.prior_info = None
|
89
100
|
self.has_activation = has_activation
|
90
101
|
self.is_custom = is_custom
|
102
|
+
self.node_fw_info = self._get_fw_node_attrs(layer_class, framework_attr)
|
103
|
+
|
104
|
+
def _get_fw_node_attrs(self, node_type, framework_attr):
|
105
|
+
fw_info = get_fw_info()
|
106
|
+
return None if fw_info is None else NodeFrameworkInfo(
|
107
|
+
fw_info.get_kernel_channels(node_type),
|
108
|
+
fw_info.get_out_channel_axis(node_type),
|
109
|
+
fw_info.get_layer_min_max(node_type, framework_attr),
|
110
|
+
fw_info.get_kernel_op_attribute(node_type),
|
111
|
+
fw_info.is_kernel_op(node_type)
|
112
|
+
)
|
113
|
+
|
114
|
+
def _assert_fw_info_exists(self):
|
115
|
+
"""
|
116
|
+
Verify NodeFrameworkInfo was initialized.
|
117
|
+
"""
|
118
|
+
assert self.node_fw_info is not None, f"NodeFrameworkInfo not initialized for node {self.name}" # pragma: no cover
|
119
|
+
|
120
|
+
@property
|
121
|
+
def channel_axis(self) -> ChannelAxisMapping:
|
122
|
+
"""
|
123
|
+
Extract channels axis from node's NodeFrameworkInfo.
|
124
|
+
|
125
|
+
Returns:
|
126
|
+
Channels axis named tuple.
|
127
|
+
"""
|
128
|
+
self._assert_fw_info_exists()
|
129
|
+
return self.node_fw_info.channel_axis
|
130
|
+
|
131
|
+
@property
|
132
|
+
def out_channel_axis(self) -> int:
|
133
|
+
"""
|
134
|
+
Extract output channel axis from node's NodeFrameworkInfo.
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
Output channel axis.
|
138
|
+
"""
|
139
|
+
self._assert_fw_info_exists()
|
140
|
+
return self.node_fw_info.out_channel_axis
|
141
|
+
|
142
|
+
@property
|
143
|
+
def minmax(self) -> Tuple[float, float]:
|
144
|
+
"""
|
145
|
+
Extract expected min-max activation values from node's NodeFrameworkInfo.
|
146
|
+
|
147
|
+
Returns:
|
148
|
+
A tuple of min-max values.
|
149
|
+
"""
|
150
|
+
self._assert_fw_info_exists()
|
151
|
+
return self.node_fw_info.minmax
|
152
|
+
|
153
|
+
@property
|
154
|
+
def kernel_attr(self) -> str:
|
155
|
+
"""
|
156
|
+
Extract kernel name from node's NodeFrameworkInfo.
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
Kernel name.
|
160
|
+
"""
|
161
|
+
self._assert_fw_info_exists()
|
162
|
+
return self.node_fw_info.kernel_attr
|
163
|
+
|
164
|
+
@property
|
165
|
+
def is_kernel_op(self) -> bool:
|
166
|
+
"""
|
167
|
+
Check if kernel exists for the node.
|
168
|
+
|
169
|
+
Returns:
|
170
|
+
Whether the node has a kernel or not.
|
171
|
+
"""
|
172
|
+
self._assert_fw_info_exists()
|
173
|
+
return self.node_fw_info.is_kernel_op
|
91
174
|
|
92
175
|
@property
|
93
176
|
def type(self):
|
@@ -298,14 +381,11 @@ class BaseNode:
|
|
298
381
|
|
299
382
|
return input_tensors
|
300
383
|
|
301
|
-
def get_num_parameters(self
|
384
|
+
def get_num_parameters(self) -> Tuple[int,int]:
|
302
385
|
"""
|
303
386
|
Compute the number of parameters the node holds.
|
304
387
|
It returns a tuple: Number of quantized parameters, number of float parameters.
|
305
388
|
|
306
|
-
Args:
|
307
|
-
fw_info: Framework info to decide which attributes should be quantized.
|
308
|
-
|
309
389
|
Returns:
|
310
390
|
A tuple of (Number of quantized parameters, number of float parameters).
|
311
391
|
|
@@ -314,11 +394,10 @@ class BaseNode:
|
|
314
394
|
|
315
395
|
q_node_num_params = 0
|
316
396
|
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
q_node_num_params += w.flatten().shape[0]
|
397
|
+
if self.kernel_attr is not None:
|
398
|
+
w = self.get_weights_by_keys(self.kernel_attr)
|
399
|
+
if w is not None:
|
400
|
+
q_node_num_params += w.flatten().shape[0]
|
322
401
|
|
323
402
|
f_node_num_params = total_node_params - q_node_num_params
|
324
403
|
|
@@ -326,22 +405,19 @@ class BaseNode:
|
|
326
405
|
assert int(f_node_num_params) == f_node_num_params
|
327
406
|
return int(q_node_num_params), int(f_node_num_params)
|
328
407
|
|
329
|
-
def get_memory_bytes(self
|
408
|
+
def get_memory_bytes(self) -> float:
|
330
409
|
"""
|
331
410
|
Compute the number of bytes the node's memory requires.
|
332
411
|
|
333
|
-
Args:
|
334
|
-
fw_info: Framework info to decide which attributes should be quantized.
|
335
|
-
|
336
412
|
Returns: Number of bytes the node's memory requires.
|
337
413
|
|
338
414
|
"""
|
339
415
|
# TODO: this method is used for tensorboard only. If we want to enable logging of other attributes memory
|
340
416
|
# then it needs to be modified. But, it might be better to remove this method from the BaseNode completely.
|
341
|
-
kernel_attr =
|
417
|
+
kernel_attr = self.kernel_attr
|
342
418
|
if kernel_attr is None:
|
343
419
|
return 0
|
344
|
-
q_params, f_params = self.get_num_parameters(
|
420
|
+
q_params, f_params = self.get_num_parameters()
|
345
421
|
if self.final_weights_quantization_cfg is None: # float coefficients
|
346
422
|
memory = (f_params+q_params) * FP32_BYTES_PER_PARAMETER
|
347
423
|
else:
|
@@ -351,15 +427,12 @@ class BaseNode:
|
|
351
427
|
|
352
428
|
return memory
|
353
429
|
|
354
|
-
def get_unified_weights_candidates_dict(self
|
430
|
+
def get_unified_weights_candidates_dict(self) -> Dict[str, Any]:
|
355
431
|
"""
|
356
432
|
In Mixed-Precision, a node's kernel can have multiple candidates for weights quantization configuration.
|
357
433
|
In order to display a single view of a node (for example, for logging in TensorBoard) we need a way
|
358
434
|
to create a single dictionary from all candidates.
|
359
|
-
This method is aimed to build such
|
360
|
-
|
361
|
-
Args:
|
362
|
-
fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
|
435
|
+
This method is aimed to build such a unified dictionary for a node.
|
363
436
|
|
364
437
|
Returns: A dictionary containing information from node's weight quantization configuration candidates.
|
365
438
|
|
@@ -369,7 +442,7 @@ class BaseNode:
|
|
369
442
|
# We assume that only the kernel attribute have more than one candidate, since we only allow to
|
370
443
|
# quantize the kernel using mixed precision
|
371
444
|
# TODO: need to modify if we want to present a unified config for other attributes
|
372
|
-
kernel_attr =
|
445
|
+
kernel_attr = self.kernel_attr
|
373
446
|
if kernel_attr is None:
|
374
447
|
# This node doesn't have a kernel attribute
|
375
448
|
return {}
|
@@ -437,20 +510,13 @@ class BaseNode:
|
|
437
510
|
candidates = self.get_all_weights_attr_candidates(attr)
|
438
511
|
return all(candidate == candidates[0] for candidate in candidates[1:])
|
439
512
|
|
440
|
-
def has_kernel_weight_to_quantize(self
|
513
|
+
def has_kernel_weight_to_quantize(self):
|
441
514
|
"""
|
442
|
-
Checks whether the node has kernel attribute that need to be quantized according to the framework info.
|
515
|
+
Checks whether the node has kernel attribute that need to be quantized according to the node's framework info.
|
443
516
|
|
444
|
-
|
445
|
-
fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
|
446
|
-
|
447
|
-
Returns: Whether the node has weights that need to be quantized.
|
517
|
+
Returns: Whether the node's kernel need to be quantized.
|
448
518
|
"""
|
449
|
-
|
450
|
-
for attr in attrs:
|
451
|
-
if attr and self.get_weights_by_keys(attr) is not None:
|
452
|
-
return True
|
453
|
-
return False
|
519
|
+
return self.kernel_attr and self.get_weights_by_keys(self.kernel_attr) is not None
|
454
520
|
|
455
521
|
def has_any_weight_attr_to_quantize(self) -> bool:
|
456
522
|
"""
|
@@ -724,7 +790,7 @@ class BaseNode:
|
|
724
790
|
Logger.critical(f"SIMD is expected to be a non-positive integer but found: {_simd}")
|
725
791
|
return _simd
|
726
792
|
|
727
|
-
def sort_node_candidates(self
|
793
|
+
def sort_node_candidates(self):
|
728
794
|
"""
|
729
795
|
Sorts the node candidates.
|
730
796
|
We assume that the candidates are ordered in the following way (for mixed precision purposes):
|
@@ -733,16 +799,11 @@ class BaseNode:
|
|
733
799
|
- If the node doesn't have a kernel we only consider the candidate activation number of bits to sort
|
734
800
|
the candidates in descending order.
|
735
801
|
The operation is done inplace.
|
736
|
-
|
737
|
-
Args:
|
738
|
-
fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
|
739
|
-
|
740
802
|
"""
|
741
803
|
if self.candidates_quantization_cfg is not None:
|
742
|
-
kernel_attr
|
743
|
-
if kernel_attr is not None:
|
804
|
+
if self.kernel_attr is not None:
|
744
805
|
self.candidates_quantization_cfg.sort(
|
745
|
-
key=lambda c: (c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits,
|
806
|
+
key=lambda c: (c.weights_quantization_cfg.get_attr_config(self.kernel_attr).weights_n_bits,
|
746
807
|
c.activation_quantization_cfg.activation_n_bits), reverse=True)
|
747
808
|
else:
|
748
809
|
self.candidates_quantization_cfg.sort(key=lambda c: c.activation_quantization_cfg.activation_n_bits,
|
@@ -1,6 +1,21 @@
|
|
1
|
+
# Copyright 2021 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
1
16
|
from typing import Dict, Any, Tuple, Type, List, Union
|
2
17
|
|
3
|
-
from model_compression_toolkit.
|
18
|
+
from model_compression_toolkit.core.common.framework_info import get_fw_info
|
4
19
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
5
20
|
import numpy as np
|
6
21
|
|
@@ -45,6 +60,7 @@ class FunctionalNode(BaseNode):
|
|
45
60
|
inputs_as_list: Whether to pass the node its input tensors as a list or not when calling the layer.
|
46
61
|
has_activation: Whether the node has activations that we might want to quantize.
|
47
62
|
tensor_input_allocs: A list of indices and strings for allocations input tensors in the node's args and kwargs.
|
63
|
+
|
48
64
|
"""
|
49
65
|
|
50
66
|
super().__init__(name,
|
@@ -63,6 +79,7 @@ class FunctionalNode(BaseNode):
|
|
63
79
|
self.op_call_args = list(op_call_args)
|
64
80
|
self.functional_op = functional_op
|
65
81
|
self.tensor_input_allocs = [] if tensor_input_allocs is None else tensor_input_allocs
|
82
|
+
self.node_fw_info = self._get_fw_node_attrs(functional_op, framework_attr)
|
66
83
|
|
67
84
|
@property
|
68
85
|
def type(self):
|
@@ -15,14 +15,13 @@
|
|
15
15
|
import abc
|
16
16
|
import uuid
|
17
17
|
|
18
|
-
from model_compression_toolkit.core import FrameworkInfo
|
19
18
|
from model_compression_toolkit.constants import VIRTUAL_ACTIVATION_WEIGHTS_NODE_PREFIX, \
|
20
19
|
VIRTUAL_WEIGHTS_SUFFIX, VIRTUAL_ACTIVATION_SUFFIX, FLOAT_BITWIDTH
|
21
|
-
from model_compression_toolkit.core.common.framework_info import DEFAULT_KERNEL_ATTRIBUTES
|
22
20
|
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
23
21
|
from model_compression_toolkit.core.common.quantization.candidate_node_quantization_config import \
|
24
22
|
CandidateNodeQuantizationConfig
|
25
23
|
from model_compression_toolkit.core.common.quantization.node_quantization_config import ActivationQuantizationMode
|
24
|
+
from model_compression_toolkit.core.common.framework_info import DEFAULT_KERNEL_ATTRIBUTE
|
26
25
|
|
27
26
|
|
28
27
|
class VirtualNode(BaseNode, abc.ABC):
|
@@ -128,28 +127,23 @@ class VirtualActivationWeightsNode(VirtualNode):
|
|
128
127
|
|
129
128
|
def __init__(self,
|
130
129
|
act_node: BaseNode,
|
131
|
-
weights_node: BaseNode
|
132
|
-
fw_info: FrameworkInfo):
|
130
|
+
weights_node: BaseNode):
|
133
131
|
"""
|
134
132
|
Init a VirtualActivationWeightsNode object.
|
135
133
|
|
136
134
|
Args:
|
137
135
|
act_node: The original activation node.
|
138
136
|
weights_node: The original weights node.
|
139
|
-
fw_info: A FrameworkInfo object with framework specific information.
|
140
137
|
"""
|
141
138
|
# Validate weights node
|
142
|
-
kernel_attrs = fw_info.get_kernel_op_attributes(weights_node.type)
|
143
|
-
assert len(kernel_attrs) == 1 and kernel_attrs[0] is not None, f'Expected exactly one kernel attr, {kernel_attrs}'
|
144
|
-
kernel_attr = kernel_attrs[0]
|
145
139
|
conf_weights = [attr for attr in weights_node.weights if weights_node.is_configurable_weight(attr)]
|
146
|
-
if len(conf_weights) > 1 or len(conf_weights) == 1 and not weights_node.is_configurable_weight(kernel_attr):
|
140
|
+
if len(conf_weights) > 1 or len(conf_weights) == 1 and not weights_node.is_configurable_weight(weights_node.kernel_attr):
|
147
141
|
raise NotImplementedError(f'Only kernel weight can be configurable. Got configurable {conf_weights}.')
|
148
142
|
|
149
143
|
weights = weights_node.weights.copy()
|
150
144
|
act_node_w_rename = {}
|
151
145
|
if act_node.weights:
|
152
|
-
if
|
146
|
+
if act_node.kernel_attr != DEFAULT_KERNEL_ATTRIBUTE:
|
153
147
|
raise NotImplementedError(f'Node {act_node} with kernel cannot be used as activation for '
|
154
148
|
f'VirtualActivationWeightsNode.')
|
155
149
|
if act_node.has_any_configurable_weight():
|
@@ -157,7 +151,7 @@ class VirtualActivationWeightsNode(VirtualNode):
|
|
157
151
|
'VirtualActivationWeightsNode.')
|
158
152
|
# combine weights from activation and weights
|
159
153
|
for w_id, w in act_node.weights.items():
|
160
|
-
if w_id not in weights and not (isinstance(w_id, str) and kernel_attr in w_id):
|
154
|
+
if w_id not in weights and not (isinstance(w_id, str) and weights_node.kernel_attr in w_id):
|
161
155
|
weights[w_id] = w
|
162
156
|
continue
|
163
157
|
# if same identifier is used as in weight nodes (or contains the kernel substring), generate a new
|
@@ -185,7 +179,7 @@ class VirtualActivationWeightsNode(VirtualNode):
|
|
185
179
|
self.original_weights_node = weights_node
|
186
180
|
|
187
181
|
v_candidates = []
|
188
|
-
weights_candidates_quantization_cfg = weights_node.get_unique_weights_candidates(kernel_attr)
|
182
|
+
weights_candidates_quantization_cfg = weights_node.get_unique_weights_candidates(weights_node.kernel_attr)
|
189
183
|
for c_a in act_node.candidates_quantization_cfg:
|
190
184
|
for c_w in weights_candidates_quantization_cfg:
|
191
185
|
composed_candidate = CandidateNodeQuantizationConfig(activation_quantization_cfg=c_a.activation_quantization_cfg,
|
@@ -203,7 +197,7 @@ class VirtualActivationWeightsNode(VirtualNode):
|
|
203
197
|
v_candidates.append(composed_candidate)
|
204
198
|
|
205
199
|
# sorting the candidates by weights number of bits first and then by activation number of bits (reversed order)
|
206
|
-
v_candidates.sort(key=lambda c: (c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits,
|
200
|
+
v_candidates.sort(key=lambda c: (c.weights_quantization_cfg.get_attr_config(weights_node.kernel_attr).weights_n_bits,
|
207
201
|
c.activation_quantization_cfg.activation_n_bits), reverse=True)
|
208
202
|
|
209
203
|
self.candidates_quantization_cfg = v_candidates
|
@@ -37,20 +37,18 @@ def set_bit_widths(mixed_precision_enable: bool,
|
|
37
37
|
"""
|
38
38
|
if mixed_precision_enable:
|
39
39
|
assert all([len(n.candidates_quantization_cfg) > 0
|
40
|
-
for n in graph.get_configurable_sorted_nodes(
|
40
|
+
for n in graph.get_configurable_sorted_nodes()]), \
|
41
41
|
"All configurable nodes in graph should have at least one candidate configuration in mixed precision mode"
|
42
42
|
|
43
43
|
# Get a list of nodes' names we need to finalize (that they have at least one weight qc candidate).
|
44
|
-
sorted_nodes_names = graph.get_configurable_sorted_nodes_names(
|
44
|
+
sorted_nodes_names = graph.get_configurable_sorted_nodes_names()
|
45
45
|
|
46
46
|
for node in graph.nodes: # set a specific node qc for each node final qc
|
47
47
|
# If it's reused, take the configuration that the base node has
|
48
48
|
node_name = node.name if not node.reuse else '_'.join(node.name.split('_')[:-2])
|
49
49
|
if node_name in sorted_nodes_names: # only configurable nodes are in this list
|
50
50
|
node_index_in_graph = sorted_nodes_names.index(node_name)
|
51
|
-
_set_node_final_qc(bit_widths_config[node_index_in_graph],
|
52
|
-
node,
|
53
|
-
graph.fw_info)
|
51
|
+
_set_node_final_qc(bit_widths_config[node_index_in_graph], node)
|
54
52
|
else:
|
55
53
|
if node.is_activation_quantization_enabled():
|
56
54
|
# If we are here, this means that we are in weights-only mixed-precision
|
@@ -83,8 +81,7 @@ def set_bit_widths(mixed_precision_enable: bool,
|
|
83
81
|
|
84
82
|
|
85
83
|
def _get_node_qc_by_bit_widths(node: BaseNode,
|
86
|
-
node_bit_width_cfg: int
|
87
|
-
fw_info) -> Any:
|
84
|
+
node_bit_width_cfg: int) -> Any:
|
88
85
|
"""
|
89
86
|
Get the node's quantization configuration that
|
90
87
|
matches to the bit width index as in the MP configuration bit_width_cfg.
|
@@ -93,21 +90,18 @@ def _get_node_qc_by_bit_widths(node: BaseNode,
|
|
93
90
|
Args:
|
94
91
|
node: Node to get its quantization configuration candidate.
|
95
92
|
node_bit_width_cfg: Configuration which determines the node's desired bit width.
|
96
|
-
fw_info: Information relevant to a specific framework about how layers should be quantized.
|
97
93
|
|
98
94
|
Returns:
|
99
95
|
Node quantization configuration if it was found, or None otherwise.
|
100
96
|
"""
|
101
97
|
# only the weights kernel attribute is quantized in weights mixed precision at the moment
|
102
|
-
kernel_attr = fw_info.get_kernel_op_attributes(node.type)
|
103
|
-
|
104
98
|
if node.is_activation_quantization_enabled():
|
105
99
|
qc = node.candidates_quantization_cfg[node_bit_width_cfg]
|
106
100
|
|
107
101
|
return qc
|
108
102
|
|
109
|
-
elif kernel_attr is not None:
|
110
|
-
if node.is_weights_quantization_enabled(kernel_attr
|
103
|
+
elif node.kernel_attr is not None:
|
104
|
+
if node.is_weights_quantization_enabled(node.kernel_attr):
|
111
105
|
qc = node.candidates_quantization_cfg[node_bit_width_cfg]
|
112
106
|
|
113
107
|
return qc
|
@@ -116,8 +110,7 @@ def _get_node_qc_by_bit_widths(node: BaseNode,
|
|
116
110
|
|
117
111
|
|
118
112
|
def _set_node_final_qc(node_bit_width_cfg: int,
|
119
|
-
node: BaseNode
|
120
|
-
fw_info):
|
113
|
+
node: BaseNode):
|
121
114
|
"""
|
122
115
|
Get the node's quantization configuration that
|
123
116
|
matches to the bit width index as in the MP configuration bit_width_cfg, and use it to finalize the node's
|
@@ -127,12 +120,9 @@ def _set_node_final_qc(node_bit_width_cfg: int,
|
|
127
120
|
Args:
|
128
121
|
node_bit_width_cfg: Configuration which determines the node's desired bit width.
|
129
122
|
node: Node to set its node quantization configuration.
|
130
|
-
fw_info: Information relevant to a specific framework about how layers should be quantized.
|
131
123
|
|
132
124
|
"""
|
133
|
-
node_qc = _get_node_qc_by_bit_widths(node,
|
134
|
-
node_bit_width_cfg,
|
135
|
-
fw_info)
|
125
|
+
node_qc = _get_node_qc_by_bit_widths(node, node_bit_width_cfg)
|
136
126
|
|
137
127
|
if node_qc is None:
|
138
128
|
Logger.critical(f'Node {node.name} quantization configuration from configuration file' # pragma: no cover
|
@@ -22,7 +22,6 @@ from model_compression_toolkit.target_platform_capabilities.targetplatform2frame
|
|
22
22
|
|
23
23
|
def filter_candidates_for_mixed_precision(graph: Graph,
|
24
24
|
target_resource_utilization: ResourceUtilization,
|
25
|
-
fw_info: FrameworkInfo,
|
26
25
|
fqc: FrameworkQuantizationCapabilities):
|
27
26
|
"""
|
28
27
|
Filters out candidates in case of mixed precision search for only weights or activation compression.
|
@@ -35,7 +34,6 @@ def filter_candidates_for_mixed_precision(graph: Graph,
|
|
35
34
|
Args:
|
36
35
|
graph: A graph representation of the model to be quantized.
|
37
36
|
target_resource_utilization: The resource utilization of the target device.
|
38
|
-
fw_info: fw_info: Information needed for quantization about the specific framework.
|
39
37
|
fqc: FrameworkQuantizationCapabilities object that describes the desired inference target platform.
|
40
38
|
|
41
39
|
"""
|
@@ -59,11 +57,10 @@ def filter_candidates_for_mixed_precision(graph: Graph,
|
|
59
57
|
elif tru.activation_restricted() and not tru.weight_restricted():
|
60
58
|
# Running mixed precision for activation compression only -
|
61
59
|
# filter out candidates weights only configurable node
|
62
|
-
weight_configurable_nodes = [n for n in graph.get_weights_configurable_nodes(
|
60
|
+
weight_configurable_nodes = [n for n in graph.get_weights_configurable_nodes()]
|
63
61
|
for n in weight_configurable_nodes:
|
64
|
-
|
65
|
-
base_cfg_nbits = n.get_qco(fqc).base_config.attr_weights_configs_mapping[kernel_attr].weights_n_bits
|
62
|
+
base_cfg_nbits = n.get_qco(fqc).base_config.attr_weights_configs_mapping[n.kernel_attr].weights_n_bits
|
66
63
|
filtered_conf = [c for c in n.candidates_quantization_cfg if
|
67
|
-
c.weights_quantization_cfg.get_attr_config(kernel_attr).enable_weights_quantization and
|
68
|
-
c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits == base_cfg_nbits]
|
64
|
+
c.weights_quantization_cfg.get_attr_config(n.kernel_attr).enable_weights_quantization and
|
65
|
+
c.weights_quantization_cfg.get_attr_config(n.kernel_attr).weights_n_bits == base_cfg_nbits]
|
69
66
|
n.candidates_quantization_cfg = filtered_conf
|
@@ -30,11 +30,10 @@ from model_compression_toolkit.core.common.quantization.node_quantization_config
|
|
30
30
|
class MixedPrecisionRUHelper:
|
31
31
|
""" Helper class for resource utilization computations for mixed precision optimization. """
|
32
32
|
|
33
|
-
def __init__(self, graph: Graph,
|
33
|
+
def __init__(self, graph: Graph, fw_impl: FrameworkImplementation):
|
34
34
|
self.graph = graph
|
35
|
-
self.fw_info = fw_info
|
36
35
|
self.fw_impl = fw_impl
|
37
|
-
self.ru_calculator = ResourceUtilizationCalculator(graph, fw_impl
|
36
|
+
self.ru_calculator = ResourceUtilizationCalculator(graph, fw_impl)
|
38
37
|
|
39
38
|
def compute_utilization(self, ru_targets: Set[RUTarget], mp_cfg: Dict[BaseNode, int]) -> Dict[RUTarget, np.ndarray]:
|
40
39
|
"""
|
@@ -35,7 +35,6 @@ class BitWidthSearchMethod(Enum):
|
|
35
35
|
|
36
36
|
|
37
37
|
def search_bit_width(graph: Graph,
|
38
|
-
fw_info: FrameworkInfo,
|
39
38
|
fw_impl: FrameworkImplementation,
|
40
39
|
target_resource_utilization: ResourceUtilization,
|
41
40
|
mp_config: MixedPrecisionQuantizationConfig,
|
@@ -52,7 +51,6 @@ def search_bit_width(graph: Graph,
|
|
52
51
|
|
53
52
|
Args:
|
54
53
|
graph: Graph to search a MP configuration for.
|
55
|
-
fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
|
56
54
|
fw_impl: FrameworkImplementation object with specific framework methods implementation.
|
57
55
|
target_resource_utilization: Target Resource Utilization to bound our feasible solution space s.t the configuration does not violate it.
|
58
56
|
mp_config: Mixed-precision quantization configuration.
|
@@ -79,7 +77,7 @@ def search_bit_width(graph: Graph,
|
|
79
77
|
|
80
78
|
# Set Sensitivity Evaluator for MP search. It should always work with the original MP graph,
|
81
79
|
# even if a virtual graph was created (and is used only for BOPS utilization computation purposes)
|
82
|
-
se = SensitivityEvaluation(graph, mp_config, representative_data_gen=representative_data_gen,
|
80
|
+
se = SensitivityEvaluation(graph, mp_config, representative_data_gen=representative_data_gen,
|
83
81
|
fw_impl=fw_impl, disable_activation_for_metric=disable_activation_for_metric,
|
84
82
|
hessian_info_service=hessian_info_service)
|
85
83
|
|
@@ -93,7 +91,6 @@ def search_bit_width(graph: Graph,
|
|
93
91
|
|
94
92
|
# Search manager and LP are highly coupled, so LP search method was moved inside search manager.
|
95
93
|
search_manager = MixedPrecisionSearchManager(graph,
|
96
|
-
fw_info=fw_info,
|
97
94
|
fw_impl=fw_impl,
|
98
95
|
sensitivity_evaluator=se,
|
99
96
|
target_resource_utilization=target_resource_utilization,
|
@@ -105,6 +102,6 @@ def search_bit_width(graph: Graph,
|
|
105
102
|
if mp_config.refine_mp_solution:
|
106
103
|
nodes_bit_cfg = greedy_solution_refinement_procedure(nodes_bit_cfg, search_manager, target_resource_utilization)
|
107
104
|
|
108
|
-
topo_bit_cfg = [nodes_bit_cfg[n] for n in graph.get_configurable_sorted_nodes(
|
105
|
+
topo_bit_cfg = [nodes_bit_cfg[n] for n in graph.get_configurable_sorted_nodes()]
|
109
106
|
assert len(topo_bit_cfg) == len(nodes_bit_cfg)
|
110
107
|
return topo_bit_cfg
|
@@ -53,7 +53,6 @@ class MixedPrecisionSearchManager:
|
|
53
53
|
|
54
54
|
def __init__(self,
|
55
55
|
graph: Graph,
|
56
|
-
fw_info: FrameworkInfo,
|
57
56
|
fw_impl: FrameworkImplementation,
|
58
57
|
sensitivity_evaluator: SensitivityEvaluation,
|
59
58
|
target_resource_utilization: ResourceUtilization,
|
@@ -62,14 +61,12 @@ class MixedPrecisionSearchManager:
|
|
62
61
|
|
63
62
|
Args:
|
64
63
|
graph: Graph to search for its MP configuration.
|
65
|
-
fw_info: FrameworkInfo object about the specific framework (e.g., attributes of different layers' weights to quantize).
|
66
64
|
fw_impl: FrameworkImplementation object with specific framework methods implementation.
|
67
65
|
sensitivity_evaluator: A SensitivityEvaluation which provides a function that evaluates the sensitivity of
|
68
66
|
a bit-width configuration for the MP model.
|
69
67
|
target_resource_utilization: Target Resource Utilization to bound our feasible solution space s.t the configuration does not violate it.
|
70
68
|
"""
|
71
69
|
|
72
|
-
self.fw_info = fw_info
|
73
70
|
self.fw_impl = fw_impl
|
74
71
|
|
75
72
|
self.original_graph = graph
|
@@ -81,12 +78,12 @@ class MixedPrecisionSearchManager:
|
|
81
78
|
self.target_resource_utilization = target_resource_utilization
|
82
79
|
self.mp_config = mp_config
|
83
80
|
|
84
|
-
self.mp_topo_configurable_nodes = self.mp_graph.get_configurable_sorted_nodes(
|
81
|
+
self.mp_topo_configurable_nodes = self.mp_graph.get_configurable_sorted_nodes()
|
85
82
|
|
86
83
|
self.ru_targets = target_resource_utilization.get_restricted_targets()
|
87
|
-
self.orig_graph_ru_helper = MixedPrecisionRUHelper(self.original_graph,
|
84
|
+
self.orig_graph_ru_helper = MixedPrecisionRUHelper(self.original_graph, fw_impl)
|
88
85
|
|
89
|
-
self.min_ru_config: Dict[BaseNode, int] = self.mp_graph.get_min_candidates_config(
|
86
|
+
self.min_ru_config: Dict[BaseNode, int] = self.mp_graph.get_min_candidates_config()
|
90
87
|
|
91
88
|
self.config_reconstructor = None
|
92
89
|
orig_min_config = self.min_ru_config
|
@@ -124,10 +124,9 @@ class ResourceUtilizationCalculator:
|
|
124
124
|
unexpected_qc_error = 'Custom quantization configuration is not expected for non-custom bit mode.'
|
125
125
|
unexpected_qc_nodes_error = 'Custom quantization configuration contains unexpected node names.'
|
126
126
|
|
127
|
-
def __init__(self, graph: Graph, fw_impl: FrameworkImplementation
|
127
|
+
def __init__(self, graph: Graph, fw_impl: FrameworkImplementation):
|
128
128
|
self.graph = graph
|
129
129
|
self.fw_impl = fw_impl
|
130
|
-
self.fw_info = fw_info
|
131
130
|
|
132
131
|
# Currently we go over the full graph even if utilization won't be requested for all nodes.
|
133
132
|
# We could fill the cache on the fly only for requested nodes, but it's probably negligible.
|
@@ -544,14 +543,10 @@ class ResourceUtilizationCalculator:
|
|
544
543
|
self._validate_custom_qcs(w_qc, bitwidth_mode)
|
545
544
|
|
546
545
|
# check if the node has kernel
|
547
|
-
|
548
|
-
if len(kernel_attrs) > 1: # pragma: no cover
|
549
|
-
raise NotImplementedError('Multiple kernel attributes are not supported for BOPS computation.')
|
550
|
-
if not kernel_attrs or not kernel_attrs[0]:
|
546
|
+
if not n.kernel_attr:
|
551
547
|
return 0
|
552
548
|
|
553
|
-
|
554
|
-
node_mac = self.fw_impl.get_node_mac_operations(n, self.fw_info)
|
549
|
+
node_mac = self.fw_impl.get_node_mac_operations(n)
|
555
550
|
if node_mac == 0:
|
556
551
|
return node_mac
|
557
552
|
|
@@ -559,12 +554,12 @@ class ResourceUtilizationCalculator:
|
|
559
554
|
assert len(prev_nodes) == 1, f'Weights node is expected to have exactly one input, {n} has {len(prev_nodes)}'
|
560
555
|
a_node = prev_nodes[0]
|
561
556
|
if (target_criterion == TargetInclusionCriterion.AnyQuantized and
|
562
|
-
not (a_node.is_activation_quantization_enabled() or n.is_weights_quantization_enabled(kernel_attr))):
|
557
|
+
not (a_node.is_activation_quantization_enabled() or n.is_weights_quantization_enabled(n.kernel_attr))):
|
563
558
|
return 0
|
564
559
|
|
565
560
|
act_qc = self._extract_qc(a_node, act_qcs)
|
566
561
|
a_nbits = self._get_activation_nbits(a_node, bitwidth_mode, act_qc)
|
567
|
-
w_nbits = self._get_weight_nbits(n, kernel_attr, bitwidth_mode, w_qc)
|
562
|
+
w_nbits = self._get_weight_nbits(n, n.kernel_attr, bitwidth_mode, w_qc)
|
568
563
|
node_bops = a_nbits * w_nbits * node_mac
|
569
564
|
return node_bops
|
570
565
|
|
@@ -15,7 +15,7 @@
|
|
15
15
|
import copy
|
16
16
|
from typing import Callable, Any
|
17
17
|
|
18
|
-
from model_compression_toolkit.core import
|
18
|
+
from model_compression_toolkit.core import ResourceUtilization, CoreConfig, QuantizationErrorMethod
|
19
19
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
20
20
|
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \
|
21
21
|
ResourceUtilizationCalculator, BitwidthMode, TargetInclusionCriterion
|
@@ -27,7 +27,6 @@ def compute_resource_utilization_data(in_model: Any,
|
|
27
27
|
representative_data_gen: Callable,
|
28
28
|
core_config: CoreConfig,
|
29
29
|
fqc: FrameworkQuantizationCapabilities,
|
30
|
-
fw_info: FrameworkInfo,
|
31
30
|
fw_impl: FrameworkImplementation) -> ResourceUtilization:
|
32
31
|
"""
|
33
32
|
Compute Resource Utilization of a model with the default single precision quantization.
|
@@ -39,7 +38,6 @@ def compute_resource_utilization_data(in_model: Any,
|
|
39
38
|
core_config: CoreConfig containing parameters of how the model should be quantized.
|
40
39
|
fqc: FrameworkQuantizationCapabilities object that models the inference target platform and
|
41
40
|
the attached framework operator's information.
|
42
|
-
fw_info: Information needed for quantization about the specific framework.
|
43
41
|
fw_impl: FrameworkImplementation object with a specific framework methods implementation.
|
44
42
|
|
45
43
|
Returns:
|
@@ -55,12 +53,11 @@ def compute_resource_utilization_data(in_model: Any,
|
|
55
53
|
transformed_graph = graph_preparation_runner(in_model,
|
56
54
|
representative_data_gen=representative_data_gen,
|
57
55
|
quantization_config=core_config.quantization_config,
|
58
|
-
fw_info=fw_info,
|
59
56
|
fw_impl=fw_impl,
|
60
57
|
fqc=fqc,
|
61
58
|
bit_width_config=core_config.bit_width_config,
|
62
59
|
mixed_precision_enable=False,
|
63
60
|
running_gptq=False)
|
64
61
|
|
65
|
-
ru_calculator = ResourceUtilizationCalculator(transformed_graph, fw_impl
|
62
|
+
ru_calculator = ResourceUtilizationCalculator(transformed_graph, fw_impl)
|
66
63
|
return ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantizedNonFused, BitwidthMode.QDefaultSP)
|