mct-nightly 2.3.0.20250414.604__py3-none-any.whl → 2.3.0.20250415.557__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mct-nightly
3
- Version: 2.3.0.20250414.604
3
+ Version: 2.3.0.20250415.557
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: License :: OSI Approved :: Apache Software License
@@ -21,7 +21,7 @@ Requires-Dist: matplotlib<3.10.0
21
21
  Requires-Dist: scipy
22
22
  Requires-Dist: protobuf
23
23
  Requires-Dist: mct-quantizers-nightly
24
- Requires-Dist: pydantic<2.0
24
+ Requires-Dist: pydantic>=2.0
25
25
  Requires-Dist: sony-custom-layers-dev==0.4.0.dev6
26
26
  Dynamic: classifier
27
27
  Dynamic: description
@@ -63,7 +63,7 @@ ________________________________________________________________________________
63
63
 
64
64
  ## <div align="center">Getting Started</div>
65
65
  ### Quick Installation
66
- Pip install the model compression toolkit package in a Python>=3.9 environment with PyTorch>=2.1 or Tensorflow>=2.12.
66
+ Pip install the model compression toolkit package in a Python>=3.9 environment with PyTorch>=2.1 or Tensorflow>=2.14.
67
67
  ```
68
68
  pip install model-compression-toolkit
69
69
  ```
@@ -170,11 +170,11 @@ Currently, MCT is being tested on various Python, Pytorch and TensorFlow version
170
170
  | Python 3.11 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch22.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch22.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch23.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch24.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch24.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch25.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_pytorch25.yml) |
171
171
  | Python 3.12 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch22.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch22.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch23.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch24.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch24.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch25.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch25.yml) |
172
172
 
173
- | | TensorFlow 2.12 | TensorFlow 2.13 | TensorFlow 2.14 | TensorFlow 2.15 |
173
+ | | TensorFlow 2.14 | TensorFlow 2.15 |
174
174
  |-------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
175
- | Python 3.9 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras215.yml) |
176
- | Python 3.10 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras215.yml) |
177
- | Python 3.11 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras212.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras212.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras213.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras213.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras215.yml) |
175
+ | Python 3.9 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras215.yml) |
176
+ | Python 3.10 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras215.yml) |
177
+ | Python 3.11 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras215.yml) |
178
178
 
179
179
  </details>
180
180
 
@@ -1,5 +1,5 @@
1
- mct_nightly-2.3.0.20250414.604.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
- model_compression_toolkit/__init__.py,sha256=JNLJGpeFskBs0w68Qp-Q-USHO-IvOXspcQIaPue9mZM,1557
1
+ mct_nightly-2.3.0.20250415.557.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
+ model_compression_toolkit/__init__.py,sha256=0HvKA_YR0Kpy00hpXHDhBtb6kS5hRq64d6L3f-cjcQM,1557
3
3
  model_compression_toolkit/constants.py,sha256=2ltuH-gdaLZoZV4CPUgKjC3S9ojz2z4OTVdenyVEypU,3912
4
4
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
5
5
  model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
@@ -354,7 +354,7 @@ model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantiz
354
354
  model_compression_toolkit/gptq/__init__.py,sha256=pEgkJvmf05KSw70iLDTz_6LI_2Oi5L8sTN0JsEUpnpk,1445
355
355
  model_compression_toolkit/gptq/runner.py,sha256=La12JTYjWyJW0YW4Al4TP1_Xi4JWBCEKw6FR_JQsxe0,5982
356
356
  model_compression_toolkit/gptq/common/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
357
- model_compression_toolkit/gptq/common/gptq_config.py,sha256=QwSEZZlC6OpnpoBQoAFfgXTrdBgewgqlgaCV2hoJEso,6143
357
+ model_compression_toolkit/gptq/common/gptq_config.py,sha256=xVzjy3CyR07rpGvUy2jsSaijXq7-0KStpU_yVu7VLVA,6144
358
358
  model_compression_toolkit/gptq/common/gptq_constants.py,sha256=8HB0yiX75zZ1IKgQUPWpFCM5sS8HAqslws5XrOhxJQ0,750
