mct-nightly 2.2.0.20241201.617__py3-none-any.whl → 2.2.0.20241202.131715__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20241201.617.dist-info → mct_nightly-2.2.0.20241202.131715.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20241201.617.dist-info → mct_nightly-2.2.0.20241202.131715.dist-info}/RECORD +58 -58
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/constants.py +0 -3
- model_compression_toolkit/core/common/graph/base_node.py +7 -5
- model_compression_toolkit/core/common/graph/functional_node.py +1 -1
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -2
- model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -2
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +2 -1
- model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +1 -1
- model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -1
- model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +1 -1
- model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -2
- model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +2 -2
- model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -1
- model_compression_toolkit/metadata.py +14 -5
- model_compression_toolkit/target_platform_capabilities/schema/__init__.py +14 -0
- model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +11 -0
- model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +37 -0
- model_compression_toolkit/target_platform_capabilities/{target_platform/op_quantization_config.py → schema/v1.py} +377 -24
- model_compression_toolkit/target_platform_capabilities/target_platform/__init__.py +3 -5
- model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py +2 -214
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +1 -2
- model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +6 -10
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +39 -32
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_keras.py +3 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tpc_pytorch.py +3 -5
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +36 -31
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_keras.py +3 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tpc_pytorch.py +3 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py +37 -32
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_keras.py +3 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tpc_pytorch.py +3 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +39 -32
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_keras.py +3 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tpc_pytorch.py +3 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +36 -31
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_keras.py +3 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tpc_pytorch.py +3 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +45 -38
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_keras.py +3 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tpc_pytorch.py +3 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +37 -32
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_keras.py +3 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tpc_pytorch.py +3 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +70 -62
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py +3 -2
- model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py +3 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tp_model.py +22 -17
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_keras.py +3 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/qnnpack_tpc/v1/tpc_pytorch.py +3 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +56 -51
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_keras.py +3 -4
- model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tpc_pytorch.py +3 -4
- model_compression_toolkit/target_platform_capabilities/target_platform/fusing.py +0 -85
- model_compression_toolkit/target_platform_capabilities/target_platform/operators.py +0 -87
- model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model_component.py +0 -40
- {mct_nightly-2.2.0.20241201.617.dist-info → mct_nightly-2.2.0.20241202.131715.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20241201.617.dist-info → mct_nightly-2.2.0.20241202.131715.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20241201.617.dist-info → mct_nightly-2.2.0.20241202.131715.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,14 +12,23 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
15
|
import copy
|
17
|
-
|
16
|
+
|
18
17
|
from enum import Enum
|
19
18
|
|
19
|
+
import pprint
|
20
|
+
|
21
|
+
from typing import Dict, Any, Union, Tuple, List, Optional
|
22
|
+
|
20
23
|
from mct_quantizers import QuantizationMethod
|
21
24
|
from model_compression_toolkit.constants import FLOAT_BITWIDTH
|
25
|
+
|
22
26
|
from model_compression_toolkit.logger import Logger
|
27
|
+
from model_compression_toolkit.target_platform_capabilities.constants import OPS_SET_LIST
|
28
|
+
from model_compression_toolkit.target_platform_capabilities.immutable import ImmutableClass
|
29
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import \
|
30
|
+
get_current_tp_model, _current_tp_model
|
31
|
+
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import clone_and_edit_object_params
|
23
32
|
|
24
33
|
|
25
34
|
class Signedness(Enum):
|
@@ -35,27 +44,6 @@ class Signedness(Enum):
|
|
35
44
|
UNSIGNED = 2
|
36
45
|
|
37
46
|
|
38
|
-
def clone_and_edit_object_params(obj: Any, **kwargs: Dict) -> Any:
|
39
|
-
"""
|
40
|
-
Clones the given object and edit some of its parameters.
|
41
|
-
|
42
|
-
Args:
|
43
|
-
obj: An object to clone.
|
44
|
-
**kwargs: Keyword arguments to edit in the cloned object.
|
45
|
-
|
46
|
-
Returns:
|
47
|
-
Edited copy of the given object.
|
48
|
-
"""
|
49
|
-
|
50
|
-
obj_copy = copy.deepcopy(obj)
|
51
|
-
for k, v in kwargs.items():
|
52
|
-
assert hasattr(obj_copy,
|
53
|
-
k), f'Edit parameter is possible only for existing parameters in the given object, ' \
|
54
|
-
f'but {k} is not a parameter of {obj_copy}.'
|
55
|
-
setattr(obj_copy, k, v)
|
56
|
-
return obj_copy
|
57
|
-
|
58
|
-
|
59
47
|
class AttributeQuantizationConfig:
|
60
48
|
"""
|
61
49
|
Hold the quantization configuration of a weight attribute of a layer.
|
@@ -376,3 +364,368 @@ class QuantizationConfigOptions:
|
|
376
364
|
def get_info(self):
|
377
365
|
return {f'option {i}': cfg.get_info() for i, cfg in enumerate(self.quantization_config_list)}
|
378
366
|
|
367
|
+
|
368
|
+
class TargetPlatformModelComponent:
|
369
|
+
"""
|
370
|
+
Component of TargetPlatformModel (Fusing, OperatorsSet, etc.)
|
371
|
+
"""
|
372
|
+
def __init__(self, name: str):
|
373
|
+
"""
|
374
|
+
|
375
|
+
Args:
|
376
|
+
name: Name of component.
|
377
|
+
"""
|
378
|
+
self.name = name
|
379
|
+
_current_tp_model.get().append_component(self)
|
380
|
+
|
381
|
+
def get_info(self) -> Dict[str, Any]:
|
382
|
+
"""
|
383
|
+
|
384
|
+
Returns: Get information about the component to display (return an empty dictionary.
|
385
|
+
the actual component should fill it with info).
|
386
|
+
|
387
|
+
"""
|
388
|
+
return {}
|
389
|
+
|
390
|
+
|
391
|
+
class OperatorsSetBase(TargetPlatformModelComponent):
|
392
|
+
"""
|
393
|
+
Base class to represent a set of operators.
|
394
|
+
"""
|
395
|
+
def __init__(self, name: str):
|
396
|
+
"""
|
397
|
+
|
398
|
+
Args:
|
399
|
+
name: Name of OperatorsSet.
|
400
|
+
"""
|
401
|
+
super().__init__(name=name)
|
402
|
+
|
403
|
+
|
404
|
+
class OperatorsSet(OperatorsSetBase):
|
405
|
+
def __init__(self,
|
406
|
+
name: str,
|
407
|
+
qc_options: QuantizationConfigOptions = None):
|
408
|
+
"""
|
409
|
+
Set of operators that are represented by a unique label.
|
410
|
+
|
411
|
+
Args:
|
412
|
+
name (str): Set's label (must be unique in a TargetPlatformModel).
|
413
|
+
qc_options (QuantizationConfigOptions): Configuration options to use for this set of operations.
|
414
|
+
"""
|
415
|
+
|
416
|
+
super().__init__(name)
|
417
|
+
self.qc_options = qc_options
|
418
|
+
is_fusing_set = qc_options is None
|
419
|
+
self.is_default = _current_tp_model.get().default_qco == self.qc_options or is_fusing_set
|
420
|
+
|
421
|
+
|
422
|
+
def get_info(self) -> Dict[str,Any]:
|
423
|
+
"""
|
424
|
+
|
425
|
+
Returns: Info about the set as a dictionary.
|
426
|
+
|
427
|
+
"""
|
428
|
+
return {"name": self.name,
|
429
|
+
"is_default_qc": self.is_default}
|
430
|
+
|
431
|
+
|
432
|
+
class OperatorSetConcat(OperatorsSetBase):
|
433
|
+
"""
|
434
|
+
Concatenate a list of operator sets to treat them similarly in different places (like fusing).
|
435
|
+
"""
|
436
|
+
def __init__(self, *opsets: OperatorsSet):
|
437
|
+
"""
|
438
|
+
Group a list of operation sets.
|
439
|
+
|
440
|
+
Args:
|
441
|
+
*opsets (OperatorsSet): List of operator sets to group.
|
442
|
+
"""
|
443
|
+
name = "_".join([a.name for a in opsets])
|
444
|
+
super().__init__(name=name)
|
445
|
+
self.op_set_list = opsets
|
446
|
+
self.qc_options = None # Concat have no qc options
|
447
|
+
|
448
|
+
def get_info(self) -> Dict[str,Any]:
|
449
|
+
"""
|
450
|
+
|
451
|
+
Returns: Info about the sets group as a dictionary.
|
452
|
+
|
453
|
+
"""
|
454
|
+
return {"name": self.name,
|
455
|
+
OPS_SET_LIST: [s.name for s in self.op_set_list]}
|
456
|
+
|
457
|
+
|
458
|
+
class Fusing(TargetPlatformModelComponent):
|
459
|
+
"""
|
460
|
+
Fusing defines a list of operators that should be combined and treated as a single operator,
|
461
|
+
hence no quantization is applied between them.
|
462
|
+
"""
|
463
|
+
|
464
|
+
def __init__(self,
|
465
|
+
operator_groups_list: List[Union[OperatorsSet, OperatorSetConcat]],
|
466
|
+
name: str = None):
|
467
|
+
"""
|
468
|
+
Args:
|
469
|
+
operator_groups_list (List[Union[OperatorsSet, OperatorSetConcat]]): A list of operator groups, each being either an OperatorSetConcat or an OperatorsSet.
|
470
|
+
name (str): The name for the Fusing instance. If not provided, it's generated from the operator groups' names.
|
471
|
+
"""
|
472
|
+
assert isinstance(operator_groups_list,
|
473
|
+
list), f'List of operator groups should be of type list but is {type(operator_groups_list)}'
|
474
|
+
assert len(operator_groups_list) >= 2, f'Fusing can not be created for a single operators group'
|
475
|
+
|
476
|
+
# Generate a name from the operator groups if no name is provided
|
477
|
+
if name is None:
|
478
|
+
name = '_'.join([x.name for x in operator_groups_list])
|
479
|
+
|
480
|
+
super().__init__(name)
|
481
|
+
self.operator_groups_list = operator_groups_list
|
482
|
+
|
483
|
+
def contains(self, other: Any) -> bool:
|
484
|
+
"""
|
485
|
+
Determines if the current Fusing instance contains another Fusing instance.
|
486
|
+
|
487
|
+
Args:
|
488
|
+
other: The other Fusing instance to check against.
|
489
|
+
|
490
|
+
Returns:
|
491
|
+
A boolean indicating whether the other instance is contained within this one.
|
492
|
+
"""
|
493
|
+
if not isinstance(other, Fusing):
|
494
|
+
return False
|
495
|
+
|
496
|
+
# Check for containment by comparing operator groups
|
497
|
+
for i in range(len(self.operator_groups_list) - len(other.operator_groups_list) + 1):
|
498
|
+
for j in range(len(other.operator_groups_list)):
|
499
|
+
if self.operator_groups_list[i + j] != other.operator_groups_list[j] and not (
|
500
|
+
isinstance(self.operator_groups_list[i + j], OperatorSetConcat) and (
|
501
|
+
other.operator_groups_list[j] in self.operator_groups_list[i + j].op_set_list)):
|
502
|
+
break
|
503
|
+
else:
|
504
|
+
# If all checks pass, the other Fusing instance is contained
|
505
|
+
return True
|
506
|
+
# Other Fusing instance is not contained
|
507
|
+
return False
|
508
|
+
|
509
|
+
def get_info(self):
|
510
|
+
"""
|
511
|
+
Retrieves information about the Fusing instance, including its name and the sequence of operator groups.
|
512
|
+
|
513
|
+
Returns:
|
514
|
+
A dictionary with the Fusing instance's name as the key and the sequence of operator groups as the value,
|
515
|
+
or just the sequence of operator groups if no name is set.
|
516
|
+
"""
|
517
|
+
if self.name is not None:
|
518
|
+
return {self.name: ' -> '.join([x.name for x in self.operator_groups_list])}
|
519
|
+
return ' -> '.join([x.name for x in self.operator_groups_list])
|
520
|
+
|
521
|
+
|
522
|
+
class TargetPlatformModel(ImmutableClass):
|
523
|
+
"""
|
524
|
+
Represents the hardware configuration used for quantized model inference.
|
525
|
+
|
526
|
+
This model defines:
|
527
|
+
- The operators and their associated quantization configurations.
|
528
|
+
- Fusing patterns, enabling multiple operators to be combined into a single operator
|
529
|
+
for optimization during inference.
|
530
|
+
- Versioning support through minor and patch versions for backward compatibility.
|
531
|
+
|
532
|
+
Attributes:
|
533
|
+
SCHEMA_VERSION (int): The schema version of the target platform model.
|
534
|
+
"""
|
535
|
+
SCHEMA_VERSION = 1
|
536
|
+
def __init__(self,
|
537
|
+
default_qco: QuantizationConfigOptions,
|
538
|
+
tpc_minor_version: Optional[int],
|
539
|
+
tpc_patch_version: Optional[int],
|
540
|
+
tpc_platform_type: Optional[str],
|
541
|
+
add_metadata: bool = True,
|
542
|
+
name="default_tp_model"):
|
543
|
+
"""
|
544
|
+
|
545
|
+
Args:
|
546
|
+
default_qco (QuantizationConfigOptions): Default QuantizationConfigOptions to use for operators that their QuantizationConfigOptions are not defined in the model.
|
547
|
+
tpc_minor_version (Optional[int]): The minor version of the target platform capabilities.
|
548
|
+
tpc_patch_version (Optional[int]): The patch version of the target platform capabilities.
|
549
|
+
tpc_platform_type (Optional[str]): The platform type of the target platform capabilities.
|
550
|
+
add_metadata (bool): Whether to add metadata to the model or not.
|
551
|
+
name (str): Name of the model.
|
552
|
+
|
553
|
+
Raises:
|
554
|
+
AssertionError: If the provided `default_qco` does not contain exactly one quantization configuration.
|
555
|
+
"""
|
556
|
+
|
557
|
+
super().__init__()
|
558
|
+
self.tpc_minor_version = tpc_minor_version
|
559
|
+
self.tpc_patch_version = tpc_patch_version
|
560
|
+
self.tpc_platform_type = tpc_platform_type
|
561
|
+
self.add_metadata = add_metadata
|
562
|
+
self.name = name
|
563
|
+
self.operator_set = []
|
564
|
+
assert isinstance(default_qco, QuantizationConfigOptions), \
|
565
|
+
"default_qco must be an instance of QuantizationConfigOptions"
|
566
|
+
assert len(default_qco.quantization_config_list) == 1, \
|
567
|
+
"Default QuantizationConfigOptions must contain exactly one option."
|
568
|
+
|
569
|
+
self.default_qco = default_qco
|
570
|
+
self.fusing_patterns = []
|
571
|
+
self.is_simd_padding = False
|
572
|
+
|
573
|
+
def get_config_options_by_operators_set(self,
|
574
|
+
operators_set_name: str) -> QuantizationConfigOptions:
|
575
|
+
"""
|
576
|
+
Get the QuantizationConfigOptions of a OperatorsSet by the OperatorsSet name.
|
577
|
+
If the name is not in the model, the default QuantizationConfigOptions is returned.
|
578
|
+
|
579
|
+
Args:
|
580
|
+
operators_set_name: Name of OperatorsSet to get.
|
581
|
+
|
582
|
+
Returns:
|
583
|
+
QuantizationConfigOptions to use for ops in OperatorsSet named operators_set_name.
|
584
|
+
"""
|
585
|
+
for op_set in self.operator_set:
|
586
|
+
if operators_set_name == op_set.name:
|
587
|
+
return op_set.qc_options
|
588
|
+
return self.default_qco
|
589
|
+
|
590
|
+
def get_default_op_quantization_config(self) -> OpQuantizationConfig:
|
591
|
+
"""
|
592
|
+
|
593
|
+
Returns: The default OpQuantizationConfig of the TargetPlatformModel.
|
594
|
+
|
595
|
+
"""
|
596
|
+
assert len(self.default_qco.quantization_config_list) == 1, \
|
597
|
+
f'Default quantization configuration options must contain only one option,' \
|
598
|
+
f' but found {len(get_current_tp_model().default_qco.quantization_config_list)} configurations.'
|
599
|
+
return self.default_qco.quantization_config_list[0]
|
600
|
+
|
601
|
+
def is_opset_in_model(self,
|
602
|
+
opset_name: str) -> bool:
|
603
|
+
"""
|
604
|
+
Check whether an operators set is defined in the model or not.
|
605
|
+
|
606
|
+
Args:
|
607
|
+
opset_name: Operators set name to check.
|
608
|
+
|
609
|
+
Returns:
|
610
|
+
Whether an operators set is defined in the model or not.
|
611
|
+
"""
|
612
|
+
return opset_name in [x.name for x in self.operator_set]
|
613
|
+
|
614
|
+
def get_opset_by_name(self,
|
615
|
+
opset_name: str) -> OperatorsSetBase:
|
616
|
+
"""
|
617
|
+
Get an OperatorsSet object from the model by its name.
|
618
|
+
If name is not in the model - None is returned.
|
619
|
+
|
620
|
+
Args:
|
621
|
+
opset_name: OperatorsSet name to retrieve.
|
622
|
+
|
623
|
+
Returns:
|
624
|
+
OperatorsSet object with the name opset_name, or None if opset_name is not in the model.
|
625
|
+
"""
|
626
|
+
|
627
|
+
opset_list = [x for x in self.operator_set if x.name == opset_name]
|
628
|
+
assert len(opset_list) <= 1, f'Found more than one OperatorsSet in' \
|
629
|
+
f' TargetPlatformModel with the name {opset_name}. ' \
|
630
|
+
f'OperatorsSet name must be unique.'
|
631
|
+
if len(opset_list) == 0: # opset_name is not in the model.
|
632
|
+
return None
|
633
|
+
|
634
|
+
return opset_list[0] # There's one opset with that name
|
635
|
+
|
636
|
+
def append_component(self,
|
637
|
+
tp_model_component: TargetPlatformModelComponent):
|
638
|
+
"""
|
639
|
+
Attach a TargetPlatformModel component to the model. Components can be for example:
|
640
|
+
Fusing, OperatorsSet, etc.
|
641
|
+
|
642
|
+
Args:
|
643
|
+
tp_model_component: Component to attach to the model.
|
644
|
+
|
645
|
+
"""
|
646
|
+
if isinstance(tp_model_component, Fusing):
|
647
|
+
self.fusing_patterns.append(tp_model_component)
|
648
|
+
elif isinstance(tp_model_component, OperatorsSetBase):
|
649
|
+
self.operator_set.append(tp_model_component)
|
650
|
+
else: # pragma: no cover
|
651
|
+
Logger.critical(f'Attempted to append an unrecognized TargetPlatformModelComponent of type: {type(tp_model_component)}.')
|
652
|
+
|
653
|
+
def __enter__(self):
|
654
|
+
"""
|
655
|
+
Start defining the TargetPlatformModel using 'with'.
|
656
|
+
|
657
|
+
Returns: Initialized TargetPlatformModel object.
|
658
|
+
|
659
|
+
"""
|
660
|
+
_current_tp_model.set(self)
|
661
|
+
return self
|
662
|
+
|
663
|
+
def __exit__(self, exc_type, exc_value, tb):
|
664
|
+
"""
|
665
|
+
Finish defining the TargetPlatformModel at the end of the 'with' clause.
|
666
|
+
Returns the final and immutable TargetPlatformModel instance.
|
667
|
+
"""
|
668
|
+
|
669
|
+
if exc_value is not None:
|
670
|
+
print(exc_value, exc_value.args)
|
671
|
+
raise exc_value
|
672
|
+
self.__validate_model() # Assert that model is valid.
|
673
|
+
_current_tp_model.reset()
|
674
|
+
self.initialized_done() # Make model immutable.
|
675
|
+
return self
|
676
|
+
|
677
|
+
def __validate_model(self):
|
678
|
+
"""
|
679
|
+
|
680
|
+
Assert model is valid.
|
681
|
+
Model is invalid if, for example, it contains multiple operator sets with the same name,
|
682
|
+
as their names should be unique.
|
683
|
+
|
684
|
+
"""
|
685
|
+
opsets_names = [op.name for op in self.operator_set]
|
686
|
+
if len(set(opsets_names)) != len(opsets_names):
|
687
|
+
Logger.critical(f'Operator Sets must have unique names.')
|
688
|
+
|
689
|
+
def get_default_config(self) -> OpQuantizationConfig:
|
690
|
+
"""
|
691
|
+
|
692
|
+
Returns:
|
693
|
+
|
694
|
+
"""
|
695
|
+
assert len(self.default_qco.quantization_config_list) == 1, \
|
696
|
+
f'Default quantization configuration options must contain only one option,' \
|
697
|
+
f' but found {len(self.default_qco.quantization_config_list)} configurations.'
|
698
|
+
return self.default_qco.quantization_config_list[0]
|
699
|
+
|
700
|
+
def get_info(self) -> Dict[str, Any]:
|
701
|
+
"""
|
702
|
+
|
703
|
+
Returns: Dictionary that summarizes the TargetPlatformModel properties (for display purposes).
|
704
|
+
|
705
|
+
"""
|
706
|
+
return {"Model name": self.name,
|
707
|
+
"Default quantization config": self.get_default_config().get_info(),
|
708
|
+
"Operators sets": [o.get_info() for o in self.operator_set],
|
709
|
+
"Fusing patterns": [f.get_info() for f in self.fusing_patterns]
|
710
|
+
}
|
711
|
+
|
712
|
+
def show(self):
|
713
|
+
"""
|
714
|
+
|
715
|
+
Display the TargetPlatformModel.
|
716
|
+
|
717
|
+
"""
|
718
|
+
pprint.pprint(self.get_info(), sort_dicts=False)
|
719
|
+
|
720
|
+
def set_simd_padding(self,
|
721
|
+
is_simd_padding: bool):
|
722
|
+
"""
|
723
|
+
Set flag is_simd_padding to indicate whether this TP model defines
|
724
|
+
that padding due to SIMD constrains occurs.
|
725
|
+
|
726
|
+
Args:
|
727
|
+
is_simd_padding: Whether this TP model defines that padding due to SIMD constrains occurs.
|
728
|
+
|
729
|
+
"""
|
730
|
+
self.is_simd_padding = is_simd_padding
|
731
|
+
|
@@ -13,13 +13,11 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.fusing import Fusing
|
17
16
|
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attribute_filter import AttributeFilter
|
18
17
|
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities, OperationsSetToLayers, Smaller, SmallerEq, NotEq, Eq, GreaterEq, Greater, LayerFilterParams, OperationsToLayers, get_current_tpc
|
19
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model import get_default_quantization_config_options
|
20
|
-
from model_compression_toolkit.target_platform_capabilities.
|
21
|
-
|
22
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorsSet, OperatorSetConcat
|
18
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model import get_default_quantization_config_options
|
19
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel, OperatorsSet, \
|
20
|
+
OperatorSetConcat, Signedness, AttributeQuantizationConfig, OpQuantizationConfig, QuantizationConfigOptions, Fusing
|
23
21
|
|
24
22
|
from mct_quantizers import QuantizationMethod
|
25
23
|
|
model_compression_toolkit/target_platform_capabilities/target_platform/target_platform_model.py
CHANGED
@@ -13,19 +13,8 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
import
|
17
|
-
from
|
18
|
-
|
19
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import _current_tp_model, \
|
20
|
-
get_current_tp_model
|
21
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.fusing import Fusing
|
22
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model_component import \
|
23
|
-
TargetPlatformModelComponent
|
24
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.op_quantization_config import OpQuantizationConfig, \
|
25
|
-
QuantizationConfigOptions
|
26
|
-
from model_compression_toolkit.target_platform_capabilities.target_platform.operators import OperatorsSetBase
|
27
|
-
from model_compression_toolkit.target_platform_capabilities.immutable import ImmutableClass
|
28
|
-
from model_compression_toolkit.logger import Logger
|
16
|
+
from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import get_current_tp_model
|
17
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions
|
29
18
|
|
30
19
|
|
31
20
|
def get_default_quantization_config_options() -> QuantizationConfigOptions:
|
@@ -39,204 +28,3 @@ def get_default_quantization_config_options() -> QuantizationConfigOptions:
|
|
39
28
|
return get_current_tp_model().default_qco
|
40
29
|
|
41
30
|
|
42
|
-
def get_default_quantization_config():
|
43
|
-
"""
|
44
|
-
|
45
|
-
Returns: The default OpQuantizationConfig of the model. This is the OpQuantizationConfig
|
46
|
-
to use when a layer's options is queried and it wasn't specified in the TargetPlatformCapabilities.
|
47
|
-
This OpQuantizationConfig is the single option in the default QuantizationConfigOptions.
|
48
|
-
|
49
|
-
"""
|
50
|
-
|
51
|
-
return get_current_tp_model().get_default_op_quantization_config()
|
52
|
-
|
53
|
-
|
54
|
-
class TargetPlatformModel(ImmutableClass):
|
55
|
-
"""
|
56
|
-
Modeling of the hardware the quantized model will use during inference.
|
57
|
-
The model contains definition of operators, quantization configurations of them, and
|
58
|
-
fusing patterns so that multiple operators will be combined into a single operator.
|
59
|
-
"""
|
60
|
-
|
61
|
-
def __init__(self,
|
62
|
-
default_qco: QuantizationConfigOptions,
|
63
|
-
add_metadata: bool = False,
|
64
|
-
name="default_tp_model"):
|
65
|
-
"""
|
66
|
-
|
67
|
-
Args:
|
68
|
-
default_qco (QuantizationConfigOptions): Default QuantizationConfigOptions to use for operators that their QuantizationConfigOptions are not defined in the model.
|
69
|
-
add_metadata (bool): Whether to add metadata to the model or not.
|
70
|
-
name (str): Name of the model.
|
71
|
-
"""
|
72
|
-
|
73
|
-
super().__init__()
|
74
|
-
self.add_metadata = add_metadata
|
75
|
-
self.name = name
|
76
|
-
self.operator_set = []
|
77
|
-
assert isinstance(default_qco, QuantizationConfigOptions)
|
78
|
-
assert len(default_qco.quantization_config_list) == 1, \
|
79
|
-
f'Default QuantizationConfigOptions must contain only one option'
|
80
|
-
self.default_qco = default_qco
|
81
|
-
self.fusing_patterns = []
|
82
|
-
self.is_simd_padding = False
|
83
|
-
|
84
|
-
def get_config_options_by_operators_set(self,
|
85
|
-
operators_set_name: str) -> QuantizationConfigOptions:
|
86
|
-
"""
|
87
|
-
Get the QuantizationConfigOptions of a OperatorsSet by the OperatorsSet name.
|
88
|
-
If the name is not in the model, the default QuantizationConfigOptions is returned.
|
89
|
-
|
90
|
-
Args:
|
91
|
-
operators_set_name: Name of OperatorsSet to get.
|
92
|
-
|
93
|
-
Returns:
|
94
|
-
QuantizationConfigOptions to use for ops in OperatorsSet named operators_set_name.
|
95
|
-
"""
|
96
|
-
for op_set in self.operator_set:
|
97
|
-
if operators_set_name == op_set.name:
|
98
|
-
return op_set.qc_options
|
99
|
-
return self.default_qco
|
100
|
-
|
101
|
-
def get_default_op_quantization_config(self) -> OpQuantizationConfig:
|
102
|
-
"""
|
103
|
-
|
104
|
-
Returns: The default OpQuantizationConfig of the TargetPlatformModel.
|
105
|
-
|
106
|
-
"""
|
107
|
-
assert len(self.default_qco.quantization_config_list) == 1, \
|
108
|
-
f'Default quantization configuration options must contain only one option,' \
|
109
|
-
f' but found {len(get_current_tp_model().default_qco.quantization_config_list)} configurations.'
|
110
|
-
return self.default_qco.quantization_config_list[0]
|
111
|
-
|
112
|
-
def is_opset_in_model(self,
|
113
|
-
opset_name: str) -> bool:
|
114
|
-
"""
|
115
|
-
Check whether an operators set is defined in the model or not.
|
116
|
-
|
117
|
-
Args:
|
118
|
-
opset_name: Operators set name to check.
|
119
|
-
|
120
|
-
Returns:
|
121
|
-
Whether an operators set is defined in the model or not.
|
122
|
-
"""
|
123
|
-
return opset_name in [x.name for x in self.operator_set]
|
124
|
-
|
125
|
-
def get_opset_by_name(self,
|
126
|
-
opset_name: str) -> OperatorsSetBase:
|
127
|
-
"""
|
128
|
-
Get an OperatorsSet object from the model by its name.
|
129
|
-
If name is not in the model - None is returned.
|
130
|
-
|
131
|
-
Args:
|
132
|
-
opset_name: OperatorsSet name to retrieve.
|
133
|
-
|
134
|
-
Returns:
|
135
|
-
OperatorsSet object with the name opset_name, or None if opset_name is not in the model.
|
136
|
-
"""
|
137
|
-
|
138
|
-
opset_list = [x for x in self.operator_set if x.name == opset_name]
|
139
|
-
assert len(opset_list) <= 1, f'Found more than one OperatorsSet in' \
|
140
|
-
f' TargetPlatformModel with the name {opset_name}. ' \
|
141
|
-
f'OperatorsSet name must be unique.'
|
142
|
-
if len(opset_list) == 0: # opset_name is not in the model.
|
143
|
-
return None
|
144
|
-
|
145
|
-
return opset_list[0] # There's one opset with that name
|
146
|
-
|
147
|
-
def append_component(self,
|
148
|
-
tp_model_component: TargetPlatformModelComponent):
|
149
|
-
"""
|
150
|
-
Attach a TargetPlatformModel component to the model. Components can be for example:
|
151
|
-
Fusing, OperatorsSet, etc.
|
152
|
-
|
153
|
-
Args:
|
154
|
-
tp_model_component: Component to attach to the model.
|
155
|
-
|
156
|
-
"""
|
157
|
-
if isinstance(tp_model_component, Fusing):
|
158
|
-
self.fusing_patterns.append(tp_model_component)
|
159
|
-
elif isinstance(tp_model_component, OperatorsSetBase):
|
160
|
-
self.operator_set.append(tp_model_component)
|
161
|
-
else: # pragma: no cover
|
162
|
-
Logger.critical(f'Attempted to append an unrecognized TargetPlatformModelComponent of type: {type(tp_model_component)}.')
|
163
|
-
|
164
|
-
def __enter__(self):
|
165
|
-
"""
|
166
|
-
Start defining the TargetPlatformModel using 'with'.
|
167
|
-
|
168
|
-
Returns: Initialized TargetPlatformModel object.
|
169
|
-
|
170
|
-
"""
|
171
|
-
_current_tp_model.set(self)
|
172
|
-
return self
|
173
|
-
|
174
|
-
def __exit__(self, exc_type, exc_value, tb):
|
175
|
-
"""
|
176
|
-
Finish defining the TargetPlatformModel at the end of the 'with' clause.
|
177
|
-
Returns the final and immutable TargetPlatformModel instance.
|
178
|
-
"""
|
179
|
-
|
180
|
-
if exc_value is not None:
|
181
|
-
print(exc_value, exc_value.args)
|
182
|
-
raise exc_value
|
183
|
-
self.__validate_model() # Assert that model is valid.
|
184
|
-
_current_tp_model.reset()
|
185
|
-
self.initialized_done() # Make model immutable.
|
186
|
-
return self
|
187
|
-
|
188
|
-
def __validate_model(self):
|
189
|
-
"""
|
190
|
-
|
191
|
-
Assert model is valid.
|
192
|
-
Model is invalid if, for example, it contains multiple operator sets with the same name,
|
193
|
-
as their names should be unique.
|
194
|
-
|
195
|
-
"""
|
196
|
-
opsets_names = [op.name for op in self.operator_set]
|
197
|
-
if len(set(opsets_names)) != len(opsets_names):
|
198
|
-
Logger.critical(f'Operator Sets must have unique names.')
|
199
|
-
|
200
|
-
def get_default_config(self) -> OpQuantizationConfig:
|
201
|
-
"""
|
202
|
-
|
203
|
-
Returns:
|
204
|
-
|
205
|
-
"""
|
206
|
-
assert len(self.default_qco.quantization_config_list) == 1, \
|
207
|
-
f'Default quantization configuration options must contain only one option,' \
|
208
|
-
f' but found {len(self.default_qco.quantization_config_list)} configurations.'
|
209
|
-
return self.default_qco.quantization_config_list[0]
|
210
|
-
|
211
|
-
def get_info(self) -> Dict[str, Any]:
|
212
|
-
"""
|
213
|
-
|
214
|
-
Returns: Dictionary that summarizes the TargetPlatformModel properties (for display purposes).
|
215
|
-
|
216
|
-
"""
|
217
|
-
return {"Model name": self.name,
|
218
|
-
"Default quantization config": self.get_default_config().get_info(),
|
219
|
-
"Operators sets": [o.get_info() for o in self.operator_set],
|
220
|
-
"Fusing patterns": [f.get_info() for f in self.fusing_patterns]
|
221
|
-
}
|
222
|
-
|
223
|
-
def show(self):
|
224
|
-
"""
|
225
|
-
|
226
|
-
Display the TargetPlatformModel.
|
227
|
-
|
228
|
-
"""
|
229
|
-
pprint.pprint(self.get_info(), sort_dicts=False)
|
230
|
-
|
231
|
-
def set_simd_padding(self,
|
232
|
-
is_simd_padding: bool):
|
233
|
-
"""
|
234
|
-
Set flag is_simd_padding to indicate whether this TP model defines
|
235
|
-
that padding due to SIMD constrains occurs.
|
236
|
-
|
237
|
-
Args:
|
238
|
-
is_simd_padding: Whether this TP model defines that padding due to SIMD constrains occurs.
|
239
|
-
|
240
|
-
"""
|
241
|
-
self.is_simd_padding = is_simd_padding
|
242
|
-
|
@@ -18,8 +18,7 @@ from typing import List, Any, Dict
|
|
18
18
|
from model_compression_toolkit.logger import Logger
|
19
19
|
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc
|
20
20
|
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent
|
21
|
-
from model_compression_toolkit.target_platform_capabilities.
|
22
|
-
OperatorsSetBase
|
21
|
+
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorsSetBase, OperatorSetConcat
|
23
22
|
from model_compression_toolkit import DefaultDict
|
24
23
|
|
25
24
|
|