mct-nightly 2.0.0.20240521.151450__py3-none-any.whl → 2.0.0.20240522.172031__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.0.0.20240521.151450.dist-info → mct_nightly-2.0.0.20240522.172031.dist-info}/METADATA +1 -1
- {mct_nightly-2.0.0.20240521.151450.dist-info → mct_nightly-2.0.0.20240522.172031.dist-info}/RECORD +12 -12
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/framework_implementation.py +6 -4
- model_compression_toolkit/core/common/graph/base_node.py +6 -2
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +10 -5
- model_compression_toolkit/core/keras/keras_implementation.py +6 -4
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -2
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +6 -4
- {mct_nightly-2.0.0.20240521.151450.dist-info → mct_nightly-2.0.0.20240522.172031.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.0.0.20240521.151450.dist-info → mct_nightly-2.0.0.20240522.172031.dist-info}/WHEEL +0 -0
- {mct_nightly-2.0.0.20240521.151450.dist-info → mct_nightly-2.0.0.20240522.172031.dist-info}/top_level.txt +0 -0
{mct_nightly-2.0.0.20240521.151450.dist-info → mct_nightly-2.0.0.20240522.172031.dist-info}/RECORD
RENAMED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
|
1
|
+
model_compression_toolkit/__init__.py,sha256=yildqQE7rOBK1de3gV4i80U90kQLviv92xEz2-zY2XI,1573
|
|
2
2
|
model_compression_toolkit/constants.py,sha256=b63Jk_bC7VXEX3Qn9TZ3wUvrNKD8Mkz8zIuayoyF5eU,3828
|
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
|
4
4
|
model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
|
|
@@ -10,7 +10,7 @@ model_compression_toolkit/core/quantization_prep_runner.py,sha256=0ga95vh_ZXO79r
|
|
|
10
10
|
model_compression_toolkit/core/runner.py,sha256=yref5I8eUo2A4hAmc4bOQOj6lUZRDQjLQR_5lJCjXiQ,12696
|
|
11
11
|
model_compression_toolkit/core/common/__init__.py,sha256=Wh127PbXcETZX_d1PQqZ71ETK3J9XO5A-HpadGUbj6o,1447
|
|
12
12
|
model_compression_toolkit/core/common/base_substitutions.py,sha256=xDFSmVVs_iFSZfajytI0cuQaNRNcwHX3uqOoHgVUvxQ,1666
|
|
13
|
-
model_compression_toolkit/core/common/framework_implementation.py,sha256=
|
|
13
|
+
model_compression_toolkit/core/common/framework_implementation.py,sha256=8b6M1GcUR9bDgoxwqyNP8C6KSU9OTQ5hIk20Y74eLPo,20896
|
|
14
14
|
model_compression_toolkit/core/common/framework_info.py,sha256=1ZMMGS9ip-kSflqkartyNRt9aQ5ub1WepuTRcTy-YSQ,6337
|
|
15
15
|
model_compression_toolkit/core/common/memory_computation.py,sha256=ixoSpV5ZYZGyzhre3kQcvR2sNA8KBsPZ3lgbkDnw9Cs,1205
|
|
16
16
|
model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3ug7Y9-lLyV99_FoNHxkGZMgcm0Vkpss,1324
|
|
@@ -31,7 +31,7 @@ model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5
|
|
|
31
31
|
model_compression_toolkit/core/common/fusion/layer_fusing.py,sha256=lOubqpc18TslhXZijWUJQAa1c3jIB2S-M-5HK78wJPQ,5548
|
|
32
32
|
model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
|
|
33
33
|
model_compression_toolkit/core/common/graph/base_graph.py,sha256=lmIw0srKiwCvz7KWqfwKTxyQHDy3s6rWMIXzFAa1UMo,38326
|
|
34
|
-
model_compression_toolkit/core/common/graph/base_node.py,sha256=
|
|
34
|
+
model_compression_toolkit/core/common/graph/base_node.py,sha256=exvUkLDChl6YaoaQRHgSrettsgOsd18bfq01tPxXr-4,29722
|
|
35
35
|
model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
|
|
36
36
|
model_compression_toolkit/core/common/graph/functional_node.py,sha256=71_4TrCdqR_r0mtgxmAyqI05iP5YoQQGeSmDgynuzTw,3902
|
|
37
37
|
model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
|
|
@@ -64,7 +64,7 @@ model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha2
|
|
|
64
64
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=DP5tcxPtiVbSWAeoFbEp7iTwpxDBU1g7V5w7ehDG6jI,4573
|
|
65
65
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=JmHopRNpHjxnoyeqXRVO0t-DdqEOm-jOZI06w5aAl9k,7550
|
|
66
66
|
model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=TTTux4YiOnQqt-2h7Y38959XaDwNZc0eufLMx_yws5U,37578
|
|
67
|
-
model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=
|
|
67
|
+
model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=DKaxU9MD97J0yYJOCkhtQUrJLD_xrp0TK7mtcZEp1oA,28940
|
|
68
68
|
model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=P8QtKgFXtt5b2RoubzI5OGlCfbEfZsAirjyrkFzK26A,2846
|
|
69
69
|
model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=KifDMbm7qkSfvSl6pcZzQ82naIXzeKL6aT-VsvWZYyc,7901
|
|
70
70
|
model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
|
@@ -150,7 +150,7 @@ model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7V
|
|
|
150
150
|
model_compression_toolkit/core/keras/constants.py,sha256=Uv3c0UdW55pIVQNW_1HQlgl-dHXREkltOLyzp8G1mTQ,3163
|
|
151
151
|
model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
|
|
152
152
|
model_compression_toolkit/core/keras/default_framework_info.py,sha256=HcHplb7IcnOTyK2p6uhp3OVG4-RV3RDo9C_4evaIzkQ,4981
|
|
153
|
-
model_compression_toolkit/core/keras/keras_implementation.py,sha256=
|
|
153
|
+
model_compression_toolkit/core/keras/keras_implementation.py,sha256=bRH39d4lW7Ngm8xi7v9JQd9gNfGlB_lb-bolbzTYUcc,29881
|
|
154
154
|
model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
|
|
155
155
|
model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=HUmzEXDQ8LGX7uOYSRiLZ2TNbYxLX9J9IeAa6QYlifg,3927
|
|
156
156
|
model_compression_toolkit/core/keras/resource_utilization_data_facade.py,sha256=Xmk2ZL5CaYdb7iG62HdtZ1F64vap7ffnrsuR3e3G5hc,4851
|
|
@@ -213,7 +213,7 @@ model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKW
|
|
|
213
213
|
model_compression_toolkit/core/pytorch/constants.py,sha256=NI-J7REuxn06oEIHsmJ4GqtNC3TbV8xlkJjt5Ar-c4U,2626
|
|
214
214
|
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=r1XyzUFvrjGcJHQM5ETLsMZIG2yHCr9HMjqf0ti9inw,4175
|
|
215
215
|
model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
|
|
216
|
-
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=
|
|
216
|
+
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=Qe0GCbXsq8hqheMwZaZGl5caWK59RY4ldL5aJWcCmQ8,27516
|
|
217
217
|
model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=2LDQ7qupglHQ7o1Am7LWdfYVacfQnl-aW2N6l9det1w,3264
|
|
218
218
|
model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=E6ifk1HdO60k4IRH2EFBzAYWtwUlrGqJoQ66nknpHoQ,4983
|
|
219
219
|
model_compression_toolkit/core/pytorch/utils.py,sha256=OT_mrNEJqPgWLdtQuivKMQVjtJY49cmoIVvbRhANl1w,3004
|
|
@@ -222,7 +222,7 @@ model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,s
|
|
|
222
222
|
model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
|
|
223
223
|
model_compression_toolkit/core/pytorch/back2framework/instance_builder.py,sha256=BBHBfTqeWm7L3iDyPBpk0jxvj-rBg1QWI23imkjfIl0,1467
|
|
224
224
|
model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py,sha256=D7lU1r9Uq_7fdNuKk2BMF8ho5GrsY-8gyGN6yYoHaVg,15060
|
|
225
|
-
model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=
|
|
225
|
+
model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=iswwKSTVGJKkYDBiVzs5L0sw2zYax11UfInbelkgU1k,18258
|
|
226
226
|
model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py,sha256=qZNNOlNTTV4ZKPG3q5GDXkIVTPUEr8dvxAS_YiMORmg,3456
|
|
227
227
|
model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
228
228
|
model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py,sha256=q2JDw10NKng50ee2i9faGzWZ-IydnR2aOMGSn9RoZmc,5773
|
|
@@ -483,8 +483,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
|
|
|
483
483
|
model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
|
|
484
484
|
model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
|
485
485
|
model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=MxylaVFPgN7zBiRBy6WV610EA4scLgRJFbMucKvvNDU,2896
|
|
486
|
-
mct_nightly-2.0.0.
|
|
487
|
-
mct_nightly-2.0.0.
|
|
488
|
-
mct_nightly-2.0.0.
|
|
489
|
-
mct_nightly-2.0.0.
|
|
490
|
-
mct_nightly-2.0.0.
|
|
486
|
+
mct_nightly-2.0.0.20240522.172031.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
|
487
|
+
mct_nightly-2.0.0.20240522.172031.dist-info/METADATA,sha256=0wVsPZqNQDo_t80clcCNqoXhLcKDeUuhj8TciICuq0s,19724
|
|
488
|
+
mct_nightly-2.0.0.20240522.172031.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
489
|
+
mct_nightly-2.0.0.20240522.172031.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
|
490
|
+
mct_nightly-2.0.0.20240522.172031.dist-info/RECORD,,
|
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
|
27
27
|
from model_compression_toolkit import pruning
|
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
|
29
29
|
|
|
30
|
-
__version__ = "2.0.0.
|
|
30
|
+
__version__ = "2.0.0.20240522.172031"
|
|
@@ -348,13 +348,14 @@ class FrameworkImplementation(ABC):
|
|
|
348
348
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
349
349
|
f'framework\'s count_node_for_mixed_precision_interest_points method.') # pragma: no cover
|
|
350
350
|
|
|
351
|
-
def
|
|
351
|
+
def get_mp_node_distance_fn(self, layer_class: type,
|
|
352
352
|
framework_attrs: Dict[str, Any],
|
|
353
353
|
compute_distance_fn: Callable = None,
|
|
354
|
-
axis: int = None
|
|
354
|
+
axis: int = None,
|
|
355
|
+
norm_mse: bool = False) -> Callable:
|
|
355
356
|
"""
|
|
356
357
|
A mapping between layers' types and a distance function for computing the distance between
|
|
357
|
-
two tensors (for loss computation purposes). Returns a specific function if node of specific types is
|
|
358
|
+
two tensors in mixed precision (for loss computation purposes). Returns a specific function if node of specific types is
|
|
358
359
|
given, or a default (normalized MSE) function otherwise.
|
|
359
360
|
|
|
360
361
|
Args:
|
|
@@ -362,12 +363,13 @@ class FrameworkImplementation(ABC):
|
|
|
362
363
|
framework_attrs: Framework attributes the layer had which the graph node holds.
|
|
363
364
|
compute_distance_fn: An optional distance function to use globally for all nodes.
|
|
364
365
|
axis: The axis on which the operation is preformed (if specified).
|
|
366
|
+
norm_mse: whether to normalize mse distance function.
|
|
365
367
|
|
|
366
368
|
Returns: A distance function between two tensors.
|
|
367
369
|
"""
|
|
368
370
|
|
|
369
371
|
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
|
|
370
|
-
f'framework\'s
|
|
372
|
+
f'framework\'s get_mp_node_distance_fn method.') # pragma: no cover
|
|
371
373
|
|
|
372
374
|
|
|
373
375
|
@abstractmethod
|
|
@@ -238,8 +238,12 @@ class BaseNode:
|
|
|
238
238
|
"""
|
|
239
239
|
for pos, weight in sorted((pos, weight) for pos, weight in self.weights.items()
|
|
240
240
|
if isinstance(pos, int)):
|
|
241
|
-
|
|
242
|
-
|
|
241
|
+
if pos > len(input_tensors):
|
|
242
|
+
Logger.critical("The positional weight index cannot exceed the number of input tensors to the node.") # pragma: no cover
|
|
243
|
+
# Insert only positional weights that are not subject to quantization. If the positional weight is
|
|
244
|
+
# subject to quantization, the quantization wrapper inserts the positional weight into the node.
|
|
245
|
+
if not self.is_weights_quantization_enabled(pos):
|
|
246
|
+
input_tensors.insert(pos, weight)
|
|
243
247
|
|
|
244
248
|
return input_tensors
|
|
245
249
|
|
|
@@ -89,10 +89,13 @@ class SensitivityEvaluation:
|
|
|
89
89
|
fw_impl.count_node_for_mixed_precision_interest_points,
|
|
90
90
|
quant_config.num_interest_points_factor)
|
|
91
91
|
|
|
92
|
-
|
|
92
|
+
# We use normalized MSE when not running hessian-based. For Hessian-based normalized MSE is not needed
|
|
93
|
+
# beacause hessian weights already do normalization.
|
|
94
|
+
use_normalized_mse = self.quant_config.use_hessian_based_scores is False
|
|
95
|
+
self.ips_distance_fns, self.ips_axis = self._init_metric_points_lists(self.interest_points, use_normalized_mse)
|
|
93
96
|
|
|
94
97
|
self.output_points = get_output_nodes_for_metric(graph)
|
|
95
|
-
self.out_ps_distance_fns, self.out_ps_axis = self._init_metric_points_lists(self.output_points)
|
|
98
|
+
self.out_ps_distance_fns, self.out_ps_axis = self._init_metric_points_lists(self.output_points, use_normalized_mse)
|
|
96
99
|
|
|
97
100
|
# Setting lists with relative position of the interest points
|
|
98
101
|
# and output points in the list of all mp model activation tensors
|
|
@@ -128,7 +131,7 @@ class SensitivityEvaluation:
|
|
|
128
131
|
self.interest_points_hessians = self._compute_hessian_based_scores()
|
|
129
132
|
self.quant_config.distance_weighting_method = lambda d: self.interest_points_hessians
|
|
130
133
|
|
|
131
|
-
def _init_metric_points_lists(self, points: List[BaseNode]) -> Tuple[List[Callable], List[int]]:
|
|
134
|
+
def _init_metric_points_lists(self, points: List[BaseNode], norm_mse: bool = False) -> Tuple[List[Callable], List[int]]:
|
|
132
135
|
"""
|
|
133
136
|
Initiates required lists for future use when computing the sensitivity metric.
|
|
134
137
|
Each point on which the metric is computed uses a dedicated distance function based on its type.
|
|
@@ -136,6 +139,7 @@ class SensitivityEvaluation:
|
|
|
136
139
|
|
|
137
140
|
Args:
|
|
138
141
|
points: The set of nodes in the graph for which we need to initiate the lists.
|
|
142
|
+
norm_mse: whether to normalize mse distance function.
|
|
139
143
|
|
|
140
144
|
Returns: A lists with distance functions and an axis list for each node.
|
|
141
145
|
|
|
@@ -144,11 +148,12 @@ class SensitivityEvaluation:
|
|
|
144
148
|
axis_list = []
|
|
145
149
|
for n in points:
|
|
146
150
|
axis = n.framework_attr.get(AXIS) if not isinstance(n, FunctionalNode) else n.op_call_kwargs.get(AXIS)
|
|
147
|
-
distance_fn = self.fw_impl.
|
|
151
|
+
distance_fn = self.fw_impl.get_mp_node_distance_fn(
|
|
148
152
|
layer_class=n.layer_class,
|
|
149
153
|
framework_attrs=n.framework_attr,
|
|
150
154
|
compute_distance_fn=self.quant_config.compute_distance_fn,
|
|
151
|
-
axis=axis
|
|
155
|
+
axis=axis,
|
|
156
|
+
norm_mse=norm_mse)
|
|
152
157
|
distance_fns_list.append(distance_fn)
|
|
153
158
|
# Axis is needed only for KL Divergence calculation, otherwise we use per-tensor computation
|
|
154
159
|
axis_list.append(axis if distance_fn==compute_kl_divergence else None)
|
|
@@ -421,13 +421,14 @@ class KerasImplementation(FrameworkImplementation):
|
|
|
421
421
|
|
|
422
422
|
return False
|
|
423
423
|
|
|
424
|
-
def
|
|
424
|
+
def get_mp_node_distance_fn(self, layer_class: type,
|
|
425
425
|
framework_attrs: Dict[str, Any],
|
|
426
426
|
compute_distance_fn: Callable = None,
|
|
427
|
-
axis: int = None
|
|
427
|
+
axis: int = None,
|
|
428
|
+
norm_mse: bool = False) -> Callable:
|
|
428
429
|
"""
|
|
429
430
|
A mapping between layers' types and a distance function for computing the distance between
|
|
430
|
-
two tensors (for loss computation purposes). Returns a specific function if node of specific types is
|
|
431
|
+
two tensors in mixed precision (for loss computation purposes). Returns a specific function if node of specific types is
|
|
431
432
|
given, or a default (normalized MSE) function otherwise.
|
|
432
433
|
|
|
433
434
|
Args:
|
|
@@ -435,6 +436,7 @@ class KerasImplementation(FrameworkImplementation):
|
|
|
435
436
|
framework_attrs: Framework attributes the layer had which the graph node holds.
|
|
436
437
|
compute_distance_fn: An optional distance function to use globally for all nodes.
|
|
437
438
|
axis: The axis on which the operation is preformed (if specified).
|
|
439
|
+
norm_mse: whether to normalize mse distance function.
|
|
438
440
|
|
|
439
441
|
Returns: A distance function between two tensors.
|
|
440
442
|
"""
|
|
@@ -456,7 +458,7 @@ class KerasImplementation(FrameworkImplementation):
|
|
|
456
458
|
return compute_cs
|
|
457
459
|
elif layer_class == Dense:
|
|
458
460
|
return compute_cs
|
|
459
|
-
return compute_mse
|
|
461
|
+
return partial(compute_mse, norm=norm_mse)
|
|
460
462
|
|
|
461
463
|
def get_trace_hessian_calculator(self,
|
|
462
464
|
graph: Graph,
|
|
@@ -67,8 +67,7 @@ def _build_input_tensors_list(node: BaseNode,
|
|
|
67
67
|
_input_tensors = node_to_output_tensors_dict[ie.source_node]
|
|
68
68
|
input_tensors.append(_input_tensors)
|
|
69
69
|
input_tensors = [tensor for tensor_list in input_tensors for tensor in tensor_list] # flat list of lists
|
|
70
|
-
|
|
71
|
-
input_tensors = node.insert_positional_weights_to_input_list(input_tensors)
|
|
70
|
+
input_tensors = node.insert_positional_weights_to_input_list(input_tensors)
|
|
72
71
|
# convert inputs from positional weights (numpy arrays) to tensors. Must handle each element in the
|
|
73
72
|
# list separately, because in FX the tensors are FX objects and fail to_torch_tensor
|
|
74
73
|
input_tensors = [to_torch_tensor(t, numpy_type=t.dtype) if isinstance(t, np.ndarray) else t
|
|
@@ -403,13 +403,14 @@ class PytorchImplementation(FrameworkImplementation):
|
|
|
403
403
|
return True
|
|
404
404
|
return False
|
|
405
405
|
|
|
406
|
-
def
|
|
406
|
+
def get_mp_node_distance_fn(self, layer_class: type,
|
|
407
407
|
framework_attrs: Dict[str, Any],
|
|
408
408
|
compute_distance_fn: Callable = None,
|
|
409
|
-
axis: int = None
|
|
409
|
+
axis: int = None,
|
|
410
|
+
norm_mse: bool = False) -> Callable:
|
|
410
411
|
"""
|
|
411
412
|
A mapping between layers' types and a distance function for computing the distance between
|
|
412
|
-
two tensors (for loss computation purposes). Returns a specific function if node of specific types is
|
|
413
|
+
two tensors in mixed precision (for loss computation purposes). Returns a specific function if node of specific types is
|
|
413
414
|
given, or a default (normalized MSE) function otherwise.
|
|
414
415
|
|
|
415
416
|
Args:
|
|
@@ -417,6 +418,7 @@ class PytorchImplementation(FrameworkImplementation):
|
|
|
417
418
|
framework_attrs: Framework attributes the layer had which the graph node holds.
|
|
418
419
|
compute_distance_fn: An optional distance function to use globally for all nodes.
|
|
419
420
|
axis: The axis on which the operation is preformed (if specified).
|
|
421
|
+
norm_mse: whether to normalize mse distance function.
|
|
420
422
|
|
|
421
423
|
Returns: A distance function between two tensors.
|
|
422
424
|
"""
|
|
@@ -430,7 +432,7 @@ class PytorchImplementation(FrameworkImplementation):
|
|
|
430
432
|
return compute_cs
|
|
431
433
|
elif layer_class == Linear:
|
|
432
434
|
return compute_cs
|
|
433
|
-
return compute_mse
|
|
435
|
+
return partial(compute_mse, norm=norm_mse)
|
|
434
436
|
|
|
435
437
|
def is_output_node_compatible_for_hessian_score_computation(self,
|
|
436
438
|
node: BaseNode) -> bool:
|
|
File without changes
|
{mct_nightly-2.0.0.20240521.151450.dist-info → mct_nightly-2.0.0.20240522.172031.dist-info}/WHEEL
RENAMED
|
File without changes
|
|
File without changes
|