359
359
  model_compression_toolkit/gptq/common/gptq_framework_implementation.py,sha256=n3mSf4J92kFjekzyGyrJULylI-8Jf5OVWJ5AFoVnEx0,1266
360
360
  model_compression_toolkit/gptq/common/gptq_graph.py,sha256=-bL5HhPcKqV8nj4dZPXc5QmQJbFBel6etrioikP0tEo,3039
@@ -433,13 +433,13 @@ model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py,sha2
433
433
  model_compression_toolkit/target_platform_capabilities/__init__.py,sha256=8RVOriZg-XNjSt53h_4Yum0oRgOe2gp5H45dfG_lZxE,1415
434
434
  model_compression_toolkit/target_platform_capabilities/constants.py,sha256=JRz9DoxLRpkqvu532TFkIvv0595Bfb9NtU4pRp4urDY,1540
435
435
  model_compression_toolkit/target_platform_capabilities/immutable.py,sha256=YhROBiXEIB3TU-bAFrnL3qbAsb1yuWPBAQ_CLOJbYUU,1827
436
- model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py,sha256=-SGcHXGsNkQyasxsK7f4e05doHumJgyeHDJOGLdamE8,4615
436
+ model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py,sha256=nbmlygR-nc3bzwnUDrRamq3a6KFkC4-cCpbUeF7EEmo,4626
437
437
  model_compression_toolkit/target_platform_capabilities/schema/__init__.py,sha256=pKAdbTCFM_2BrZXUtTIw0ouKotrWwUDF_hP3rPwCM2k,696
438
438
  model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py,sha256=PvO8eHxnb3A55gyExT5fZGnOUl3ce7BbbT5SPxCEXNo,541
439
439
  model_compression_toolkit/target_platform_capabilities/schema/schema_compatability.py,sha256=TtMPbiibV6Hk53nl5Y_ctfpI6mSbd8VVH9fxnv5j9eM,4430
440
440
  model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py,sha256=vBkXxVJagm9JKB9cdm4Pvi7u_luriXUjvNn0-m8Zr0k,4653
441
- model_compression_toolkit/target_platform_capabilities/schema/v1.py,sha256=4CGpWENuOyjwaIMaGrFI0Act7jsSeT7m94pjrv91dxE,27516
442
- model_compression_toolkit/target_platform_capabilities/schema/v2.py,sha256=cDNtJW1bVemStg07sDyVqmhHTUiQx66diq3-fqGPNek,4557
441
+ model_compression_toolkit/target_platform_capabilities/schema/v1.py,sha256=oWKNQnnz04kmijmdWtRyXgVXbJ6BG_V_bUBz_MfUM94,27116
442
+ model_compression_toolkit/target_platform_capabilities/schema/v2.py,sha256=0lYfOs6Dcoi8pLciCPXgiJDeo9LHP0quccWKe7-ZR2Y,4571
443
443
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/__init__.py,sha256=XjNws3zoiJkeH4ixKqrLA5xBvpv5rq31qX7wYQjNpZM,1447
444
444
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2fw.py,sha256=HJ8uc3PFfyxg-WpVXPBg4mGaox8Z9bRqtQNbRfIyAk4,3745
445
445
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py,sha256=Ehwpd_sL6zxmJFpJugOdN9uNxNX05nijvOCilNfHnFs,7162
@@ -528,7 +528,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
528
528
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
529
529
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
530
530
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
531
- mct_nightly-2.3.0.20250414.604.dist-info/METADATA,sha256=GJ0DyWZGjeT5LJ0HqXJRTYaCQcOfibsKmT9ruboJTeU,27148
532
- mct_nightly-2.3.0.20250414.604.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
533
- mct_nightly-2.3.0.20250414.604.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
534
- mct_nightly-2.3.0.20250414.604.dist-info/RECORD,,
531
+ mct_nightly-2.3.0.20250415.557.dist-info/METADATA,sha256=xAD7AYKh0CWSUycRyXvOiV0F5B8CkdsFa1S5Tq0k63I,25413
532
+ mct_nightly-2.3.0.20250415.557.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
533
+ mct_nightly-2.3.0.20250415.557.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
534
+ mct_nightly-2.3.0.20250415.557.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.3.0.20250414.000604"
30
+ __version__ = "2.3.0.20250415.000557"
@@ -84,7 +84,7 @@ class QFractionLinearAnnealingConfig:
84
84
  raise ValueError(f'Expected start_step >= 0. received {self.start_step}.')
