mct-nightly 2.0.0.20240409.404__py3-none-any.whl → 2.0.0.20240411.406__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.0.0.20240409.404.dist-info → mct_nightly-2.0.0.20240411.406.dist-info}/METADATA +1 -1
- {mct_nightly-2.0.0.20240409.404.dist-info → mct_nightly-2.0.0.20240411.406.dist-info}/RECORD +19 -17
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/quantization/node_quantization_config.py +2 -1
- model_compression_toolkit/core/common/quantization/quantization_config.py +3 -1
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +2 -2
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/concat_threshold_update.py +66 -0
- model_compression_toolkit/core/keras/keras_implementation.py +6 -3
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/concat_threshold_update.py +69 -0
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +4 -0
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py +1 -1
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +1 -1
- model_compression_toolkit/gptq/keras/gptq_training.py +17 -15
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py +2 -1
- model_compression_toolkit/gptq/pytorch/gptq_training.py +18 -16
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py +2 -1
- {mct_nightly-2.0.0.20240409.404.dist-info → mct_nightly-2.0.0.20240411.406.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.0.0.20240409.404.dist-info → mct_nightly-2.0.0.20240411.406.dist-info}/WHEEL +0 -0
- {mct_nightly-2.0.0.20240409.404.dist-info → mct_nightly-2.0.0.20240411.406.dist-info}/top_level.txt +0 -0
{mct_nightly-2.0.0.20240409.404.dist-info → mct_nightly-2.0.0.20240411.406.dist-info}/RECORD
RENAMED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
|
1
|
+
model_compression_toolkit/__init__.py,sha256=Py1f8nJnEfhzHK091eeZjxPHNqF_ZXrOa97rXbJWdw0,1573
|
|
2
2
|
model_compression_toolkit/constants.py,sha256=KW_HUEPmQEYqCvWGyORqkYxpvO7w5LViB5J5D-pm_6o,3648
|
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
|
4
4
|
model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
|
|
@@ -100,8 +100,8 @@ model_compression_toolkit/core/common/quantization/candidate_node_quantization_c
|
|
|
100
100
|
model_compression_toolkit/core/common/quantization/core_config.py,sha256=KYdyfSmjSL4ye24nKlC_c4_AxYb14qoqaeMnZj4-8kE,2257
|
|
101
101
|
model_compression_toolkit/core/common/quantization/debug_config.py,sha256=HtkMmneN-EmAzgZK4Vp4M8Sqm5QKdrvNyyZMpaVqYzY,1482
|
|
102
102
|
model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py,sha256=fwF4VILaX-u3ZaFd81xjbJuhg8Ef-JX_KfMXW0TPV-I,7136
|
|
103
|
-
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=
|
|
104
|
-
model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=
|
|
103
|
+
model_compression_toolkit/core/common/quantization/node_quantization_config.py,sha256=TCgpvtfyzFUedv4sZ6sKzsTyikaVl2ixLj_aHPSC2r0,27014
|
|
104
|
+
model_compression_toolkit/core/common/quantization/quantization_config.py,sha256=BieZDv9oc-Mc78S_LRMGo-s_2acbqiLE0ewaSE1v2VY,6818
|
|
105
105
|
model_compression_toolkit/core/common/quantization/quantization_fn_selection.py,sha256=T1nVWdRJfBQ_iuMQYQSIkjfkR-2n3lAOKGAz_rUZZN0,2190
|
|
106
106
|
model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py,sha256=MwIOBZ4BlZSTIOG75PDvlI3JmZ6t8YjPc1VP9Adei60,3847
|
|
107
107
|
model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha256=N005MSvx8UypVpa7XrxNrB2G732n2wHj3RmLyjTgd3I,2728
|
|
@@ -113,7 +113,7 @@ model_compression_toolkit/core/common/quantization/quantization_params_generatio
|
|
|
113
113
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/outlier_filter.py,sha256=9gnfJV89jpGwAx8ImJ5E9NjCv3lDtbyulP4OtgWb62M,1772
|
|
114
114
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py,sha256=BiwDqt5CeU6CW0Qusy3LwWhFtf2J9BvSuGMsTsG6rSw,8538
|
|
115
115
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py,sha256=noEdvGiyyW7acgQ2OFWLedCODibTGYJifC9qo8YIU5U,4558
|
|
116
|
-
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=
|
|
116
|
+
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=H2D9rdChIviL_j0mF6zy8Qeu_ZXKRu-hLqckSAT1MR8,4352
|
|
117
117
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py,sha256=7kt0JB8PQE0SW9kg8fCwZ5mBkHNgiRrn0of4ZQYQN2A,41524
|
|
118
118
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py,sha256=nug6XgsywxYf57XF_Tnt2xwdf0zLLsajiZKEblo4lFc,3882
|
|
119
119
|
model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py,sha256=QtSAtdAb7sTgtoe9L6DnMFO7rjkOtpzE9kD9xmG7eYM,9743
|
|
@@ -148,7 +148,7 @@ model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7V
|
|
|
148
148
|
model_compression_toolkit/core/keras/constants.py,sha256=Uv3c0UdW55pIVQNW_1HQlgl-dHXREkltOLyzp8G1mTQ,3163
|
|
149
149
|
model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
|
|
150
150
|
model_compression_toolkit/core/keras/default_framework_info.py,sha256=Ha4HTHuiw_KTS5Po1Xnv6GyK9eprpDhYWf-eooS62Ys,4961
|
|
151
|
-
model_compression_toolkit/core/keras/keras_implementation.py,sha256=
|
|
151
|
+
model_compression_toolkit/core/keras/keras_implementation.py,sha256=RS2UEtZ_anZeDxz7Zv6sNv7v9tFVct6d9KVrUlxTGpo,29309
|
|
152
152
|
model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
|
|
153
153
|
model_compression_toolkit/core/keras/keras_node_prior_info.py,sha256=Aqh31wOPaiZcJIOm-uJwzev0eTMdJyXaOk97rs4z7BU,3879
|
|
154
154
|
model_compression_toolkit/core/keras/resource_utilization_data_facade.py,sha256=Xmk2ZL5CaYdb7iG62HdtZ1F64vap7ffnrsuR3e3G5hc,4851
|
|
@@ -166,6 +166,7 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/activatio
|
|
|
166
166
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py,sha256=9YCNPiK5BD7tLs1meabPhzfb2VsyPxrZM17zMFsW_Fo,8158
|
|
167
167
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_reconstruction.py,sha256=GR1a3mCZpNUu4WxixJXF_aSm57phAdxaRoHecNx3hxw,3168
|
|
168
168
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_refusing.py,sha256=5df_xGfXkqNub4xVRnCWQvSohWqdv12axjJ6edVU2H0,2478
|
|
169
|
+
model_compression_toolkit/core/keras/graph_substitutions/substitutions/concat_threshold_update.py,sha256=Hl4LEQ_bw_Vpmf3ZqHujYUqVdvTNsPlEMvr9dZhwg2U,2806
|
|
169
170
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py,sha256=R3U7cjc2E0zheMem16GHygp5jZFGSaomkNOTxTjcAgw,5794
|
|
170
171
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py,sha256=V6hp67CkS_A3WqdsjLjs0ETtdZAOo4P9mhy4aT7W5FE,5940
|
|
171
172
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py,sha256=dyhqZrxSTclXyarT2JYnI5WPX0OvWR_CQiwddIr632U,8143
|
|
@@ -209,7 +210,7 @@ model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKW
|
|
|
209
210
|
model_compression_toolkit/core/pytorch/constants.py,sha256=NI-J7REuxn06oEIHsmJ4GqtNC3TbV8xlkJjt5Ar-c4U,2626
|
|
210
211
|
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=r1XyzUFvrjGcJHQM5ETLsMZIG2yHCr9HMjqf0ti9inw,4175
|
|
211
212
|
model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=IoMvTch5awAEPvB6Tg6ANhFGXvfSgv7JLsUBlxpMwk4,4330
|
|
212
|
-
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=
|
|
213
|
+
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=mT4jd8E1saCpAgrsClufQbnVJ0eYn1xaTQ3teALu4jk,27117
|
|
213
214
|
model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=n_B4a6FMwM9D2w8kzy3oenBWZgXNZuIZgTJC6JEuTy0,3250
|
|
214
215
|
model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=E6ifk1HdO60k4IRH2EFBzAYWtwUlrGqJoQ66nknpHoQ,4983
|
|
215
216
|
model_compression_toolkit/core/pytorch/utils.py,sha256=dRPiteBg2dBNsHwZyYzXiCIAjnelSoeZZsDXlsTw5JQ,2880
|
|
@@ -228,6 +229,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/__init_
|
|
|
228
229
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_folding.py,sha256=j3q5DzbH3ys5MPFfSOVnAXdD7-g4XEKj2ADrdihVr30,8292
|
|
229
230
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_reconstruction.py,sha256=B7aC2TZNrQJ2oQVGBFhKAVqdUU5lYVJSMmwKhjxOHWk,2822
|
|
230
231
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/batchnorm_refusing.py,sha256=JDWOaNwYrZG0zTwd3HwoZUM3tKu7zPbzLOrqNQsu8xA,2162
|
|
232
|
+
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/concat_threshold_update.py,sha256=SBrR24ZAnWPftLinv4FuIqdBGjfYtfXbYQJN5mgy5V4,2861
|
|
231
233
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py,sha256=dYGyb5ebnoeFBF0EaHPQU7CkXvoARdznEEe0laM47LA,3919
|
|
232
234
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_batch_norm.py,sha256=iX8bLHtw2osP42-peNLTRmbpX3cUxdGsAbEfw7NLpx0,3935
|
|
233
235
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/functional_layer_norm.py,sha256=zKSgtVw_P9fUvdq4e7P9yaLDPG_vZ0cecM9sVPtm1ns,3799
|
|
@@ -318,12 +320,12 @@ model_compression_toolkit/exporter/model_wrapper/fw_agnostic/get_inferable_quant
|
|
|
318
320
|
model_compression_toolkit/exporter/model_wrapper/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
319
321
|
model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py,sha256=YffgbVYJG5LKeIsW84Pi7NqzQcvJMeQRnAKQCCmIL6c,3776
|
|
320
322
|
model_compression_toolkit/exporter/model_wrapper/keras/builder/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
321
|
-
model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py,sha256=
|
|
323
|
+
model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py,sha256=k3UrGAw6vKTmZ-oO1lv0VqK3IpAiet9jlIHyEIoL2u0,5132
|
|
322
324
|
model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py,sha256=uL6tJWC4s2IWUy8GJVwtMWpwZZioRRztfKyPJHo14xI,9442
|
|
323
325
|
model_compression_toolkit/exporter/model_wrapper/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
|
|
324
326
|
model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py,sha256=uTQcnzvP44CgPO0twsUdiMmTBE_Td6ZdQtz5U0GZuPI,3464
|
|
325
327
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
326
|
-
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py,sha256=
|
|
328
|
+
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py,sha256=tbXDDPEeWHRS_5DL8e9tTtG6nJ5UohfkLVjI2EIhQeo,4917
|
|
327
329
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py,sha256=4sN5z-6BXrTE5Dp2FX_jKO9ty5iZ2r4RM7XvXtDVLSI,9348
|
|
328
330
|
model_compression_toolkit/gptq/__init__.py,sha256=YKg-tMj9D4Yd0xW9VRD5EN1J5JrmlRbNEF2fOSgodqA,1228
|
|
329
331
|
model_compression_toolkit/gptq/runner.py,sha256=MIg-oBtR1nbHkexySdCJD_XfjRoHSknLotmGBMuD5qM,5924
|
|
@@ -336,14 +338,14 @@ model_compression_toolkit/gptq/common/gptq_training.py,sha256=rLA1xlOO-6gWfmc2dL
|
|
|
336
338
|
model_compression_toolkit/gptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
337
339
|
model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCSjq5xk-xGymOwSOqjp39It-CVtGcCTRTf0E_4,1248
|
|
338
340
|
model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
|
|
339
|
-
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=
|
|
341
|
+
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=zyVcEQzdnNsrIz32U1pqqoi08hzxRdJ2CumaPFGwbDM,19123
|
|
340
342
|
model_compression_toolkit/gptq/keras/graph_info.py,sha256=5IvgGlJlgOmQYmldjdCBv7tuzAoY0HazatG5Pedrg0Q,4639
|
|
341
343
|
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=zAkzWpWP9_aobWgMo_BlUm7-4fR5dHvoGx0sDqs2rZg,14299
|
|
342
344
|
model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
|
|
343
345
|
model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=2YU-x4-Q5f6hkUJf0tw6vcwdNwRMHdefrFjhhyHYsvA,4782
|
|
344
346
|
model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
|
|
345
347
|
model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py,sha256=FmK5cPwgLAzrDjHTWf_vbRO5s70S7iwpnjnlqEQTuGE,4408
|
|
346
|
-
model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py,sha256=
|
|
348
|
+
model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py,sha256=guf7ygnLsZeWnTDz4yJdE2iTkd1oE0uQAZwKnGV3OAk,1957
|
|
347
349
|
model_compression_toolkit/gptq/keras/quantizer/soft_rounding/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
|
348
350
|
model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=qUuMKysUpjWYjNbchFuyb_UFwzV1HL7R3Y7o0Z5rf60,4016
|
|
349
351
|
model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py,sha256=BBSDWLmeywjSM5N6oJkMgcuo7zrXTesB4zLwRGG8QB0,12159
|
|
@@ -353,14 +355,14 @@ model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py,sha
|
|
|
353
355
|
model_compression_toolkit/gptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
|
354
356
|
model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=kDuWw-6zh17wZpYWh4Xa94rpoodf82DksgjQCnL7nBc,2719
|
|
355
357
|
model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
|
|
356
|
-
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=
|
|
358
|
+
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=xkDa62AdIRwv8dEshffALW9Ri66eseEpyUF9taMUKns,16509
|
|
357
359
|
model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=yXJzDd24zfGs2_vfMovxD1WSh1RxXoPxN4GztOf3P5c,3967
|
|
358
360
|
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=-4USg-tep6EQSArcTxBowhMeAuExrBTNLOWgHFpsIy4,12699
|
|
359
361
|
model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
|
|
360
362
|
model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=TCA1hAc7raPnrjl06sjFtVM4XUtLtuwAhCGX4U3KGZo,4137
|
|
361
363
|
model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
|
|
362
364
|
model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py,sha256=uT9N_aBj965hvQfKd67fS1B0SXGnOLVcqa3wW4b2iZE,4566
|
|
363
|
-
model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py,sha256
|
|
365
|
+
model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py,sha256=mDWZERLwtDzqWeJUwHMVyGdlS8wPLjJ3NvZiKBP6BNA,1959
|
|
364
366
|
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
|
|
365
367
|
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=oO7WgsAHMnWoXNm_gTKAAe-Nd79mGL_m677ai-ui424,4132
|
|
366
368
|
model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py,sha256=kLVQC1hXzDpP4Jx7AwnA764oGnY5AMEuvUUhAvhz09M,12347
|
|
@@ -469,8 +471,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
|
|
|
469
471
|
model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
|
|
470
472
|
model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
|
471
473
|
model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=7bbzqJN8ZAycVDvZr_5xC-niTAR5df8f03Kooev_pfg,3047
|
|
472
|
-
mct_nightly-2.0.0.
|
|
473
|
-
mct_nightly-2.0.0.
|
|
474
|
-
mct_nightly-2.0.0.
|
|
475
|
-
mct_nightly-2.0.0.
|
|
476
|
-
mct_nightly-2.0.0.
|
|
474
|
+
mct_nightly-2.0.0.20240411.406.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
|
475
|
+
mct_nightly-2.0.0.20240411.406.dist-info/METADATA,sha256=IbtNTzo6qu2zeJ6yTF4uKQCQlaWuTHvIURKZwP1akx0,18795
|
|
476
|
+
mct_nightly-2.0.0.20240411.406.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
477
|
+
mct_nightly-2.0.0.20240411.406.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
|
478
|
+
mct_nightly-2.0.0.20240411.406.dist-info/RECORD,,
|
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
|
27
27
|
from model_compression_toolkit import pruning
|
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
|
29
29
|
|
|
30
|
-
__version__ = "2.0.0.
|
|
30
|
+
__version__ = "2.0.0.20240411.000406"
|
|
@@ -106,6 +106,7 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
106
106
|
self.z_threshold = qc.z_threshold
|
|
107
107
|
self.shift_negative_ratio = qc.shift_negative_ratio
|
|
108
108
|
self.shift_negative_threshold_recalculation = qc.shift_negative_threshold_recalculation
|
|
109
|
+
self.concat_threshold_update = qc.concat_threshold_update
|
|
109
110
|
|
|
110
111
|
def quantize_node_output(self,
|
|
111
112
|
tensors: Any) -> Any:
|
|
@@ -219,7 +220,7 @@ class NodeActivationQuantizationConfig(BaseNodeQuantizationConfig):
|
|
|
219
220
|
self.shift_negative_activation_correction == other.shift_negative_activation_correction and \
|
|
220
221
|
self.z_threshold == other.z_threshold and \
|
|
221
222
|
self.shift_negative_ratio == other.shift_negative_ratio and \
|
|
222
|
-
self.shift_negative_threshold_recalculation == other.shift_negative_threshold_recalculation
|
|
223
|
+
self.shift_negative_threshold_recalculation == other.shift_negative_threshold_recalculation
|
|
223
224
|
|
|
224
225
|
def __hash__(self):
|
|
225
226
|
return hash((self.activation_quantization_fn,
|
|
@@ -62,7 +62,8 @@ class QuantizationConfig:
|
|
|
62
62
|
residual_collapsing: bool = True,
|
|
63
63
|
shift_negative_ratio: float = 0.05,
|
|
64
64
|
shift_negative_threshold_recalculation: bool = False,
|
|
65
|
-
shift_negative_params_search: bool = False
|
|
65
|
+
shift_negative_params_search: bool = False,
|
|
66
|
+
concat_threshold_update: bool = False):
|
|
66
67
|
"""
|
|
67
68
|
Class to wrap all different parameters the library quantize the input model according to.
|
|
68
69
|
|
|
@@ -117,6 +118,7 @@ class QuantizationConfig:
|
|
|
117
118
|
self.shift_negative_ratio = shift_negative_ratio
|
|
118
119
|
self.shift_negative_threshold_recalculation = shift_negative_threshold_recalculation
|
|
119
120
|
self.shift_negative_params_search = shift_negative_params_search
|
|
121
|
+
self.concat_threshold_update = concat_threshold_update
|
|
120
122
|
|
|
121
123
|
def __repr__(self):
|
|
122
124
|
return str(self.__dict__)
|
|
@@ -42,14 +42,14 @@ def calculate_quantization_params(graph: Graph,
|
|
|
42
42
|
|
|
43
43
|
"""
|
|
44
44
|
|
|
45
|
-
Logger.info(f"
|
|
45
|
+
Logger.info(f"\nRunning quantization parameters search. "
|
|
46
46
|
f"This process might take some time, "
|
|
47
47
|
f"depending on the model size and the selected quantization methods.\n")
|
|
48
48
|
|
|
49
49
|
# Create a list of nodes to compute their thresholds
|
|
50
50
|
nodes_list: List[BaseNode] = nodes if specific_nodes else graph.nodes()
|
|
51
51
|
|
|
52
|
-
for n in tqdm(nodes_list, "Calculating quantization
|
|
52
|
+
for n in tqdm(nodes_list, "Calculating quantization parameters"): # iterate only nodes that we should compute their thresholds
|
|
53
53
|
for candidate_qc in n.candidates_quantization_cfg:
|
|
54
54
|
for attr in n.get_node_weights_attributes():
|
|
55
55
|
if n.is_weights_quantization_enabled(attr):
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/concat_threshold_update.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
from tensorflow.keras.layers import Concatenate
|
|
18
|
+
import tensorflow as tf
|
|
19
|
+
|
|
20
|
+
from model_compression_toolkit.core import common
|
|
21
|
+
from model_compression_toolkit.core.common import Graph, BaseNode
|
|
22
|
+
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
|
23
|
+
from model_compression_toolkit.constants import THRESHOLD
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ConcatThresholdUpdate(common.BaseSubstitution):
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
"""
|
|
31
|
+
Find concat layers and match their prior layers thresholds unless prior layer outputs to multiple layers.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self):
|
|
35
|
+
"""
|
|
36
|
+
Initialize a threshold_updater object.
|
|
37
|
+
"""
|
|
38
|
+
concatination_node = NodeOperationMatcher(Concatenate) | \
|
|
39
|
+
NodeOperationMatcher(tf.concat)
|
|
40
|
+
super().__init__(matcher_instance=concatination_node)
|
|
41
|
+
|
|
42
|
+
def substitute(self,
|
|
43
|
+
graph: Graph,
|
|
44
|
+
node: BaseNode) -> Graph:
|
|
45
|
+
"""
|
|
46
|
+
Update previous layers thresholds to match concatinations quantization thresholds. No change if
|
|
47
|
+
previous layer outputs to multiple layers. No change in case of uniform quantization.
|
|
48
|
+
No change in case of multiple quantization candidates (mixed precision).
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
graph: Graph we apply the substitution on.
|
|
53
|
+
node: Node refference to edit previous nodes thresholds.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Graph after applying the substitution.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
if len(node.candidates_quantization_cfg) == 1 and THRESHOLD in node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params:
|
|
60
|
+
concat_threshold = node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD]
|
|
61
|
+
prev_nodes = graph.get_prev_nodes(node)
|
|
62
|
+
for prev_node in prev_nodes:
|
|
63
|
+
if len(graph.get_next_nodes(prev_node))==1 and prev_node.type != Concatenate and prev_node.type != tf.concat:
|
|
64
|
+
prev_node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD] = concat_threshold
|
|
65
|
+
|
|
66
|
+
return graph
|
|
@@ -80,7 +80,8 @@ from model_compression_toolkit.core.keras.graph_substitutions.substitutions.line
|
|
|
80
80
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.residual_collapsing import \
|
|
81
81
|
keras_residual_collapsing
|
|
82
82
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.input_scaling import InputScaling, \
|
|
83
|
-
InputScalingWithPad
|
|
83
|
+
InputScalingWithPad
|
|
84
|
+
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.concat_threshold_update import ConcatThresholdUpdate
|
|
84
85
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.relu_bound_to_power_of_2 import \
|
|
85
86
|
ReLUBoundToPowerOfTwo
|
|
86
87
|
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.multi_head_attention_decomposition import \
|
|
@@ -300,8 +301,8 @@ class KerasImplementation(FrameworkImplementation):
|
|
|
300
301
|
"""
|
|
301
302
|
return keras_op2d_add_const_collapsing()
|
|
302
303
|
|
|
303
|
-
def get_substitutions_post_statistics_collection(self,
|
|
304
|
-
|
|
304
|
+
def get_substitutions_post_statistics_collection(self,
|
|
305
|
+
quant_config: QuantizationConfig) -> List[common.BaseSubstitution]:
|
|
305
306
|
"""
|
|
306
307
|
Return a list of the framework substitutions used after we collect statistics.
|
|
307
308
|
|
|
@@ -317,6 +318,8 @@ class KerasImplementation(FrameworkImplementation):
|
|
|
317
318
|
if quant_config.input_scaling:
|
|
318
319
|
substitutions_list.append(InputScaling())
|
|
319
320
|
substitutions_list.append(InputScalingWithPad())
|
|
321
|
+
if quant_config.concat_threshold_update:
|
|
322
|
+
substitutions_list.append(ConcatThresholdUpdate())
|
|
320
323
|
return substitutions_list
|
|
321
324
|
|
|
322
325
|
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/concat_threshold_update.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import List
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
from model_compression_toolkit.core import common
|
|
21
|
+
from model_compression_toolkit.core.common.graph.base_graph import Graph
|
|
22
|
+
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
|
23
|
+
from model_compression_toolkit.core.common.graph.base_node import BaseNode
|
|
24
|
+
from model_compression_toolkit.constants import THRESHOLD
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ConcatThresholdUpdate(common.BaseSubstitution):
|
|
28
|
+
"""
|
|
29
|
+
Find concat layers and match their prior layers thresholds unless prior layer outputs to multiple layers.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def __init__(self):
|
|
34
|
+
"""
|
|
35
|
+
Initialize a threshold_updater object.
|
|
36
|
+
"""
|
|
37
|
+
concatination_node = NodeOperationMatcher(torch.cat) | \
|
|
38
|
+
NodeOperationMatcher(torch.concat)
|
|
39
|
+
super().__init__(matcher_instance=concatination_node)
|
|
40
|
+
|
|
41
|
+
def substitute(self,
|
|
42
|
+
graph: Graph,
|
|
43
|
+
node: BaseNode) -> Graph:
|
|
44
|
+
"""
|
|
45
|
+
Update previous layers thresholds to match concatinations quantization thresholds. No change if
|
|
46
|
+
previous layer outputs to multiple layers. No change in case of uniform quantization.
|
|
47
|
+
No change in case of multiple quantization candidates (mixed precision).
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
graph: Graph we apply the substitution on.
|
|
52
|
+
node: Node refference to edit previous nodes thresholds.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Graph after applying the substitution.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
if len(node.candidates_quantization_cfg) == 1 and THRESHOLD in node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params:
|
|
59
|
+
concat_threshold = node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD]
|
|
60
|
+
prev_nodes = graph.get_prev_nodes(node)
|
|
61
|
+
for prev_node in prev_nodes:
|
|
62
|
+
if len(graph.get_next_nodes(prev_node))==1 and prev_node.type != torch.cat and prev_node.type != torch.concat:
|
|
63
|
+
prev_node.candidates_quantization_cfg[0].activation_quantization_cfg.activation_quantization_params[THRESHOLD] = concat_threshold
|
|
64
|
+
|
|
65
|
+
return graph
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
|
|
@@ -73,6 +73,8 @@ from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.vi
|
|
|
73
73
|
VirtualActivationWeightsComposition
|
|
74
74
|
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.weights_activation_split import \
|
|
75
75
|
WeightsActivationSplit
|
|
76
|
+
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.concat_threshold_update import \
|
|
77
|
+
ConcatThresholdUpdate
|
|
76
78
|
from model_compression_toolkit.core.pytorch.hessian.activation_trace_hessian_calculator_pytorch import \
|
|
77
79
|
ActivationTraceHessianCalculatorPytorch
|
|
78
80
|
from model_compression_toolkit.core.pytorch.hessian.weights_trace_hessian_calculator_pytorch import \
|
|
@@ -302,6 +304,8 @@ class PytorchImplementation(FrameworkImplementation):
|
|
|
302
304
|
substitutions_list.append(pytorch_softmax_shift())
|
|
303
305
|
if quant_config.input_scaling:
|
|
304
306
|
Logger.critical('Input scaling is currently not supported for Pytorch.')
|
|
307
|
+
if quant_config.concat_threshold_update:
|
|
308
|
+
substitutions_list.append(ConcatThresholdUpdate())
|
|
305
309
|
return substitutions_list
|
|
306
310
|
|
|
307
311
|
|
model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py
CHANGED
|
@@ -90,7 +90,7 @@ if FOUND_TF:
|
|
|
90
90
|
fw_impl=C.keras.keras_implementation.KerasImplementation())).build_model()
|
|
91
91
|
exportable_model.trainable = False
|
|
92
92
|
|
|
93
|
-
Logger.info("
|
|
93
|
+
Logger.info("\nPlease run your accuracy evaluation on the exported quantized model to verify it's accuracy.\n"
|
|
94
94
|
"Checkout the FAQ and Troubleshooting pages for resolving common issues and improving the quantized model accuracy:\n"
|
|
95
95
|
"FAQ: https://github.com/sony/model_optimization/tree/main/FAQ.md\n"
|
|
96
96
|
"Quantization Troubleshooting: https://github.com/sony/model_optimization/tree/main/quantization_troubleshooting.md")
|
model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py
CHANGED
|
@@ -82,7 +82,7 @@ if FOUND_TORCH:
|
|
|
82
82
|
get_activation_quantizer_holder(n,
|
|
83
83
|
fw_impl=C.pytorch.pytorch_implementation.PytorchImplementation())).build_model()
|
|
84
84
|
|
|
85
|
-
Logger.info("
|
|
85
|
+
Logger.info("\nPlease run your accuracy evaluation on the exported quantized model to verify it's accuracy.\n"
|
|
86
86
|
"Checkout the FAQ and Troubleshooting pages for resolving common issues and improving the quantized model accuracy:\n"
|
|
87
87
|
"FAQ: https://github.com/sony/model_optimization/tree/main/FAQ.md\n"
|
|
88
88
|
"Quantization Troubleshooting: https://github.com/sony/model_optimization/tree/main/quantization_troubleshooting.md")
|
|
@@ -301,21 +301,23 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
|
301
301
|
Returns: None
|
|
302
302
|
|
|
303
303
|
"""
|
|
304
|
-
|
|
305
|
-
for
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
304
|
+
with tqdm(range(n_epochs), "Running GPTQ optimization") as epochs_pbar:
|
|
305
|
+
for _ in epochs_pbar:
|
|
306
|
+
with tqdm(data_function(), position=1, leave=False) as data_pbar:
|
|
307
|
+
for data in data_pbar:
|
|
308
|
+
input_data = [d * self.input_scale for d in data]
|
|
309
|
+
|
|
310
|
+
loss_value_step, grads = self.nano_training_step(input_data, in_compute_gradients,
|
|
311
|
+
in_optimizer_with_param, is_training)
|
|
312
|
+
# Run one step of gradient descent by updating
|
|
313
|
+
# the value of the variables to minimize the loss.
|
|
314
|
+
for i, (o, p) in enumerate(in_optimizer_with_param):
|
|
315
|
+
o.apply_gradients(zip(grads[i], p))
|
|
316
|
+
if self.gptq_config.log_function is not None:
|
|
317
|
+
self.gptq_config.log_function(loss_value_step, grads[0], in_optimizer_with_param[0][-1],
|
|
318
|
+
self.compare_points)
|
|
319
|
+
self.loss_list.append(loss_value_step.numpy())
|
|
320
|
+
Logger.debug(f'last loss value: {self.loss_list[-1]}')
|
|
319
321
|
|
|
320
322
|
def update_graph(self):
|
|
321
323
|
"""
|
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
from tqdm import tqdm
|
|
15
16
|
from typing import Callable
|
|
16
17
|
|
|
17
18
|
from model_compression_toolkit.gptq import RoundingType, GradientPTQConfig, GradientPTQConfig
|
|
@@ -35,7 +36,7 @@ def get_regularization(gptq_config: GradientPTQConfig, representative_data_gen:
|
|
|
35
36
|
if gptq_config.rounding_type == RoundingType.SoftQuantizer:
|
|
36
37
|
# dry run on the representative dataset to count number of batches
|
|
37
38
|
num_batches = 0
|
|
38
|
-
for _ in representative_data_gen():
|
|
39
|
+
for _ in tqdm(representative_data_gen(), "GPTQ initialization"):
|
|
39
40
|
num_batches += 1
|
|
40
41
|
|
|
41
42
|
return SoftQuantizerRegularization(total_gradient_steps=num_batches * gptq_config.n_epochs)
|
|
@@ -248,22 +248,24 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
|
248
248
|
data_function: A callable function that give a batch of samples.
|
|
249
249
|
n_epochs: Number of update iterations of representative dataset.
|
|
250
250
|
"""
|
|
251
|
-
|
|
252
|
-
for
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
251
|
+
with tqdm(range(n_epochs), "Running GPTQ optimization") as epochs_pbar:
|
|
252
|
+
for _ in epochs_pbar:
|
|
253
|
+
with tqdm(data_function(), position=1, leave=False) as data_pbar:
|
|
254
|
+
for data in data_pbar:
|
|
255
|
+
input_data = [d * self.input_scale for d in data]
|
|
256
|
+
input_tensor = to_torch_tensor(input_data)
|
|
257
|
+
y_float = self.float_model(input_tensor) # running float model
|
|
258
|
+
loss_value, grads = self.compute_gradients(y_float, input_tensor)
|
|
259
|
+
# Run one step of gradient descent by updating the value of the variables to minimize the loss.
|
|
260
|
+
for (optimizer, _) in self.optimizer_with_param:
|
|
261
|
+
optimizer.step()
|
|
262
|
+
optimizer.zero_grad()
|
|
263
|
+
if self.gptq_config.log_function is not None:
|
|
264
|
+
self.gptq_config.log_function(loss_value.item(),
|
|
265
|
+
torch_tensor_to_numpy(grads),
|
|
266
|
+
torch_tensor_to_numpy(self.optimizer_with_param[0][-1]))
|
|
267
|
+
self.loss_list.append(loss_value.item())
|
|
268
|
+
Logger.debug(f'last loss value: {self.loss_list[-1]}')
|
|
267
269
|
|
|
268
270
|
def update_graph(self) -> Graph:
|
|
269
271
|
"""
|
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
+
from tqdm import tqdm
|
|
15
16
|
from typing import Callable
|
|
16
17
|
|
|
17
18
|
from model_compression_toolkit.gptq import RoundingType, GradientPTQConfig, GradientPTQConfig
|
|
@@ -35,7 +36,7 @@ def get_regularization(gptq_config: GradientPTQConfig, representative_data_gen:
|
|
|
35
36
|
if gptq_config.rounding_type == RoundingType.SoftQuantizer:
|
|
36
37
|
# dry run on the representative dataset to count number of batches
|
|
37
38
|
num_batches = 0
|
|
38
|
-
for _ in representative_data_gen():
|
|
39
|
+
for _ in tqdm(representative_data_gen(), "GPTQ initialization"):
|
|
39
40
|
num_batches += 1
|
|
40
41
|
|
|
41
42
|
return SoftQuantizerRegularization(total_gradient_steps=num_batches * gptq_config.n_epochs)
|
{mct_nightly-2.0.0.20240409.404.dist-info → mct_nightly-2.0.0.20240411.406.dist-info}/LICENSE.md
RENAMED
|
File without changes
|
|
File without changes
|
{mct_nightly-2.0.0.20240409.404.dist-info → mct_nightly-2.0.0.20240411.406.dist-info}/top_level.txt
RENAMED
|
File without changes
|