mct-nightly 2.0.0.20240413.406__py3-none-any.whl → 2.0.0.20240415.5018__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.0.0.20240413.406
3
+ Version: 2.0.0.20240415.5018
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=aBKAsetxgA8LyN7yib2REk7UPRyfK2jBb-YqBpIpCbE,1573
1
+ model_compression_toolkit/__init__.py,sha256=pU9oIc4ZlkMr0MR9kraXEOboqcZF2lShgkyhaDHxzn0,1573
2
2
  model_compression_toolkit/constants.py,sha256=f9at1H_-vb5nvdHRmAHUco4ja4_QermK6yu0N9qbRGE,3723
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -109,7 +109,7 @@ model_compression_toolkit/core/common/quantization/quantize_graph_weights.py,sha
109
109
  model_compression_toolkit/core/common/quantization/quantize_node.py,sha256=cdzGNWfT4MRogIU8ehs0tr3lVjnzAI-jeoS9b4TwVBo,2854
110
110
  model_compression_toolkit/core/common/quantization/set_node_quantization_config.py,sha256=9BEv2l0z2trDEsr40VB8tO3ToBA_b2sd_jH9uqZ5Wo8,11503
111
111
  model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py,sha256=eCDGwsWYLU6z7qbEVb4TozMW_nd5VEP_iCJ6PcvyEPw,1486
112
- model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=eH3nSXPFn94ATF3dZn2HxNAGVJUWotirN6o8wwDfkLg,18165
112
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py,sha256=TUJuSpX8pcsIPbJ6z_YGWgD_uafqlKRJcpsTIFpjMKU,19936
113
113
  model_compression_toolkit/core/common/quantization/quantization_params_generation/lut_kmeans_params.py,sha256=HSbAlDKXZMn8BtQQGL8TnlXvO2f_2oTLXAK1khraX7g,7410
114
114
  model_compression_toolkit/core/common/quantization/quantization_params_generation/outlier_filter.py,sha256=9gnfJV89jpGwAx8ImJ5E9NjCv3lDtbyulP4OtgWb62M,1772
115
115
  model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py,sha256=BiwDqt5CeU6CW0Qusy3LwWhFtf2J9BvSuGMsTsG6rSw,8538
@@ -480,8 +480,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
480
480
  model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
481
481
  model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
482
482
  model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=7bbzqJN8ZAycVDvZr_5xC-niTAR5df8f03Kooev_pfg,3047