85
85
  if self.end_step is not None and self.end_step <= self.start_step:
86
86
  raise ValueError('Expected start_step < end_step, '
87
- 'received end_step {self.end_step} and start_step {self.start_stap}.')
87
+ f'received end_step {self.end_step} and start_step {self.start_step}.')
88
88
 
89
89
 
90
90
  @dataclass
@@ -16,7 +16,8 @@ import pprint
16
16
  from enum import Enum
17
17
  from typing import Dict, Any, Union, Tuple, List, Optional, Literal, Annotated
18
18
 
19
- from pydantic import BaseModel, Field, root_validator, validator, PositiveInt
19
+ from pydantic import BaseModel, Field, root_validator, validator, PositiveInt, ConfigDict, field_validator, \
20
+ model_validator
20
21
 
21
22
  from mct_quantizers import QuantizationMethod
22
23
  from model_compression_toolkit.constants import FLOAT_BITWIDTH
@@ -118,9 +119,7 @@ class AttributeQuantizationConfig(BaseModel):
118
119
  enable_weights_quantization: bool = False
119
120
  lut_values_bitwidth: Optional[int] = None
120
121
 
121
- class Config:
122
- # Makes the model immutable (frozen)
123
- frozen = True
122
+ model_config = ConfigDict(frozen=True)
124
123
 
125
124
  @property
126
125
  def field_names(self) -> list:
@@ -137,7 +136,7 @@ class AttributeQuantizationConfig(BaseModel):
137
136
  Returns:
138
137
  AttributeQuantizationConfig: A new instance of AttributeQuantizationConfig with updated attributes.
139
138
  """
140
- return self.copy(update=kwargs)
139
+ return self.model_copy(update=kwargs)
141
140
 
142
141
 
143
142
  class OpQuantizationConfig(BaseModel):
@@ -164,15 +163,14 @@ class OpQuantizationConfig(BaseModel):
164
163
  supported_input_activation_n_bits: Union[int, Tuple[int, ...]]
165
164
  enable_activation_quantization: bool
166
165
  quantization_preserving: bool
167
- fixed_scale: Optional[float]
168
- fixed_zero_point: Optional[int]
169
- simd_size: Optional[int]
170
166
  signedness: Signedness
167
+ fixed_scale: Optional[float] = None
168
+ fixed_zero_point: Optional[int] = None
169
+ simd_size: Optional[int] = None
171
170
 
172
- class Config:
173
- frozen = True
171
+ model_config = ConfigDict(frozen=True)
174
172
 
175
- @validator('supported_input_activation_n_bits', pre=True, allow_reuse=True)
173
+ @field_validator('supported_input_activation_n_bits', mode='before')
176
174
  def validate_supported_input_activation_n_bits(cls, v):
177
175
  """
178
176
  Validate and process the supported_input_activation_n_bits field.
@@ -199,9 +197,9 @@ class OpQuantizationConfig(BaseModel):
199
197
  return self.dict() # pragma: no cover
200
198
 
201
199
  def clone_and_edit(
202
- self,
203
- attr_to_edit: Dict[str, Dict[str, Any]] = {},
204
- **kwargs: Any
200
+ self,
201
+ attr_to_edit: Dict[str, Dict[str, Any]] = {},
202
+ **kwargs: Any
205
203
  ) -> 'OpQuantizationConfig':
206
204
  """
207
205
  Clone the quantization config and edit some of its attributes.
@@ -215,17 +213,17 @@ class OpQuantizationConfig(BaseModel):
215
213
  OpQuantizationConfig: Edited quantization configuration.
216
214
  """
217
215
  # Clone and update top-level attributes
218
- updated_config = self.copy(update=kwargs)
216
+ updated_config = self.model_copy(update=kwargs)
219
217
 
220
218
  # Clone and update nested immutable dataclasses in `attr_weights_configs_mapping`
221
219
  updated_attr_mapping = {
222
220
  attr_name: (attr_cfg.clone_and_edit(**attr_to_edit[attr_name])
223
- if attr_name in attr_to_edit else attr_cfg)
221
+ if attr_name in attr_to_edit else attr_cfg)
224
222
  for attr_name, attr_cfg in updated_config.attr_weights_configs_mapping.items()
