mct-nightly 2.2.0.20241222.533__py3-none-any.whl → 2.2.0.20241224.532__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20241222.533.dist-info → mct_nightly-2.2.0.20241224.532.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20241222.533.dist-info → mct_nightly-2.2.0.20241224.532.dist-info}/RECORD +29 -28
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/base_graph.py +1 -1
- model_compression_toolkit/core/common/graph/base_node.py +3 -3
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +4 -4
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +2 -2
- model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +1 -0
- model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +4 -5
- model_compression_toolkit/target_platform_capabilities/schema/v1.py +66 -172
- model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +0 -1
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2fw.py +56 -0
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2keras.py +107 -0
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/attach2pytorch.py +91 -0
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +1 -1
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +7 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +50 -51
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +54 -52
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py +57 -53
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +52 -51
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +53 -51
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +59 -57
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +54 -52
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +90 -83
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py +26 -24
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +57 -55
- model_compression_toolkit/target_platform_capabilities/target_platform/current_tp_model.py +0 -67
- model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py +0 -30
- {mct_nightly-2.2.0.20241222.533.dist-info → mct_nightly-2.2.0.20241224.532.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20241222.533.dist-info → mct_nightly-2.2.0.20241224.532.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20241222.533.dist-info → mct_nightly-2.2.0.20241224.532.dist-info}/top_level.txt +0 -0
@@ -21,13 +21,11 @@ from mct_quantizers import QuantizationMethod
|
|
21
21
|
from model_compression_toolkit.constants import FLOAT_BITWIDTH
|
22
22
|
from model_compression_toolkit.logger import Logger
|
23
23
|
from model_compression_toolkit.target_platform_capabilities.constants import OPS_SET_LIST
|
24
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import \
|
25
|
-
_current_tp_model
|
26
24
|
|
27
25
|
class OperatorSetNames(Enum):
|
28
26
|
OPSET_CONV = "Conv"
|
29
27
|
OPSET_DEPTHWISE_CONV = "DepthwiseConv2D"
|
30
|
-
OPSET_CONV_TRANSPOSE = "
|
28
|
+
OPSET_CONV_TRANSPOSE = "ConvTranspose"
|
31
29
|
OPSET_FULLY_CONNECTED = "FullyConnected"
|
32
30
|
OPSET_CONCATENATE = "Concatenate"
|
33
31
|
OPSET_STACK = "Stack"
|
@@ -43,7 +41,8 @@ class OperatorSetNames(Enum):
|
|
43
41
|
OPSET_SUB = "Sub"
|
44
42
|
OPSET_MUL = "Mul"
|
45
43
|
OPSET_DIV = "Div"
|
46
|
-
|
44
|
+
OPSET_MIN = "Min"
|
45
|
+
OPSET_MAX = "Max"
|
47
46
|
OPSET_PRELU = "PReLU"
|
48
47
|
OPSET_SWISH = "Swish"
|
49
48
|
OPSET_SIGMOID = "Sigmoid"
|
@@ -61,7 +60,6 @@ class OperatorSetNames(Enum):
|
|
61
60
|
OPSET_DROPOUT = "Dropout"
|
62
61
|
OPSET_SPLIT = "Split"
|
63
62
|
OPSET_CHUNK = "Chunk"
|
64
|
-
OPSET_UNBIND = "Unbind"
|
65
63
|
OPSET_MAXPOOL = "MaxPool"
|
66
64
|
OPSET_SIZE = "Size"
|
67
65
|
OPSET_SHAPE = "Shape"
|
@@ -74,6 +72,7 @@ class OperatorSetNames(Enum):
|
|
74
72
|
OPSET_ZERO_PADDING2d = "ZeroPadding2D"
|
75
73
|
OPSET_CAST = "Cast"
|
76
74
|
OPSET_STRIDED_SLICE = "StridedSlice"
|
75
|
+
OPSET_SSD_POST_PROCESS = "SSDPostProcess"
|
77
76
|
|
78
77
|
@classmethod
|
79
78
|
def get_values(cls):
|
@@ -225,10 +224,10 @@ class QuantizationConfigOptions:
|
|
225
224
|
QuantizationConfigOptions wraps a set of quantization configurations to consider during the quantization of an operator.
|
226
225
|
|
227
226
|
Attributes:
|
228
|
-
|
227
|
+
quantization_configurations (Tuple[OpQuantizationConfig]): Tuple of possible OpQuantizationConfig to gather.
|
229
228
|
base_config (Optional[OpQuantizationConfig]): Fallback OpQuantizationConfig to use when optimizing the model in a non-mixed-precision manner.
|
230
229
|
"""
|
231
|
-
|
230
|
+
quantization_configurations: Tuple[OpQuantizationConfig]
|
232
231
|
base_config: Optional[OpQuantizationConfig] = None
|
233
232
|
|
234
233
|
def __post_init__(self):
|
@@ -236,32 +235,32 @@ class QuantizationConfigOptions:
|
|
236
235
|
Post-initialization processing for input validation.
|
237
236
|
|
238
237
|
Raises:
|
239
|
-
Logger critical if
|
238
|
+
Logger critical if quantization_configurations is not a tuple, contains invalid elements, or if base_config is not set correctly.
|
240
239
|
"""
|
241
|
-
# Validate `
|
242
|
-
if not isinstance(self.
|
240
|
+
# Validate `quantization_configurations`
|
241
|
+
if not isinstance(self.quantization_configurations, tuple):
|
243
242
|
Logger.critical(
|
244
|
-
f"'
|
245
|
-
for cfg in self.
|
243
|
+
f"'quantization_configurations' must be a tuple, but received: {type(self.quantization_configurations)}.") # pragma: no cover
|
244
|
+
for cfg in self.quantization_configurations:
|
246
245
|
if not isinstance(cfg, OpQuantizationConfig):
|
247
246
|
Logger.critical(
|
248
247
|
f"Each option must be an instance of 'OpQuantizationConfig', but found an object of type: {type(cfg)}.") # pragma: no cover
|
249
248
|
|
250
249
|
# Handle base_config
|
251
|
-
if len(self.
|
250
|
+
if len(self.quantization_configurations) > 1:
|
252
251
|
if self.base_config is None:
|
253
252
|
Logger.critical(f"For multiple configurations, a 'base_config' is required for non-mixed-precision optimization.") # pragma: no cover
|
254
|
-
if not any(self.base_config == cfg for cfg in self.
|
255
|
-
Logger.critical(f"'base_config' must be included in the quantization config options
|
256
|
-
elif len(self.
|
253
|
+
if not any(self.base_config == cfg for cfg in self.quantization_configurations):
|
254
|
+
Logger.critical(f"'base_config' must be included in the quantization config options.") # pragma: no cover
|
255
|
+
elif len(self.quantization_configurations) == 1:
|
257
256
|
if self.base_config is None:
|
258
|
-
object.__setattr__(self, 'base_config', self.
|
259
|
-
elif self.base_config != self.
|
257
|
+
object.__setattr__(self, 'base_config', self.quantization_configurations[0])
|
258
|
+
elif self.base_config != self.quantization_configurations[0]:
|
260
259
|
Logger.critical(
|
261
|
-
"'base_config' should be the same as the sole item in '
|
260
|
+
"'base_config' should be the same as the sole item in 'quantization_configurations'.") # pragma: no cover
|
262
261
|
|
263
|
-
elif len(self.
|
264
|
-
Logger.critical("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided
|
262
|
+
elif len(self.quantization_configurations) == 0:
|
263
|
+
Logger.critical("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided configurations is empty.") # pragma: no cover
|
265
264
|
|
266
265
|
def clone_and_edit(self, **kwargs) -> 'QuantizationConfigOptions':
|
267
266
|
"""
|
@@ -274,10 +273,10 @@ class QuantizationConfigOptions:
|
|
274
273
|
A new instance of QuantizationConfigOptions with updated configurations.
|
275
274
|
"""
|
276
275
|
updated_base_config = replace(self.base_config, **kwargs)
|
277
|
-
|
278
|
-
replace(cfg, **kwargs) for cfg in self.
|
276
|
+
updated_configs = [
|
277
|
+
replace(cfg, **kwargs) for cfg in self.quantization_configurations
|
279
278
|
]
|
280
|
-
return replace(self, base_config=updated_base_config,
|
279
|
+
return replace(self, base_config=updated_base_config, quantization_configurations=tuple(updated_configs))
|
281
280
|
|
282
281
|
def clone_and_edit_weight_attribute(self, attrs: List[str] = None, **kwargs) -> 'QuantizationConfigOptions':
|
283
282
|
"""
|
@@ -292,7 +291,7 @@ class QuantizationConfigOptions:
|
|
292
291
|
"""
|
293
292
|
updated_base_config = self.base_config
|
294
293
|
updated_configs = []
|
295
|
-
for qc in self.
|
294
|
+
for qc in self.quantization_configurations:
|
296
295
|
if attrs is None:
|
297
296
|
attrs_to_update = list(qc.attr_weights_configs_mapping.keys())
|
298
297
|
else:
|
@@ -300,7 +299,7 @@ class QuantizationConfigOptions:
|
|
300
299
|
# Ensure all attributes exist in the config
|
301
300
|
for attr in attrs_to_update:
|
302
301
|
if attr not in qc.attr_weights_configs_mapping:
|
303
|
-
Logger.critical(f"{attr} does not exist in {qc}.")
|
302
|
+
Logger.critical(f"{attr} does not exist in {qc}.") # pragma: no cover
|
304
303
|
updated_attr_mapping = {
|
305
304
|
attr: qc.attr_weights_configs_mapping[attr].clone_and_edit(**kwargs)
|
306
305
|
for attr in attrs_to_update
|
@@ -308,7 +307,7 @@ class QuantizationConfigOptions:
|
|
308
307
|
if qc == updated_base_config:
|
309
308
|
updated_base_config = replace(updated_base_config, attr_weights_configs_mapping=updated_attr_mapping)
|
310
309
|
updated_configs.append(replace(qc, attr_weights_configs_mapping=updated_attr_mapping))
|
311
|
-
return replace(self, base_config=updated_base_config,
|
310
|
+
return replace(self, base_config=updated_base_config, quantization_configurations=tuple(updated_configs))
|
312
311
|
|
313
312
|
def clone_and_map_weights_attr_keys(self, layer_attrs_mapping: Optional[Dict[str, str]]) -> 'QuantizationConfigOptions':
|
314
313
|
"""
|
@@ -322,7 +321,7 @@ class QuantizationConfigOptions:
|
|
322
321
|
"""
|
323
322
|
updated_configs = []
|
324
323
|
new_base_config = self.base_config
|
325
|
-
for qc in self.
|
324
|
+
for qc in self.quantization_configurations:
|
326
325
|
if layer_attrs_mapping is None:
|
327
326
|
new_attr_mapping = {}
|
328
327
|
else:
|
@@ -333,7 +332,7 @@ class QuantizationConfigOptions:
|
|
333
332
|
if qc == self.base_config:
|
334
333
|
new_base_config = replace(qc, attr_weights_configs_mapping=new_attr_mapping)
|
335
334
|
updated_configs.append(replace(qc, attr_weights_configs_mapping=new_attr_mapping))
|
336
|
-
return replace(self, base_config=new_base_config,
|
335
|
+
return replace(self, base_config=new_base_config, quantization_configurations=tuple(updated_configs))
|
337
336
|
|
338
337
|
def get_info(self) -> Dict[str, Any]:
|
339
338
|
"""
|
@@ -342,7 +341,7 @@ class QuantizationConfigOptions:
|
|
342
341
|
Returns:
|
343
342
|
dict: Information about the quantization configuration options as a dictionary.
|
344
343
|
"""
|
345
|
-
return {f'option {i}': cfg.get_info() for i, cfg in enumerate(self.
|
344
|
+
return {f'option {i}': cfg.get_info() for i, cfg in enumerate(self.quantization_configurations)}
|
346
345
|
|
347
346
|
|
348
347
|
@dataclass(frozen=True)
|
@@ -350,22 +349,7 @@ class TargetPlatformModelComponent:
|
|
350
349
|
"""
|
351
350
|
Component of TargetPlatformModel (Fusing, OperatorsSet, etc.).
|
352
351
|
"""
|
353
|
-
|
354
|
-
def __post_init__(self):
|
355
|
-
"""
|
356
|
-
Post-initialization to register the component with the current TargetPlatformModel.
|
357
|
-
"""
|
358
|
-
_current_tp_model.get().append_component(self)
|
359
|
-
|
360
|
-
def get_info(self) -> Dict[str, Any]:
|
361
|
-
"""
|
362
|
-
Get information about the component to display.
|
363
|
-
|
364
|
-
Returns:
|
365
|
-
Dict[str, Any]: Returns an empty dictionary. The actual component should override
|
366
|
-
this method to provide relevant information.
|
367
|
-
"""
|
368
|
-
return {}
|
352
|
+
pass
|
369
353
|
|
370
354
|
|
371
355
|
@dataclass(frozen=True)
|
@@ -374,12 +358,7 @@ class OperatorsSetBase(TargetPlatformModelComponent):
|
|
374
358
|
Base class to represent a set of a target platform model component of operator set types.
|
375
359
|
Inherits from TargetPlatformModelComponent.
|
376
360
|
"""
|
377
|
-
|
378
|
-
"""
|
379
|
-
Post-initialization to ensure the component is registered with the TargetPlatformModel.
|
380
|
-
Calls the parent class's __post_init__ method to append this component to the current TargetPlatformModel.
|
381
|
-
"""
|
382
|
-
super().__post_init__()
|
361
|
+
pass
|
383
362
|
|
384
363
|
|
385
364
|
@dataclass(frozen=True)
|
@@ -394,23 +373,9 @@ class OperatorsSet(OperatorsSetBase):
|
|
394
373
|
is_default (bool): Indicates whether this set is the default quantization configuration
|
395
374
|
for the TargetPlatformModel or a fusing set.
|
396
375
|
"""
|
397
|
-
name: str
|
376
|
+
name: Union[str, OperatorSetNames]
|
398
377
|
qc_options: QuantizationConfigOptions = None
|
399
378
|
|
400
|
-
def __post_init__(self):
|
401
|
-
"""
|
402
|
-
Post-initialization processing to mark the operator set as default if applicable.
|
403
|
-
|
404
|
-
Calls the parent class's __post_init__ method and sets `is_default` to True
|
405
|
-
if this set corresponds to the default quantization configuration for the
|
406
|
-
TargetPlatformModel or if it is a fusing set.
|
407
|
-
|
408
|
-
"""
|
409
|
-
super().__post_init__()
|
410
|
-
is_fusing_set = self.qc_options is None
|
411
|
-
is_default = _current_tp_model.get().default_qco == self.qc_options or is_fusing_set
|
412
|
-
object.__setattr__(self, 'is_default', is_default)
|
413
|
-
|
414
379
|
def get_info(self) -> Dict[str, Any]:
|
415
380
|
"""
|
416
381
|
Get information about the set as a dictionary.
|
@@ -419,83 +384,67 @@ class OperatorsSet(OperatorsSetBase):
|
|
419
384
|
Dict[str, Any]: A dictionary containing the set name and
|
420
385
|
whether it is the default quantization configuration.
|
421
386
|
"""
|
422
|
-
return {"name": self.name
|
423
|
-
"is_default_qc": self.is_default}
|
387
|
+
return {"name": self.name}
|
424
388
|
|
425
389
|
|
426
390
|
@dataclass(frozen=True)
|
427
391
|
class OperatorSetConcat(OperatorsSetBase):
|
428
392
|
"""
|
429
|
-
Concatenate a
|
393
|
+
Concatenate a tuple of operator sets to treat them similarly in different places (like fusing).
|
430
394
|
|
431
395
|
Attributes:
|
432
|
-
|
396
|
+
operators_set (Tuple[OperatorsSet]): Tuple of operator sets to group.
|
433
397
|
qc_options (None): Configuration options for the set, always None for concatenated sets.
|
434
|
-
name (str): Concatenated name generated from the names of the operator sets
|
398
|
+
name (str): Concatenated name generated from the names of the operator sets.
|
435
399
|
"""
|
436
|
-
|
400
|
+
operators_set: Tuple[OperatorsSet]
|
437
401
|
qc_options: None = field(default=None, init=False)
|
438
|
-
name: str = None
|
439
402
|
|
440
403
|
def __post_init__(self):
|
441
404
|
"""
|
442
405
|
Post-initialization processing to generate the concatenated name and set it as the `name` attribute.
|
443
406
|
|
444
407
|
Calls the parent class's __post_init__ method and creates a concatenated name
|
445
|
-
by joining the names of all operator sets in `
|
408
|
+
by joining the names of all operator sets in `operators_set`.
|
446
409
|
"""
|
447
|
-
super().__post_init__()
|
448
410
|
# Generate the concatenated name from the operator sets
|
449
|
-
concatenated_name = "_".join([op.name for op in self.
|
411
|
+
concatenated_name = "_".join([op.name.value if hasattr(op.name, "value") else op.name for op in self.operators_set])
|
450
412
|
# Set the inherited name attribute using `object.__setattr__` since the dataclass is frozen
|
451
413
|
object.__setattr__(self, "name", concatenated_name)
|
452
414
|
|
453
|
-
def get_info(self) -> Dict[str, Any]:
|
454
|
-
"""
|
455
|
-
Get information about the concatenated set as a dictionary.
|
456
|
-
|
457
|
-
Returns:
|
458
|
-
Dict[str, Any]: A dictionary containing the concatenated name and
|
459
|
-
the list of names of the operator sets in `op_set_list`.
|
460
|
-
"""
|
461
|
-
return {"name": self.name,
|
462
|
-
OPS_SET_LIST: [s.name for s in self.op_set_list]}
|
463
|
-
|
464
415
|
|
465
416
|
@dataclass(frozen=True)
|
466
417
|
class Fusing(TargetPlatformModelComponent):
|
467
418
|
"""
|
468
|
-
Fusing defines a
|
419
|
+
Fusing defines a tuple of operators that should be combined and treated as a single operator,
|
469
420
|
hence no quantization is applied between them.
|
470
421
|
|
471
422
|
Attributes:
|
472
|
-
|
423
|
+
operator_groups (Tuple[Union[OperatorsSet, OperatorSetConcat]]): A tuple of operator groups,
|
473
424
|
each being either an OperatorSetConcat or an OperatorsSet.
|
474
425
|
name (str): The name for the Fusing instance. If not provided, it is generated from the operator groups' names.
|
475
426
|
"""
|
476
|
-
|
477
|
-
name: str = None
|
427
|
+
operator_groups: Tuple[Union[OperatorsSet, OperatorSetConcat]]
|
478
428
|
|
479
429
|
def __post_init__(self):
|
480
430
|
"""
|
481
431
|
Post-initialization processing for input validation and name generation.
|
482
432
|
|
483
|
-
Calls the parent class's __post_init__ method, validates the
|
433
|
+
Calls the parent class's __post_init__ method, validates the operator_groups,
|
484
434
|
and generates the name if not explicitly provided.
|
485
435
|
|
486
436
|
Raises:
|
487
|
-
Logger critical if
|
437
|
+
Logger critical if operator_groups is not a tuple or if it contains fewer than two operators.
|
488
438
|
"""
|
489
|
-
|
490
|
-
|
491
|
-
if not isinstance(self.operator_groups_list, list):
|
439
|
+
# Validate the operator_groups
|
440
|
+
if not isinstance(self.operator_groups, tuple):
|
492
441
|
Logger.critical(
|
493
|
-
f"
|
494
|
-
if len(self.
|
442
|
+
f"Operator groups should be of type 'tuple' but is {type(self.operator_groups)}.") # pragma: no cover
|
443
|
+
if len(self.operator_groups) < 2:
|
495
444
|
Logger.critical("Fusing cannot be created for a single operator.") # pragma: no cover
|
496
445
|
|
497
446
|
# Generate the name from the operator groups if not provided
|
498
|
-
generated_name = '_'.join([x.name for x in self.
|
447
|
+
generated_name = '_'.join([x.name.value if hasattr(x.name, 'value') else x.name for x in self.operator_groups])
|
499
448
|
object.__setattr__(self, 'name', generated_name)
|
500
449
|
|
501
450
|
def contains(self, other: Any) -> bool:
|
@@ -512,11 +461,11 @@ class Fusing(TargetPlatformModelComponent):
|
|
512
461
|
return False
|
513
462
|
|
514
463
|
# Check for containment by comparing operator groups
|
515
|
-
for i in range(len(self.
|
516
|
-
for j in range(len(other.
|
517
|
-
if self.
|
518
|
-
isinstance(self.
|
519
|
-
other.
|
464
|
+
for i in range(len(self.operator_groups) - len(other.operator_groups) + 1):
|
465
|
+
for j in range(len(other.operator_groups)):
|
466
|
+
if self.operator_groups[i + j] != other.operator_groups[j] and not (
|
467
|
+
isinstance(self.operator_groups[i + j], OperatorSetConcat) and (
|
468
|
+
other.operator_groups[j] in self.operator_groups[i + j].operators_set)):
|
520
469
|
break
|
521
470
|
else:
|
522
471
|
# If all checks pass, the other Fusing instance is contained
|
@@ -534,8 +483,8 @@ class Fusing(TargetPlatformModelComponent):
|
|
534
483
|
or just the sequence of operator groups if no name is set.
|
535
484
|
"""
|
536
485
|
if self.name is not None:
|
537
|
-
return {self.name: ' -> '.join([x.name for x in self.
|
538
|
-
return ' -> '.join([x.name for x in self.
|
486
|
+
return {self.name: ' -> '.join([x.name for x in self.operator_groups])}
|
487
|
+
return ' -> '.join([x.name for x in self.operator_groups])
|
539
488
|
|
540
489
|
|
541
490
|
@dataclass(frozen=True)
|
@@ -550,8 +499,8 @@ class TargetPlatformModel:
|
|
550
499
|
tpc_platform_type (Optional[str]): Type of the platform for the Target Platform Configuration.
|
551
500
|
add_metadata (bool): Flag to determine if metadata should be added.
|
552
501
|
name (str): Name of the Target Platform Model.
|
553
|
-
operator_set (
|
554
|
-
fusing_patterns (
|
502
|
+
operator_set (Tuple[OperatorsSetBase]): Tuple of operator sets within the model.
|
503
|
+
fusing_patterns (Tuple[Fusing]): Tuple of fusing patterns for the model.
|
555
504
|
is_simd_padding (bool): Indicates if SIMD padding is applied.
|
556
505
|
SCHEMA_VERSION (int): Version of the schema for the Target Platform Model.
|
557
506
|
"""
|
@@ -561,8 +510,8 @@ class TargetPlatformModel:
|
|
561
510
|
tpc_platform_type: Optional[str]
|
562
511
|
add_metadata: bool = True
|
563
512
|
name: str = "default_tp_model"
|
564
|
-
operator_set:
|
565
|
-
fusing_patterns:
|
513
|
+
operator_set: Tuple[OperatorsSetBase] = None
|
514
|
+
fusing_patterns: Tuple[Fusing] = None
|
566
515
|
is_simd_padding: bool = False
|
567
516
|
|
568
517
|
SCHEMA_VERSION: int = 1
|
@@ -578,26 +527,12 @@ class TargetPlatformModel:
|
|
578
527
|
# Validate `default_qco`
|
579
528
|
if not isinstance(self.default_qco, QuantizationConfigOptions):
|
580
529
|
Logger.critical("'default_qco' must be an instance of QuantizationConfigOptions.") # pragma: no cover
|
581
|
-
if len(self.default_qco.
|
530
|
+
if len(self.default_qco.quantization_configurations) != 1:
|
582
531
|
Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover
|
583
532
|
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
Args:
|
589
|
-
tp_model_component (TargetPlatformModelComponent): Component to attach to the model.
|
590
|
-
|
591
|
-
Raises:
|
592
|
-
Logger critical if the component is not an instance of Fusing or OperatorsSetBase.
|
593
|
-
"""
|
594
|
-
if isinstance(tp_model_component, Fusing):
|
595
|
-
self.fusing_patterns.append(tp_model_component)
|
596
|
-
elif isinstance(tp_model_component, OperatorsSetBase):
|
597
|
-
self.operator_set.append(tp_model_component)
|
598
|
-
else:
|
599
|
-
Logger.critical(
|
600
|
-
f"Attempted to append an unrecognized TargetPlatformModelComponent of type: {type(tp_model_component)}.") # pragma: no cover
|
533
|
+
opsets_names = [op.name.value if hasattr(op.name, "value") else op.name for op in self.operator_set] if self.operator_set else []
|
534
|
+
if len(set(opsets_names)) != len(opsets_names):
|
535
|
+
Logger.critical("Operator Sets must have unique names.") # pragma: no cover
|
601
536
|
|
602
537
|
def get_info(self) -> Dict[str, Any]:
|
603
538
|
"""
|
@@ -608,51 +543,10 @@ class TargetPlatformModel:
|
|
608
543
|
"""
|
609
544
|
return {
|
610
545
|
"Model name": self.name,
|
611
|
-
"Operators sets": [o.get_info() for o in self.operator_set],
|
612
|
-
"Fusing patterns": [f.get_info() for f in self.fusing_patterns],
|
546
|
+
"Operators sets": [o.get_info() for o in self.operator_set] if self.operator_set else [],
|
547
|
+
"Fusing patterns": [f.get_info() for f in self.fusing_patterns] if self.fusing_patterns else [],
|
613
548
|
}
|
614
549
|
|
615
|
-
def __validate_model(self):
|
616
|
-
"""
|
617
|
-
Validate the model's configuration to ensure its integrity.
|
618
|
-
|
619
|
-
Raises:
|
620
|
-
Logger critical if the model contains multiple operator sets with the same name.
|
621
|
-
"""
|
622
|
-
opsets_names = [op.name for op in self.operator_set]
|
623
|
-
if len(set(opsets_names)) != len(opsets_names):
|
624
|
-
Logger.critical("Operator Sets must have unique names.") # pragma: no cover
|
625
|
-
|
626
|
-
def __enter__(self) -> 'TargetPlatformModel':
|
627
|
-
"""
|
628
|
-
Start defining the TargetPlatformModel using a 'with' statement.
|
629
|
-
|
630
|
-
Returns:
|
631
|
-
TargetPlatformModel: The initialized TargetPlatformModel object.
|
632
|
-
"""
|
633
|
-
_current_tp_model.set(self)
|
634
|
-
return self
|
635
|
-
|
636
|
-
def __exit__(self, exc_type, exc_value, tb) -> 'TargetPlatformModel':
|
637
|
-
"""
|
638
|
-
Finalize and validate the TargetPlatformModel at the end of the 'with' clause.
|
639
|
-
|
640
|
-
Args:
|
641
|
-
exc_type: Exception type, if any occurred.
|
642
|
-
exc_value: Exception value, if any occurred.
|
643
|
-
tb: Traceback object, if an exception occurred.
|
644
|
-
|
645
|
-
Raises:
|
646
|
-
The exception raised in the 'with' block, if any.
|
647
|
-
|
648
|
-
Returns:
|
649
|
-
TargetPlatformModel: The validated TargetPlatformModel object.
|
650
|
-
"""
|
651
|
-
if exc_value is not None:
|
652
|
-
raise exc_value
|
653
|
-
self.__validate_model()
|
654
|
-
_current_tp_model.reset()
|
655
|
-
return self
|
656
550
|
|
657
551
|
def show(self):
|
658
552
|
"""
|
@@ -15,7 +15,6 @@
|
|
15
15
|
|
16
16
|
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attribute_filter import AttributeFilter
|
17
17
|
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities, OperationsSetToLayers, Smaller, SmallerEq, NotEq, Eq, GreaterEq, Greater, LayerFilterParams, OperationsToLayers, get_current_tpc
|
18
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model import get_default_quantization_config_options
|
19
18
|
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel, OperatorsSet, \
|
20
19
|
OperatorSetConcat, Signedness, AttributeQuantizationConfig, OpQuantizationConfig, QuantizationConfigOptions, Fusing
|
21
20
|
|
@@ -0,0 +1,56 @@
|
|
1
|
+
from typing import Dict, Tuple, List, Any, Optional
|
2
|
+
|
3
|
+
from model_compression_toolkit import DefaultDict
|
4
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
|
5
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, \
|
6
|
+
OperationsSetToLayers
|
7
|
+
|
8
|
+
|
9
|
+
class AttachTpModelToFw:
|
10
|
+
|
11
|
+
def __init__(self):
|
12
|
+
self._opset2layer = None
|
13
|
+
|
14
|
+
# A mapping that associates each layer type in the operation set (with weight attributes and a quantization
|
15
|
+
# configuration in the target platform model) to its framework-specific attribute name. If not all layer types
|
16
|
+
# in the operation set are provided in the mapping, a DefaultDict should be supplied to handle missing entries.
|
17
|
+
self._opset2attr_mapping = None # Mapping of operation sets to their corresponding framework-specific layers
|
18
|
+
|
19
|
+
def attach(self, tpc_model: TargetPlatformModel,
|
20
|
+
custom_opset2layer: Dict[str, Tuple[List[Any], Optional[Dict[str, DefaultDict]]]] = None
|
21
|
+
) -> TargetPlatformCapabilities:
|
22
|
+
"""
|
23
|
+
Attaching a TargetPlatformModel which includes a platform capabilities description to specific
|
24
|
+
framework's operators.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
tpc_model: a TargetPlatformModel object.
|
28
|
+
custom_opset2layer: optional set of custom operator sets which allows to add/override the built-in set
|
29
|
+
of framework operator, to define a specific behavior for those operators. This dictionary should map
|
30
|
+
an operator set unique name to a pair of: a list of framework operators and an optional
|
31
|
+
operator's attributes names mapping.
|
32
|
+
|
33
|
+
Returns: a TargetPlatformCapabilities object.
|
34
|
+
|
35
|
+
"""
|
36
|
+
|
37
|
+
tpc = TargetPlatformCapabilities(tpc_model)
|
38
|
+
|
39
|
+
with tpc:
|
40
|
+
for opset_name, operators in self._opset2layer.items():
|
41
|
+
attr_mapping = self._opset2attr_mapping.get(opset_name)
|
42
|
+
OperationsSetToLayers(opset_name, operators, attr_mapping=attr_mapping)
|
43
|
+
|
44
|
+
if custom_opset2layer is not None:
|
45
|
+
for opset_name, operators in custom_opset2layer.items():
|
46
|
+
if len(operators) == 1:
|
47
|
+
OperationsSetToLayers(opset_name, operators[0])
|
48
|
+
elif len(operators) == 2:
|
49
|
+
OperationsSetToLayers(opset_name, operators[0], attr_mapping=operators[1])
|
50
|
+
else:
|
51
|
+
raise ValueError(f"Custom operator set to layer mapping should include up to 2 elements - "
|
52
|
+
f"a list of layers to attach to the operator and an optional mapping of "
|
53
|
+
f"attributes names, but given a mapping contains {len(operators)} elements.")
|
54
|
+
|
55
|
+
return tpc
|
56
|
+
|
@@ -0,0 +1,107 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import tensorflow as tf
|
17
|
+
from packaging import version
|
18
|
+
|
19
|
+
from model_compression_toolkit.verify_packages import FOUND_SONY_CUSTOM_LAYERS
|
20
|
+
|
21
|
+
if FOUND_SONY_CUSTOM_LAYERS:
|
22
|
+
from sony_custom_layers.keras.object_detection.ssd_post_process import SSDPostProcess
|
23
|
+
|
24
|
+
if version.parse(tf.__version__) >= version.parse("2.13"):
|
25
|
+
from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \
|
26
|
+
MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
|
27
|
+
Conv2DTranspose, Identity, Concatenate, BatchNormalization, Minimum, Maximum
|
28
|
+
else:
|
29
|
+
from keras.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \
|
30
|
+
MaxPooling2D, Activation, ReLU, Add, Subtract, Multiply, PReLU, Flatten, Cropping2D, LeakyReLU, Permute, \
|
31
|
+
Conv2DTranspose, Concatenate, BatchNormalization, Minimum, Maximum
|
32
|
+
|
33
|
+
from model_compression_toolkit import DefaultDict
|
34
|
+
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS, \
|
35
|
+
BIAS_ATTR, KERAS_KERNEL, KERAS_DEPTHWISE_KERNEL
|
36
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorSetNames
|
37
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform import LayerFilterParams
|
38
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attach2fw import \
|
39
|
+
AttachTpModelToFw
|
40
|
+
|
41
|
+
|
42
|
+
class AttachTpModelToKeras(AttachTpModelToFw):
|
43
|
+
def __init__(self):
|
44
|
+
super().__init__()
|
45
|
+
|
46
|
+
self._opset2layer = {
|
47
|
+
OperatorSetNames.OPSET_CONV.value: [Conv2D, tf.nn.conv2d],
|
48
|
+
OperatorSetNames.OPSET_DEPTHWISE_CONV.value: [DepthwiseConv2D, tf.nn.depthwise_conv2d],
|
49
|
+
OperatorSetNames.OPSET_CONV_TRANSPOSE.value: [Conv2DTranspose, tf.nn.conv2d_transpose],
|
50
|
+
OperatorSetNames.OPSET_FULLY_CONNECTED.value: [Dense],
|
51
|
+
OperatorSetNames.OPSET_CONCATENATE.value: [tf.concat, Concatenate],
|
52
|
+
OperatorSetNames.OPSET_STACK.value: [tf.stack],
|
53
|
+
OperatorSetNames.OPSET_UNSTACK.value: [tf.unstack],
|
54
|
+
OperatorSetNames.OPSET_GATHER.value: [tf.gather, tf.compat.v1.gather],
|
55
|
+
OperatorSetNames.OPSET_EXPAND.value: [],
|
56
|
+
OperatorSetNames.OPSET_BATCH_NORM.value: [BatchNormalization],
|
57
|
+
OperatorSetNames.OPSET_RELU.value: [tf.nn.relu, ReLU],
|
58
|
+
OperatorSetNames.OPSET_RELU6.value: [tf.nn.relu6],
|
59
|
+
OperatorSetNames.OPSET_LEAKY_RELU.value: [tf.nn.leaky_relu, LeakyReLU],
|
60
|
+
OperatorSetNames.OPSET_HARD_TANH.value: [LayerFilterParams(Activation, activation="hard_tanh")],
|
61
|
+
OperatorSetNames.OPSET_ADD.value: [tf.add, Add],
|
62
|
+
OperatorSetNames.OPSET_SUB.value: [tf.subtract, Subtract],
|
63
|
+
OperatorSetNames.OPSET_MUL.value: [tf.math.multiply, Multiply],
|
64
|
+
OperatorSetNames.OPSET_DIV.value: [tf.math.divide, tf.math.truediv],
|
65
|
+
OperatorSetNames.OPSET_MIN.value: [tf.math.minimum, Minimum],
|
66
|
+
OperatorSetNames.OPSET_MAX.value: [tf.math.maximum, Maximum],
|
67
|
+
OperatorSetNames.OPSET_PRELU.value: [PReLU],
|
68
|
+
OperatorSetNames.OPSET_SWISH.value: [tf.nn.swish, LayerFilterParams(Activation, activation="swish")],
|
69
|
+
OperatorSetNames.OPSET_SIGMOID.value: [tf.nn.sigmoid, LayerFilterParams(Activation, activation="sigmoid")],
|
70
|
+
OperatorSetNames.OPSET_TANH.value: [tf.nn.tanh, LayerFilterParams(Activation, activation="tanh")],
|
71
|
+
OperatorSetNames.OPSET_GELU.value: [tf.nn.gelu, LayerFilterParams(Activation, activation="gelu")],
|
72
|
+
OperatorSetNames.OPSET_HARDSIGMOID.value: [tf.keras.activations.hard_sigmoid,
|
73
|
+
LayerFilterParams(Activation, activation="hard_sigmoid")],
|
74
|
+
OperatorSetNames.OPSET_FLATTEN.value: [Flatten],
|
75
|
+
OperatorSetNames.OPSET_GET_ITEM.value: [tf.__operators__.getitem],
|
76
|
+
OperatorSetNames.OPSET_RESHAPE.value: [Reshape, tf.reshape],
|
77
|
+
OperatorSetNames.OPSET_PERMUTE.value: [Permute],
|
78
|
+
OperatorSetNames.OPSET_TRANSPOSE.value: [tf.transpose],
|
79
|
+
OperatorSetNames.OPSET_DROPOUT.value: [Dropout],
|
80
|
+
OperatorSetNames.OPSET_SPLIT.value: [tf.split],
|
81
|
+
OperatorSetNames.OPSET_MAXPOOL.value: [MaxPooling2D],
|
82
|
+
OperatorSetNames.OPSET_SHAPE.value: [tf.shape, tf.compat.v1.shape],
|
83
|
+
OperatorSetNames.OPSET_EQUAL.value: [tf.math.equal],
|
84
|
+
OperatorSetNames.OPSET_ARGMAX.value: [tf.math.argmax],
|
85
|
+
OperatorSetNames.OPSET_TOPK.value: [tf.nn.top_k],
|
86
|
+
OperatorSetNames.OPSET_FAKE_QUANT_WITH_MIN_MAX_VARS.value: [tf.quantization.fake_quant_with_min_max_vars],
|
87
|
+
OperatorSetNames.OPSET_COMBINED_NON_MAX_SUPPRESSION.value: [tf.image.combined_non_max_suppression],
|
88
|
+
OperatorSetNames.OPSET_CROPPING2D.value: [Cropping2D],
|
89
|
+
OperatorSetNames.OPSET_ZERO_PADDING2d.value: [ZeroPadding2D],
|
90
|
+
OperatorSetNames.OPSET_CAST.value: [tf.cast],
|
91
|
+
OperatorSetNames.OPSET_STRIDED_SLICE.value: [tf.strided_slice]
|
92
|
+
}
|
93
|
+
|
94
|
+
if FOUND_SONY_CUSTOM_LAYERS:
|
95
|
+
self._opset2layer[OperatorSetNames.OPSET_POST_PROCESS] = [SSDPostProcess]
|
96
|
+
|
97
|
+
self._opset2attr_mapping = {OperatorSetNames.OPSET_CONV.value: {
|
98
|
+
KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
|
99
|
+
BIAS_ATTR: DefaultDict(default_value=BIAS)},
|
100
|
+
OperatorSetNames.OPSET_DEPTHWISE_CONV.value: {
|
101
|
+
KERNEL_ATTR: DefaultDict({
|
102
|
+
DepthwiseConv2D: KERAS_DEPTHWISE_KERNEL,
|
103
|
+
tf.nn.depthwise_conv2d: KERAS_DEPTHWISE_KERNEL}, default_value=KERAS_KERNEL),
|
104
|
+
BIAS_ATTR: DefaultDict(default_value=BIAS)},
|
105
|
+
OperatorSetNames.OPSET_FULLY_CONNECTED.value: {
|
106
|
+
KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
|
107
|
+
BIAS_ATTR: DefaultDict(default_value=BIAS)}}
|