mct-nightly 2.1.0.20240725.446__py3-none-any.whl → 2.1.0.20240727.431__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {mct_nightly-2.1.0.20240725.446.dist-info → mct_nightly-2.1.0.20240727.431.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.1.0.20240725.446.dist-info → mct_nightly-2.1.0.20240727.431.dist-info}/RECORD +35 -31
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/pytorch/constants.py +6 -1
  5. model_compression_toolkit/core/pytorch/utils.py +27 -0
  6. model_compression_toolkit/data_generation/common/data_generation.py +20 -18
  7. model_compression_toolkit/data_generation/common/data_generation_config.py +8 -11
  8. model_compression_toolkit/data_generation/common/enums.py +24 -12
  9. model_compression_toolkit/data_generation/common/image_pipeline.py +50 -12
  10. model_compression_toolkit/data_generation/common/model_info_exctractors.py +0 -8
  11. model_compression_toolkit/data_generation/common/optimization_utils.py +7 -11
  12. model_compression_toolkit/data_generation/keras/constants.py +5 -2
  13. model_compression_toolkit/data_generation/keras/image_operations.py +189 -0
  14. model_compression_toolkit/data_generation/keras/image_pipeline.py +50 -104
  15. model_compression_toolkit/data_generation/keras/keras_data_generation.py +28 -36
  16. model_compression_toolkit/data_generation/keras/model_info_exctractors.py +0 -13
  17. model_compression_toolkit/data_generation/keras/optimization_functions/bn_layer_weighting_functions.py +16 -6
  18. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +219 -0
  19. model_compression_toolkit/data_generation/keras/optimization_functions/output_loss_functions.py +39 -13
  20. model_compression_toolkit/data_generation/keras/optimization_functions/scheduler_step_functions.py +6 -98
  21. model_compression_toolkit/data_generation/keras/optimization_utils.py +15 -28
  22. model_compression_toolkit/data_generation/pytorch/constants.py +4 -1
  23. model_compression_toolkit/data_generation/pytorch/image_operations.py +105 -0
  24. model_compression_toolkit/data_generation/pytorch/image_pipeline.py +70 -78
  25. model_compression_toolkit/data_generation/pytorch/model_info_exctractors.py +0 -10
  26. model_compression_toolkit/data_generation/pytorch/optimization_functions/bn_layer_weighting_functions.py +17 -6
  27. model_compression_toolkit/data_generation/pytorch/optimization_functions/image_initilization.py +2 -2
  28. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +219 -0
  29. model_compression_toolkit/data_generation/pytorch/optimization_functions/output_loss_functions.py +55 -21
  30. model_compression_toolkit/data_generation/pytorch/optimization_functions/scheduler_step_functions.py +15 -0
  31. model_compression_toolkit/data_generation/pytorch/optimization_utils.py +32 -54
  32. model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py +57 -52
  33. {mct_nightly-2.1.0.20240725.446.dist-info → mct_nightly-2.1.0.20240727.431.dist-info}/LICENSE.md +0 -0
  34. {mct_nightly-2.1.0.20240725.446.dist-info → mct_nightly-2.1.0.20240727.431.dist-info}/WHEEL +0 -0
  35. {mct_nightly-2.1.0.20240725.446.dist-info → mct_nightly-2.1.0.20240727.431.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.1.0.20240725.446
3
+ Version: 2.1.0.20240727.431
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=vR8xewBS60bSjPLOXnFBfDT7srZ_io2XFWfqLG-lvkw,1573
1
+ model_compression_toolkit/__init__.py,sha256=dNSTIwKIETFrRFbKVtmCp7sMOFv7eHAfRFVw0joUkik,1573
2
2
  model_compression_toolkit/constants.py,sha256=9pVleMwnhlM4QwIL2HcEq42I1uF4rlSw63RUjkxOF4w,3923
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -210,13 +210,13 @@ model_compression_toolkit/core/keras/statistics_correction/__init__.py,sha256=9H
210
210
  model_compression_toolkit/core/keras/statistics_correction/apply_second_moment_correction.py,sha256=XNCtT9klMcsO1v5KA3MmCq_WgXOIT5QSzbfTOa9T-04,3060
211
211
  model_compression_toolkit/core/keras/visualization/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
212
212
  model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
213
- model_compression_toolkit/core/pytorch/constants.py,sha256=AguUnAsNlj41gwuKIP_7nos3FcJHsIAjewLXSQdrDQM,2624
213
+ model_compression_toolkit/core/pytorch/constants.py,sha256=YwD_joIF0vK8UG2vW1NVvg36pCNWA0vHOXjAgy_XWn0,2794
214
214
  model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=r1XyzUFvrjGcJHQM5ETLsMZIG2yHCr9HMjqf0ti9inw,4175
215
215
  model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
216
216
  model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=XL_RZcfnb_ZY2jdCjOxxz7SbRBzMokbOWsTuYOSjyRU,27569
217
217
  model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=2LDQ7qupglHQ7o1Am7LWdfYVacfQnl-aW2N6l9det1w,3264
218
218
  model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=E6ifk1HdO60k4IRH2EFBzAYWtwUlrGqJoQ66nknpHoQ,4983
219
- model_compression_toolkit/core/pytorch/utils.py,sha256=OT_mrNEJqPgWLdtQuivKMQVjtJY49cmoIVvbRhANl1w,3004
219
+ model_compression_toolkit/core/pytorch/utils.py,sha256=GE7T8q93I5C4As0iOias_dk9HpOvXM1N6---dJlyD60,3863
220
220
  model_compression_toolkit/core/pytorch/back2framework/__init__.py,sha256=H_WixgN0elVWf3exgGYsi58imPoYDj5eYPeh6x4yfug,813
221
221
  model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,sha256=DwNO8WO3JiMawKGKDhlrwCoCjMSBIw5BMbsFFF7eDS4,2279
222
222
  model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
@@ -269,36 +269,40 @@ model_compression_toolkit/core/pytorch/statistics_correction/apply_second_moment
269
269
  model_compression_toolkit/data_generation/__init__.py,sha256=S8pRUqlRvpM5AFHpFWs3zb0H0rtY5nUwmeCQij01oi4,1507
270
270
  model_compression_toolkit/data_generation/common/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
271
271
  model_compression_toolkit/data_generation/common/constants.py,sha256=21e3ZX9WVYojexG2acTgklrBk8ZO9DjJnKpP4KHZC44,1018
272
- model_compression_toolkit/data_generation/common/data_generation.py,sha256=fccGG6cTMScZwjnJDQKMugOLdgm9dKg5rRfcBD4EFYQ,6415
273
- model_compression_toolkit/data_generation/common/data_generation_config.py,sha256=ynyNaT2x2d23bYSrO2sRItM2ZsjGD0K0fM71FlibiJQ,4564
274
- model_compression_toolkit/data_generation/common/enums.py,sha256=OGnvtEGFbP5l4V3-1l32zzVQwTb1vGJhTVF0kOkYZK4,3584
275
- model_compression_toolkit/data_generation/common/image_pipeline.py,sha256=WwyeoIvgmcxKnuOX-_Hl_0APET4M26f5x-prhUB3qvU,2149
276
- model_compression_toolkit/data_generation/common/model_info_exctractors.py,sha256=kz3w4h4fO4R2N5IgLvSkqDUwjhH4S-I3n3_pK2hQ_uo,6200
277
- model_compression_toolkit/data_generation/common/optimization_utils.py,sha256=Q_yG8T8HQxfTKjVmN7bLm4M4y-5WrMeGQ_h5RnurSmg,19761
272
+ model_compression_toolkit/data_generation/common/data_generation.py,sha256=W8PeOcL1fBVB1WgXSCrEw-G7AWa6MNzjTqcFbmMhrGE,6687
273
+ model_compression_toolkit/data_generation/common/data_generation_config.py,sha256=yKqSDJGdbnc9HEmg94sPqMSXGR2OmAzt5X5MQcy_YX8,4473
274
+ model_compression_toolkit/data_generation/common/enums.py,sha256=V5qAaqMg2WFhsrJ11rTDcRWBhbsxhEHt3uwRq6cesNo,4249
275
+ model_compression_toolkit/data_generation/common/image_pipeline.py,sha256=PfunQMxYm6KqJUEUVYhtY7-JTq4J-XTyLc1HOalP15s,4761
276
+ model_compression_toolkit/data_generation/common/model_info_exctractors.py,sha256=CqruljgQ564SMRQtxgYYDWKM7HYDz18MCShNgrRYQKg,5933
277
+ model_compression_toolkit/data_generation/common/optimization_utils.py,sha256=aEDSclZ2TvIIqN1x9CLf8MBe2GA3m1aEXtbd5Sgcd8k,19528
278
278
  model_compression_toolkit/data_generation/keras/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
279
- model_compression_toolkit/data_generation/keras/constants.py,sha256=uy3eU24ykygIrjIvwOMj3j5euBeN2PwWiEFPOkJJ7ss,1088
280
- model_compression_toolkit/data_generation/keras/image_pipeline.py,sha256=_Qezq67huKmmNsxdFBBrTY-VaGR-paFzDH80dDuRnug,7623
281
- model_compression_toolkit/data_generation/keras/keras_data_generation.py,sha256=lWNxHIews7fREVGT99b2EPV6VdpuLjwDINLApatlPjo,21539
282
- model_compression_toolkit/data_generation/keras/model_info_exctractors.py,sha256=d_SOttfnOk97_nJ7Vjv5p-8tcQqAXJ_hzc1khZGU92s,8610
283
- model_compression_toolkit/data_generation/keras/optimization_utils.py,sha256=uQAJpJPpnLDTTLDQGyTS0ZYp2T38TTZLOOElcJPBKHA,21146
279
+ model_compression_toolkit/data_generation/keras/constants.py,sha256=sxhhGHC-INBs1nVXhyokbFi9ob4jPkSRviuc83JRsgQ,1152
280
+ model_compression_toolkit/data_generation/keras/image_operations.py,sha256=OtJ5Yz8BZVOnGqyTHwlseRe4EmoLDYxz3bblGtw6HnY,6233
281
+ model_compression_toolkit/data_generation/keras/image_pipeline.py,sha256=E-HVverorhq33xzteuwUPtOrGDIYoEEs4fZJgiqOAzQ,7043
282
+ model_compression_toolkit/data_generation/keras/keras_data_generation.py,sha256=IMnmUn7fUsMcJ980FZWuX36iUYXAEYxdYk8oXwz-Xd8,21207
283
+ model_compression_toolkit/data_generation/keras/model_info_exctractors.py,sha256=1E5xbn0P3py4EYjdpPD9JwGr4jlc3qe1ml1py0t40b8,8026
284
+ model_compression_toolkit/data_generation/keras/optimization_utils.py,sha256=cHv2tl-_9_D14mWqzNYtKFY8q7sJfW_V__dpZqzRvIo,20546
284
285
  model_compression_toolkit/data_generation/keras/optimization_functions/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
285
286
  model_compression_toolkit/data_generation/keras/optimization_functions/batchnorm_alignment_functions.py,sha256=f5M7KoISGnb6S6zR7SyQ9dYmQctW9iYRi0Bda1BLq70,1983
286
- model_compression_toolkit/data_generation/keras/optimization_functions/bn_layer_weighting_functions.py,sha256=T6P7VEhRRcJAvpPhMOp3izo8Gn13_nMb7GNNm1w0WhA,2699
287
+ model_compression_toolkit/data_generation/keras/optimization_functions/bn_layer_weighting_functions.py,sha256=xQWTeP-Im6xEUupF-VEjZq-UsRNzpoW0LuMHFR2cX9Q,3390
287
288
  model_compression_toolkit/data_generation/keras/optimization_functions/image_initilization.py,sha256=sjSPLLFLjJ6d0DDSaxnCE0ydIT1zhL8H73QTXEuUfgw,4119
288
- model_compression_toolkit/data_generation/keras/optimization_functions/output_loss_functions.py,sha256=bVb6pnVzGPzC7_Hjj1X3WUUNSfndYUG8_PL_-nQo2UU,5325
289
- model_compression_toolkit/data_generation/keras/optimization_functions/scheduler_step_functions.py,sha256=1HK5-Z0WSk_L62RB2RQGlrwKpCQyoR2TKI5qjXod6cc,4913
289
+ model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py,sha256=xCc7GlmW-jpf27P8mI2APpAL8LC6zmD1BMbN7Q6wzEE,8647
290
+ model_compression_toolkit/data_generation/keras/optimization_functions/output_loss_functions.py,sha256=vr_H1dbFINS7LBX_SfW59g0C8ie9grAyOIpCKuPoI1w,6384
291
+ model_compression_toolkit/data_generation/keras/optimization_functions/scheduler_step_functions.py,sha256=9RhNWtw_cdDlGqEGEdn1JWwvfA8V-Z6ioZn1ppdHFmA,1695
290
292
  model_compression_toolkit/data_generation/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
291
- model_compression_toolkit/data_generation/pytorch/constants.py,sha256=QWyreMImcfKzW1RBqf-4yl2kAr4sAABNNgQmPUkqPEo,1179
292
- model_compression_toolkit/data_generation/pytorch/image_pipeline.py,sha256=6g7OpOuO3cU4TIuelaRjBKpCPgiMbe1a3iy9bZtdZUo,6617
293
- model_compression_toolkit/data_generation/pytorch/model_info_exctractors.py,sha256=sO9tA03nIaeYnzOL4Egec5sVcSGU8H8k9-nNjhaLEbk,9690
294
- model_compression_toolkit/data_generation/pytorch/optimization_utils.py,sha256=AjYsO-lm06JOUMoKkS6VbyF4O_l_ffWXrgamqJm1ofE,19085
295
- model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py,sha256=Jymevochf1A6sz_bShQiJVj5IUtxKbfW5s5Bh7EZhUo,21238
293
+ model_compression_toolkit/data_generation/pytorch/constants.py,sha256=ZiyweWj2Bnk6duhcV4zowsPvqLdON-AlLhkAuLmCqxg,1256
294
+ model_compression_toolkit/data_generation/pytorch/image_operations.py,sha256=KUQKOj5G4UPGX9f9PSiLRlBo4e3rRRPec88wkozNgqw,3900
295
+ model_compression_toolkit/data_generation/pytorch/image_pipeline.py,sha256=dcQr-67u9-ggGuS39YAvR7z-Y0NOdJintcVQ5vy1bM8,7478
296
+ model_compression_toolkit/data_generation/pytorch/model_info_exctractors.py,sha256=y6vMed6lQQj67-BXZKrAcWUNTkH8YjiUhknOV4wSpRA,9399
297
+ model_compression_toolkit/data_generation/pytorch/optimization_utils.py,sha256=vRMeUEdInPuJisiO-SKo_9miWZV90sz8GCg5MY0AqiU,18098
298
+ model_compression_toolkit/data_generation/pytorch/pytorch_data_generation.py,sha256=OjdAG0uGdbN0ATMrkRskhEttkUgSXN8KCVd8JXKiwxk,21620
296
299
  model_compression_toolkit/data_generation/pytorch/optimization_functions/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
297
300
  model_compression_toolkit/data_generation/pytorch/optimization_functions/batchnorm_alignment_functions.py,sha256=dMc4zz9XfYfAT4Cxns57VgvGZWPAMfaGlWLFyCyl8TA,1968
298
- model_compression_toolkit/data_generation/pytorch/optimization_functions/bn_layer_weighting_functions.py,sha256=i3ePEI8xDE3xZEtmzT5lCkLn9wpObUi_OgqnVDf7nj8,2597
299
- model_compression_toolkit/data_generation/pytorch/optimization_functions/image_initilization.py,sha256=gQDZvuT2yt3V1K5TPvBwAdvdTU-ztsQgGAL7ErKm1PQ,4719
300
- model_compression_toolkit/data_generation/pytorch/optimization_functions/output_loss_functions.py,sha256=75_Dl7clEa_SDYRVF3Ub27urYMar8KydZyTUEvZT3ks,5015
301
- model_compression_toolkit/data_generation/pytorch/optimization_functions/scheduler_step_functions.py,sha256=bxbNlQY7AKX8lEJaJReEn0lGBGb6e4PiDEqyW__nmig,2587
301
+ model_compression_toolkit/data_generation/pytorch/optimization_functions/bn_layer_weighting_functions.py,sha256=We0fVMQ4oU7Y0IWQ8fKy8KpqkIiLyKoQeF9XKAQ6TH0,3317
302
+ model_compression_toolkit/data_generation/pytorch/optimization_functions/image_initilization.py,sha256=hhWSZ7w45dE5SQ6jM27cBkCSJObWkALs_RpD6afPi68,4753
303
+ model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py,sha256=NydGxFIclmrfU3HWYUrRbprg4hPt470QP6MTOMLEhRs,9172
304
+ model_compression_toolkit/data_generation/pytorch/optimization_functions/output_loss_functions.py,sha256=PRVmn8o2hTdwTdbd2ezf__LNbFvcgiVO0c25dsyg3Tg,6549
305
+ model_compression_toolkit/data_generation/pytorch/optimization_functions/scheduler_step_functions.py,sha256=zMjY2y4FSHonuY5hddbMTb8qAQtLtohYF7q1wuruDDs,3267
302
306
  model_compression_toolkit/exporter/__init__.py,sha256=Eg3c4EAjW3g6h13A-Utgf9ncHrTMRHAoySNDQGPDZ4E,1301
303
307
  model_compression_toolkit/exporter/model_exporter/__init__.py,sha256=9HIBmj8ROdCA-yvkpA8EcN6RHJe_2vEpLLW_gxOJtak,698
304
308
  model_compression_toolkit/exporter/model_exporter/fw_agonstic/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
@@ -517,8 +521,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
517
521
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=yrZNVRm2IRU7r7R-hjS2lOQ6wvEEvbeunvf2jKoWjXk,3277
518
522
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
519
523
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=eyMoXt5o5EnMr6d-rpCwQdX5mAiYiymvbgKv4tf7-a0,4576
520
- mct_nightly-2.1.0.20240725.446.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
521
- mct_nightly-2.1.0.20240725.446.dist-info/METADATA,sha256=m0FhYyaViVkXqL-az7oYXCs-wZpStJIJO_Tbf0wISUM,19719
522
- mct_nightly-2.1.0.20240725.446.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
523
- mct_nightly-2.1.0.20240725.446.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
524
- mct_nightly-2.1.0.20240725.446.dist-info/RECORD,,
524
+ mct_nightly-2.1.0.20240727.431.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
525
+ mct_nightly-2.1.0.20240727.431.dist-info/METADATA,sha256=g2Z7XHdZo_A9_vL0R9IJv5kcLIsZg_ONvFBnuOFqFkk,19719
526
+ mct_nightly-2.1.0.20240727.431.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
527
+ mct_nightly-2.1.0.20240727.431.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
528
+ mct_nightly-2.1.0.20240727.431.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.1.0.20240725.000446"
30
+ __version__ = "2.1.0.20240727.000431"
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
-
15
+ import torch
16
16
 
17
17
  # # Layer type constants:
18
18
  PLACEHOLDER = 'placeholder'
@@ -94,3 +94,8 @@ BIAS_V = 'bias_v'
94
94
  # # Batch size value for 'reshape' and 'view' operators,
95
95
  # # the value is -1 so the batch size is inferred from the length of the array and remaining dimensions.
96
96
  BATCH_DIM_VALUE = -1
97
+
98
+
99
+ # The maximum and minimum representable values for float16
100
+ MAX_FLOAT16 = torch.finfo(torch.float16).max - 1
101
+ MIN_FLOAT16 = torch.finfo(torch.float16).min - 1
@@ -13,8 +13,11 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  import torch
16
+ from torch import Tensor
16
17
  import numpy as np
17
18
  from typing import Union
19
+
20
+ from model_compression_toolkit.core.pytorch.constants import MAX_FLOAT16, MIN_FLOAT16
18
21
  from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device
19
22
  from model_compression_toolkit.logger import Logger
20
23
 
@@ -83,3 +86,27 @@ def torch_tensor_to_numpy(tensor: Union[torch.Tensor, list, tuple]) -> Union[np.
83
86
  return tensor.cpu().detach().contiguous().numpy()
84
87
  else:
85
88
  Logger.critical(f'Unsupported type for conversion to Numpy array: {type(tensor)}.')
89
+
90
+
91
+ def clip_inf_values_float16(tensor: Tensor) -> Tensor:
92
+ """
93
+ Clips +inf and -inf values in a float16 tensor to the maximum and minimum representable values.
94
+
95
+ Parameters:
96
+ tensor (Tensor): Input PyTorch tensor of dtype float16.
97
+
98
+ Returns:
99
+ Tensor: A tensor with +inf values replaced by the maximum float16 value,
100
+ and -inf values replaced by the minimum float16 value.
101
+ """
102
+ # Check if the tensor is of dtype float16
103
+ if tensor.dtype != torch.float16:
104
+ return tensor
105
+
106
+ # Create a mask for inf values (both positive and negative)
107
+ inf_mask = torch.isinf(tensor)
108
+
109
+ # Replace inf values with max float16 value
110
+ tensor[inf_mask] = MAX_FLOAT16 * torch.sign(tensor[inf_mask])
111
+
112
+ return tensor
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  # Import required modules and classes
16
- from typing import Any, Tuple, Dict, Callable, List
16
+ from typing import Any, Tuple, Dict, Callable, List, Union
17
17
 
18
18
  from model_compression_toolkit.data_generation.common.data_generation_config import DataGenerationConfig
19
19
  from model_compression_toolkit.data_generation.common.enums import ImagePipelineType, ImageNormalizationType, \
@@ -24,7 +24,7 @@ from model_compression_toolkit.logger import Logger
24
24
 
25
25
  def get_data_generation_classes(
26
26
  data_generation_config: DataGenerationConfig,
27
- output_image_size: Tuple,
27
+ output_image_size: Union[int, Tuple[int, int]],
28
28
  n_images: int,
29
29
  image_pipeline_dict: Dict,
30
30
  image_normalization_dict: Dict,
@@ -38,7 +38,7 @@ def get_data_generation_classes(
38
38
 
39
39
  Args:
40
40
  data_generation_config (DataGenerationConfig): Configuration for data generation.
41
- output_image_size (Tuple): The desired output image size.
41
+ output_image_size (Union[int, Tuple[int, int]]): The desired output image size.
42
42
  n_images (int): The number of random samples.
43
43
  image_pipeline_dict (Dict): Dictionary mapping ImagePipelineType to corresponding image pipeline classes.
44
44
  image_normalization_dict (Dict): Dictionary mapping ImageNormalizationType to corresponding
@@ -56,26 +56,28 @@ def get_data_generation_classes(
56
56
  output_loss_fn (Callable): Function to compute output loss.
57
57
  init_dataset (Any): The initial dataset used for image generation.
58
58
  """
59
+ # Get the normalization values corresponding to the specified type
60
+ normalization = image_normalization_dict.get(data_generation_config.image_normalization_type)
61
+
62
+ # Check if the image normalization type is valid
63
+ if normalization is None:
64
+ Logger.critical(
65
+ f'Invalid image_normalization_type {data_generation_config.image_normalization_type}. '
66
+ f'Please select one from {ImageNormalizationType.get_values()}.') # pragma: no cover
67
+
59
68
  # Get the image pipeline class corresponding to the specified type
60
69
  image_pipeline = (
61
70
  image_pipeline_dict.get(data_generation_config.image_pipeline_type)(
62
71
  output_image_size=output_image_size,
63
- extra_pixels=data_generation_config.extra_pixels))
72
+ extra_pixels=data_generation_config.extra_pixels,
73
+ image_clipping=data_generation_config.image_clipping,
74
+ normalization=normalization))
64
75
 
65
76
  # Check if the image pipeline type is valid
66
77
  if image_pipeline is None:
67
78
  Logger.critical(
68
79
  f'Invalid image_pipeline_type {data_generation_config.image_pipeline_type}. '
69
- f'Please select one from {ImagePipelineType.get_values()}.')
70
-
71
- # Get the normalization values corresponding to the specified type
72
- normalization = image_normalization_dict.get(data_generation_config.image_normalization_type)
73
-
74
- # Check if the image normalization type is valid
75
- if normalization is None:
76
- Logger.critical(
77
- f'Invalid image_normalization_type {data_generation_config.image_normalization_type}. '
78
- f'Please select one from {ImageNormalizationType.get_values()}.')
80
+ f'Please select one from {ImagePipelineType.get_values()}.') # pragma: no cover
79
81
 
80
82
  # Get the layer weighting function corresponding to the specified type
81
83
  bn_layer_weighting_fn = bn_layer_weighting_function_dict.get(data_generation_config.layer_weighting_type)
@@ -83,7 +85,7 @@ def get_data_generation_classes(
83
85
  if bn_layer_weighting_fn is None:
84
86
  Logger.critical(
85
87
  f'Invalid layer_weighting_type {data_generation_config.layer_weighting_type}. '
86
- f'Please select one from {BNLayerWeightingType.get_values()}.')
88
+ f'Please select one from {BNLayerWeightingType.get_values()}.') # pragma: no cover
87
89
 
88
90
  # Get the image initialization function corresponding to the specified type
89
91
  image_initialization_fn = image_initialization_function_dict.get(data_generation_config.data_init_type)
@@ -92,7 +94,7 @@ def get_data_generation_classes(
92
94
  if image_initialization_fn is None:
93
95
  Logger.critical(
94
96
  f'Invalid data_init_type {data_generation_config.data_init_type}. '
95
- f'Please select one from {DataInitType.get_values()}.')
97
+ f'Please select one from {DataInitType.get_values()}.') # pragma: no cover
96
98
 
97
99
  # Get the BatchNorm alignment loss function corresponding to the specified type
98
100
  bn_alignment_loss_fn = bn_alignment_loss_function_dict.get(data_generation_config.bn_alignment_loss_type)
@@ -101,7 +103,7 @@ def get_data_generation_classes(
101
103
  if bn_alignment_loss_fn is None:
102
104
  Logger.critical(
103
105
  f'Invalid bn_alignment_loss_type {data_generation_config.bn_alignment_loss_type}. '
104
- f'Please select one from {BatchNormAlignemntLossType.get_values()}.')
106
+ f'Please select one from {BatchNormAlignemntLossType.get_values()}.') # pragma: no cover
105
107
 
106
108
  # Get the output loss function corresponding to the specified type
107
109
  output_loss_fn = output_loss_function_dict.get(data_generation_config.output_loss_type)
@@ -110,7 +112,7 @@ def get_data_generation_classes(
110
112
  if output_loss_fn is None:
111
113
  Logger.critical(
112
114
  f'Invalid output_loss_type {data_generation_config.output_loss_type}. '
113
- f'Please select one from {OutputLossType.get_values()}.')
115
+ f'Please select one from {OutputLossType.get_values()}.') # pragma: no cover
114
116
 
115
117
  # Initialize the dataset for data generation
116
118
  init_dataset = image_initialization_fn(
@@ -12,12 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Callable, Any, List
15
+ from typing import Any, List, Tuple, Union
16
16
 
17
17
  from model_compression_toolkit.data_generation.common.enums import SchedulerType, BatchNormAlignemntLossType, \
18
18
  DataInitType, BNLayerWeightingType, ImageGranularity, ImagePipelineType, ImageNormalizationType, OutputLossType
19
19
 
20
-
21
20
  class DataGenerationConfig:
22
21
  """
23
22
  Configuration class for data generation.
@@ -36,11 +35,10 @@ class DataGenerationConfig:
36
35
  layer_weighting_type: BNLayerWeightingType = None,
37
36
  image_pipeline_type: ImagePipelineType = None,
38
37
  image_normalization_type: ImageNormalizationType = None,
39
- extra_pixels: int = 0,
38
+ extra_pixels: Union[int, Tuple[int, int]] = 0,
40
39
  bn_layer_types: List = [],
41
40
  last_layer_types: List = [],
42
- clip_images: bool = True,
43
- reflection: bool = True,
41
+ image_clipping: bool = True,
44
42
  ):
45
43
  """
46
44
  Initialize the DataGenerationConfig.
@@ -59,17 +57,15 @@ class DataGenerationConfig:
59
57
  layer_weighting_type (BNLayerWeightingType): Type of layer weighting. Defaults to None.
60
58
  image_pipeline_type (ImagePipelineType): Type of image pipeline. Defaults to None.
61
59
  image_normalization_type (ImageNormalizationType): Type of image normalization. Defaults to None.
62
- extra_pixels (int): Extra pixels to add to the input image size. Defaults to 0.
60
+ extra_pixels (Union[int, Tuple[int, int]]): Extra pixels to add to the input image size. Defaults to 0.
63
61
  bn_layer_types (List): List of BatchNorm layer types. Defaults to [].
64
62
  last_layer_types (List): List of layer types. Defaults to [].
65
- clip_images (bool): Flag to enable image clipping. Defaults to True.
66
- reflection (bool): Flag to enable reflection. Defaults to True.
63
+ image_clipping (bool): Flag to enable image clipping. Defaults to True.
67
64
  """
68
65
  self.n_iter = n_iter
69
66
  self.optimizer = optimizer
70
67
  self.data_gen_batch_size = data_gen_batch_size
71
68
  self.initial_lr = initial_lr
72
- self.output_loss_multiplier = output_loss_multiplier
73
69
  self.image_granularity = image_granularity
74
70
  self.scheduler_type = scheduler_type
75
71
  self.bn_alignment_loss_type = bn_alignment_loss_type
@@ -81,6 +77,7 @@ class DataGenerationConfig:
81
77
  self.layer_weighting_type = layer_weighting_type
82
78
  self.bn_layer_types = bn_layer_types
83
79
  self.last_layer_types = last_layer_types
84
- self.clip_images = clip_images
85
- self.reflection = reflection
80
+ self.image_clipping = image_clipping
81
+ self.output_loss_multiplier = output_loss_multiplier
82
+
86
83
 
@@ -16,7 +16,6 @@ from enum import Enum
16
16
 
17
17
 
18
18
  class EnumBaseClass(Enum):
19
- @classmethod
20
19
  def get_values(cls):
21
20
  """
22
21
  Get the list of values corresponding to the enum members.
@@ -24,7 +23,23 @@ class EnumBaseClass(Enum):
24
23
  Returns:
25
24
  List of values.
26
25
  """
27
- return [value.value for value in cls.__members__.values()]
26
+ return list(cls.__members__.values())
27
+
28
+ @classmethod
29
+ def get_enum_by_value(cls, target_value):
30
+ """
31
+ Function to get the key corresponding to a given enum value.
32
+
33
+ Parameters:
34
+ target_value: The value to find the key for.
35
+
36
+ Returns:
37
+ The key corresponding to the given value if found, else None.
38
+ """
39
+ for value in cls.__members__.values():
40
+ if value.value == target_value:
41
+ return value
42
+ return None
28
43
 
29
44
 
30
45
  class ImageGranularity(EnumBaseClass):
@@ -61,15 +76,12 @@ class ImagePipelineType(EnumBaseClass):
61
76
  """
62
77
  An enum for choosing the image pipeline type for image manipulation:
63
78
 
64
- RANDOM_CROP - Crop the images.
65
-
66
- RANDOM_CROP_FLIP - Crop and flip the images.
79
+ SMOOTHING_AND_AUGMENTATION - Apply a smoothing filter, then crop and flip the images.
67
80
 
68
81
  IDENTITY - Do not apply any manipulation (identity transformation).
69
82
 
70
83
  """
71
- RANDOM_CROP = 'random_crop'
72
- RANDOM_CROP_FLIP = 'random_crop_flip'
84
+ SMOOTHING_AND_AUGMENTATION = 'smoothing_and_augmentation'
73
85
  IDENTITY = 'identity'
74
86
 
75
87
 
@@ -118,16 +130,15 @@ class BatchNormAlignemntLossType(EnumBaseClass):
118
130
  class OutputLossType(EnumBaseClass):
119
131
  """
120
132
  An enum for choosing the output loss type:
121
-
122
133
  NONE - No output loss is applied.
123
-
124
- MIN_MAX_DIFF - Use min-max difference as the output loss.
125
-
134
+ NEGATIVE_MIN_MAX_DIFF - Use the mean of the negative min-max difference as the output loss.
135
+ INVERSE_MIN_MAX_DIFF - Use mean of the 1/(min-max) difference as the output loss.
126
136
  REGULARIZED_MIN_MAX_DIFF - Use regularized min-max difference as the output loss.
127
137
 
128
138
  """
129
139
  NONE = 'none'
130
- MIN_MAX_DIFF = 'min_max_diff'
140
+ NEGATIVE_MIN_MAX_DIFF = 'negative_min_max_diff'
141
+ INVERSE_MIN_MAX_DIFF = 'inverse_min_max_diff'
131
142
  REGULARIZED_MIN_MAX_DIFF = 'regularized_min_max_diff'
132
143
 
133
144
 
@@ -141,4 +152,5 @@ class SchedulerType(EnumBaseClass):
141
152
 
142
153
  """
143
154
  REDUCE_ON_PLATEAU = 'reduce_on_plateau'
155
+ REDUCE_ON_PLATEAU_WITH_RESET = 'reduce_on_plateau_with_reset'
144
156
  STEP = 'step'
@@ -13,31 +13,61 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  from abc import ABC, abstractmethod
16
- from typing import Any
16
+ from typing import Any, Tuple, Union, Dict, List
17
+
18
+ from model_compression_toolkit.data_generation import ImageNormalizationType
19
+ from model_compression_toolkit.logger import Logger
17
20
 
18
21
 
19
22
  class BaseImagePipeline(ABC):
20
23
  def __init__(self,
21
- output_image_size: int,
22
- extra_pixels: int = 0):
24
+ output_image_size: Union[int, Tuple[int, int]],
25
+ extra_pixels: Union[int, Tuple[int, int]] = 0,
26
+ image_clipping: bool = False,
27
+ normalization: List[List[int]] = [[0, 0, 0], [1, 1, 1]]):
23
28
  """
24
29
  Base class for image pipeline.
25
30
 
26
31
  Args:
27
- output_image_size (int): The desired output image size.
28
- extra_pixels (int, optional): Extra pixels to add to the input image size. Defaults to 0.
32
+ output_image_size (Union[int, Tuple[int, int]]): The desired output image size.
33
+ extra_pixels (Union[int, Tuple[int, int]]): Extra pixels to add to the input image size. Defaults to 0.
34
+ image_clipping (bool): Whether to clip images during optimization.
35
+ normalization (List[List[float]]): The image normalization values for processing images during optimization.
29
36
  """
30
- self.output_image_size = output_image_size
31
- self.extra_pixels = extra_pixels
37
+ if isinstance(output_image_size, int):
38
+ self.output_image_size = (output_image_size, output_image_size)
39
+ elif isinstance(output_image_size, tuple) and len(output_image_size) == 1:
40
+ self.output_image_size = output_image_size + output_image_size # concatenate two tuples
41
+ elif isinstance(output_image_size, tuple) and len(output_image_size) == 2:
42
+ self.output_image_size = output_image_size
43
+ elif isinstance(output_image_size, tuple):
44
+ Logger.critical(f"'output_image_size' should a tuple of length 1 or 2. Got tuple of length {len(output_image_size)}") # pragma: no cover
45
+ else:
46
+ Logger.critical(f"'output_image_size' should be an int or tuple but type {type(output_image_size)} was received.") # pragma: no cover
47
+
48
+ if isinstance(extra_pixels, int):
49
+ self.extra_pixels = (extra_pixels, extra_pixels)
50
+ elif isinstance(extra_pixels, tuple) and len(extra_pixels) == 1:
51
+ self.extra_pixels = extra_pixels + extra_pixels # concatenate two tuples
52
+ elif isinstance(extra_pixels, tuple) and len(extra_pixels) == 2:
53
+ self.extra_pixels = extra_pixels
54
+ elif isinstance(extra_pixels, tuple):
55
+ Logger.critical(f"'extra_pixels' should a tuple of length 1 or 2. Got tuple of length {len(extra_pixels)}") # pragma: no cover
56
+ else:
57
+ Logger.critical(f"'extra_pixels' should be an int or tuple but type {type(extra_pixels)} was received.") # pragma: no cover
58
+
59
+ self.image_clipping = image_clipping
60
+ self.normalization = normalization
61
+
32
62
  @abstractmethod
33
- def get_image_input_size(self) -> int:
63
+ def get_image_input_size(self) -> Tuple[int, int]:
34
64
  """
35
65
  Get the size of the input image for the image pipeline.
36
66
 
37
67
  Returns:
38
- int: The input image size.
68
+ Tuple[int, int]: The input image size.
39
69
  """
40
- raise NotImplemented
70
+ raise NotImplemented # pragma: no cover
41
71
 
42
72
  @abstractmethod
43
73
  def image_input_manipulation(self,
@@ -51,7 +81,7 @@ class BaseImagePipeline(ABC):
51
81
  Returns:
52
82
  Any: Manipulated images.
53
83
  """
54
- raise NotImplemented
84
+ raise NotImplemented # pragma: no cover
55
85
 
56
86
  @abstractmethod
57
87
  def image_output_finalize(self,
@@ -65,4 +95,12 @@ class BaseImagePipeline(ABC):
65
95
  Returns:
66
96
  Any: Finalized images.
67
97
  """
68
- raise NotImplemented
98
+ raise NotImplemented # pragma: no cover
99
+
100
+
101
+ # Dictionary mapping ImageNormalizationType to corresponding normalization values
102
+ image_normalization_dict: Dict[ImageNormalizationType, List[List[float]]] = {
103
+ ImageNormalizationType.TORCHVISION: [[0.485 * 255, 0.456 * 255, 0.406 * 255], [0.229 * 255, 0.224 * 255, 0.225 * 255]],
104
+ ImageNormalizationType.KERAS_APPLICATIONS: [[127.5, 127.5, 127.5], [127.5, 127.5, 127.5]],
105
+ ImageNormalizationType.NO_NORMALIZATION: [[0, 0, 0], [1, 1, 1]]
106
+ }
@@ -157,14 +157,6 @@ class ActivationExtractor:
157
157
  """
158
158
  raise NotImplemented # pragma: no cover
159
159
 
160
- def get_num_extractor_layers(self) -> int:
161
- """
162
- Get the number of layers for which to extract input activations.
163
-
164
- Returns:
165
- int: Number of layers for which to extract input activations.
166
- """
167
- return self.num_layers
168
160
 
169
161
  @abstractmethod
170
162
  def get_extractor_layer_names(self) -> List:
@@ -48,8 +48,6 @@ class ImagesOptimizationHandler:
48
48
  scheduler: Any,
49
49
  normalization_mean: List[float],
50
50
  normalization_std: List[float],
51
- clip_images: bool,
52
- reflection: bool,
53
51
  initial_lr: float,
54
52
  eps: float = 1e-6):
55
53
  """
@@ -67,8 +65,6 @@ class ImagesOptimizationHandler:
67
65
  scheduler (Any): The scheduler responsible for adjusting the learning rate of the optimizer over time.
68
66
  normalization_mean (List[float]): Mean values used for image normalization.
69
67
  normalization_std (List[float]): Standard deviation values used for image normalization.
70
- clip_images (bool): Flag indicating whether to clip generated images during optimization.
71
- reflection (bool): Flag indicating whether to use reflection during image generation.
72
68
  initial_lr (float): The initial learning rate used by the optimizer.
73
69
  eps (float, optional): A small value added for numerical stability. Defaults to 1e-6.
74
70
 
@@ -79,8 +75,6 @@ class ImagesOptimizationHandler:
79
75
  self.scheduler = scheduler
80
76
  self.scheduler_step_fn = scheduler_step_fn
81
77
  self.image_granularity = image_granularity
82
- self.clip_images = clip_images
83
- self.reflection = reflection
84
78
  self.eps = eps
85
79
  self.targets = []
86
80
  self.initial_lr = initial_lr
@@ -209,9 +203,11 @@ class ImagesOptimizationHandler:
209
203
  imgs_layer_mean, imgs_layer_second_moment, imgs_layer_std = self.all_imgs_stats_holder.get_stats(
210
204
  batch_index, layer_name)
211
205
 
212
- # Accumulate the batchnorm alignment weighted by the layer weight
213
- total_bn_loss += bn_layer_weight * bn_alignment_loss_fn(bn_layer_mean, imgs_layer_mean, bn_layer_std,
214
- imgs_layer_std)
206
+ if imgs_layer_mean is not None and imgs_layer_std is not None:
207
+ bn_alignment_loss = bn_alignment_loss_fn(bn_layer_mean, imgs_layer_mean, bn_layer_std,
208
+ imgs_layer_std)
209
+ # Accumulate the batchnorm alignment weighted by the layer weight
210
+ total_bn_loss += bn_layer_weight * bn_alignment_loss
215
211
 
216
212
  return total_bn_loss
217
213
 
@@ -418,7 +414,7 @@ class BatchStatsHolder:
418
414
  Returns:
419
415
  Any: the mean for the specified layer.
420
416
  """
421
- return self.bn_mean[bn_layer_name]
417
+ return self.bn_mean.get(bn_layer_name)
422
418
 
423
419
  def get_second_moment(self, bn_layer_name: str) -> Any:
424
420
  """
@@ -430,7 +426,7 @@ class BatchStatsHolder:
430
426
  Returns:
431
427
  Any: the second moment for the specified layer.
432
428
  """
433
- return self.bn_second_moment[bn_layer_name]
429
+ return self.bn_second_moment.get(bn_layer_name)
434
430
 
435
431
  def get_var(self, bn_layer_name: str) -> Any:
436
432
  """
@@ -18,8 +18,11 @@ BATCH_AXIS, H_AXIS, W_AXIS, CHANNEL_AXIS = 0, 1, 2, 3
18
18
  # Default initial learning rate constant for Keras.
19
19
  DEFAULT_KERAS_INITIAL_LR = 1
20
20
 
21
- # Default output loss multiplier for Keras.
22
- DEFAULT_KERAS_OUTPUT_LOSS_MULTIPLIER = 0.001
21
+ # Default extra pixels for image padding.
22
+ DEFAULT_KERAS_EXTRA_PIXELS = 32
23
+
24
+ # Default output loss multiplier.
25
+ DEFAULT_KERAS_OUTPUT_LOSS_MULTIPLIER = 1e-3
23
26
 
24
27
  # Minimum value for image pixel intensity.
25
28
  IMAGE_MIN_VAL = 0.0