225
223
  }
226
224
 
227
225
  # Return a new instance with the updated attribute mapping
228
- return updated_config.copy(update={'attr_weights_configs_mapping': updated_attr_mapping})
226
+ return updated_config.model_copy(update={'attr_weights_configs_mapping': updated_attr_mapping})
229
227
 
230
228
 
231
229
  class QuantizationConfigOptions(BaseModel):
@@ -239,10 +237,9 @@ class QuantizationConfigOptions(BaseModel):
239
237
  quantization_configurations: Tuple[OpQuantizationConfig, ...]
240
238
  base_config: Optional[OpQuantizationConfig] = None
241
239
 
242
- class Config:
243
- frozen = True
240
+ model_config = ConfigDict(frozen=True)
244
241
 
245
- @root_validator(pre=True, allow_reuse=True)
242
+ @model_validator(mode="before")
246
243
  def validate_and_set_base_config(cls, values: Dict[str, Any]) -> Dict[str, Any]:
247
244
  """
248
245
  Validate and set the base_config based on quantization_configurations.
@@ -282,12 +279,6 @@ class QuantizationConfigOptions(BaseModel):
282
279
  "'base_config' must be included in the quantization config options."
283
280
  ) # pragma: no cover
284
281
 
285
- # if num_configs == 1:
286
- # if base_config != quantization_configurations[0]:
287
- # Logger.critical(
288
- # "'base_config' should be the same as the sole item in 'quantization_configurations'."
289
- # ) # pragma: no cover
290
-
291
282
  values['base_config'] = base_config
292
283
 
293
284
  # When loading from JSON, lists are returned. If the value is a list, convert it to a tuple.
@@ -312,7 +303,7 @@ class QuantizationConfigOptions(BaseModel):
312
303
  # Clone and update all configurations
313
304
  updated_configs = tuple(cfg.clone_and_edit(**kwargs) for cfg in self.quantization_configurations)
314
305
 
315
- return self.copy(update={
306
+ return self.model_copy(update={
316
307
  'base_config': updated_base_config,
317
308
  'quantization_configurations': updated_configs
318
309
  })
@@ -360,7 +351,7 @@ class QuantizationConfigOptions(BaseModel):
360
351
  updated_cfg = qc.clone_and_edit(attr_weights_configs_mapping=updated_attr_mapping)
361
352
  updated_configs.append(updated_cfg)
362
353
 
363
- return self.copy(update={
354
+ return self.model_copy(update={
364
355
  'base_config': updated_base_config,
365
356
  'quantization_configurations': tuple(updated_configs)
366
357
  })
@@ -398,7 +389,7 @@ class QuantizationConfigOptions(BaseModel):
398
389
  updated_cfg = qc.clone_and_edit(attr_weights_configs_mapping=new_attr_mapping)
399
390
  updated_configs.append(updated_cfg)
400
391
 
401
- return self.copy(update={
392
+ return self.model_copy(update={
402
393
  'base_config': new_base_config,
403
394
  'quantization_configurations': tuple(updated_configs)
404
395
  })
@@ -412,12 +403,12 @@ class QuantizationConfigOptions(BaseModel):
412
403
  """
413
404
  return {f'option_{i}': cfg.get_info() for i, cfg in enumerate(self.quantization_configurations)}
414
405
 
406
+
415
407
  class TargetPlatformModelComponent(BaseModel):
416
408
  """
417
409
  Component of TargetPlatformCapabilities (Fusing, OperatorsSet, etc.).
418
410
  """
419
- class Config:
420
- frozen = True
411
+ model_config = ConfigDict(frozen=True)
421
412
 
422
413
 
423
414
  class OperatorsSetBase(TargetPlatformModelComponent):
@@ -444,8 +435,7 @@ class OperatorsSet(OperatorsSetBase):
444
435
  # Define a private attribute _type
445
436
  type: Literal["OperatorsSet"] = "OperatorsSet"
446
437
 
447
- class Config:
448
- frozen = True
438
+ model_config = ConfigDict(frozen=True)
449
439
 
450
440
  def get_info(self) -> Dict[str, Any]:
