mct-nightly 2.2.0.20241230.534__py3-none-any.whl → 2.2.0.20250102.111338__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20241230.534.dist-info → mct_nightly-2.2.0.20250102.111338.dist-info}/METADATA +8 -11
- {mct_nightly-2.2.0.20241230.534.dist-info → mct_nightly-2.2.0.20250102.111338.dist-info}/RECORD +19 -19
- {mct_nightly-2.2.0.20241230.534.dist-info → mct_nightly-2.2.0.20250102.111338.dist-info}/WHEEL +1 -1
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/target_platform_capabilities/schema/v1.py +308 -173
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +22 -22
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +22 -22
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py +22 -22
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +21 -21
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +22 -22
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +25 -25
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +23 -23
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +55 -40
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py +4 -6
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py +2 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py +10 -10
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +49 -46
- {mct_nightly-2.2.0.20241230.534.dist-info → mct_nightly-2.2.0.20250102.111338.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20241230.534.dist-info → mct_nightly-2.2.0.20250102.111338.dist-info}/top_level.txt +0 -0
@@ -14,13 +14,13 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
import pprint
|
16
16
|
|
17
|
-
from dataclasses import replace, dataclass, asdict, field
|
18
17
|
from enum import Enum
|
19
|
-
from typing import Dict, Any, Union, Tuple, List, Optional
|
18
|
+
from typing import Dict, Any, Union, Tuple, List, Optional, Literal, Annotated
|
20
19
|
from mct_quantizers import QuantizationMethod
|
21
20
|
from model_compression_toolkit.constants import FLOAT_BITWIDTH
|
22
21
|
from model_compression_toolkit.logger import Logger
|
23
|
-
from
|
22
|
+
from pydantic import BaseModel, Field, root_validator, validator, PositiveInt, PrivateAttr
|
23
|
+
|
24
24
|
|
25
25
|
class OperatorSetNames(Enum):
|
26
26
|
OPSET_CONV = "Conv"
|
@@ -92,8 +92,7 @@ class Signedness(Enum):
|
|
92
92
|
UNSIGNED = 2
|
93
93
|
|
94
94
|
|
95
|
-
|
96
|
-
class AttributeQuantizationConfig:
|
95
|
+
class AttributeQuantizationConfig(BaseModel):
|
97
96
|
"""
|
98
97
|
Holds the quantization configuration of a weight attribute of a layer.
|
99
98
|
|
@@ -103,27 +102,22 @@ class AttributeQuantizationConfig:
|
|
103
102
|
weights_per_channel_threshold (bool): Indicates whether to quantize the weights per-channel or per-tensor.
|
104
103
|
enable_weights_quantization (bool): Indicates whether to quantize the model weights or not.
|
105
104
|
lut_values_bitwidth (Optional[int]): Number of bits to use when quantizing in a look-up table.
|
106
|
-
|
105
|
+
If None, defaults to 8 in hptq; otherwise, it uses the provided value.
|
107
106
|
"""
|
108
107
|
weights_quantization_method: QuantizationMethod = QuantizationMethod.POWER_OF_TWO
|
109
|
-
weights_n_bits:
|
108
|
+
weights_n_bits: PositiveInt = FLOAT_BITWIDTH
|
110
109
|
weights_per_channel_threshold: bool = False
|
111
110
|
enable_weights_quantization: bool = False
|
112
111
|
lut_values_bitwidth: Optional[int] = None
|
113
112
|
|
114
|
-
|
115
|
-
|
116
|
-
|
113
|
+
class Config:
|
114
|
+
# Makes the model immutable (frozen)
|
115
|
+
frozen = True
|
117
116
|
|
118
|
-
|
119
|
-
|
120
|
-
"""
|
121
|
-
|
122
|
-
Logger.critical("weights_n_bits must be a positive integer.") # pragma: no cover
|
123
|
-
if not isinstance(self.enable_weights_quantization, bool):
|
124
|
-
Logger.critical("enable_weights_quantization must be a boolean.") # pragma: no cover
|
125
|
-
if self.lut_values_bitwidth is not None and not isinstance(self.lut_values_bitwidth, int):
|
126
|
-
Logger.critical("lut_values_bitwidth must be an integer or None.") # pragma: no cover
|
117
|
+
@property
|
118
|
+
def field_names(self) -> list:
|
119
|
+
"""Return a list of field names for the model."""
|
120
|
+
return list(self.__fields__.keys())
|
127
121
|
|
128
122
|
def clone_and_edit(self, **kwargs) -> 'AttributeQuantizationConfig':
|
129
123
|
"""
|
@@ -135,11 +129,10 @@ class AttributeQuantizationConfig:
|
|
135
129
|
Returns:
|
136
130
|
AttributeQuantizationConfig: A new instance of AttributeQuantizationConfig with updated attributes.
|
137
131
|
"""
|
138
|
-
return
|
132
|
+
return self.copy(update=kwargs)
|
139
133
|
|
140
134
|
|
141
|
-
|
142
|
-
class OpQuantizationConfig:
|
135
|
+
class OpQuantizationConfig(BaseModel):
|
143
136
|
"""
|
144
137
|
OpQuantizationConfig is a class to configure the quantization parameters of an operator.
|
145
138
|
|
@@ -148,39 +141,45 @@ class OpQuantizationConfig:
|
|
148
141
|
attr_weights_configs_mapping (Dict[str, AttributeQuantizationConfig]): A mapping between an op attribute name and its quantization configuration.
|
149
142
|
activation_quantization_method (QuantizationMethod): Which method to use from QuantizationMethod for activation quantization.
|
150
143
|
activation_n_bits (int): Number of bits to quantize the activations.
|
151
|
-
supported_input_activation_n_bits (int
|
144
|
+
supported_input_activation_n_bits (Union[int, Tuple[int, ...]]): Number of bits that operator accepts as input.
|
152
145
|
enable_activation_quantization (bool): Whether to quantize the model activations or not.
|
153
146
|
quantization_preserving (bool): Whether quantization parameters should be the same for an operator's input and output.
|
154
|
-
fixed_scale (float): Scale to use for an operator quantization parameters.
|
155
|
-
fixed_zero_point (int): Zero-point to use for an operator quantization parameters.
|
156
|
-
simd_size (int): Per op integer representing the Single Instruction, Multiple Data (SIMD) width of an operator. It indicates the number of data elements that can be fetched and processed simultaneously in a single instruction.
|
157
|
-
signedness (
|
158
|
-
|
147
|
+
fixed_scale (Optional[float]): Scale to use for an operator quantization parameters.
|
148
|
+
fixed_zero_point (Optional[int]): Zero-point to use for an operator quantization parameters.
|
149
|
+
simd_size (Optional[int]): Per op integer representing the Single Instruction, Multiple Data (SIMD) width of an operator. It indicates the number of data elements that can be fetched and processed simultaneously in a single instruction.
|
150
|
+
signedness (Signedness): Set activation quantization signedness.
|
159
151
|
"""
|
160
152
|
default_weight_attr_config: AttributeQuantizationConfig
|
161
153
|
attr_weights_configs_mapping: Dict[str, AttributeQuantizationConfig]
|
162
154
|
activation_quantization_method: QuantizationMethod
|
163
155
|
activation_n_bits: int
|
164
|
-
supported_input_activation_n_bits: Union[int, Tuple[int]]
|
156
|
+
supported_input_activation_n_bits: Union[int, Tuple[int, ...]]
|
165
157
|
enable_activation_quantization: bool
|
166
158
|
quantization_preserving: bool
|
167
|
-
fixed_scale: float
|
168
|
-
fixed_zero_point: int
|
169
|
-
simd_size: int
|
159
|
+
fixed_scale: Optional[float]
|
160
|
+
fixed_zero_point: Optional[int]
|
161
|
+
simd_size: Optional[int]
|
170
162
|
signedness: Signedness
|
171
163
|
|
172
|
-
|
173
|
-
|
174
|
-
Post-initialization processing for input validation.
|
164
|
+
class Config:
|
165
|
+
frozen = True
|
175
166
|
|
176
|
-
|
177
|
-
|
167
|
+
@validator('supported_input_activation_n_bits', pre=True, allow_reuse=True)
|
168
|
+
def validate_supported_input_activation_n_bits(cls, v):
|
178
169
|
"""
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
170
|
+
Validate and process the supported_input_activation_n_bits field.
|
171
|
+
Converts an int to a tuple containing that int.
|
172
|
+
Ensures that if a tuple is provided, all elements are ints.
|
173
|
+
"""
|
174
|
+
|
175
|
+
if isinstance(v, int):
|
176
|
+
v = (v,)
|
177
|
+
|
178
|
+
# When loading from JSON, lists are returned. If the value is a list, convert it to a tuple.
|
179
|
+
if isinstance(v, list):
|
180
|
+
v = tuple(v)
|
181
|
+
|
182
|
+
return v
|
184
183
|
|
185
184
|
def get_info(self) -> Dict[str, Any]:
|
186
185
|
"""
|
@@ -189,9 +188,13 @@ class OpQuantizationConfig:
|
|
189
188
|
Returns:
|
190
189
|
dict: Information about the quantization configuration as a dictionary.
|
191
190
|
"""
|
192
|
-
return
|
191
|
+
return self.dict() # pragma: no cover
|
193
192
|
|
194
|
-
def clone_and_edit(
|
193
|
+
def clone_and_edit(
|
194
|
+
self,
|
195
|
+
attr_to_edit: Dict[str, Dict[str, Any]] = {},
|
196
|
+
**kwargs: Any
|
197
|
+
) -> 'OpQuantizationConfig':
|
195
198
|
"""
|
196
199
|
Clone the quantization config and edit some of its attributes.
|
197
200
|
|
@@ -203,64 +206,87 @@ class OpQuantizationConfig:
|
|
203
206
|
Returns:
|
204
207
|
OpQuantizationConfig: Edited quantization configuration.
|
205
208
|
"""
|
206
|
-
|
207
209
|
# Clone and update top-level attributes
|
208
|
-
updated_config =
|
210
|
+
updated_config = self.copy(update=kwargs)
|
209
211
|
|
210
212
|
# Clone and update nested immutable dataclasses in `attr_weights_configs_mapping`
|
211
213
|
updated_attr_mapping = {
|
212
214
|
attr_name: (attr_cfg.clone_and_edit(**attr_to_edit[attr_name])
|
213
|
-
|
215
|
+
if attr_name in attr_to_edit else attr_cfg)
|
214
216
|
for attr_name, attr_cfg in updated_config.attr_weights_configs_mapping.items()
|
215
217
|
}
|
216
218
|
|
217
219
|
# Return a new instance with the updated attribute mapping
|
218
|
-
return
|
220
|
+
return updated_config.copy(update={'attr_weights_configs_mapping': updated_attr_mapping})
|
219
221
|
|
220
222
|
|
221
|
-
|
222
|
-
class QuantizationConfigOptions:
|
223
|
+
class QuantizationConfigOptions(BaseModel):
|
223
224
|
"""
|
224
225
|
QuantizationConfigOptions wraps a set of quantization configurations to consider during the quantization of an operator.
|
225
226
|
|
226
227
|
Attributes:
|
227
|
-
quantization_configurations (Tuple[OpQuantizationConfig]): Tuple of possible OpQuantizationConfig to gather.
|
228
|
+
quantization_configurations (Tuple[OpQuantizationConfig, ...]): Tuple of possible OpQuantizationConfig to gather.
|
228
229
|
base_config (Optional[OpQuantizationConfig]): Fallback OpQuantizationConfig to use when optimizing the model in a non-mixed-precision manner.
|
229
230
|
"""
|
230
|
-
quantization_configurations: Tuple[OpQuantizationConfig]
|
231
|
+
quantization_configurations: Tuple[OpQuantizationConfig, ...]
|
231
232
|
base_config: Optional[OpQuantizationConfig] = None
|
232
233
|
|
233
|
-
|
234
|
+
class Config:
|
235
|
+
frozen = True
|
236
|
+
|
237
|
+
@root_validator(pre=True, allow_reuse=True)
|
238
|
+
def validate_and_set_base_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
234
239
|
"""
|
235
|
-
|
240
|
+
Validate and set the base_config based on quantization_configurations.
|
241
|
+
|
242
|
+
Args:
|
243
|
+
values (Dict[str, Any]): Input data.
|
236
244
|
|
237
|
-
|
238
|
-
|
245
|
+
Returns:
|
246
|
+
Dict[str, Any]: Modified input data with base_config set appropriately.
|
239
247
|
"""
|
240
|
-
|
241
|
-
|
248
|
+
quantization_configurations = values.get('quantization_configurations', ())
|
249
|
+
num_configs = len(quantization_configurations)
|
250
|
+
base_config = values.get('base_config')
|
251
|
+
|
252
|
+
if not isinstance(quantization_configurations, (tuple, list)):
|
242
253
|
Logger.critical(
|
243
|
-
f"'quantization_configurations' must be a tuple, but received: {type(
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
if not any(self.base_config == cfg for cfg in self.quantization_configurations):
|
254
|
-
Logger.critical(f"'base_config' must be included in the quantization config options.") # pragma: no cover
|
255
|
-
elif len(self.quantization_configurations) == 1:
|
256
|
-
if self.base_config is None:
|
257
|
-
object.__setattr__(self, 'base_config', self.quantization_configurations[0])
|
258
|
-
elif self.base_config != self.quantization_configurations[0]:
|
254
|
+
f"'quantization_configurations' must be a list or tuple, but received: {type(quantization_configurations)}."
|
255
|
+
) # pragma: no cover
|
256
|
+
|
257
|
+
if num_configs == 0:
|
258
|
+
Logger.critical(
|
259
|
+
"'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided configurations are empty."
|
260
|
+
) # pragma: no cover
|
261
|
+
|
262
|
+
if base_config is None:
|
263
|
+
if num_configs > 1:
|
259
264
|
Logger.critical(
|
260
|
-
"
|
265
|
+
"For multiple configurations, a 'base_config' is required for non-mixed-precision optimization."
|
266
|
+
) # pragma: no cover
|
267
|
+
else:
|
268
|
+
# Automatically set base_config to the sole configuration
|
269
|
+
base_config = quantization_configurations[0]
|
261
270
|
|
262
|
-
|
263
|
-
|
271
|
+
|
272
|
+
if base_config not in quantization_configurations:
|
273
|
+
Logger.critical(
|
274
|
+
"'base_config' must be included in the quantization config options."
|
275
|
+
) # pragma: no cover
|
276
|
+
|
277
|
+
# if num_configs == 1:
|
278
|
+
# if base_config != quantization_configurations[0]:
|
279
|
+
# Logger.critical(
|
280
|
+
# "'base_config' should be the same as the sole item in 'quantization_configurations'."
|
281
|
+
# ) # pragma: no cover
|
282
|
+
|
283
|
+
values['base_config'] = base_config
|
284
|
+
|
285
|
+
# When loading from JSON, lists are returned. If the value is a list, convert it to a tuple.
|
286
|
+
if isinstance(quantization_configurations, list):
|
287
|
+
values['quantization_configurations'] = tuple(quantization_configurations)
|
288
|
+
|
289
|
+
return values
|
264
290
|
|
265
291
|
def clone_and_edit(self, **kwargs) -> 'QuantizationConfigOptions':
|
266
292
|
"""
|
@@ -270,46 +296,71 @@ class QuantizationConfigOptions:
|
|
270
296
|
**kwargs: Keyword arguments to edit in each configuration.
|
271
297
|
|
272
298
|
Returns:
|
273
|
-
A new instance
|
299
|
+
QuantizationConfigOptions: A new instance with updated configurations.
|
274
300
|
"""
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
301
|
+
# Clone and update base_config
|
302
|
+
updated_base_config = self.base_config.clone_and_edit(**kwargs) if self.base_config else None
|
303
|
+
|
304
|
+
# Clone and update all configurations
|
305
|
+
updated_configs = tuple(cfg.clone_and_edit(**kwargs) for cfg in self.quantization_configurations)
|
280
306
|
|
281
|
-
|
307
|
+
return self.copy(update={
|
308
|
+
'base_config': updated_base_config,
|
309
|
+
'quantization_configurations': updated_configs
|
310
|
+
})
|
311
|
+
|
312
|
+
def clone_and_edit_weight_attribute(
|
313
|
+
self,
|
314
|
+
attrs: Optional[List[str]] = None,
|
315
|
+
**kwargs
|
316
|
+
) -> 'QuantizationConfigOptions':
|
282
317
|
"""
|
283
318
|
Clones the quantization configurations and edits some of their attributes' parameters.
|
284
319
|
|
285
320
|
Args:
|
286
|
-
attrs (List[str]):
|
287
|
-
**kwargs: Keyword arguments to edit in the attributes configuration.
|
321
|
+
attrs (Optional[List[str]]): Attribute names to clone and edit their configurations. If None, updates all attributes.
|
322
|
+
**kwargs: Keyword arguments to edit in the attributes' configuration.
|
288
323
|
|
289
324
|
Returns:
|
290
|
-
QuantizationConfigOptions: A new instance
|
325
|
+
QuantizationConfigOptions: A new instance with edited attributes configurations.
|
291
326
|
"""
|
292
327
|
updated_base_config = self.base_config
|
293
328
|
updated_configs = []
|
329
|
+
|
294
330
|
for qc in self.quantization_configurations:
|
295
331
|
if attrs is None:
|
296
332
|
attrs_to_update = list(qc.attr_weights_configs_mapping.keys())
|
297
333
|
else:
|
298
334
|
attrs_to_update = attrs
|
335
|
+
|
299
336
|
# Ensure all attributes exist in the config
|
300
337
|
for attr in attrs_to_update:
|
301
338
|
if attr not in qc.attr_weights_configs_mapping:
|
302
|
-
Logger.critical(f"{attr} does not exist in {qc}.")
|
339
|
+
Logger.critical(f"Attribute '{attr}' does not exist in {qc}.") # pragma: no cover
|
340
|
+
|
341
|
+
# Update the specified attributes
|
303
342
|
updated_attr_mapping = {
|
304
343
|
attr: qc.attr_weights_configs_mapping[attr].clone_and_edit(**kwargs)
|
305
344
|
for attr in attrs_to_update
|
306
345
|
}
|
307
|
-
if qc == updated_base_config:
|
308
|
-
updated_base_config = replace(updated_base_config, attr_weights_configs_mapping=updated_attr_mapping)
|
309
|
-
updated_configs.append(replace(qc, attr_weights_configs_mapping=updated_attr_mapping))
|
310
|
-
return replace(self, base_config=updated_base_config, quantization_configurations=tuple(updated_configs))
|
311
346
|
|
312
|
-
|
347
|
+
# If the current config is the base_config, update it accordingly
|
348
|
+
if qc == self.base_config:
|
349
|
+
updated_base_config = qc.clone_and_edit(attr_weights_configs_mapping=updated_attr_mapping)
|
350
|
+
|
351
|
+
# Update the current config with the new attribute mappings
|
352
|
+
updated_cfg = qc.clone_and_edit(attr_weights_configs_mapping=updated_attr_mapping)
|
353
|
+
updated_configs.append(updated_cfg)
|
354
|
+
|
355
|
+
return self.copy(update={
|
356
|
+
'base_config': updated_base_config,
|
357
|
+
'quantization_configurations': tuple(updated_configs)
|
358
|
+
})
|
359
|
+
|
360
|
+
def clone_and_map_weights_attr_keys(
|
361
|
+
self,
|
362
|
+
layer_attrs_mapping: Optional[Dict[str, str]] = None
|
363
|
+
) -> 'QuantizationConfigOptions':
|
313
364
|
"""
|
314
365
|
Clones the quantization configurations and updates keys in attribute config mappings.
|
315
366
|
|
@@ -317,22 +368,32 @@ class QuantizationConfigOptions:
|
|
317
368
|
layer_attrs_mapping (Optional[Dict[str, str]]): A mapping between attribute names.
|
318
369
|
|
319
370
|
Returns:
|
320
|
-
QuantizationConfigOptions: A new instance
|
371
|
+
QuantizationConfigOptions: A new instance with updated attribute keys.
|
321
372
|
"""
|
322
|
-
updated_configs = []
|
323
373
|
new_base_config = self.base_config
|
374
|
+
updated_configs = []
|
375
|
+
|
324
376
|
for qc in self.quantization_configurations:
|
325
377
|
if layer_attrs_mapping is None:
|
326
|
-
new_attr_mapping =
|
378
|
+
new_attr_mapping = qc.attr_weights_configs_mapping
|
327
379
|
else:
|
328
380
|
new_attr_mapping = {
|
329
381
|
layer_attrs_mapping.get(attr, attr): cfg
|
330
382
|
for attr, cfg in qc.attr_weights_configs_mapping.items()
|
331
383
|
}
|
384
|
+
|
385
|
+
# If the current config is the base_config, update it accordingly
|
332
386
|
if qc == self.base_config:
|
333
|
-
new_base_config =
|
334
|
-
|
335
|
-
|
387
|
+
new_base_config = qc.clone_and_edit(attr_weights_configs_mapping=new_attr_mapping)
|
388
|
+
|
389
|
+
# Update the current config with the new attribute mappings
|
390
|
+
updated_cfg = qc.clone_and_edit(attr_weights_configs_mapping=new_attr_mapping)
|
391
|
+
updated_configs.append(updated_cfg)
|
392
|
+
|
393
|
+
return self.copy(update={
|
394
|
+
'base_config': new_base_config,
|
395
|
+
'quantization_configurations': tuple(updated_configs)
|
396
|
+
})
|
336
397
|
|
337
398
|
def get_info(self) -> Dict[str, Any]:
|
338
399
|
"""
|
@@ -341,18 +402,16 @@ class QuantizationConfigOptions:
|
|
341
402
|
Returns:
|
342
403
|
dict: Information about the quantization configuration options as a dictionary.
|
343
404
|
"""
|
344
|
-
return {f'
|
345
|
-
|
405
|
+
return {f'option_{i}': cfg.get_info() for i, cfg in enumerate(self.quantization_configurations)}
|
346
406
|
|
347
|
-
|
348
|
-
class TargetPlatformModelComponent:
|
407
|
+
class TargetPlatformModelComponent(BaseModel):
|
349
408
|
"""
|
350
409
|
Component of TargetPlatformModel (Fusing, OperatorsSet, etc.).
|
351
410
|
"""
|
352
|
-
|
411
|
+
class Config:
|
412
|
+
frozen = True
|
353
413
|
|
354
414
|
|
355
|
-
@dataclass(frozen=True)
|
356
415
|
class OperatorsSetBase(TargetPlatformModelComponent):
|
357
416
|
"""
|
358
417
|
Base class to represent a set of a target platform model component of operator set types.
|
@@ -361,91 +420,151 @@ class OperatorsSetBase(TargetPlatformModelComponent):
|
|
361
420
|
pass
|
362
421
|
|
363
422
|
|
364
|
-
@dataclass(frozen=True)
|
365
423
|
class OperatorsSet(OperatorsSetBase):
|
366
424
|
"""
|
367
425
|
Set of operators that are represented by a unique label.
|
368
426
|
|
369
427
|
Attributes:
|
370
|
-
name (str): The set's label (must be unique within a TargetPlatformModel).
|
371
|
-
qc_options (QuantizationConfigOptions): Configuration options to use for this set of operations.
|
372
|
-
|
373
|
-
|
374
|
-
for the TargetPlatformModel or a fusing set.
|
428
|
+
name (Union[str, OperatorSetNames]): The set's label (must be unique within a TargetPlatformModel).
|
429
|
+
qc_options (Optional[QuantizationConfigOptions]): Configuration options to use for this set of operations.
|
430
|
+
If None, it represents a fusing set.
|
431
|
+
type (Literal["OperatorsSet"]): Fixed type identifier.
|
375
432
|
"""
|
376
433
|
name: Union[str, OperatorSetNames]
|
377
|
-
qc_options: QuantizationConfigOptions = None
|
434
|
+
qc_options: Optional[QuantizationConfigOptions] = None
|
435
|
+
|
436
|
+
# Define a private attribute _type
|
437
|
+
type: Literal["OperatorsSet"] = "OperatorsSet"
|
438
|
+
|
439
|
+
class Config:
|
440
|
+
frozen = True
|
378
441
|
|
379
442
|
def get_info(self) -> Dict[str, Any]:
|
380
443
|
"""
|
381
444
|
Get information about the set as a dictionary.
|
382
445
|
|
383
446
|
Returns:
|
384
|
-
Dict[str, Any]: A dictionary containing the set name
|
385
|
-
whether it is the default quantization configuration.
|
447
|
+
Dict[str, Any]: A dictionary containing the set name.
|
386
448
|
"""
|
387
449
|
return {"name": self.name}
|
388
450
|
|
389
451
|
|
390
|
-
@dataclass(frozen=True)
|
391
452
|
class OperatorSetConcat(OperatorsSetBase):
|
392
453
|
"""
|
393
454
|
Concatenate a tuple of operator sets to treat them similarly in different places (like fusing).
|
394
455
|
|
395
456
|
Attributes:
|
396
|
-
operators_set (Tuple[OperatorsSet]): Tuple of operator sets to group.
|
397
|
-
|
398
|
-
name (str): Concatenated name generated from the names of the operator sets.
|
457
|
+
operators_set (Tuple[OperatorsSet, ...]): Tuple of operator sets to group.
|
458
|
+
name (Optional[str]): Concatenated name generated from the names of the operator sets.
|
399
459
|
"""
|
400
|
-
operators_set: Tuple[OperatorsSet]
|
401
|
-
|
460
|
+
operators_set: Tuple[OperatorsSet, ...]
|
461
|
+
name: Optional[str] = None # Will be set in the validator if not given
|
402
462
|
|
403
|
-
|
463
|
+
# Define a private attribute _type
|
464
|
+
type: Literal["OperatorSetConcat"] = "OperatorSetConcat"
|
465
|
+
|
466
|
+
class Config:
|
467
|
+
frozen = True
|
468
|
+
|
469
|
+
@root_validator(pre=True, allow_reuse=True)
|
470
|
+
def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
404
471
|
"""
|
405
|
-
|
472
|
+
Validate the input and set the concatenated name based on the operators_set.
|
473
|
+
|
474
|
+
Args:
|
475
|
+
values (Dict[str, Any]): Input data.
|
476
|
+
|
477
|
+
Returns:
|
478
|
+
Dict[str, Any]: Modified input data with 'name' set.
|
479
|
+
"""
|
480
|
+
operators_set = values['operators_set']
|
481
|
+
|
482
|
+
if len(operators_set) < 1:
|
483
|
+
Logger.critical("'operators_set' must contain at least one OperatorsSet") # pragma: no cover
|
406
484
|
|
407
|
-
|
408
|
-
|
485
|
+
if values.get('name') is None:
|
486
|
+
# Generate the concatenated name from the operator sets
|
487
|
+
concatenated_name = "_".join([
|
488
|
+
op.name.value if isinstance(op.name, OperatorSetNames) else op.name
|
489
|
+
for op in operators_set
|
490
|
+
])
|
491
|
+
values['name'] = concatenated_name
|
492
|
+
|
493
|
+
return values
|
494
|
+
|
495
|
+
def get_info(self) -> Dict[str, Any]:
|
409
496
|
"""
|
410
|
-
|
411
|
-
concatenated_name = "_".join([op.name.value if hasattr(op.name, "value") else op.name for op in self.operators_set])
|
412
|
-
# Set the inherited name attribute using `object.__setattr__` since the dataclass is frozen
|
413
|
-
object.__setattr__(self, "name", concatenated_name)
|
497
|
+
Get information about the concatenated operator sets as a dictionary.
|
414
498
|
|
499
|
+
Returns:
|
500
|
+
Dict[str, Any]: A dictionary containing the concatenated name and operator sets information.
|
501
|
+
"""
|
502
|
+
return {
|
503
|
+
"name": self.name,
|
504
|
+
"operators_set": [op.get_info() for op in self.operators_set]
|
505
|
+
}
|
415
506
|
|
416
|
-
@dataclass(frozen=True)
|
417
507
|
class Fusing(TargetPlatformModelComponent):
|
418
508
|
"""
|
419
509
|
Fusing defines a tuple of operators that should be combined and treated as a single operator,
|
420
510
|
hence no quantization is applied between them.
|
421
511
|
|
422
512
|
Attributes:
|
423
|
-
operator_groups (Tuple[Union[OperatorsSet, OperatorSetConcat]]): A tuple of operator groups,
|
513
|
+
operator_groups (Tuple[Union[OperatorsSet, OperatorSetConcat], ...]): A tuple of operator groups,
|
424
514
|
each being either an OperatorSetConcat or an OperatorsSet.
|
425
|
-
name (str): The name for the Fusing instance. If not provided, it is generated from the operator groups' names.
|
515
|
+
name (Optional[str]): The name for the Fusing instance. If not provided, it is generated from the operator groups' names.
|
426
516
|
"""
|
427
|
-
operator_groups: Tuple[Union[OperatorsSet, OperatorSetConcat]]
|
517
|
+
operator_groups: Tuple[Annotated[Union[OperatorsSet, OperatorSetConcat], Field(discriminator='type')], ...]
|
518
|
+
name: Optional[str] = None # Will be set in the validator if not given.
|
519
|
+
|
520
|
+
class Config:
|
521
|
+
frozen = True
|
428
522
|
|
429
|
-
|
523
|
+
@root_validator(pre=True, allow_reuse=True)
|
524
|
+
def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
430
525
|
"""
|
431
|
-
|
526
|
+
Validate the operator_groups and set the name by concatenating operator group names.
|
432
527
|
|
433
|
-
|
434
|
-
|
528
|
+
Args:
|
529
|
+
values (Dict[str, Any]): Input data.
|
435
530
|
|
436
|
-
|
437
|
-
|
531
|
+
Returns:
|
532
|
+
Dict[str, Any]: Modified input data with 'name' set.
|
438
533
|
"""
|
439
|
-
|
440
|
-
if not isinstance(self.operator_groups, tuple):
|
441
|
-
Logger.critical(
|
442
|
-
f"Operator groups should be of type 'tuple' but is {type(self.operator_groups)}.") # pragma: no cover
|
443
|
-
if len(self.operator_groups) < 2:
|
444
|
-
Logger.critical("Fusing cannot be created for a single operator.") # pragma: no cover
|
534
|
+
operator_groups = values.get('operator_groups')
|
445
535
|
|
446
|
-
#
|
447
|
-
|
448
|
-
|
536
|
+
# When loading from JSON, lists are returned. If the value is a list, convert it to a tuple.
|
537
|
+
if isinstance(operator_groups, list):
|
538
|
+
values['operator_groups'] = tuple(operator_groups)
|
539
|
+
|
540
|
+
if values.get('name') is None:
|
541
|
+
# Generate the concatenated name from the operator groups
|
542
|
+
concatenated_name = "_".join([
|
543
|
+
op.name.value if isinstance(op.name, OperatorSetNames) else op.name
|
544
|
+
for op in values['operator_groups']
|
545
|
+
])
|
546
|
+
values['name'] = concatenated_name
|
547
|
+
|
548
|
+
return values
|
549
|
+
|
550
|
+
@root_validator(allow_reuse=True)
|
551
|
+
def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
552
|
+
"""
|
553
|
+
Perform validation after the model has been instantiated.
|
554
|
+
|
555
|
+
Args:
|
556
|
+
values (Dict[str, Any]): The instantiated fusing.
|
557
|
+
|
558
|
+
Returns:
|
559
|
+
Dict[str, Any]: The validated values.
|
560
|
+
"""
|
561
|
+
operator_groups = values.get('operator_groups')
|
562
|
+
|
563
|
+
# Validate that there are at least two operator groups
|
564
|
+
if len(operator_groups) < 2:
|
565
|
+
Logger.critical("Fusing cannot be created for a single operator.") # pragma: no cover
|
566
|
+
|
567
|
+
return values
|
449
568
|
|
450
569
|
def contains(self, other: Any) -> bool:
|
451
570
|
"""
|
@@ -483,56 +602,75 @@ class Fusing(TargetPlatformModelComponent):
|
|
483
602
|
or just the sequence of operator groups if no name is set.
|
484
603
|
"""
|
485
604
|
if self.name is not None:
|
486
|
-
return {
|
487
|
-
|
488
|
-
|
605
|
+
return {
|
606
|
+
self.name: ' -> '.join([
|
607
|
+
x.name.value if isinstance(x.name, OperatorSetNames) else x.name
|
608
|
+
for x in self.operator_groups
|
609
|
+
])
|
610
|
+
}
|
611
|
+
return ' -> '.join([
|
612
|
+
x.name.value if isinstance(x.name, OperatorSetNames) else x.name
|
613
|
+
for x in self.operator_groups
|
614
|
+
])
|
489
615
|
|
490
|
-
|
491
|
-
class TargetPlatformModel:
|
616
|
+
class TargetPlatformModel(BaseModel):
|
492
617
|
"""
|
493
618
|
Represents the hardware configuration used for quantized model inference.
|
494
619
|
|
495
620
|
Attributes:
|
496
621
|
default_qco (QuantizationConfigOptions): Default quantization configuration options for the model.
|
622
|
+
operator_set (Optional[Tuple[OperatorsSet, ...]]): Tuple of operator sets within the model.
|
623
|
+
fusing_patterns (Optional[Tuple[Fusing, ...]]): Tuple of fusing patterns for the model.
|
497
624
|
tpc_minor_version (Optional[int]): Minor version of the Target Platform Configuration.
|
498
625
|
tpc_patch_version (Optional[int]): Patch version of the Target Platform Configuration.
|
499
626
|
tpc_platform_type (Optional[str]): Type of the platform for the Target Platform Configuration.
|
500
627
|
add_metadata (bool): Flag to determine if metadata should be added.
|
501
628
|
name (str): Name of the Target Platform Model.
|
502
|
-
operator_set (Tuple[OperatorsSetBase]): Tuple of operator sets within the model.
|
503
|
-
fusing_patterns (Tuple[Fusing]): Tuple of fusing patterns for the model.
|
504
629
|
is_simd_padding (bool): Indicates if SIMD padding is applied.
|
505
630
|
SCHEMA_VERSION (int): Version of the schema for the Target Platform Model.
|
506
631
|
"""
|
507
632
|
default_qco: QuantizationConfigOptions
|
633
|
+
operator_set: Optional[Tuple[OperatorsSet, ...]]
|
634
|
+
fusing_patterns: Optional[Tuple[Fusing, ...]]
|
508
635
|
tpc_minor_version: Optional[int]
|
509
636
|
tpc_patch_version: Optional[int]
|
510
637
|
tpc_platform_type: Optional[str]
|
511
638
|
add_metadata: bool = True
|
512
|
-
name: str = "default_tp_model"
|
513
|
-
operator_set: Tuple[OperatorsSetBase] = None
|
514
|
-
fusing_patterns: Tuple[Fusing] = None
|
639
|
+
name: Optional[str] = "default_tp_model"
|
515
640
|
is_simd_padding: bool = False
|
516
641
|
|
517
642
|
SCHEMA_VERSION: int = 1
|
518
643
|
|
519
|
-
|
644
|
+
class Config:
|
645
|
+
frozen = True
|
646
|
+
|
647
|
+
@root_validator(allow_reuse=True)
|
648
|
+
def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
520
649
|
"""
|
521
|
-
|
650
|
+
Perform validation after the model has been instantiated.
|
651
|
+
|
652
|
+
Args:
|
653
|
+
values (Dict[str, Any]): The instantiated target platform model.
|
522
654
|
|
523
|
-
|
524
|
-
|
525
|
-
or if it contains more than one quantization configuration.
|
655
|
+
Returns:
|
656
|
+
Dict[str, Any]: The validated values.
|
526
657
|
"""
|
527
658
|
# Validate `default_qco`
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
if
|
535
|
-
|
659
|
+
default_qco = values.get('default_qco')
|
660
|
+
if len(default_qco.quantization_configurations) != 1:
|
661
|
+
Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover
|
662
|
+
|
663
|
+
# Validate `operator_set` uniqueness
|
664
|
+
operator_set = values.get('operator_set')
|
665
|
+
if operator_set is not None:
|
666
|
+
opsets_names = [
|
667
|
+
op.name.value if isinstance(op.name, OperatorSetNames) else op.name
|
668
|
+
for op in operator_set
|
669
|
+
]
|
670
|
+
if len(set(opsets_names)) != len(opsets_names):
|
671
|
+
Logger.critical("Operator Sets must have unique names.") # pragma: no cover
|
672
|
+
|
673
|
+
return values
|
536
674
|
|
537
675
|
def get_info(self) -> Dict[str, Any]:
|
538
676
|
"""
|
@@ -547,11 +685,8 @@ class TargetPlatformModel:
|
|
547
685
|
"Fusing patterns": [f.get_info() for f in self.fusing_patterns] if self.fusing_patterns else [],
|
548
686
|
}
|
549
687
|
|
550
|
-
|
551
688
|
def show(self):
|
552
689
|
"""
|
553
|
-
|
554
690
|
Display the TargetPlatformModel.
|
555
|
-
|
556
691
|
"""
|
557
692
|
pprint.pprint(self.get_info(), sort_dicts=False)
|