mct-nightly 2.3.0.20250414.604__py3-none-any.whl → 2.3.0.20250416.541__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.3.0.20250414.604.dist-info → mct_nightly-2.3.0.20250416.541.dist-info}/METADATA +7 -7
- {mct_nightly-2.3.0.20250414.604.dist-info → mct_nightly-2.3.0.20250416.541.dist-info}/RECORD +10 -10
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/gptq/common/gptq_config.py +1 -1
- model_compression_toolkit/target_platform_capabilities/schema/v1.py +47 -67
- model_compression_toolkit/target_platform_capabilities/schema/v2.py +75 -13
- model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +1 -1
- {mct_nightly-2.3.0.20250414.604.dist-info → mct_nightly-2.3.0.20250416.541.dist-info}/WHEEL +0 -0
- {mct_nightly-2.3.0.20250414.604.dist-info → mct_nightly-2.3.0.20250416.541.dist-info}/licenses/LICENSE.md +0 -0
- {mct_nightly-2.3.0.20250414.604.dist-info → mct_nightly-2.3.0.20250416.541.dist-info}/top_level.txt +0 -0
{mct_nightly-2.3.0.20250414.604.dist-info → mct_nightly-2.3.0.20250416.541.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: mct-nightly
|
3
|
-
Version: 2.3.0.
|
3
|
+
Version: 2.3.0.20250416.541
|
4
4
|
Summary: A Model Compression Toolkit for neural networks
|
5
5
|
Classifier: Programming Language :: Python :: 3
|
6
6
|
Classifier: License :: OSI Approved :: Apache Software License
|
@@ -21,7 +21,7 @@ Requires-Dist: matplotlib<3.10.0
|
|
21
21
|
Requires-Dist: scipy
|
22
22
|
Requires-Dist: protobuf
|
23
23
|
Requires-Dist: mct-quantizers-nightly
|
24
|
-
Requires-Dist: pydantic
|
24
|
+
Requires-Dist: pydantic>=2.0
|
25
25
|
Requires-Dist: sony-custom-layers-dev==0.4.0.dev6
|
26
26
|
Dynamic: classifier
|
27
27
|
Dynamic: description
|
@@ -63,7 +63,7 @@ ________________________________________________________________________________
|
|
63
63
|
|
64
64
|
## <div align="center">Getting Started</div>
|
65
65
|
### Quick Installation
|
66
|
-
Pip install the model compression toolkit package in a Python>=3.9 environment with PyTorch>=2.1 or Tensorflow>=2.
|
66
|
+
Pip install the model compression toolkit package in a Python>=3.9 environment with PyTorch>=2.1 or Tensorflow>=2.14.
|
67
67
|
```
|
68
68
|
pip install model-compression-toolkit
|
69
69
|
```
|
@@ -170,11 +170,11 @@ Currently, MCT is being tested on various Python, Pytorch and TensorFlow version
|
|
170
170
|
| Python 3.11 | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch22.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch23.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch24.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch25.yml) |
|
171
171
|
| Python 3.12 | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch22.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch23.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch24.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch25.yml) |
|
172
172
|
|
173
|
-
| | TensorFlow 2.
|
173
|
+
| | TensorFlow 2.14 | TensorFlow 2.15 |
|
174
174
|
|-------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
175
|
-
| Python 3.9 | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras215.yml) |
|
176
|
+
| Python 3.10 | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras215.yml) |
|
177
|
+
| Python 3.11 | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml) | [](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras215.yml) |
|
178
178
|
|
179
179
|
</details>
|
180
180
|
|
{mct_nightly-2.3.0.20250414.604.dist-info → mct_nightly-2.3.0.20250416.541.dist-info}/RECORD
RENAMED
@@ -1,5 +1,5 @@
|
|
1
|
-
mct_nightly-2.3.0.
|
2
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
mct_nightly-2.3.0.20250416.541.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
2
|
+
model_compression_toolkit/__init__.py,sha256=AeZ2o5FMPLxX0sepHjLsV8WP2kgUvZWHt78DlPDh7u8,1557
|
3
3
|
model_compression_toolkit/constants.py,sha256=2ltuH-gdaLZoZV4CPUgKjC3S9ojz2z4OTVdenyVEypU,3912
|
4
4
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
5
5
|
model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
|
@@ -354,7 +354,7 @@ model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantiz
|
|
354
354
|
model_compression_toolkit/gptq/__init__.py,sha256=pEgkJvmf05KSw70iLDTz_6LI_2Oi5L8sTN0JsEUpnpk,1445
|
355
355
|
model_compression_toolkit/gptq/runner.py,sha256=La12JTYjWyJW0YW4Al4TP1_Xi4JWBCEKw6FR_JQsxe0,5982
|
356
356
|
model_compression_toolkit/gptq/common/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
357
|
-
model_compression_toolkit/gptq/common/gptq_config.py,sha256=
|
357
|
+
model_compression_toolkit/gptq/common/gptq_config.py,sha256=xVzjy3CyR07rpGvUy2jsSaijXq7-0KStpU_yVu7VLVA,6144
|
358
358
|
model_compression_toolkit/gptq/common/gptq_constants.py,sha256=8HB0yiX75zZ1IKgQUPWpFCM5sS8HAqslws5XrOhxJQ0,750
|
359
359
|
model_compression_toolkit/gptq/common/gptq_framework_implementation.py,sha256=n3mSf4J92kFjekzyGyrJULylI-8Jf5OVWJ5AFoVnEx0,1266
|
360
360
|
model_compression_toolkit/gptq/common/gptq_graph.py,sha256=-bL5HhPcKqV8nj4dZPXc5QmQJbFBel6etrioikP0tEo,3039
|
@@ -433,13 +433,13 @@ model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py,sha2
|
|
433
433
|
model_compression_toolkit/target_platform_capabilities/__init__.py,sha256=8RVOriZg-XNjSt53h_4Yum0oRgOe2gp5H45dfG_lZxE,1415
|
434
434
|
model_compression_toolkit/target_platform_capabilities/constants.py,sha256=JRz9DoxLRpkqvu532TFkIvv0595Bfb9NtU4pRp4urDY,1540
|
435
435
|
model_compression_toolkit/target_platform_capabilities/immutable.py,sha256=YhROBiXEIB3TU-bAFrnL3qbAsb1yuWPBAQ_CLOJbYUU,1827
|
436
|
-
model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py,sha256
|
436
|
+
model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py,sha256=nbmlygR-nc3bzwnUDrRamq3a6KFkC4-cCpbUeF7EEmo,4626
|
437
437
|
model_compression_toolkit/target_platform_capabilities/schema/__init__.py,sha256=pKAdbTCFM_2BrZXUtTIw0ouKotrWwUDF_hP3rPwCM2k,696
|
438
438
|
model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py,sha256=PvO8eHxnb3A55gyExT5fZGnOUl3ce7BbbT5SPxCEXNo,541
|
439
439
|
model_compression_toolkit/target_platform_capabilities/schema/schema_compatability.py,sha256=TtMPbiibV6Hk53nl5Y_ctfpI6mSbd8VVH9fxnv5j9eM,4430
|
440
440
|
model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py,sha256=vBkXxVJagm9JKB9cdm4Pvi7u_luriXUjvNn0-m8Zr0k,4653
|
441
|
-
model_compression_toolkit/target_platform_capabilities/schema/v1.py,sha256=
|
442
|
-
model_compression_toolkit/target_platform_capabilities/schema/v2.py,sha256=
|
441
|
+
model_compression_toolkit/target_platform_capabilities/schema/v1.py,sha256=oWKNQnnz04kmijmdWtRyXgVXbJ6BG_V_bUBz_MfUM94,27116
|
442
|
+
model_compression_toolkit/target_platform_capabilities/schema/v2.py,sha256=FiSkRUSuEPnJxvyDuRTwv2gwY4xveSp1hLtWKEFa8zc,6110
|
443
443
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/__init__.py,sha256=XjNws3zoiJkeH4ixKqrLA5xBvpv5rq31qX7wYQjNpZM,1447
|
444
444
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2fw.py,sha256=HJ8uc3PFfyxg-WpVXPBg4mGaox8Z9bRqtQNbRfIyAk4,3745
|
445
445
|
model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py,sha256=Ehwpd_sL6zxmJFpJugOdN9uNxNX05nijvOCilNfHnFs,7162
|
@@ -528,7 +528,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
528
528
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
|
529
529
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
530
530
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
|
531
|
-
mct_nightly-2.3.0.
|
532
|
-
mct_nightly-2.3.0.
|
533
|
-
mct_nightly-2.3.0.
|
534
|
-
mct_nightly-2.3.0.
|
531
|
+
mct_nightly-2.3.0.20250416.541.dist-info/METADATA,sha256=r1uKB8w4EULCSj-_wL_b-doM7GuOlu4NeTVo11pYUj0,25413
|
532
|
+
mct_nightly-2.3.0.20250416.541.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
533
|
+
mct_nightly-2.3.0.20250416.541.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
534
|
+
mct_nightly-2.3.0.20250416.541.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.3.0.
|
30
|
+
__version__ = "2.3.0.20250416.000541"
|
@@ -84,7 +84,7 @@ class QFractionLinearAnnealingConfig:
|
|
84
84
|
raise ValueError(f'Expected start_step >= 0. received {self.start_step}.')
|
85
85
|
if self.end_step is not None and self.end_step <= self.start_step:
|
86
86
|
raise ValueError('Expected start_step < end_step, '
|
87
|
-
'received end_step {self.end_step} and start_step {self.
|
87
|
+
f'received end_step {self.end_step} and start_step {self.start_step}.')
|
88
88
|
|
89
89
|
|
90
90
|
@dataclass
|
@@ -16,7 +16,8 @@ import pprint
|
|
16
16
|
from enum import Enum
|
17
17
|
from typing import Dict, Any, Union, Tuple, List, Optional, Literal, Annotated
|
18
18
|
|
19
|
-
from pydantic import BaseModel, Field, root_validator, validator, PositiveInt
|
19
|
+
from pydantic import BaseModel, Field, root_validator, validator, PositiveInt, ConfigDict, field_validator, \
|
20
|
+
model_validator
|
20
21
|
|
21
22
|
from mct_quantizers import QuantizationMethod
|
22
23
|
from model_compression_toolkit.constants import FLOAT_BITWIDTH
|
@@ -118,9 +119,7 @@ class AttributeQuantizationConfig(BaseModel):
|
|
118
119
|
enable_weights_quantization: bool = False
|
119
120
|
lut_values_bitwidth: Optional[int] = None
|
120
121
|
|
121
|
-
|
122
|
-
# Makes the model immutable (frozen)
|
123
|
-
frozen = True
|
122
|
+
model_config = ConfigDict(frozen=True)
|
124
123
|
|
125
124
|
@property
|
126
125
|
def field_names(self) -> list:
|
@@ -137,7 +136,7 @@ class AttributeQuantizationConfig(BaseModel):
|
|
137
136
|
Returns:
|
138
137
|
AttributeQuantizationConfig: A new instance of AttributeQuantizationConfig with updated attributes.
|
139
138
|
"""
|
140
|
-
return self.
|
139
|
+
return self.model_copy(update=kwargs)
|
141
140
|
|
142
141
|
|
143
142
|
class OpQuantizationConfig(BaseModel):
|
@@ -164,15 +163,14 @@ class OpQuantizationConfig(BaseModel):
|
|
164
163
|
supported_input_activation_n_bits: Union[int, Tuple[int, ...]]
|
165
164
|
enable_activation_quantization: bool
|
166
165
|
quantization_preserving: bool
|
167
|
-
fixed_scale: Optional[float]
|
168
|
-
fixed_zero_point: Optional[int]
|
169
|
-
simd_size: Optional[int]
|
170
166
|
signedness: Signedness
|
167
|
+
fixed_scale: Optional[float] = None
|
168
|
+
fixed_zero_point: Optional[int] = None
|
169
|
+
simd_size: Optional[int] = None
|
171
170
|
|
172
|
-
|
173
|
-
frozen = True
|
171
|
+
model_config = ConfigDict(frozen=True)
|
174
172
|
|
175
|
-
@
|
173
|
+
@field_validator('supported_input_activation_n_bits', mode='before')
|
176
174
|
def validate_supported_input_activation_n_bits(cls, v):
|
177
175
|
"""
|
178
176
|
Validate and process the supported_input_activation_n_bits field.
|
@@ -199,9 +197,9 @@ class OpQuantizationConfig(BaseModel):
|
|
199
197
|
return self.dict() # pragma: no cover
|
200
198
|
|
201
199
|
def clone_and_edit(
|
202
|
-
|
203
|
-
|
204
|
-
|
200
|
+
self,
|
201
|
+
attr_to_edit: Dict[str, Dict[str, Any]] = {},
|
202
|
+
**kwargs: Any
|
205
203
|
) -> 'OpQuantizationConfig':
|
206
204
|
"""
|
207
205
|
Clone the quantization config and edit some of its attributes.
|
@@ -215,17 +213,17 @@ class OpQuantizationConfig(BaseModel):
|
|
215
213
|
OpQuantizationConfig: Edited quantization configuration.
|
216
214
|
"""
|
217
215
|
# Clone and update top-level attributes
|
218
|
-
updated_config = self.
|
216
|
+
updated_config = self.model_copy(update=kwargs)
|
219
217
|
|
220
218
|
# Clone and update nested immutable dataclasses in `attr_weights_configs_mapping`
|
221
219
|
updated_attr_mapping = {
|
222
220
|
attr_name: (attr_cfg.clone_and_edit(**attr_to_edit[attr_name])
|
223
|
-
|
221
|
+
if attr_name in attr_to_edit else attr_cfg)
|
224
222
|
for attr_name, attr_cfg in updated_config.attr_weights_configs_mapping.items()
|
225
223
|
}
|
226
224
|
|
227
225
|
# Return a new instance with the updated attribute mapping
|
228
|
-
return updated_config.
|
226
|
+
return updated_config.model_copy(update={'attr_weights_configs_mapping': updated_attr_mapping})
|
229
227
|
|
230
228
|
|
231
229
|
class QuantizationConfigOptions(BaseModel):
|
@@ -239,10 +237,9 @@ class QuantizationConfigOptions(BaseModel):
|
|
239
237
|
quantization_configurations: Tuple[OpQuantizationConfig, ...]
|
240
238
|
base_config: Optional[OpQuantizationConfig] = None
|
241
239
|
|
242
|
-
|
243
|
-
frozen = True
|
240
|
+
model_config = ConfigDict(frozen=True)
|
244
241
|
|
245
|
-
@
|
242
|
+
@model_validator(mode="before")
|
246
243
|
def validate_and_set_base_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
247
244
|
"""
|
248
245
|
Validate and set the base_config based on quantization_configurations.
|
@@ -282,12 +279,6 @@ class QuantizationConfigOptions(BaseModel):
|
|
282
279
|
"'base_config' must be included in the quantization config options."
|
283
280
|
) # pragma: no cover
|
284
281
|
|
285
|
-
# if num_configs == 1:
|
286
|
-
# if base_config != quantization_configurations[0]:
|
287
|
-
# Logger.critical(
|
288
|
-
# "'base_config' should be the same as the sole item in 'quantization_configurations'."
|
289
|
-
# ) # pragma: no cover
|
290
|
-
|
291
282
|
values['base_config'] = base_config
|
292
283
|
|
293
284
|
# When loading from JSON, lists are returned. If the value is a list, convert it to a tuple.
|
@@ -312,7 +303,7 @@ class QuantizationConfigOptions(BaseModel):
|
|
312
303
|
# Clone and update all configurations
|
313
304
|
updated_configs = tuple(cfg.clone_and_edit(**kwargs) for cfg in self.quantization_configurations)
|
314
305
|
|
315
|
-
return self.
|
306
|
+
return self.model_copy(update={
|
316
307
|
'base_config': updated_base_config,
|
317
308
|
'quantization_configurations': updated_configs
|
318
309
|
})
|
@@ -360,7 +351,7 @@ class QuantizationConfigOptions(BaseModel):
|
|
360
351
|
updated_cfg = qc.clone_and_edit(attr_weights_configs_mapping=updated_attr_mapping)
|
361
352
|
updated_configs.append(updated_cfg)
|
362
353
|
|
363
|
-
return self.
|
354
|
+
return self.model_copy(update={
|
364
355
|
'base_config': updated_base_config,
|
365
356
|
'quantization_configurations': tuple(updated_configs)
|
366
357
|
})
|
@@ -398,7 +389,7 @@ class QuantizationConfigOptions(BaseModel):
|
|
398
389
|
updated_cfg = qc.clone_and_edit(attr_weights_configs_mapping=new_attr_mapping)
|
399
390
|
updated_configs.append(updated_cfg)
|
400
391
|
|
401
|
-
return self.
|
392
|
+
return self.model_copy(update={
|
402
393
|
'base_config': new_base_config,
|
403
394
|
'quantization_configurations': tuple(updated_configs)
|
404
395
|
})
|
@@ -412,12 +403,12 @@ class QuantizationConfigOptions(BaseModel):
|
|
412
403
|
"""
|
413
404
|
return {f'option_{i}': cfg.get_info() for i, cfg in enumerate(self.quantization_configurations)}
|
414
405
|
|
406
|
+
|
415
407
|
class TargetPlatformModelComponent(BaseModel):
|
416
408
|
"""
|
417
409
|
Component of TargetPlatformCapabilities (Fusing, OperatorsSet, etc.).
|
418
410
|
"""
|
419
|
-
|
420
|
-
frozen = True
|
411
|
+
model_config = ConfigDict(frozen=True)
|
421
412
|
|
422
413
|
|
423
414
|
class OperatorsSetBase(TargetPlatformModelComponent):
|
@@ -444,8 +435,7 @@ class OperatorsSet(OperatorsSetBase):
|
|
444
435
|
# Define a private attribute _type
|
445
436
|
type: Literal["OperatorsSet"] = "OperatorsSet"
|
446
437
|
|
447
|
-
|
448
|
-
frozen = True
|
438
|
+
model_config = ConfigDict(frozen=True)
|
449
439
|
|
450
440
|
def get_info(self) -> Dict[str, Any]:
|
451
441
|
"""
|
@@ -471,10 +461,9 @@ class OperatorSetGroup(OperatorsSetBase):
|
|
471
461
|
# Define a private attribute _type
|
472
462
|
type: Literal["OperatorSetGroup"] = "OperatorSetGroup"
|
473
463
|
|
474
|
-
|
475
|
-
frozen = True
|
464
|
+
model_config = ConfigDict(frozen=True)
|
476
465
|
|
477
|
-
@
|
466
|
+
@model_validator(mode="before")
|
478
467
|
def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
479
468
|
"""
|
480
469
|
Validate the input and set the concatenated name based on the operators_set.
|
@@ -512,6 +501,7 @@ class OperatorSetGroup(OperatorsSetBase):
|
|
512
501
|
"operators_set": [op.get_info() for op in self.operators_set]
|
513
502
|
}
|
514
503
|
|
504
|
+
|
515
505
|
class Fusing(TargetPlatformModelComponent):
|
516
506
|
"""
|
517
507
|
Fusing defines a tuple of operators that should be combined and treated as a single operator,
|
@@ -525,10 +515,9 @@ class Fusing(TargetPlatformModelComponent):
|
|
525
515
|
operator_groups: Tuple[Annotated[Union[OperatorsSet, OperatorSetGroup], Field(discriminator='type')], ...]
|
526
516
|
name: Optional[str] = None # Will be set in the validator if not given.
|
527
517
|
|
528
|
-
|
529
|
-
frozen = True
|
518
|
+
model_config = ConfigDict(frozen=True)
|
530
519
|
|
531
|
-
@
|
520
|
+
@model_validator(mode="before")
|
532
521
|
def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
533
522
|
"""
|
534
523
|
Validate the operator_groups and set the name by concatenating operator group names.
|
@@ -555,24 +544,15 @@ class Fusing(TargetPlatformModelComponent):
|
|
555
544
|
|
556
545
|
return values
|
557
546
|
|
558
|
-
@
|
559
|
-
def validate_after_initialization(cls,
|
547
|
+
@model_validator(mode="after")
|
548
|
+
def validate_after_initialization(cls, model: 'Fusing') -> Any:
|
560
549
|
"""
|
561
550
|
Perform validation after the model has been instantiated.
|
562
|
-
|
563
|
-
Args:
|
564
|
-
values (Dict[str, Any]): The instantiated fusing.
|
565
|
-
|
566
|
-
Returns:
|
567
|
-
Dict[str, Any]: The validated values.
|
551
|
+
Ensures that there are at least two operator groups.
|
568
552
|
"""
|
569
|
-
|
570
|
-
|
571
|
-
# Validate that there are at least two operator groups
|
572
|
-
if len(operator_groups) < 2:
|
553
|
+
if len(model.operator_groups) < 2:
|
573
554
|
Logger.critical("Fusing cannot be created for a single operator.") # pragma: no cover
|
574
|
-
|
575
|
-
return values
|
555
|
+
return model
|
576
556
|
|
577
557
|
def contains(self, other: Any) -> bool:
|
578
558
|
"""
|
@@ -621,6 +601,7 @@ class Fusing(TargetPlatformModelComponent):
|
|
621
601
|
for x in self.operator_groups
|
622
602
|
])
|
623
603
|
|
604
|
+
|
624
605
|
class TargetPlatformCapabilities(BaseModel):
|
625
606
|
"""
|
626
607
|
Represents the hardware configuration used for quantized model inference.
|
@@ -638,38 +619,37 @@ class TargetPlatformCapabilities(BaseModel):
|
|
638
619
|
SCHEMA_VERSION (int): Version of the schema for the Target Platform Model.
|
639
620
|
"""
|
640
621
|
default_qco: QuantizationConfigOptions
|
641
|
-
operator_set: Optional[Tuple[OperatorsSet, ...]]
|
642
|
-
fusing_patterns: Optional[Tuple[Fusing, ...]]
|
643
|
-
tpc_minor_version: Optional[int]
|
644
|
-
tpc_patch_version: Optional[int]
|
645
|
-
tpc_platform_type: Optional[str]
|
622
|
+
operator_set: Optional[Tuple[OperatorsSet, ...]] = None
|
623
|
+
fusing_patterns: Optional[Tuple[Fusing, ...]] = None
|
624
|
+
tpc_minor_version: Optional[int] = None
|
625
|
+
tpc_patch_version: Optional[int] = None
|
626
|
+
tpc_platform_type: Optional[str] = None
|
646
627
|
add_metadata: bool = True
|
647
628
|
name: Optional[str] = "default_tpc"
|
648
629
|
is_simd_padding: bool = False
|
649
630
|
|
650
631
|
SCHEMA_VERSION: int = 1
|
651
632
|
|
652
|
-
|
653
|
-
frozen = True
|
633
|
+
model_config = ConfigDict(frozen=True)
|
654
634
|
|
655
|
-
@
|
656
|
-
def validate_after_initialization(cls,
|
635
|
+
@model_validator(mode="after")
|
636
|
+
def validate_after_initialization(cls, model: 'TargetPlatformCapabilities') -> Any:
|
657
637
|
"""
|
658
638
|
Perform validation after the model has been instantiated.
|
659
639
|
|
660
640
|
Args:
|
661
|
-
|
641
|
+
model (TargetPlatformCapabilities): The instantiated target platform model.
|
662
642
|
|
663
643
|
Returns:
|
664
|
-
|
644
|
+
TargetPlatformCapabilities: The validated model.
|
665
645
|
"""
|
666
646
|
# Validate `default_qco`
|
667
|
-
default_qco =
|
647
|
+
default_qco = model.default_qco
|
668
648
|
if len(default_qco.quantization_configurations) != 1:
|
669
649
|
Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover
|
670
650
|
|
671
651
|
# Validate `operator_set` uniqueness
|
672
|
-
operator_set =
|
652
|
+
operator_set = model.operator_set
|
673
653
|
if operator_set is not None:
|
674
654
|
opsets_names = [
|
675
655
|
op.name.value if isinstance(op.name, OperatorSetNames) else op.name
|
@@ -678,7 +658,7 @@ class TargetPlatformCapabilities(BaseModel):
|
|
678
658
|
if len(set(opsets_names)) != len(opsets_names):
|
679
659
|
Logger.critical("Operator Sets must have unique names.") # pragma: no cover
|
680
660
|
|
681
|
-
return
|
661
|
+
return model
|
682
662
|
|
683
663
|
def get_info(self) -> Dict[str, Any]:
|
684
664
|
"""
|
@@ -16,7 +16,7 @@ import pprint
|
|
16
16
|
from enum import Enum
|
17
17
|
from typing import Dict, Any, Tuple, Optional
|
18
18
|
|
19
|
-
from pydantic import BaseModel, root_validator
|
19
|
+
from pydantic import BaseModel, root_validator, model_validator, ConfigDict
|
20
20
|
|
21
21
|
from mct_quantizers import QuantizationMethod
|
22
22
|
from model_compression_toolkit.constants import FLOAT_BITWIDTH
|
@@ -30,8 +30,72 @@ from model_compression_toolkit.target_platform_capabilities.schema.v1 import (
|
|
30
30
|
OperatorsSetBase,
|
31
31
|
OperatorsSet,
|
32
32
|
OperatorSetGroup,
|
33
|
-
Fusing
|
34
|
-
|
33
|
+
Fusing)
|
34
|
+
|
35
|
+
|
36
|
+
class OperatorSetNames(str, Enum):
|
37
|
+
CONV = "Conv"
|
38
|
+
DEPTHWISE_CONV = "DepthwiseConv2D"
|
39
|
+
CONV_TRANSPOSE = "ConvTranspose"
|
40
|
+
FULLY_CONNECTED = "FullyConnected"
|
41
|
+
CONCATENATE = "Concatenate"
|
42
|
+
STACK = "Stack"
|
43
|
+
UNSTACK = "Unstack"
|
44
|
+
GATHER = "Gather"
|
45
|
+
EXPAND = "Expend"
|
46
|
+
BATCH_NORM = "BatchNorm"
|
47
|
+
L2NORM = "L2Norm"
|
48
|
+
RELU = "ReLU"
|
49
|
+
RELU6 = "ReLU6"
|
50
|
+
LEAKY_RELU = "LeakyReLU"
|
51
|
+
ELU = "Elu"
|
52
|
+
HARD_TANH = "HardTanh"
|
53
|
+
ADD = "Add"
|
54
|
+
SUB = "Sub"
|
55
|
+
MUL = "Mul"
|
56
|
+
DIV = "Div"
|
57
|
+
MIN = "Min"
|
58
|
+
MAX = "Max"
|
59
|
+
PRELU = "PReLU"
|
60
|
+
ADD_BIAS = "AddBias"
|
61
|
+
SWISH = "Swish"
|
62
|
+
SIGMOID = "Sigmoid"
|
63
|
+
SOFTMAX = "Softmax"
|
64
|
+
LOG_SOFTMAX = "LogSoftmax"
|
65
|
+
TANH = "Tanh"
|
66
|
+
GELU = "Gelu"
|
67
|
+
HARDSIGMOID = "HardSigmoid"
|
68
|
+
HARDSWISH = "HardSwish"
|
69
|
+
FLATTEN = "Flatten"
|
70
|
+
GET_ITEM = "GetItem"
|
71
|
+
RESHAPE = "Reshape"
|
72
|
+
UNSQUEEZE = "Unsqueeze"
|
73
|
+
SQUEEZE = "Squeeze"
|
74
|
+
PERMUTE = "Permute"
|
75
|
+
TRANSPOSE = "Transpose"
|
76
|
+
DROPOUT = "Dropout"
|
77
|
+
SPLIT_CHUNK = "SplitChunk"
|
78
|
+
MAXPOOL = "MaxPool"
|
79
|
+
AVGPOOL = "AvgPool"
|
80
|
+
SIZE = "Size"
|
81
|
+
SHAPE = "Shape"
|
82
|
+
EQUAL = "Equal"
|
83
|
+
ARGMAX = "ArgMax"
|
84
|
+
TOPK = "TopK"
|
85
|
+
FAKE_QUANT = "FakeQuant"
|
86
|
+
COMBINED_NON_MAX_SUPPRESSION = "CombinedNonMaxSuppression"
|
87
|
+
ZERO_PADDING2D = "ZeroPadding2D"
|
88
|
+
CAST = "Cast"
|
89
|
+
RESIZE = "Resize"
|
90
|
+
PAD = "Pad"
|
91
|
+
FOLD = "Fold"
|
92
|
+
STRIDED_SLICE = "StridedSlice"
|
93
|
+
SSD_POST_PROCESS = "SSDPostProcess"
|
94
|
+
EXP = "Exp"
|
95
|
+
|
96
|
+
@classmethod
|
97
|
+
def get_values(cls):
|
98
|
+
return [v.value for v in cls]
|
35
99
|
|
36
100
|
|
37
101
|
class TargetPlatformCapabilities(BaseModel):
|
@@ -62,27 +126,26 @@ class TargetPlatformCapabilities(BaseModel):
|
|
62
126
|
|
63
127
|
SCHEMA_VERSION: int = 2
|
64
128
|
|
65
|
-
|
66
|
-
frozen = True
|
129
|
+
model_config = ConfigDict(frozen=True)
|
67
130
|
|
68
|
-
@
|
69
|
-
def validate_after_initialization(cls,
|
131
|
+
@model_validator(mode="after")
|
132
|
+
def validate_after_initialization(cls, model: 'TargetPlatformCapabilities') -> Any:
|
70
133
|
"""
|
71
134
|
Perform validation after the model has been instantiated.
|
72
135
|
|
73
136
|
Args:
|
74
|
-
|
137
|
+
model (TargetPlatformCapabilities): The instantiated target platform model.
|
75
138
|
|
76
139
|
Returns:
|
77
|
-
|
140
|
+
TargetPlatformCapabilities: The validated model.
|
78
141
|
"""
|
79
142
|
# Validate `default_qco`
|
80
|
-
default_qco =
|
143
|
+
default_qco = model.default_qco
|
81
144
|
if len(default_qco.quantization_configurations) != 1:
|
82
145
|
Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover
|
83
146
|
|
84
147
|
# Validate `operator_set` uniqueness
|
85
|
-
operator_set =
|
148
|
+
operator_set = model.operator_set
|
86
149
|
if operator_set is not None:
|
87
150
|
opsets_names = [
|
88
151
|
op.name.value if isinstance(op.name, OperatorSetNames) else op.name
|
@@ -91,7 +154,7 @@ class TargetPlatformCapabilities(BaseModel):
|
|
91
154
|
if len(set(opsets_names)) != len(opsets_names):
|
92
155
|
Logger.critical("Operator Sets must have unique names.") # pragma: no cover
|
93
156
|
|
94
|
-
return
|
157
|
+
return model
|
95
158
|
|
96
159
|
def get_info(self) -> Dict[str, Any]:
|
97
160
|
"""
|
@@ -111,4 +174,3 @@ class TargetPlatformCapabilities(BaseModel):
|
|
111
174
|
Display the TargetPlatformCapabilities.
|
112
175
|
"""
|
113
176
|
pprint.pprint(self.get_info(), sort_dicts=False)
|
114
|
-
|
@@ -100,6 +100,6 @@ def export_target_platform_capabilities(model: schema.TargetPlatformCapabilities
|
|
100
100
|
|
101
101
|
# Export the model to JSON and write to the file
|
102
102
|
with path.open('w', encoding='utf-8') as file:
|
103
|
-
file.write(model.
|
103
|
+
file.write(model.model_dump_json(indent=4))
|
104
104
|
except OSError as e:
|
105
105
|
raise OSError(f"Failed to write to file '{export_path}': {e.strerror}") from e
|
File without changes
|
File without changes
|
{mct_nightly-2.3.0.20250414.604.dist-info → mct_nightly-2.3.0.20250416.541.dist-info}/top_level.txt
RENAMED
File without changes
|