mediapipe-nightly 0.10.10.post20240216__cp310-cp310-macosx_11_0_universal2.whl → 0.10.10.post20240220__cp310-cp310-macosx_11_0_universal2.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. mediapipe/__init__.py +1 -1
  2. mediapipe/python/_framework_bindings.cpython-310-darwin.so +0 -0
  3. mediapipe/tasks/python/__init__.py +1 -0
  4. mediapipe/tasks/python/genai/__init__.py +14 -0
  5. mediapipe/tasks/python/genai/converter/__init__.py +24 -0
  6. mediapipe/tasks/python/genai/converter/converter_base.py +172 -0
  7. mediapipe/tasks/python/genai/converter/converter_factory.py +79 -0
  8. mediapipe/tasks/python/genai/converter/llm_converter.py +213 -0
  9. mediapipe/tasks/python/genai/converter/pytorch_converter.py +315 -0
  10. mediapipe/tasks/python/genai/converter/pytorch_converter_test.py +86 -0
  11. mediapipe/tasks/python/genai/converter/quantization_util.py +516 -0
  12. mediapipe/tasks/python/genai/converter/quantization_util_test.py +259 -0
  13. mediapipe/tasks/python/genai/converter/safetensors_converter.py +521 -0
  14. mediapipe/tasks/python/genai/converter/safetensors_converter_test.py +83 -0
  15. mediapipe/tasks/python/genai/converter/weight_bins_writer.py +111 -0
  16. mediapipe/tasks/python/genai/converter/weight_bins_writer_test.py +62 -0
  17. mediapipe/version.txt +1 -1
  18. {mediapipe_nightly-0.10.10.post20240216.dist-info → mediapipe_nightly-0.10.10.post20240220.dist-info}/METADATA +1 -1
  19. {mediapipe_nightly-0.10.10.post20240216.dist-info → mediapipe_nightly-0.10.10.post20240220.dist-info}/RECORD +21 -8
  20. {mediapipe_nightly-0.10.10.post20240216.dist-info → mediapipe_nightly-0.10.10.post20240220.dist-info}/LICENSE +0 -0
  21. {mediapipe_nightly-0.10.10.post20240216.dist-info → mediapipe_nightly-0.10.10.post20240220.dist-info}/WHEEL +0 -0
  22. {mediapipe_nightly-0.10.10.post20240216.dist-info → mediapipe_nightly-0.10.10.post20240220.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,516 @@
1
+ # Copyright 2024 The MediaPipe Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Utilities for quantizing tensors.
16
+
17
+ Note that this is a reduced fork version of the praxis libraries to provide a
18
+ self-contained library for packaging.
19
+ """
20
+
21
+ from typing import Any, List, Optional, Sequence, Tuple, Union
22
+
23
+ import jax
24
+ from jax import lax
25
+ from jax import numpy as jnp
26
+ import numpy as np
27
+
28
+
29
+ JTensor = jax.Array
30
+ _UINT4_ZP = 8 # Default zero point for unsigned 4-bit.
31
+
32
+
33
+ def _get_scan_range() -> np.ndarray:
34
+ # Produce candidate scan values.
35
+ return np.linspace(1.0, 0.5, num=11)
36
+
37
+
38
+ def _get_mean_error(bound, t, min_value, max_value, p_value):
39
+ scale = bound / max_value
40
+ candidate = jnp.divide(t, scale)
41
+ candidate = jnp.clip(jnp.round(candidate), min_value, max_value)
42
+ candidate = jnp.multiply(candidate, scale)
43
+ pmean_error = jnp.mean(jnp.abs(jnp.subtract(candidate, t)) ** p_value)
44
+ return pmean_error
45
+
46
+
47
+ def _get_best_bound_per_tensor(
48
+ t: JTensor,
49
+ bound: JTensor,
50
+ min_value: float,
51
+ max_value: float,
52
+ p_value: float = 1.0,
53
+ ) -> JTensor:
54
+ """Scan around [0.5, 1] * hard max value to get bound value for whole tensor.
55
+
56
+ This does a scan to get bound value(s) that minimize mean absolute error (MAE)
57
+ between original tensor 't' and quantized tensor. It's (almost) equivalent to
58
+ maximizing entropy.
59
+
60
+ Args:
61
+ t: The input float tensor.
62
+ bound: The hard max value for tensor 't'. It has the same length as shape.
63
+ min_value: Minimal value for the quantization bound.
64
+ max_value: Maximal value for the quantization bound.
65
+ p_value: Exponent of the p-mean error metric. Default to 1.0 which is MAE.
66
+
67
+ Returns:
68
+ The best bound values for 't', that minimize p-mean error.
69
+ """
70
+
71
+ def _quant(scaling_factors):
72
+ return _get_mean_error(
73
+ bound * scaling_factors, t, min_value, max_value, p_value
74
+ )
75
+
76
+ scaling_factors = _get_scan_range()
77
+ diffs = jax.vmap(_quant)(scaling_factors)
78
+ best_scaling = scaling_factors[jnp.argmin(diffs)].astype(bound.dtype)
79
+ return bound * best_scaling
80
+
81
+
82
+ def _quantrow(
83
+ vec: JTensor,
84
+ bound: JTensor,
85
+ min_value: float,
86
+ max_value: float,
87
+ p_value: float,
88
+ factors: np.ndarray,
89
+ ) -> JTensor:
90
+ """Get best rescaling factor from a list of factors applied a channel.
91
+
92
+ Args:
93
+ vec: The vector in a channel.
94
+ bound: The hard bound (max(abs(vec))) of the vector.
95
+ min_value: The target min value.
96
+ max_value: The target max value.
97
+ p_value: Exponent of the p-mean error metric.
98
+ factors: The values to be applied on top of bound.
99
+
100
+ Returns:
101
+ adjusted bound value out of the list of factors applied to bound.
102
+ """
103
+
104
+ def _quant(bounds):
105
+ return _get_mean_error(bounds, vec, min_value, max_value, p_value)
106
+
107
+ diffs = jax.vmap(_quant)(bound * factors)
108
+ best_scaling = factors[jnp.argmin(diffs)]
109
+ return bound * best_scaling
110
+
111
+
112
+ def _get_best_bound_per_channel(
113
+ t: JTensor,
114
+ bound: JTensor,
115
+ min_value: float,
116
+ max_value: float,
117
+ p_value: float = 1.0,
118
+ ) -> JTensor:
119
+ """Scan around [0.5, 1] * hard max value to get bound value for each channel.
120
+
121
+ This does a scan to get bound value(s) that minimize mean absolute error (MAE)
122
+ between original tensor 't' and quantized tensor. It's (almost) equivalent to
123
+ maximizing entropy.
124
+
125
+ Args:
126
+ t: The input float tensor.
127
+ bound: The hard max value for tensor 't'. It has the same length as shape.
128
+ min_value: Minimal value for the quantization bound.
129
+ max_value: Maximal value for the quantization bound.
130
+ p_value: Exponent of the p-mean error metric. Default to 1.0 which is MAE.
131
+
132
+ Returns:
133
+ The best bound values for 't', that minimize p-mean error.
134
+ """
135
+ assert len(t.shape) == 2
136
+ assert len(bound.shape) == 2
137
+ assert t.shape[1] == bound.shape[1]
138
+ assert bound.shape[0] == 1
139
+ scans = _get_scan_range()
140
+
141
+ def _quant(tensor, bound, min_value, max_value, p_value, factors):
142
+ ret = np.zeros(bound.shape)
143
+ for i in range(len(tensor)):
144
+ best = _quantrow(
145
+ tensor[i], bound[i], min_value, max_value, p_value, factors
146
+ )
147
+ ret[i] = best
148
+ return ret
149
+
150
+ t = t.transpose()
151
+ t_split = list(t)
152
+ res = _quant(t_split, bound[0, :], min_value, max_value, p_value, scans)
153
+ res = res.reshape(bound.shape)
154
+ return res
155
+
156
+
157
+ def get_best_bound(
158
+ t: JTensor,
159
+ bound: JTensor,
160
+ min_value: float,
161
+ max_value: float,
162
+ p_value: float = 1.0,
163
+ per_channel: bool = False,
164
+ ) -> JTensor:
165
+ """Scan mutliple factors on max value to get best bound value.
166
+
167
+ This does a scan to get bound value(s) that minimize mean absolute error (MAE)
168
+ between original tensor 't' and quantized tensor. It's (almost) equivalent to
169
+ maximizing entropy.
170
+
171
+ Args:
172
+ t: The input float tensor.
173
+ bound: The hard max value for tensor 't'. It has the same length as shape.
174
+ min_value: Minimal value for the quantization bound.
175
+ max_value: Maximal value for the quantization bound.
176
+ p_value: Exponent of the p-mean error metric. Default to 1.0 which is MAE.
177
+ per_channel: if get best bound for entire tensor or per channel.
178
+
179
+ Returns:
180
+ The best bound values for 't', that minimize p-mean error.
181
+ """
182
+ if per_channel:
183
+ return _get_best_bound_per_channel(t, bound, min_value, max_value, p_value)
184
+ else:
185
+ return _get_best_bound_per_tensor(t, bound, min_value, max_value, p_value)
186
+
187
+
188
+ def get_min_max(
189
+ bits: int = 8,
190
+ unsigned: bool = False,
191
+ use_fp: bool = False,
192
+ ) -> Tuple[float, float]:
193
+ """Gets the min/max range for a given number of bits.
194
+
195
+ Args:
196
+ bits: Target number of bits for quantization.
197
+ unsigned: If True compute min and max for unsigned number, else for signed.
198
+ use_fp: in floating point.
199
+
200
+ Returns:
201
+ min/max values for the provide number of bits.
202
+ """
203
+ if use_fp:
204
+ # TODO: support other fp types.
205
+ return -448.0, 448.0
206
+ # Calculation instead of jax.iinfo is used to support bits beside 4 and 8.
207
+ if unsigned:
208
+ # For unsigned 8 bits precision it is [0, 255]
209
+ return 0, 2**bits - 1
210
+ else:
211
+ # For signed 8 bits precision it is [-128, 127]
212
+ return -1 * 2 ** (bits - 1), 2 ** (bits - 1) - 1
213
+
214
+
215
+ def pass_through(x: JTensor, fn: Any) -> JTensor:
216
+ # Create an exactly-zero expression with Sterbenz lemma that has an
217
+ # exactly-one gradient.
218
+ return x - jax.lax.stop_gradient(x) + jax.lax.stop_gradient(fn(x))
219
+
220
+
221
+ def reduce_precision(
222
+ t: JTensor,
223
+ contract_dims: Optional[Sequence[int]],
224
+ need_gradient: bool = False,
225
+ bits: int = 8,
226
+ optimization_on_bound: bool = False,
227
+ p_value: float = 1.0,
228
+ percentile: float = 1.0,
229
+ use_symmetric: bool = True,
230
+ use_fp: bool = False,
231
+ add_scale_eps: bool = False,
232
+ per_channel: bool = False,
233
+ random_rounding: bool = False,
234
+ key: Optional[jax.Array] = None,
235
+ ) -> Tuple[JTensor, JTensor, Optional[JTensor]]:
236
+ """Reduce the precision of a tensor.
237
+
238
+ Generic for all tensors.
239
+
240
+ Args:
241
+ t: Input tensor.
242
+ contract_dims: Speficies contracting dimesnions of the input tensor.
243
+ need_gradient: If gradient is needed out of this function.
244
+ bits: Target number of bits.
245
+ optimization_on_bound: If MAE bound optimizer is used.
246
+ p_value: Exponent of the p-mean error metric. Default to 1.0 which is MAE.
247
+ percentile: Percentile Factor to apply on the min/max range. Setting this to
248
+ other than 1.0 disables optimization_on_bound.
249
+ use_symmetric: If the input tensor is quantized symmetrically.
250
+ use_fp: Use floating point.
251
+ add_scale_eps: Add eps value or replace zero value by 1 to avoid division by
252
+ zero.
253
+ per_channel: use per-channel clipping optimization.
254
+ random_rounding: round with uniform random.
255
+ key: rng key for rounding.
256
+
257
+ Returns:
258
+ A tuple of quantized tensor, quantization scale
259
+ and quantization zero point (optional).
260
+ """
261
+ min_value, max_value = get_min_max(bits, use_fp=use_fp)
262
+
263
+ if use_symmetric:
264
+ bound = jnp.max(jnp.abs(t), axis=contract_dims, keepdims=True)
265
+ scale_bound = max_value
266
+ else:
267
+ t_max = jnp.max(t, axis=contract_dims, keepdims=True)
268
+ t_min = jnp.min(t, axis=contract_dims, keepdims=True)
269
+ bound = t_max - t_min
270
+ scale_bound = max_value - min_value
271
+
272
+ if percentile < 1.0:
273
+ bound = jnp.multiply(bound, percentile)
274
+ elif optimization_on_bound:
275
+ bound = get_best_bound(
276
+ t, bound, min_value, max_value, p_value, per_channel=per_channel
277
+ )
278
+
279
+ scale = bound / scale_bound
280
+
281
+ if add_scale_eps:
282
+ # Add epsilon to avoid divide-by-zero.
283
+ scale = scale + jnp.finfo(t.dtype).eps
284
+ else:
285
+ scale = jnp.where(scale == 0.0, 1.0, scale)
286
+
287
+ if use_symmetric:
288
+ zp = None
289
+ t = jnp.divide(t, scale)
290
+ else:
291
+ zp = min_value - t_min / scale
292
+ t = jnp.divide(t, scale) + zp
293
+ zp = jnp.multiply(scale, zp)
294
+
295
+ if use_fp:
296
+ # No need to round.
297
+ t = jnp.clip(t, min_value, max_value).astype(jnp.float8_e4m3fn)
298
+ # TODO: refactor to remove this logic.
299
+ t = jax.lax.bitcast_convert_type(t, new_dtype=jnp.int8)
300
+ else:
301
+ if need_gradient:
302
+ t = pass_through(t, jnp.round)
303
+ t = jnp.clip(t, min_value, max_value)
304
+ else:
305
+ if random_rounding:
306
+ t = t + jax.random.uniform(
307
+ key=key, shape=t.shape, minval=-0.5, maxval=0.5
308
+ )
309
+ t = jnp.round(t)
310
+ container_dtype = (
311
+ jnp.int8 if bits <= 8 else jnp.int16 if bits <= 16 else jnp.int32
312
+ )
313
+ t = jnp.clip(t, min_value, max_value).astype(container_dtype)
314
+
315
+ return t, scale, zp
316
+
317
+
318
+ def quantize_tensor(
319
+ var: np.ndarray,
320
+ axis: List[int],
321
+ factor: float = 1.0,
322
+ sym: bool = True,
323
+ number_bits: int = 8,
324
+ use_fp: bool = False,
325
+ add_scale_eps: bool = False,
326
+ optimization_on_bound: bool = False,
327
+ p_value: float = 1.0,
328
+ per_channel: bool = False,
329
+ block_size: int = 0,
330
+ ) -> Union[
331
+ Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray, np.ndarray]
332
+ ]:
333
+ """Quantize a tensor.
334
+
335
+ Args:
336
+ var: The variable to be quantized.
337
+ axis: The axis along which variable will be quantized.
338
+ factor: The clipping factor.
339
+ sym: Symmetric or asymmetric quantize the variable.
340
+ number_bits: Number of bits for quantized value.
341
+ use_fp: do fp with number of bits (i.e. fp8)
342
+ add_scale_eps: add epsilon to scale to avoid division by zero, else it will
343
+ replace zero scale by 1.
344
+ optimization_on_bound: If p-mean bound optimizer is used.
345
+ p_value: Exponent of the p-mean error metric. Default to 1.0 which is MAE.
346
+ per_channel: use per-channel clipping optimization.
347
+ block_size: block size for sub-channel quantization. Defaults to 0, which
348
+ means off.
349
+
350
+ Returns:
351
+ Quantized tensors, along with scales and zero point.
352
+ """
353
+ # TODO: support jnp.float8_e5m2
354
+ assert number_bits == 8 or number_bits == 4
355
+ jnp_var = jnp.asarray(var)
356
+ # When using sub-channel, the contracting dim is split into a sub-channel
357
+ # dim followed by the block dim. Therefore the contracting dim
358
+ # (quantize_axis) should increment by one, and the corresponding pack_dim
359
+ # should also increment by one.
360
+ if block_size > 0:
361
+ shape = list(jnp_var.shape)
362
+ assert len(axis) == 1, 'Only support 1D sub-channel quantization'
363
+ sub_channels, rem = divmod(shape[axis[0]], block_size)
364
+ assert rem == 0
365
+ shape.insert(axis[0], sub_channels)
366
+ axis[0] += 1
367
+ shape[axis[0]] = block_size
368
+ jnp_var = jnp.reshape(jnp_var, shape)
369
+
370
+ qvar, scale, zp = reduce_precision(
371
+ jnp_var,
372
+ contract_dims=axis,
373
+ need_gradient=False,
374
+ bits=number_bits,
375
+ optimization_on_bound=optimization_on_bound,
376
+ percentile=factor,
377
+ use_symmetric=sym,
378
+ use_fp=use_fp,
379
+ add_scale_eps=add_scale_eps,
380
+ p_value=p_value,
381
+ per_channel=per_channel,
382
+ )
383
+ if sym:
384
+ return np.array(qvar), np.array(jnp.squeeze(scale, axis=axis)) # pytype: disable=wrong-arg-types # jnp-type
385
+ else:
386
+ return (
387
+ np.array(qvar),
388
+ # CAVEAT: the following squeezes should squeeze along the quantization
389
+ # axis only.
390
+ np.array(jnp.squeeze(scale)),
391
+ np.array(jnp.squeeze(zp)),
392
+ )
393
+
394
+
395
+ def pack_4bit(
396
+ x: np.ndarray, pack_dim: int, packed_dtype: jnp.dtype = jnp.int32
397
+ ) -> np.ndarray:
398
+ """Pack int8 or uint8 tensor where its values are actually int4 or uint4, to int32 or int8 nibble format along pack_dim.
399
+
400
+ Args:
401
+ x: Original int8 or uint8 tensor to pack.
402
+ pack_dim: Dimension to pack along. x.shape[pack_dim] must be divisible by 8,
403
+ when packed_dtype is int32 and divisible by 2 when target_type is int8.
404
+ Also pack_dim must be < x.ndim - 1.
405
+ packed_dtype: Target type to pack to, int32 or int8.
406
+
407
+ Returns:
408
+ int32 or int8 packed tensor where the pack_dim size is dividened by 8
409
+ from the original tensor x.
410
+ """
411
+ x = jnp.asarray(x)
412
+ if packed_dtype == jnp.int8 and x.dtype == jnp.uint8:
413
+ # It doesn't make sense to pack uint8 numbers into int4 as we'll
414
+ # the range overlap between uint8 and int4 is [0..7].
415
+ raise ValueError(
416
+ 'only int8 input dtype is supported when packing into int8. '
417
+ f'Given {x.dtype}'
418
+ )
419
+
420
+ if x.dtype != jnp.int8 and x.dtype != jnp.uint8:
421
+ raise ValueError(
422
+ f'input dtype must be either int8 or uint8. Given {x.dtype}'
423
+ )
424
+ if pack_dim >= x.ndim - 1:
425
+ raise ValueError(
426
+ f'pack_dim must be < input ndim - 1. input shape {x.shape} and pack_dim'
427
+ f' {pack_dim}'
428
+ )
429
+ if packed_dtype != jnp.int32 and packed_dtype != jnp.int8:
430
+ raise ValueError(
431
+ f'packed_dtype must be either int32 or int8. Given {packed_dtype}'
432
+ )
433
+ if packed_dtype == jnp.int32 and x.shape[pack_dim] % 8 != 0:
434
+ raise ValueError(
435
+ 'input shape[pack_dim] must be divisible by 8 when target_type '
436
+ f'is int32. Given shape {x.shape}'
437
+ )
438
+ if packed_dtype == jnp.int8 and x.shape[pack_dim] % 2 != 0:
439
+ raise ValueError(
440
+ 'input shape[pack_dim] must be divisible by 2 when target_type '
441
+ f'is int8. Given shape {x.shape}'
442
+ )
443
+
444
+ int4s_per_packed_type = 8 if packed_dtype == jnp.int32 else 2
445
+
446
+ rep_shape = list(x.shape)
447
+ rep_shape.insert(pack_dim + 1, int4s_per_packed_type)
448
+ rep_shape[pack_dim] //= int4s_per_packed_type
449
+
450
+ shifts = lax.broadcasted_iota(packed_dtype, rep_shape, pack_dim + 1)
451
+ shifts <<= 2
452
+
453
+ # Promote x to packed_dtype
454
+ x = x & jnp.array(0x0F, packed_dtype)
455
+ x = lax.reshape(x, rep_shape)
456
+ x = x << shifts
457
+ x = lax.reduce(x, jnp.array(0x0, packed_dtype), lax.add, [pack_dim + 1])
458
+ return np.asarray(x)
459
+
460
+
461
+ def update_to_uint4(
462
+ qx: np.ndarray, scale: np.ndarray, zp: Optional[np.ndarray] = None
463
+ ):
464
+ """Updates the quantized weights from int4 to uint4.
465
+
466
+ This is a conversion function designed for XNNPack as it expects the 4-bit
467
+ quantized weight to be represented differently from the original Pax setting.
468
+ Specifically, the differences are:
469
+ 1) The dynamic range of weight values: int4 (Pax) vs. uint4 (XNNPack).
470
+ 2) The dynamic range of zero-point: float (Pax) vs. uint4 (XNNPack).
471
+ 3) The number of zero-point: per-channel (Pax) vs. per-tensor (XNNPack).
472
+
473
+ Args:
474
+ qx: np.array of shape [..., channel], which is the quantized weight values
475
+ from Pax in the shape of. The values are in the dynamic range of int4 but
476
+ are hosted as int8 type. Note that if the first dimension is 3, it means
477
+ the qkv matrices are concatenated together and should be treated
478
+ differently.
479
+ scale: np.array of shape [1(3), channel] as np.float type, which are the
480
+ scaling factors for dequantization per channel.
481
+ zp: (optional) np.array of shape [1 (or 3), channel] as np.float type, which
482
+ are the zero points for dequantization per channel.
483
+
484
+ Returns:
485
+ A tuple (qx, scale, zp):
486
+ qx: The updated np.array of shape [..., channel] as np.int8 type with
487
+ updated dynamic range as uint4 (with 8 as the default zero points).
488
+ scale: Same as the input scale.
489
+ zp: (optional) np.array of shape [1 (or 3)] as np.int8 type with the
490
+ updated zero point values in the dynamic range as uint4.
491
+ """
492
+ if qx.dtype != np.int8 or ('float' not in str(scale.dtype)):
493
+ raise ValueError(
494
+ 'Unexpected dtype qx:' + str(qx.dtype) + ' scale:' + str(scale.dtype)
495
+ )
496
+
497
+ scale = scale.astype(np.float32)
498
+
499
+ def get_new_zp(old_zp):
500
+ new_zp = old_zp / (scale + np.finfo(np.float32).eps)
501
+ per_tensor_zp = np.mean(new_zp)
502
+ per_tensor_zp = per_tensor_zp.astype(np.int8) + _UINT4_ZP
503
+ return per_tensor_zp
504
+
505
+ if zp is not None:
506
+ if qx.shape[0] == 3:
507
+ per_tensor_zp = np.stack([get_new_zp(szp) for szp in zp], axis=0)
508
+ else:
509
+ per_tensor_zp = get_new_zp(zp)
510
+ else:
511
+ per_tensor_zp = (
512
+ _UINT4_ZP * np.ones(shape=(3)) if qx.shape[0] == 3 else _UINT4_ZP
513
+ )
514
+
515
+ qx = qx + _UINT4_ZP
516
+ return qx, scale, np.array(per_tensor_zp, dtype=np.int32)