451
441
  """
@@ -471,10 +461,9 @@ class OperatorSetGroup(OperatorsSetBase):
471
461
  # Define a private attribute _type
472
462
  type: Literal["OperatorSetGroup"] = "OperatorSetGroup"
473
463
 
474
- class Config:
475
- frozen = True
464
+ model_config = ConfigDict(frozen=True)
476
465
 
477
- @root_validator(pre=True, allow_reuse=True)
466
+ @model_validator(mode="before")
478
467
  def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
479
468
  """
480
469
  Validate the input and set the concatenated name based on the operators_set.
@@ -512,6 +501,7 @@ class OperatorSetGroup(OperatorsSetBase):
512
501
  "operators_set": [op.get_info() for op in self.operators_set]
513
502
  }
514
503
 
504
+
515
505
  class Fusing(TargetPlatformModelComponent):
516
506
  """
517
507
  Fusing defines a tuple of operators that should be combined and treated as a single operator,
@@ -525,10 +515,9 @@ class Fusing(TargetPlatformModelComponent):
525
515
  operator_groups: Tuple[Annotated[Union[OperatorsSet, OperatorSetGroup], Field(discriminator='type')], ...]
526
516
  name: Optional[str] = None # Will be set in the validator if not given.
527
517
 
528
- class Config:
529
- frozen = True
518
+ model_config = ConfigDict(frozen=True)
530
519
 
531
- @root_validator(pre=True, allow_reuse=True)
520
+ @model_validator(mode="before")
532
521
  def validate_and_set_name(cls, values: Dict[str, Any]) -> Dict[str, Any]:
533
522
  """
534
523
  Validate the operator_groups and set the name by concatenating operator group names.
@@ -555,24 +544,15 @@ class Fusing(TargetPlatformModelComponent):
555
544
 
556
545
  return values
557
546
 
558
- @root_validator(allow_reuse=True)
559
- def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]:
547
+ @model_validator(mode="after")
548
+ def validate_after_initialization(cls, model: 'Fusing') -> Any:
560
549
  """
561
550
  Perform validation after the model has been instantiated.
562
-
563
- Args:
564
- values (Dict[str, Any]): The instantiated fusing.
565
-
566
- Returns:
567
- Dict[str, Any]: The validated values.
551
+ Ensures that there are at least two operator groups.
568
552
  """
569
- operator_groups = values.get('operator_groups')
570
-
571
- # Validate that there are at least two operator groups
572
- if len(operator_groups) < 2:
553
+ if len(model.operator_groups) < 2:
573
554
  Logger.critical("Fusing cannot be created for a single operator.") # pragma: no cover
574
-
575
- return values
555
+ return model
576
556
 
577
557
  def contains(self, other: Any) -> bool:
578
558
  """
@@ -621,6 +601,7 @@ class Fusing(TargetPlatformModelComponent):
621
601
  for x in self.operator_groups
622
602
  ])
623
603
 
604
+
624
605
  class TargetPlatformCapabilities(BaseModel):
625
606
  """
626
607
  Represents the hardware configuration used for quantized model inference.
@@ -638,38 +619,37 @@ class TargetPlatformCapabilities(BaseModel):
638
619
  SCHEMA_VERSION (int): Version of the schema for the Target Platform Model.
639
620
  """
640
621
  default_qco: QuantizationConfigOptions
641
- operator_set: Optional[Tuple[OperatorsSet, ...]]
642
- fusing_patterns: Optional[Tuple[Fusing, ...]]
643
- tpc_minor_version: Optional[int]
644
- tpc_patch_version: Optional[int]
645
- tpc_platform_type: Optional[str]
622
+ operator_set: Optional[Tuple[OperatorsSet, ...]] = None
623
+ fusing_patterns: Optional[Tuple[Fusing, ...]] = None
624
+ tpc_minor_version: Optional[int] = None
625
+ tpc_patch_version: Optional[int] = None
626
+ tpc_platform_type: Optional[str] = None
646
627
  add_metadata: bool = True
647
628
  name: Optional[str] = "default_tpc"
648
629
  is_simd_padding: bool = False
649
630
 
650
631
  SCHEMA_VERSION: int = 1
651
632
 
652
- class Config:
653
- frozen = True
633
+ model_config = ConfigDict(frozen=True)
654
634
 
655
- @root_validator(allow_reuse=True)
656
- def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]:
635
+ @model_validator(mode="after")
636
+ def validate_after_initialization(cls, model: 'TargetPlatformCapabilities') -> Any:
657
637
  """