483
- mct_nightly-2.0.0.20240413.406.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
484
- mct_nightly-2.0.0.20240413.406.dist-info/METADATA,sha256=KPkoZIsVNAhmDShzs6X5LUBpcV_hvmCO9elwhFuTduw,18795
485
- mct_nightly-2.0.0.20240413.406.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
486
- mct_nightly-2.0.0.20240413.406.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
487
- mct_nightly-2.0.0.20240413.406.dist-info/RECORD,,
483
+ mct_nightly-2.0.0.20240415.5018.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
484
+ mct_nightly-2.0.0.20240415.5018.dist-info/METADATA,sha256=NaQHHj_S3oEuCRqkeaaOSlRIUc6HthT3C9IjzX6p7FQ,18796
485
+ mct_nightly-2.0.0.20240415.5018.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
486
+ mct_nightly-2.0.0.20240415.5018.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
487
+ mct_nightly-2.0.0.20240415.5018.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.0.0.20240413.000406"
30
+ __version__ = "2.0.0.20240415.005018"
@@ -89,8 +89,8 @@ def _lp_error_histogram(q_bins: np.ndarray,
89
89
 
90
90
 
91
91
  def _kl_error_function(x: np.ndarray,
92
- range_min: float,
93
- range_max: float,
92
+ range_min: np.ndarray,
93
+ range_max: np.ndarray,
94
94
  n_bins: int = 2048,
95
95
  n_bits: int = 8) -> np.float32:
96
96
  """
@@ -148,7 +148,8 @@ def _kl_error_function_wrapper(x: np.ndarray,
148
148
  range_min: np.ndarray,
149
149
  range_max: np.ndarray,
150
150
  n_bins: int = 2048,
151
- n_bits: int = 8) -> np.ndarray:
151
+ n_bits: int = 8,
152
+ per_channel: int = False) -> np.ndarray:
152
153
  """
153
154
  Computes the error function between a tensor and its quantized version for each channel.
154
155
  The error is based on the KL-divergence between the distributions.
@@ -161,6 +162,7 @@ def _kl_error_function_wrapper(x: np.ndarray,
161
162
  range_max: Array specifying the maximum bound of the quantization range for each channel.
162
163
  n_bins: Number of bins for the float histogram.
163
164
  n_bits: Number of bits used for quantization.
165
+ per_channel: Whether quantization is done per-channel.
164
166
 
165
167
  Returns:
166
168
  An array containing the KL-divergence between the float and quantized histograms of the tensor for each channel.
@@ -168,8 +170,11 @@ def _kl_error_function_wrapper(x: np.ndarray,
168
170
  """
169
171
 
170
172
  error_list = []
171
- for j in range(x.shape[0]): # iterate all channels of the tensor.
172
- error_list.append(_kl_error_function(x[j], range_min[j], range_max[j], n_bins=n_bins, n_bits=n_bits))
173
+ if per_channel:
174
+ for j in range(x.shape[0]): # iterate all channels of the tensor.
175
+ error_list.append(_kl_error_function(x[j], range_min[j], range_max[j], n_bins=n_bins, n_bits=n_bits))
176
+ else:
177
+ error_list.append(_kl_error_function(x, range_min, range_max, n_bins=n_bins, n_bits=n_bits))
173
178
  return np.asarray(error_list)
174
179
 
175
180
 
@@ -177,8 +182,8 @@ def _kl_error_histogram(q_bins: np.ndarray,
177
182
  q_count: np.ndarray,
178
183
  bins: np.ndarray,
179
184
  counts: np.ndarray,
180
- range_min: float,
181
- range_max: float) -> np.float32:
185
+ range_min: np.ndarray,
186
+ range_max: np.ndarray) -> np.float32:
182
187
  """
183
188
  Compute the error function between a histogram to its quantized version.
184
189
  The error is computed based on the KL-divergence the distributions have.
@@ -241,8 +246,8 @@ def _kl_error_histogram(q_bins: np.ndarray,
241
246
 
242
247
 
243
248
  def _get_bins_indices_from_range(bins: np.ndarray,
244
- range_min: float,
245
- range_max: float) -> Tuple[int, int]:
249
+ range_min: np.ndarray,
250
+ range_max: np.ndarray) -> Tuple[int, int]:
246
251
  """
247
252
  For bins and a threshold, compute the first and last bins in between the threshold
248
253
  ranges.
@@ -262,7 +267,7 @@ def _get_bins_indices_from_range(bins: np.ndarray,
262
267
  return first_bin_idx, last_bin_idx
263
268
 
264
269
 
265
- def _is_range_valid(bins: np.ndarray, range_min: float, range_max: float) -> bool:
270
+ def _is_range_valid(bins: np.ndarray, range_min: np.ndarray, range_max: np.ndarray) -> bool:
266
271
  """
267
272
  Check whether there are some bins from a numpy array of bins that are in between
268
273
  a threshold range or not.
@@ -387,15 +392,36 @@ def get_threshold_selection_tensor_error_function(quantization_method: Quantizat
387
392
 
388
393
  Returns: a Callable method that calculates the error between a tensor and a quantized tensor.
389
394
  """
395
+ if quant_error_method == qc.QuantizationErrorMethod.KL:
396
+ if axis is None:
397
+ # per-tensor
398
+ if quantization_method == QuantizationMethod.UNIFORM:
399
+ return lambda x, y, threshold: _kl_error_function_wrapper(x, range_min=threshold[0],
400
+ range_max=threshold[1],
401
+ n_bits=n_bits,
402
+ per_channel=False)
403
+ else:
404
+ return lambda x, y, threshold: _kl_error_function_wrapper(x, range_min=0 if not signed else -threshold,
405
+ range_max=threshold,
406
+ n_bits=n_bits,
407
+ per_channel=False)
408
+ else:
409
+ # per-channel
410
+ if quantization_method == QuantizationMethod.UNIFORM:
411
+ return lambda x, y, threshold: _kl_error_function_wrapper(x, range_min=threshold[:, 0],
412
+ range_max=threshold[:, 1],
413
+ n_bits=n_bits,
414
+ per_channel=True)
415
+ else:
416
+ return lambda x, y, threshold: _kl_error_function_wrapper(x, range_min=0 if not signed else -threshold,
417
+ range_max=threshold,
418
+ n_bits=n_bits,
419
+ per_channel=True)
390
420
 
391
421
  quant_method_error_function_mapping = {
392
422
  qc.QuantizationErrorMethod.MSE: lambda x, y, threshold: compute_mse(x, y, norm=norm, axis=axis),
393
423
  qc.QuantizationErrorMethod.MAE: lambda x, y, threshold: compute_mae(x, y, norm=norm, axis=axis),
394
424
  qc.QuantizationErrorMethod.LP: lambda x, y, threshold: compute_lp_norm(x, y, p=p, norm=norm, axis=axis),
395
- qc.QuantizationErrorMethod.KL:
396
- lambda x, y, threshold: _kl_error_function_wrapper(x, range_min=threshold[:,0], range_max=threshold[:,1],
397
- n_bits=n_bits) if quantization_method == QuantizationMethod.UNIFORM
398
- else _kl_error_function_wrapper(x, range_min=0 if not signed else -threshold, range_max=threshold, n_bits=n_bits)
399
425
  }
400
426
 
401
427
  return quant_method_error_function_mapping[quant_error_method]