mct-nightly 2.1.0.20240725.446__py3-none-any.whl → 2.1.0.20240727.431__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.1.0.20240725.446.dist-info → mct_nightly-2.1.0.20240727.431.dist-info}/METADATA +1 -1
- {mct_nightly-2.1.0.20240725.446.dist-info → mct_nightly-2.1.0.20240727.431.dist-info}/RECORD +35 -31
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/pytorch/constants.py +6 -1
- model_compression_toolkit/core/pytorch/utils.py +27 -0
- model_compression_toolkit/data_generation/common/data_generation.py +20 -18
- model_compression_toolkit/data_generation/common/data_generation_config.py +8 -11
- model_compression_toolkit/data_generation/common/enums.py +24 -12
- model_compression_toolkit/data_generation/common/image_pipeline.py +50 -12
- model_compression_toolkit/data_generation/common/model_info_exctractors.py +0 -8
- model_compression_toolkit/data_generation/common/optimization_utils.py +7 -11
- model_compression_toolkit/data_generation/keras/constants.py +5 -2
- model_compression_toolkit/data_generation/keras/image_operations.py +189 -0
- model_compression_toolkit/data_generation/keras/image_pipeline.py +50 -104
- model_compression_toolkit/data_generation/keras/keras_data_generation.py +28 -36
- model_compression_toolkit/data_generation/keras/model_info_exctractors.py +0 -13
- model_compression_toolkit/data_generation/keras/optimization_functions/bn_layer_weighting_functions.py +16 -6
- model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +219 -0
- model_compression_toolkit/data_generation/keras/optimization_functions/output_loss_functions.py +39 -13
- model_compression_toolkit/data_generation/keras/optimization_functions/scheduler_step_functions.py +6 -98
- model_compression_toolkit/data_generation/keras/optimization_utils.py +15 -28
- model_compression_toolkit/data_generation/pytorch/constants.py +4 -1
- model_compression_toolkit/data_generation/pytorch/image_operations.py +105 -0
- model_compression_toolkit/data_generation/pytorch/image_pipeline.py +70 -78
- model_compression_toolkit/data_generation/pytorch/model_info_exctractors.py +0 -10
- model_compression_toolkit/data_generation/pytorch/optimization_functions/bn_layer_weighting_functions.py +17 -6
- model_compression_toolkit/data_generation/pytorch/optimization_functions/image_initilization.py +2 -2
- model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +219 -0
- model_compression_toolkit/data_generation/pytorch/optimization_functions/output_loss_functions.py +55 -21
- model_compression_toolkit/data_generation/pytorch/optimization_functions/scheduler_step_functions.py +15 -0
- model_compression_toolkit/data_generation/pytorch/optimization_utils.py +32 -54
- model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py +57 -52
- {mct_nightly-2.1.0.20240725.446.dist-info → mct_nightly-2.1.0.20240727.431.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.1.0.20240725.446.dist-info → mct_nightly-2.1.0.20240727.431.dist-info}/WHEEL +0 -0
- {mct_nightly-2.1.0.20240725.446.dist-info → mct_nightly-2.1.0.20240727.431.dist-info}/top_level.txt +0 -0
{mct_nightly-2.1.0.20240725.446.dist-info → mct_nightly-2.1.0.20240727.431.dist-info}/RECORD
RENAMED
@@ -1,4 +1,4 @@
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
model_compression_toolkit/__init__.py,sha256=dNSTIwKIETFrRFbKVtmCp7sMOFv7eHAfRFVw0joUkik,1573
|
2
2
|
model_compression_toolkit/constants.py,sha256=9pVleMwnhlM4QwIL2HcEq42I1uF4rlSw63RUjkxOF4w,3923
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
4
4
|
model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
|
@@ -210,13 +210,13 @@ model_compression_toolkit/core/keras/statistics_correction/__init__.py,sha256=9H
|
|
210
210
|
model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py,sha256=XNCtT9klMcsO1v5KA3MmCq_WgXOIT5QSzbfTOa9T-04,3060
|
211
211
|
model_compression_toolkit/core/keras/visualization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
212
212
|
model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
213
|
-
model_compression_toolkit/core/pytorch/constants.py,sha256=
|
213
|
+
model_compression_toolkit/core/pytorch/constants.py,sha256=YwD_joIF0vK8UG2vW1NVvg36pCNWA0vHOXjAgy_XWn0,2794
|
214
214
|
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=r1XyzUFvrjGcJHQM5ETLsMZIG2yHCr9HMjqf0ti9inw,4175
|
215
215
|
model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
|
216
216
|
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=XL_RZcfnb_ZY2jdCjOxxz7SbRBzMokbOWsTuYOSjyRU,27569
|
217
217
|
model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=2LDQ7qupglHQ7o1Am7LWdfYVacfQnl-aW2N6l9det1w,3264
|
218
218
|
model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=E6ifk1HdO60k4IRH2EFBzAYWtwUlrGqJoQ66nknpHoQ,4983
|
219
|
-
model_compression_toolkit/core/pytorch/utils.py,sha256=
|
219
|
+
model_compression_toolkit/core/pytorch/utils.py,sha256=GE7T8q93I5C4As0iOias_dk9HpOvXM1N6---dJlyD60,3863
|
220
220
|
model_compression_toolkit/core/pytorch/back2framework/__init__.py,sha256=H_WixgN0elVWf3exgGYsi58imPoYDj5eYPeh6x4yfug,813
|
221
221
|
model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,sha256=DwNO8WO3JiMawKGKDhlrwCoCjMSBIw5BMbsFFF7eDS4,2279
|
222
222
|
model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
|
@@ -269,36 +269,40 @@ model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment
|
|
269
269
|
model_compression_toolkit/data_generation/__init__.py,sha256=S8pRUqlRvpM5AFHpFWs3zb0H0rtY5nUwmeCQij01oi4,1507
|
270
270
|
model_compression_toolkit/data_generation/common/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
271
271
|
model_compression_toolkit/data_generation/common/constants.py,sha256=21e3ZX9WVYojexG2acTgklrBk8ZO9DjJnKpP4KHZC44,1018
|
272
|
-
model_compression_toolkit/data_generation/common/data_generation.py,sha256=
|
273
|
-
model_compression_toolkit/data_generation/common/data_generation_config.py,sha256=
|
274
|
-
model_compression_toolkit/data_generation/common/enums.py,sha256=
|
275
|
-
model_compression_toolkit/data_generation/common/image_pipeline.py,sha256=
|
276
|
-
model_compression_toolkit/data_generation/common/model_info_exctractors.py,sha256=
|
277
|
-
model_compression_toolkit/data_generation/common/optimization_utils.py,sha256=
|
272
|
+
model_compression_toolkit/data_generation/common/data_generation.py,sha256=W8PeOcL1fBVB1WgXSCrEw-G7AWa6MNzjTqcFbmMhrGE,6687
|
273
|
+
model_compression_toolkit/data_generation/common/data_generation_config.py,sha256=yKqSDJGdbnc9HEmg94sPqMSXGR2OmAzt5X5MQcy_YX8,4473
|
274
|
+
model_compression_toolkit/data_generation/common/enums.py,sha256=V5qAaqMg2WFhsrJ11rTDcRWBhbsxhEHt3uwRq6cesNo,4249
|
275
|
+
model_compression_toolkit/data_generation/common/image_pipeline.py,sha256=PfunQMxYm6KqJUEUVYhtY7-JTq4J-XTyLc1HOalP15s,4761
|
276
|
+
model_compression_toolkit/data_generation/common/model_info_exctractors.py,sha256=CqruljgQ564SMRQtxgYYDWKM7HYDz18MCShNgrRYQKg,5933
|
277
|
+
model_compression_toolkit/data_generation/common/optimization_utils.py,sha256=aEDSclZ2TvIIqN1x9CLf8MBe2GA3m1aEXtbd5Sgcd8k,19528
|
278
278
|
model_compression_toolkit/data_generation/keras/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
|
279
|
-
model_compression_toolkit/data_generation/keras/constants.py,sha256=
|
280
|
-
model_compression_toolkit/data_generation/keras/
|
281
|
-
model_compression_toolkit/data_generation/keras/
|
282
|
-
model_compression_toolkit/data_generation/keras/
|
283
|
-
model_compression_toolkit/data_generation/keras/
|
279
|
+
model_compression_toolkit/data_generation/keras/constants.py,sha256=sxhhGHC-INBs1nVXhyokbFi9ob4jPkSRviuc83JRsgQ,1152
|
280
|
+
model_compression_toolkit/data_generation/keras/image_operations.py,sha256=OtJ5Yz8BZVOnGqyTHwlseRe4EmoLDYxz3bblGtw6HnY,6233
|
281
|
+
model_compression_toolkit/data_generation/keras/image_pipeline.py,sha256=E-HVverorhq33xzteuwUPtOrGDIYoEEs4fZJgiqOAzQ,7043
|
282
|
+
model_compression_toolkit/data_generation/keras/keras_data_generation.py,sha256=IMnmUn7fUsMcJ980FZWuX36iUYXAEYxdYk8oXwz-Xd8,21207
|
283
|
+
model_compression_toolkit/data_generation/keras/model_info_exctractors.py,sha256=1E5xbn0P3py4EYjdpPD9JwGr4jlc3qe1ml1py0t40b8,8026
|
284
|
+
model_compression_toolkit/data_generation/keras/optimization_utils.py,sha256=cHv2tl-_9_D14mWqzNYtKFY8q7sJfW_V__dpZqzRvIo,20546
|
284
285
|
model_compression_toolkit/data_generation/keras/optimization_functions/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
285
286
|
model_compression_toolkit/data_generation/keras/optimization_functions/batchnorm_alignment_functions.py,sha256=f5M7KoISGnb6S6zR7SyQ9dYmQctW9iYRi0Bda1BLq70,1983
|
286
|
-
model_compression_toolkit/data_generation/keras/optimization_functions/bn_layer_weighting_functions.py,sha256=
|
287
|
+
model_compression_toolkit/data_generation/keras/optimization_functions/bn_layer_weighting_functions.py,sha256=xQWTeP-Im6xEUupF-VEjZq-UsRNzpoW0LuMHFR2cX9Q,3390
|
287
288
|
model_compression_toolkit/data_generation/keras/optimization_functions/image_initilization.py,sha256=sjSPLLFLjJ6d0DDSaxnCE0ydIT1zhL8H73QTXEuUfgw,4119
|
288
|
-
model_compression_toolkit/data_generation/keras/optimization_functions/
|
289
|
-
model_compression_toolkit/data_generation/keras/optimization_functions/
|
289
|
+
model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py,sha256=xCc7GlmW-jpf27P8mI2APpAL8LC6zmD1BMbN7Q6wzEE,8647
|
290
|
+
model_compression_toolkit/data_generation/keras/optimization_functions/output_loss_functions.py,sha256=vr_H1dbFINS7LBX_SfW59g0C8ie9grAyOIpCKuPoI1w,6384
|
291
|
+
model_compression_toolkit/data_generation/keras/optimization_functions/scheduler_step_functions.py,sha256=9RhNWtw_cdDlGqEGEdn1JWwvfA8V-Z6ioZn1ppdHFmA,1695
|
290
292
|
model_compression_toolkit/data_generation/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
291
|
-
model_compression_toolkit/data_generation/pytorch/constants.py,sha256=
|
292
|
-
model_compression_toolkit/data_generation/pytorch/
|
293
|
-
model_compression_toolkit/data_generation/pytorch/
|
294
|
-
model_compression_toolkit/data_generation/pytorch/
|
295
|
-
model_compression_toolkit/data_generation/pytorch/
|
293
|
+
model_compression_toolkit/data_generation/pytorch/constants.py,sha256=ZiyweWj2Bnk6duhcV4zowsPvqLdON-AlLhkAuLmCqxg,1256
|
294
|
+
model_compression_toolkit/data_generation/pytorch/image_operations.py,sha256=KUQKOj5G4UPGX9f9PSiLRlBo4e3rRRPec88wkozNgqw,3900
|
295
|
+
model_compression_toolkit/data_generation/pytorch/image_pipeline.py,sha256=dcQr-67u9-ggGuS39YAvR7z-Y0NOdJintcVQ5vy1bM8,7478
|
296
|
+
model_compression_toolkit/data_generation/pytorch/model_info_exctractors.py,sha256=y6vMed6lQQj67-BXZKrAcWUNTkH8YjiUhknOV4wSpRA,9399
|
297
|
+
model_compression_toolkit/data_generation/pytorch/optimization_utils.py,sha256=vRMeUEdInPuJisiO-SKo_9miWZV90sz8GCg5MY0AqiU,18098
|
298
|
+
model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py,sha256=OjdAG0uGdbN0ATMrkRskhEttkUgSXN8KCVd8JXKiwxk,21620
|
296
299
|
model_compression_toolkit/data_generation/pytorch/optimization_functions/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
297
300
|
model_compression_toolkit/data_generation/pytorch/optimization_functions/batchnorm_alignment_functions.py,sha256=dMc4zz9XfYfAT4Cxns57VgvGZWPAMfaGlWLFyCyl8TA,1968
|
298
|
-
model_compression_toolkit/data_generation/pytorch/optimization_functions/bn_layer_weighting_functions.py,sha256=
|
299
|
-
model_compression_toolkit/data_generation/pytorch/optimization_functions/image_initilization.py,sha256=
|
300
|
-
model_compression_toolkit/data_generation/pytorch/optimization_functions/
|
301
|
-
model_compression_toolkit/data_generation/pytorch/optimization_functions/
|
301
|
+
model_compression_toolkit/data_generation/pytorch/optimization_functions/bn_layer_weighting_functions.py,sha256=We0fVMQ4oU7Y0IWQ8fKy8KpqkIiLyKoQeF9XKAQ6TH0,3317
|
302
|
+
model_compression_toolkit/data_generation/pytorch/optimization_functions/image_initilization.py,sha256=hhWSZ7w45dE5SQ6jM27cBkCSJObWkALs_RpD6afPi68,4753
|
303
|
+
model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py,sha256=NydGxFIclmrfU3HWYUrRbprg4hPt470QP6MTOMLEhRs,9172
|
304
|
+
model_compression_toolkit/data_generation/pytorch/optimization_functions/output_loss_functions.py,sha256=PRVmn8o2hTdwTdbd2ezf__LNbFvcgiVO0c25dsyg3Tg,6549
|
305
|
+
model_compression_toolkit/data_generation/pytorch/optimization_functions/scheduler_step_functions.py,sha256=zMjY2y4FSHonuY5hddbMTb8qAQtLtohYF7q1wuruDDs,3267
|
302
306
|
model_compression_toolkit/exporter/__init__.py,sha256=Eg3c4EAjW3g6h13A-Utgf9ncHrTMRHAoySNDQGPDZ4E,1301
|
303
307
|
model_compression_toolkit/exporter/model_exporter/__init__.py,sha256=9HIBmj8ROdCA-yvkpA8EcN6RHJe_2vEpLLW_gxOJtak,698
|
304
308
|
model_compression_toolkit/exporter/model_exporter/fw_agonstic/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
@@ -517,8 +521,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
517
521
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=yrZNVRm2IRU7r7R-hjS2lOQ6wvEEvbeunvf2jKoWjXk,3277
|
518
522
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
519
523
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=eyMoXt5o5EnMr6d-rpCwQdX5mAiYiymvbgKv4tf7-a0,4576
|
520
|
-
mct_nightly-2.1.0.
|
521
|
-
mct_nightly-2.1.0.
|
522
|
-
mct_nightly-2.1.0.
|
523
|
-
mct_nightly-2.1.0.
|
524
|
-
mct_nightly-2.1.0.
|
524
|
+
mct_nightly-2.1.0.20240727.431.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
525
|
+
mct_nightly-2.1.0.20240727.431.dist-info/METADATA,sha256=g2Z7XHdZo_A9_vL0R9IJv5kcLIsZg_ONvFBnuOFqFkk,19719
|
526
|
+
mct_nightly-2.1.0.20240727.431.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
527
|
+
mct_nightly-2.1.0.20240727.431.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
528
|
+
mct_nightly-2.1.0.20240727.431.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.1.0.
|
30
|
+
__version__ = "2.1.0.20240727.000431"
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
15
|
+
import torch
|
16
16
|
|
17
17
|
# # Layer type constants:
|
18
18
|
PLACEHOLDER = 'placeholder'
|
@@ -94,3 +94,8 @@ BIAS_V = 'bias_v'
|
|
94
94
|
# # Batch size value for 'reshape' and 'view' operators,
|
95
95
|
# # the value is -1 so the batch size is inferred from the length of the array and remaining dimensions.
|
96
96
|
BATCH_DIM_VALUE = -1
|
97
|
+
|
98
|
+
|
99
|
+
# The maximum and minimum representable values for float16
|
100
|
+
MAX_FLOAT16 = torch.finfo(torch.float16).max - 1
|
101
|
+
MIN_FLOAT16 = torch.finfo(torch.float16).min - 1
|
@@ -13,8 +13,11 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
import torch
|
16
|
+
from torch import Tensor
|
16
17
|
import numpy as np
|
17
18
|
from typing import Union
|
19
|
+
|
20
|
+
from model_compression_toolkit.core.pytorch.constants import MAX_FLOAT16, MIN_FLOAT16
|
18
21
|
from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device
|
19
22
|
from model_compression_toolkit.logger import Logger
|
20
23
|
|
@@ -83,3 +86,27 @@ def torch_tensor_to_numpy(tensor: Union[torch.Tensor, list, tuple]) -> Union[np.
|
|
83
86
|
return tensor.cpu().detach().contiguous().numpy()
|
84
87
|
else:
|
85
88
|
Logger.critical(f'Unsupported type for conversion to Numpy array: {type(tensor)}.')
|
89
|
+
|
90
|
+
|
91
|
+
def clip_inf_values_float16(tensor: Tensor) -> Tensor:
|
92
|
+
"""
|
93
|
+
Clips +inf and -inf values in a float16 tensor to the maximum and minimum representable values.
|
94
|
+
|
95
|
+
Parameters:
|
96
|
+
tensor (Tensor): Input PyTorch tensor of dtype float16.
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
Tensor: A tensor with +inf values replaced by the maximum float16 value,
|
100
|
+
and -inf values replaced by the minimum float16 value.
|
101
|
+
"""
|
102
|
+
# Check if the tensor is of dtype float16
|
103
|
+
if tensor.dtype != torch.float16:
|
104
|
+
return tensor
|
105
|
+
|
106
|
+
# Create a mask for inf values (both positive and negative)
|
107
|
+
inf_mask = torch.isinf(tensor)
|
108
|
+
|
109
|
+
# Replace inf values with max float16 value
|
110
|
+
tensor[inf_mask] = MAX_FLOAT16 * torch.sign(tensor[inf_mask])
|
111
|
+
|
112
|
+
return tensor
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
# Import required modules and classes
|
16
|
-
from typing import Any, Tuple, Dict, Callable, List
|
16
|
+
from typing import Any, Tuple, Dict, Callable, List, Union
|
17
17
|
|
18
18
|
from model_compression_toolkit.data_generation.common.data_generation_config import DataGenerationConfig
|
19
19
|
from model_compression_toolkit.data_generation.common.enums import ImagePipelineType, ImageNormalizationType, \
|
@@ -24,7 +24,7 @@ from model_compression_toolkit.logger import Logger
|
|
24
24
|
|
25
25
|
def get_data_generation_classes(
|
26
26
|
data_generation_config: DataGenerationConfig,
|
27
|
-
output_image_size: Tuple,
|
27
|
+
output_image_size: Union[int, Tuple[int, int]],
|
28
28
|
n_images: int,
|
29
29
|
image_pipeline_dict: Dict,
|
30
30
|
image_normalization_dict: Dict,
|
@@ -38,7 +38,7 @@ def get_data_generation_classes(
|
|
38
38
|
|
39
39
|
Args:
|
40
40
|
data_generation_config (DataGenerationConfig): Configuration for data generation.
|
41
|
-
output_image_size (Tuple): The desired output image size.
|
41
|
+
output_image_size (Union[int, Tuple[int, int]]): The desired output image size.
|
42
42
|
n_images (int): The number of random samples.
|
43
43
|
image_pipeline_dict (Dict): Dictionary mapping ImagePipelineType to corresponding image pipeline classes.
|
44
44
|
image_normalization_dict (Dict): Dictionary mapping ImageNormalizationType to corresponding
|
@@ -56,26 +56,28 @@ def get_data_generation_classes(
|
|
56
56
|
output_loss_fn (Callable): Function to compute output loss.
|
57
57
|
init_dataset (Any): The initial dataset used for image generation.
|
58
58
|
"""
|
59
|
+
# Get the normalization values corresponding to the specified type
|
60
|
+
normalization = image_normalization_dict.get(data_generation_config.image_normalization_type)
|
61
|
+
|
62
|
+
# Check if the image normalization type is valid
|
63
|
+
if normalization is None:
|
64
|
+
Logger.critical(
|
65
|
+
f'Invalid image_normalization_type {data_generation_config.image_normalization_type}. '
|
66
|
+
f'Please select one from {ImageNormalizationType.get_values()}.') # pragma: no cover
|
67
|
+
|
59
68
|
# Get the image pipeline class corresponding to the specified type
|
60
69
|
image_pipeline = (
|
61
70
|
image_pipeline_dict.get(data_generation_config.image_pipeline_type)(
|
62
71
|
output_image_size=output_image_size,
|
63
|
-
extra_pixels=data_generation_config.extra_pixels
|
72
|
+
extra_pixels=data_generation_config.extra_pixels,
|
73
|
+
image_clipping=data_generation_config.image_clipping,
|
74
|
+
normalization=normalization))
|
64
75
|
|
65
76
|
# Check if the image pipeline type is valid
|
66
77
|
if image_pipeline is None:
|
67
78
|
Logger.critical(
|
68
79
|
f'Invalid image_pipeline_type {data_generation_config.image_pipeline_type}. '
|
69
|
-
f'Please select one from {ImagePipelineType.get_values()}.')
|
70
|
-
|
71
|
-
# Get the normalization values corresponding to the specified type
|
72
|
-
normalization = image_normalization_dict.get(data_generation_config.image_normalization_type)
|
73
|
-
|
74
|
-
# Check if the image normalization type is valid
|
75
|
-
if normalization is None:
|
76
|
-
Logger.critical(
|
77
|
-
f'Invalid image_normalization_type {data_generation_config.image_normalization_type}. '
|
78
|
-
f'Please select one from {ImageNormalizationType.get_values()}.')
|
80
|
+
f'Please select one from {ImagePipelineType.get_values()}.') # pragma: no cover
|
79
81
|
|
80
82
|
# Get the layer weighting function corresponding to the specified type
|
81
83
|
bn_layer_weighting_fn = bn_layer_weighting_function_dict.get(data_generation_config.layer_weighting_type)
|
@@ -83,7 +85,7 @@ def get_data_generation_classes(
|
|
83
85
|
if bn_layer_weighting_fn is None:
|
84
86
|
Logger.critical(
|
85
87
|
f'Invalid layer_weighting_type {data_generation_config.layer_weighting_type}. '
|
86
|
-
f'Please select one from {BNLayerWeightingType.get_values()}.')
|
88
|
+
f'Please select one from {BNLayerWeightingType.get_values()}.') # pragma: no cover
|
87
89
|
|
88
90
|
# Get the image initialization function corresponding to the specified type
|
89
91
|
image_initialization_fn = image_initialization_function_dict.get(data_generation_config.data_init_type)
|
@@ -92,7 +94,7 @@ def get_data_generation_classes(
|
|
92
94
|
if image_initialization_fn is None:
|
93
95
|
Logger.critical(
|
94
96
|
f'Invalid data_init_type {data_generation_config.data_init_type}. '
|
95
|
-
f'Please select one from {DataInitType.get_values()}.')
|
97
|
+
f'Please select one from {DataInitType.get_values()}.') # pragma: no cover
|
96
98
|
|
97
99
|
# Get the BatchNorm alignment loss function corresponding to the specified type
|
98
100
|
bn_alignment_loss_fn = bn_alignment_loss_function_dict.get(data_generation_config.bn_alignment_loss_type)
|
@@ -101,7 +103,7 @@ def get_data_generation_classes(
|
|
101
103
|
if bn_alignment_loss_fn is None:
|
102
104
|
Logger.critical(
|
103
105
|
f'Invalid bn_alignment_loss_type {data_generation_config.bn_alignment_loss_type}. '
|
104
|
-
f'Please select one from {BatchNormAlignemntLossType.get_values()}.')
|
106
|
+
f'Please select one from {BatchNormAlignemntLossType.get_values()}.') # pragma: no cover
|
105
107
|
|
106
108
|
# Get the output loss function corresponding to the specified type
|
107
109
|
output_loss_fn = output_loss_function_dict.get(data_generation_config.output_loss_type)
|
@@ -110,7 +112,7 @@ def get_data_generation_classes(
|
|
110
112
|
if output_loss_fn is None:
|
111
113
|
Logger.critical(
|
112
114
|
f'Invalid output_loss_type {data_generation_config.output_loss_type}. '
|
113
|
-
f'Please select one from {OutputLossType.get_values()}.')
|
115
|
+
f'Please select one from {OutputLossType.get_values()}.') # pragma: no cover
|
114
116
|
|
115
117
|
# Initialize the dataset for data generation
|
116
118
|
init_dataset = image_initialization_fn(
|
@@ -12,12 +12,11 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from typing import
|
15
|
+
from typing import Any, List, Tuple, Union
|
16
16
|
|
17
17
|
from model_compression_toolkit.data_generation.common.enums import SchedulerType, BatchNormAlignemntLossType, \
|
18
18
|
DataInitType, BNLayerWeightingType, ImageGranularity, ImagePipelineType, ImageNormalizationType, OutputLossType
|
19
19
|
|
20
|
-
|
21
20
|
class DataGenerationConfig:
|
22
21
|
"""
|
23
22
|
Configuration class for data generation.
|
@@ -36,11 +35,10 @@ class DataGenerationConfig:
|
|
36
35
|
layer_weighting_type: BNLayerWeightingType = None,
|
37
36
|
image_pipeline_type: ImagePipelineType = None,
|
38
37
|
image_normalization_type: ImageNormalizationType = None,
|
39
|
-
extra_pixels: int = 0,
|
38
|
+
extra_pixels: Union[int, Tuple[int, int]] = 0,
|
40
39
|
bn_layer_types: List = [],
|
41
40
|
last_layer_types: List = [],
|
42
|
-
|
43
|
-
reflection: bool = True,
|
41
|
+
image_clipping: bool = True,
|
44
42
|
):
|
45
43
|
"""
|
46
44
|
Initialize the DataGenerationConfig.
|
@@ -59,17 +57,15 @@ class DataGenerationConfig:
|
|
59
57
|
layer_weighting_type (BNLayerWeightingType): Type of layer weighting. Defaults to None.
|
60
58
|
image_pipeline_type (ImagePipelineType): Type of image pipeline. Defaults to None.
|
61
59
|
image_normalization_type (ImageNormalizationType): Type of image normalization. Defaults to None.
|
62
|
-
extra_pixels (int): Extra pixels to add to the input image size. Defaults to 0.
|
60
|
+
extra_pixels (Union[int, Tuple[int, int]]): Extra pixels to add to the input image size. Defaults to 0.
|
63
61
|
bn_layer_types (List): List of BatchNorm layer types. Defaults to [].
|
64
62
|
last_layer_types (List): List of layer types. Defaults to [].
|
65
|
-
|
66
|
-
reflection (bool): Flag to enable reflection. Defaults to True.
|
63
|
+
image_clipping (bool): Flag to enable image clipping. Defaults to True.
|
67
64
|
"""
|
68
65
|
self.n_iter = n_iter
|
69
66
|
self.optimizer = optimizer
|
70
67
|
self.data_gen_batch_size = data_gen_batch_size
|
71
68
|
self.initial_lr = initial_lr
|
72
|
-
self.output_loss_multiplier = output_loss_multiplier
|
73
69
|
self.image_granularity = image_granularity
|
74
70
|
self.scheduler_type = scheduler_type
|
75
71
|
self.bn_alignment_loss_type = bn_alignment_loss_type
|
@@ -81,6 +77,7 @@ class DataGenerationConfig:
|
|
81
77
|
self.layer_weighting_type = layer_weighting_type
|
82
78
|
self.bn_layer_types = bn_layer_types
|
83
79
|
self.last_layer_types = last_layer_types
|
84
|
-
self.
|
85
|
-
self.
|
80
|
+
self.image_clipping = image_clipping
|
81
|
+
self.output_loss_multiplier = output_loss_multiplier
|
82
|
+
|
86
83
|
|
@@ -16,7 +16,6 @@ from enum import Enum
|
|
16
16
|
|
17
17
|
|
18
18
|
class EnumBaseClass(Enum):
|
19
|
-
@classmethod
|
20
19
|
def get_values(cls):
|
21
20
|
"""
|
22
21
|
Get the list of values corresponding to the enum members.
|
@@ -24,7 +23,23 @@ class EnumBaseClass(Enum):
|
|
24
23
|
Returns:
|
25
24
|
List of values.
|
26
25
|
"""
|
27
|
-
return
|
26
|
+
return list(cls.__members__.values())
|
27
|
+
|
28
|
+
@classmethod
|
29
|
+
def get_enum_by_value(cls, target_value):
|
30
|
+
"""
|
31
|
+
Function to get the key corresponding to a given enum value.
|
32
|
+
|
33
|
+
Parameters:
|
34
|
+
target_value: The value to find the key for.
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
The key corresponding to the given value if found, else None.
|
38
|
+
"""
|
39
|
+
for value in cls.__members__.values():
|
40
|
+
if value.value == target_value:
|
41
|
+
return value
|
42
|
+
return None
|
28
43
|
|
29
44
|
|
30
45
|
class ImageGranularity(EnumBaseClass):
|
@@ -61,15 +76,12 @@ class ImagePipelineType(EnumBaseClass):
|
|
61
76
|
"""
|
62
77
|
An enum for choosing the image pipeline type for image manipulation:
|
63
78
|
|
64
|
-
|
65
|
-
|
66
|
-
RANDOM_CROP_FLIP - Crop and flip the images.
|
79
|
+
SMOOTHING_AND_AUGMENTATION - Apply a smoothing filter, then crop and flip the images.
|
67
80
|
|
68
81
|
IDENTITY - Do not apply any manipulation (identity transformation).
|
69
82
|
|
70
83
|
"""
|
71
|
-
|
72
|
-
RANDOM_CROP_FLIP = 'random_crop_flip'
|
84
|
+
SMOOTHING_AND_AUGMENTATION = 'smoothing_and_augmentation'
|
73
85
|
IDENTITY = 'identity'
|
74
86
|
|
75
87
|
|
@@ -118,16 +130,15 @@ class BatchNormAlignemntLossType(EnumBaseClass):
|
|
118
130
|
class OutputLossType(EnumBaseClass):
|
119
131
|
"""
|
120
132
|
An enum for choosing the output loss type:
|
121
|
-
|
122
133
|
NONE - No output loss is applied.
|
123
|
-
|
124
|
-
|
125
|
-
|
134
|
+
NEGATIVE_MIN_MAX_DIFF - Use the mean of the negative min-max difference as the output loss.
|
135
|
+
INVERSE_MIN_MAX_DIFF - Use mean of the 1/(min-max) difference as the output loss.
|
126
136
|
REGULARIZED_MIN_MAX_DIFF - Use regularized min-max difference as the output loss.
|
127
137
|
|
128
138
|
"""
|
129
139
|
NONE = 'none'
|
130
|
-
|
140
|
+
NEGATIVE_MIN_MAX_DIFF = 'negative_min_max_diff'
|
141
|
+
INVERSE_MIN_MAX_DIFF = 'inverse_min_max_diff'
|
131
142
|
REGULARIZED_MIN_MAX_DIFF = 'regularized_min_max_diff'
|
132
143
|
|
133
144
|
|
@@ -141,4 +152,5 @@ class SchedulerType(EnumBaseClass):
|
|
141
152
|
|
142
153
|
"""
|
143
154
|
REDUCE_ON_PLATEAU = 'reduce_on_plateau'
|
155
|
+
REDUCE_ON_PLATEAU_WITH_RESET = 'reduce_on_plateau_with_reset'
|
144
156
|
STEP = 'step'
|
@@ -13,31 +13,61 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
from abc import ABC, abstractmethod
|
16
|
-
from typing import Any
|
16
|
+
from typing import Any, Tuple, Union, Dict, List
|
17
|
+
|
18
|
+
from model_compression_toolkit.data_generation import ImageNormalizationType
|
19
|
+
from model_compression_toolkit.logger import Logger
|
17
20
|
|
18
21
|
|
19
22
|
class BaseImagePipeline(ABC):
|
20
23
|
def __init__(self,
|
21
|
-
output_image_size: int,
|
22
|
-
extra_pixels: int = 0
|
24
|
+
output_image_size: Union[int, Tuple[int, int]],
|
25
|
+
extra_pixels: Union[int, Tuple[int, int]] = 0,
|
26
|
+
image_clipping: bool = False,
|
27
|
+
normalization: List[List[int]] = [[0, 0, 0], [1, 1, 1]]):
|
23
28
|
"""
|
24
29
|
Base class for image pipeline.
|
25
30
|
|
26
31
|
Args:
|
27
|
-
output_image_size (int): The desired output image size.
|
28
|
-
extra_pixels (int,
|
32
|
+
output_image_size (Union[int, Tuple[int, int]]): The desired output image size.
|
33
|
+
extra_pixels (Union[int, Tuple[int, int]]): Extra pixels to add to the input image size. Defaults to 0.
|
34
|
+
image_clipping (bool): Whether to clip images during optimization.
|
35
|
+
normalization (List[List[float]]): The image normalization values for processing images during optimization.
|
29
36
|
"""
|
30
|
-
|
31
|
-
|
37
|
+
if isinstance(output_image_size, int):
|
38
|
+
self.output_image_size = (output_image_size, output_image_size)
|
39
|
+
elif isinstance(output_image_size, tuple) and len(output_image_size) == 1:
|
40
|
+
self.output_image_size = output_image_size + output_image_size # concatenate two tuples
|
41
|
+
elif isinstance(output_image_size, tuple) and len(output_image_size) == 2:
|
42
|
+
self.output_image_size = output_image_size
|
43
|
+
elif isinstance(output_image_size, tuple):
|
44
|
+
Logger.critical(f"'output_image_size' should a tuple of length 1 or 2. Got tuple of length {len(output_image_size)}") # pragma: no cover
|
45
|
+
else:
|
46
|
+
Logger.critical(f"'output_image_size' should be an int or tuple but type {type(output_image_size)} was received.") # pragma: no cover
|
47
|
+
|
48
|
+
if isinstance(extra_pixels, int):
|
49
|
+
self.extra_pixels = (extra_pixels, extra_pixels)
|
50
|
+
elif isinstance(extra_pixels, tuple) and len(extra_pixels) == 1:
|
51
|
+
self.extra_pixels = extra_pixels + extra_pixels # concatenate two tuples
|
52
|
+
elif isinstance(extra_pixels, tuple) and len(extra_pixels) == 2:
|
53
|
+
self.extra_pixels = extra_pixels
|
54
|
+
elif isinstance(extra_pixels, tuple):
|
55
|
+
Logger.critical(f"'extra_pixels' should a tuple of length 1 or 2. Got tuple of length {len(extra_pixels)}") # pragma: no cover
|
56
|
+
else:
|
57
|
+
Logger.critical(f"'extra_pixels' should be an int or tuple but type {type(extra_pixels)} was received.") # pragma: no cover
|
58
|
+
|
59
|
+
self.image_clipping = image_clipping
|
60
|
+
self.normalization = normalization
|
61
|
+
|
32
62
|
@abstractmethod
|
33
|
-
def get_image_input_size(self) -> int:
|
63
|
+
def get_image_input_size(self) -> Tuple[int, int]:
|
34
64
|
"""
|
35
65
|
Get the size of the input image for the image pipeline.
|
36
66
|
|
37
67
|
Returns:
|
38
|
-
int: The input image size.
|
68
|
+
Tuple[int, int]: The input image size.
|
39
69
|
"""
|
40
|
-
raise NotImplemented
|
70
|
+
raise NotImplemented # pragma: no cover
|
41
71
|
|
42
72
|
@abstractmethod
|
43
73
|
def image_input_manipulation(self,
|
@@ -51,7 +81,7 @@ class BaseImagePipeline(ABC):
|
|
51
81
|
Returns:
|
52
82
|
Any: Manipulated images.
|
53
83
|
"""
|
54
|
-
raise NotImplemented
|
84
|
+
raise NotImplemented # pragma: no cover
|
55
85
|
|
56
86
|
@abstractmethod
|
57
87
|
def image_output_finalize(self,
|
@@ -65,4 +95,12 @@ class BaseImagePipeline(ABC):
|
|
65
95
|
Returns:
|
66
96
|
Any: Finalized images.
|
67
97
|
"""
|
68
|
-
raise NotImplemented
|
98
|
+
raise NotImplemented # pragma: no cover
|
99
|
+
|
100
|
+
|
101
|
+
# Dictionary mapping ImageNormalizationType to corresponding normalization values
|
102
|
+
image_normalization_dict: Dict[ImageNormalizationType, List[List[float]]] = {
|
103
|
+
ImageNormalizationType.TORCHVISION: [[0.485 * 255, 0.456 * 255, 0.406 * 255], [0.229 * 255, 0.224 * 255, 0.225 * 255]],
|
104
|
+
ImageNormalizationType.KERAS_APPLICATIONS: [[127.5, 127.5, 127.5], [127.5, 127.5, 127.5]],
|
105
|
+
ImageNormalizationType.NO_NORMALIZATION: [[0, 0, 0], [1, 1, 1]]
|
106
|
+
}
|
@@ -157,14 +157,6 @@ class ActivationExtractor:
|
|
157
157
|
"""
|
158
158
|
raise NotImplemented # pragma: no cover
|
159
159
|
|
160
|
-
def get_num_extractor_layers(self) -> int:
|
161
|
-
"""
|
162
|
-
Get the number of layers for which to extract input activations.
|
163
|
-
|
164
|
-
Returns:
|
165
|
-
int: Number of layers for which to extract input activations.
|
166
|
-
"""
|
167
|
-
return self.num_layers
|
168
160
|
|
169
161
|
@abstractmethod
|
170
162
|
def get_extractor_layer_names(self) -> List:
|
@@ -48,8 +48,6 @@ class ImagesOptimizationHandler:
|
|
48
48
|
scheduler: Any,
|
49
49
|
normalization_mean: List[float],
|
50
50
|
normalization_std: List[float],
|
51
|
-
clip_images: bool,
|
52
|
-
reflection: bool,
|
53
51
|
initial_lr: float,
|
54
52
|
eps: float = 1e-6):
|
55
53
|
"""
|
@@ -67,8 +65,6 @@ class ImagesOptimizationHandler:
|
|
67
65
|
scheduler (Any): The scheduler responsible for adjusting the learning rate of the optimizer over time.
|
68
66
|
normalization_mean (List[float]): Mean values used for image normalization.
|
69
67
|
normalization_std (List[float]): Standard deviation values used for image normalization.
|
70
|
-
clip_images (bool): Flag indicating whether to clip generated images during optimization.
|
71
|
-
reflection (bool): Flag indicating whether to use reflection during image generation.
|
72
68
|
initial_lr (float): The initial learning rate used by the optimizer.
|
73
69
|
eps (float, optional): A small value added for numerical stability. Defaults to 1e-6.
|
74
70
|
|
@@ -79,8 +75,6 @@ class ImagesOptimizationHandler:
|
|
79
75
|
self.scheduler = scheduler
|
80
76
|
self.scheduler_step_fn = scheduler_step_fn
|
81
77
|
self.image_granularity = image_granularity
|
82
|
-
self.clip_images = clip_images
|
83
|
-
self.reflection = reflection
|
84
78
|
self.eps = eps
|
85
79
|
self.targets = []
|
86
80
|
self.initial_lr = initial_lr
|
@@ -209,9 +203,11 @@ class ImagesOptimizationHandler:
|
|
209
203
|
imgs_layer_mean, imgs_layer_second_moment, imgs_layer_std = self.all_imgs_stats_holder.get_stats(
|
210
204
|
batch_index, layer_name)
|
211
205
|
|
212
|
-
|
213
|
-
|
214
|
-
|
206
|
+
if imgs_layer_mean is not None and imgs_layer_std is not None:
|
207
|
+
bn_alignment_loss = bn_alignment_loss_fn(bn_layer_mean, imgs_layer_mean, bn_layer_std,
|
208
|
+
imgs_layer_std)
|
209
|
+
# Accumulate the batchnorm alignment weighted by the layer weight
|
210
|
+
total_bn_loss += bn_layer_weight * bn_alignment_loss
|
215
211
|
|
216
212
|
return total_bn_loss
|
217
213
|
|
@@ -418,7 +414,7 @@ class BatchStatsHolder:
|
|
418
414
|
Returns:
|
419
415
|
Any: the mean for the specified layer.
|
420
416
|
"""
|
421
|
-
return self.bn_mean
|
417
|
+
return self.bn_mean.get(bn_layer_name)
|
422
418
|
|
423
419
|
def get_second_moment(self, bn_layer_name: str) -> Any:
|
424
420
|
"""
|
@@ -430,7 +426,7 @@ class BatchStatsHolder:
|
|
430
426
|
Returns:
|
431
427
|
Any: the second moment for the specified layer.
|
432
428
|
"""
|
433
|
-
return self.bn_second_moment
|
429
|
+
return self.bn_second_moment.get(bn_layer_name)
|
434
430
|
|
435
431
|
def get_var(self, bn_layer_name: str) -> Any:
|
436
432
|
"""
|
@@ -18,8 +18,11 @@ BATCH_AXIS, H_AXIS, W_AXIS, CHANNEL_AXIS = 0, 1, 2, 3
|
|
18
18
|
# Default initial learning rate constant for Keras.
|
19
19
|
DEFAULT_KERAS_INITIAL_LR = 1
|
20
20
|
|
21
|
-
# Default
|
22
|
-
|
21
|
+
# Default extra pixels for image padding.
|
22
|
+
DEFAULT_KERAS_EXTRA_PIXELS = 32
|
23
|
+
|
24
|
+
# Default output loss multiplier.
|
25
|
+
DEFAULT_KERAS_OUTPUT_LOSS_MULTIPLIER = 1e-3
|
23
26
|
|
24
27
|
# Minimum value for image pixel intensity.
|
25
28
|
IMAGE_MIN_VAL = 0.0
|