658
638
  Perform validation after the model has been instantiated.
659
639
 
660
640
  Args:
661
- values (Dict[str, Any]): The instantiated target platform model.
641
+ model (TargetPlatformCapabilities): The instantiated target platform model.
662
642
 
663
643
  Returns:
664
- Dict[str, Any]: The validated values.
644
+ TargetPlatformCapabilities: The validated model.
665
645
  """
666
646
  # Validate `default_qco`
667
- default_qco = values.get('default_qco')
647
+ default_qco = model.default_qco
668
648
  if len(default_qco.quantization_configurations) != 1:
669
649
  Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover
670
650
 
671
651
  # Validate `operator_set` uniqueness
672
- operator_set = values.get('operator_set')
652
+ operator_set = model.operator_set
673
653
  if operator_set is not None:
674
654
  opsets_names = [
675
655
  op.name.value if isinstance(op.name, OperatorSetNames) else op.name
@@ -678,7 +658,7 @@ class TargetPlatformCapabilities(BaseModel):
678
658
  if len(set(opsets_names)) != len(opsets_names):
679
659
  Logger.critical("Operator Sets must have unique names.") # pragma: no cover
680
660
 
681
- return values
661
+ return model
682
662
 
683
663
  def get_info(self) -> Dict[str, Any]:
684
664
  """
@@ -13,10 +13,9 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  import pprint
16
- from enum import Enum
17
16
  from typing import Dict, Any, Tuple, Optional
18
17
 
19
- from pydantic import BaseModel, root_validator
18
+ from pydantic import BaseModel, root_validator, model_validator, ConfigDict
20
19
 
21
20
  from mct_quantizers import QuantizationMethod
22
21
  from model_compression_toolkit.constants import FLOAT_BITWIDTH
@@ -62,27 +61,26 @@ class TargetPlatformCapabilities(BaseModel):
62
61
 
63
62
  SCHEMA_VERSION: int = 2
64
63
 
65
- class Config:
66
- frozen = True
64
+ model_config = ConfigDict(frozen=True)
67
65
 
68
- @root_validator(allow_reuse=True)
69
- def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]:
66
+ @model_validator(mode="after")
67
+ def validate_after_initialization(cls, model: 'TargetPlatformCapabilities') -> Any:
70
68
  """
71
69
  Perform validation after the model has been instantiated.
72
70
 
73
71
  Args:
74
- values (Dict[str, Any]): The instantiated target platform model.
72
+ model (TargetPlatformCapabilities): The instantiated target platform model.
75
73
 
76
74
  Returns:
77
- Dict[str, Any]: The validated values.
75
+ TargetPlatformCapabilities: The validated model.
78
76
  """
79
77
  # Validate `default_qco`
80
- default_qco = values.get('default_qco')
78
+ default_qco = model.default_qco
81
79
  if len(default_qco.quantization_configurations) != 1:
82
80
  Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover
83
81
 
84
82
  # Validate `operator_set` uniqueness
85
- operator_set = values.get('operator_set')
83
+ operator_set = model.operator_set
86
84
  if operator_set is not None:
87
85
  opsets_names = [
88
86
  op.name.value if isinstance(op.name, OperatorSetNames) else op.name
@@ -91,7 +89,7 @@ class TargetPlatformCapabilities(BaseModel):
91
89
  if len(set(opsets_names)) != len(opsets_names):
92
90
  Logger.critical("Operator Sets must have unique names.") # pragma: no cover
93
91
 
94
- return values
92
+ return model
95
93
 
96
94
  def get_info(self) -> Dict[str, Any]:
97
95
  """
@@ -100,6 +100,6 @@ def export_target_platform_capabilities(model: schema.TargetPlatformCapabilities
100
100
 
101
101
  # Export the model to JSON and write to the file
102
102
  with path.open('w', encoding='utf-8') as file:
103
- file.write(model.json(indent=4))
103
+ file.write(model.model_dump_json(indent=4))
104
104
  except OSError as e:
105
105
  raise OSError(f"Failed to write to file '{export_path}': {e.strerror}") from e