ai-edge-quantizer-nightly 0.0.1.dev20250211__py3-none-any.whl → 0.0.1.dev20250213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,559 +19,161 @@ from typing import Any, Optional
19
19
  import numpy as np
20
20
  from ai_edge_quantizer import qtyping
21
21
  from ai_edge_quantizer.algorithms.uniform_quantize import uniform_quantize_tensor
22
- from ai_edge_quantizer.algorithms.utils import min_max_quantize_utils as utils
22
+ from ai_edge_quantizer.algorithms.utils import common_utils
23
23
  from ai_edge_quantizer.utils import tfl_flatbuffer_utils
24
24
 
25
25
  ALGORITHM_KEY = "min_max_uniform_quantize"
26
26
  _TFLOpName = qtyping.TFLOperationName
27
27
  _QuantTransformation = qtyping.QuantTransformation
28
- _OpQuantConstraint = utils.OpQuantConstraint
29
- _ComputePrecision = qtyping.ComputePrecision
28
+ _IntType = uniform_quantize_tensor.IntType
30
29
 
31
30
 
32
- def check_op_quantization_config(
33
- op_name: _TFLOpName,
34
- op_quant_config: qtyping.OpQuantizationConfig,
35
- config_check_policy: qtyping.ConfigCheckPolicyDict,
36
- ) -> None:
37
- """Checks the op quantization config.
38
-
39
- Args:
40
- op_name: The name of the op.
41
- op_quant_config: The quantization config for the op.
42
- config_check_policy: The policy to check the op quantization config.
43
-
44
- Raises:
45
- ValueError: If the op quantization config is invalid.
46
- """
47
- if op_quant_config.weight_tensor_config is None:
48
- raise ValueError(
49
- "Weight tensor quantization is required for min/max uniform"
50
- " quantization."
51
- )
52
- if op_quant_config.weight_tensor_config.dtype != qtyping.TensorDataType.INT:
53
- raise ValueError(
54
- "Weights need to have integer type for min/max uniform quantization. If"
55
- " you wish to perform float casting quantization (e.g., fp16 weight"
56
- " only), please set algorithm key as 'float_casting'."
57
- )
58
-
59
- if op_quant_config.min_weight_elements < 0:
60
- raise ValueError(
61
- f"min_weight_elements must be non-negative for op: {op_name} with"
62
- f" config: {op_quant_config}."
63
- )
64
-
65
- if op_quant_config.compute_precision in [
66
- _ComputePrecision.INTEGER,
67
- _ComputePrecision.FLOAT,
68
- ]:
69
- # Use policy-based mechanism to validate op.
70
- utils.check_if_valid_op_config(
71
- op_name, op_quant_config, config_check_policy
72
- )
73
- utils.check_subchannel_config(op_name, op_quant_config)
74
-
75
-
76
- def materialize_input(
77
- op_info: qtyping.OpInfo,
78
- graph_info: qtyping.GraphInfo,
79
- tensor_name_to_qsv: dict[str, Any],
80
- ) -> list[qtyping.TensorTransformationParams]:
81
- """Materialize tensors in the virtual input op."""
82
- return utils.materialize_standard_op(
83
- op_info,
84
- graph_info,
85
- tensor_name_to_qsv,
86
- )
87
-
88
-
89
- def materialize_output(
90
- op_info: qtyping.OpInfo,
91
- graph_info: qtyping.GraphInfo,
92
- tensor_name_to_qsv: dict[str, Any],
93
- ) -> list[qtyping.TensorTransformationParams]:
94
- """Materialize tensors in the virtual output op."""
95
- return utils.materialize_standard_op(
96
- op_info,
97
- graph_info,
98
- tensor_name_to_qsv,
99
- )
100
-
101
-
102
- def materialize_add(
103
- op_info: qtyping.OpInfo,
104
- graph_info: qtyping.GraphInfo,
105
- tensor_name_to_qsv: dict[str, Any],
106
- ) -> list[qtyping.TensorTransformationParams]:
107
- """Materialize tensors in tfl.add."""
108
- return utils.materialize_standard_op(
109
- op_info,
110
- graph_info,
111
- tensor_name_to_qsv,
112
- )
113
-
114
-
115
- def materialize_sub(
116
- op_info: qtyping.OpInfo,
117
- graph_info: qtyping.GraphInfo,
118
- tensor_name_to_qsv: dict[str, Any],
119
- ) -> list[qtyping.TensorTransformationParams]:
120
- """Materialize tensors in tfl.sub."""
121
- return utils.materialize_standard_op(
122
- op_info,
123
- graph_info,
124
- tensor_name_to_qsv,
125
- )
126
-
127
-
128
- def materialize_mul(
129
- op_info: qtyping.OpInfo,
130
- graph_info: qtyping.GraphInfo,
131
- tensor_name_to_qsv: dict[str, Any],
132
- ) -> list[qtyping.TensorTransformationParams]:
133
- """Materialize tensors in tfl.mul."""
134
- return utils.materialize_standard_op(
135
- op_info,
136
- graph_info,
137
- tensor_name_to_qsv,
138
- )
139
-
140
-
141
- def materialize_softmax_and_logistic(
142
- op_info: qtyping.OpInfo,
143
- graph_info: qtyping.GraphInfo,
144
- tensor_name_to_qsv: dict[str, Any],
145
- ) -> list[qtyping.TensorTransformationParams]:
146
- """Materialize tensors in tfl.softmax and tfl.logistic."""
147
- # Hard code scales and zp values as they are hard coded in TFL kernels.
148
- # Softmax:
149
- # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/activations.cc#L548
150
- # Logistic:
151
- # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/activations.cc#L421
152
- output_activation_constraints = {
153
- 8: qtyping.UniformQuantParams(
154
- num_bits=8,
155
- quantized_dimension=None,
156
- scale=np.array(1.0 / 256),
157
- zero_point=np.array(-128),
158
- symmetric=False,
159
- ),
160
- 16: qtyping.UniformQuantParams(
161
- num_bits=16,
162
- quantized_dimension=None,
163
- scale=np.array(1.0 / 32768),
164
- zero_point=np.array(0),
165
- ),
166
- }
167
-
168
- return utils.materialize_op_with_output_activation_constraint(
169
- op_info,
170
- graph_info,
171
- tensor_name_to_qsv,
172
- output_activation_constraints,
173
- )
174
-
175
-
176
- def materialize_batch_matmul(
177
- op_info: qtyping.OpInfo,
178
- graph_info: qtyping.GraphInfo,
179
- tensor_name_to_qsv: dict[str, Any],
180
- ) -> list[qtyping.TensorTransformationParams]:
181
- """Materialize tensors in tfl.batch_matmul."""
182
- return utils.materialize_standard_op(
183
- op_info,
184
- graph_info,
185
- tensor_name_to_qsv,
186
- )
187
-
188
-
189
- def materialize_embedding_lookup(
190
- op_info: qtyping.OpInfo,
191
- graph_info: qtyping.GraphInfo,
192
- tensor_name_to_qsv: dict[str, Any],
193
- ) -> list[qtyping.TensorTransformationParams]:
194
- """Materialize tensors in tfl.embedding_lookup."""
195
- return utils.materialize_standard_op(
196
- op_info,
197
- graph_info,
198
- tensor_name_to_qsv,
199
- inputs_to_ignore=[0], # Lookup index does not need to be quantized.
200
- )
201
-
202
-
203
- def materialize_reshape(
204
- op_info: qtyping.OpInfo,
205
- graph_info: qtyping.GraphInfo,
206
- tensor_name_to_qsv: dict[str, Any],
207
- ) -> list[qtyping.TensorTransformationParams]:
208
- """Materialize tensors in tfl.reshape."""
209
- return utils.materialize_standard_op(
210
- op_info,
211
- graph_info,
212
- tensor_name_to_qsv,
213
- constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
214
- inputs_to_ignore=[1], # Shape tensor does not need to be quantized.
215
- )
216
-
217
-
218
- def materialize_average_pool_2d(
219
- op_info: qtyping.OpInfo,
220
- graph_info: qtyping.GraphInfo,
221
- tensor_name_to_qsv: dict[str, Any],
222
- ) -> list[qtyping.TensorTransformationParams]:
223
- """Materialize tensors in tfl.average_pool_2d."""
224
- return utils.materialize_standard_op(
225
- op_info,
226
- graph_info,
227
- tensor_name_to_qsv,
228
- constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
229
- )
230
-
231
-
232
- def _materialize_bias_for_conv_ops(
31
+ def _init_tensor_min_max(
32
+ tensor_data: Optional[np.ndarray],
233
33
  op_info: qtyping.OpInfo,
234
- graph_info: qtyping.GraphInfo,
235
- op_tensor_params: list[qtyping.TensorTransformationParams],
236
- op_input_index: int = 0,
237
- op_weight_index: int = 1,
238
- op_bias_index: int = 2,
239
- ):
240
- """Materializes bias tensors in conv ops by updating `op_tensor_params`.
241
-
242
- Args:
243
- op_info: Aggregated information about the op (e.g., quantization config).
244
- graph_info: Graph information needed to perform quantization for the op.
245
- op_tensor_params: Partially populated quantization configuration for the
246
- tensors associated with the op in the order of input, weight, output.
247
- op_input_index: Index for the input tensor in the op.
248
- op_weight_index: Index for the weight tensor in the op.
249
- op_bias_index: Index for the bias tensor in the op.
250
- """
251
- _, _, bias_tensor, _ = tfl_flatbuffer_utils.parse_fc_bmm_conv_tensors(
252
- op_info.op,
253
- graph_info.subgraph_tensors,
254
- op_input_index,
255
- op_weight_index,
256
- op_bias_index,
257
- )
258
- if bias_tensor is not None:
259
- bias_quant_params = None
260
- # Fused bias needs to be quantized for SRQ.
261
- # Check if SRQ.
34
+ ) -> qtyping.QSV:
35
+ """Initialize the min/max for a tensor."""
36
+ if tensor_data is None:
37
+ return {}
38
+ else:
39
+ quantized_dim = None
262
40
  if (
263
- op_info.op_quant_config.compute_precision == _ComputePrecision.INTEGER
264
- and op_info.op_quant_config.activation_tensor_config is not None
41
+ op_info.op_quant_config.weight_tensor_config is not None
42
+ and op_info.op_quant_config.weight_tensor_config.granularity
43
+ == qtyping.QuantGranularity.BLOCKWISE
265
44
  ):
266
- bias_content = tfl_flatbuffer_utils.get_tensor_data(
267
- bias_tensor,
268
- graph_info.buffers,
45
+ # TODO(b/346612503): emulate subchannel only supports fully connected,
46
+ # will skip special handling. Once we have a spec, we can change this.
47
+ block_size = op_info.op_quant_config.weight_tensor_config.block_size
48
+ # assuming tensor is 2D, which is correct for FULLY_CONNECTED
49
+ transposed_tensor_data = np.transpose(tensor_data, (1, 0))
50
+ if transposed_tensor_data.shape[0] % block_size:
51
+ raise ValueError(
52
+ f"Block size {block_size} does not divide channel dimension"
53
+ f" {transposed_tensor_data.shape[0]}."
54
+ )
55
+ reshaped_tensor_data = np.reshape(
56
+ transposed_tensor_data,
57
+ (
58
+ 1,
59
+ int(transposed_tensor_data.shape[0] / block_size),
60
+ block_size,
61
+ transposed_tensor_data.shape[1],
62
+ ),
269
63
  )
270
- bias_quant_params = (
271
- uniform_quantize_tensor.symmetric_quantize_bias_tensor(
272
- bias_content,
273
- op_tensor_params[op_input_index].consumers[0].parameters,
274
- op_tensor_params[op_weight_index].consumers[0].parameters,
275
- )
276
- )
277
- # We only quantize bias under SRQ. Setting is_constant=True for SRQ only
278
- # to avoid quantize bias for DRQ and weight-only cases.
279
- is_constant = (
280
- # Check if SRQ.
281
- op_info.op_quant_config.compute_precision == _ComputePrecision.INTEGER
282
- and op_info.op_quant_config.activation_tensor_config is not None
283
- )
284
- op_tensor_params[op_bias_index] = utils.get_tensor_transformation_params(
285
- tfl_flatbuffer_utils.get_tensor_name(bias_tensor),
286
- op_info,
287
- is_inbounding_tensor=True,
288
- quant_params=bias_quant_params,
289
- is_constant=is_constant,
64
+ return {
65
+ "min": np.min(reshaped_tensor_data, axis=(0, 1, 2), keepdims=True),
66
+ "max": np.max(reshaped_tensor_data, axis=(0, 1, 2), keepdims=True),
67
+ }
68
+ if (
69
+ op_info.op_quant_config.weight_tensor_config is not None
70
+ and op_info.op_quant_config.weight_tensor_config.granularity
71
+ == qtyping.QuantGranularity.CHANNELWISE
72
+ ):
73
+ if op_info.op_name == _TFLOpName.BATCH_MATMUL:
74
+ quantized_dim = common_utils.get_bmm_weight_quantized_dim(
75
+ tensor_data, adj_y=op_info.op.builtinOptions.adjY
76
+ )
77
+ else:
78
+ quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM.get(
79
+ op_info.op_name, None
80
+ )
81
+ reduce_dims = common_utils.get_reduce_dims(
82
+ quantized_dim, list(tensor_data.shape)
290
83
  )
84
+ return {
85
+ "min": np.min(tensor_data, axis=reduce_dims, keepdims=True),
86
+ "max": np.max(tensor_data, axis=reduce_dims, keepdims=True),
87
+ }
291
88
 
292
89
 
293
- def _are_weights_too_small(
294
- op_info: qtyping.OpInfo,
295
- graph_info: qtyping.GraphInfo,
296
- weight_index: int,
297
- ) -> bool:
298
- """Checks if weights are too small to be quantized."""
299
- tensor = graph_info.subgraph_tensors[op_info.op.inputs[weight_index]]
300
- tensor_data = tfl_flatbuffer_utils.get_tensor_data(
301
- tensor,
302
- graph_info.buffers,
303
- )
304
- return (
305
- tensor_data is not None
306
- and np.size(tensor_data) < op_info.op_quant_config.min_weight_elements
307
- )
308
-
309
-
310
- def materialize_slice(
311
- op_info: qtyping.OpInfo,
312
- graph_info: qtyping.GraphInfo,
313
- tensor_name_to_qsv: dict[str, Any],
314
- ) -> list[qtyping.TensorTransformationParams]:
315
- """Materialize tensors in tfl.slice."""
316
- return utils.materialize_standard_op(
317
- op_info,
318
- graph_info,
319
- tensor_name_to_qsv,
320
- constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
321
- inputs_to_ignore=[
322
- 1,
323
- 2,
324
- ], # Begin and size indices do not need to be quantized.
325
- )
326
-
327
-
328
- def materialize_select_v2(
329
- op_info: qtyping.OpInfo,
330
- graph_info: qtyping.GraphInfo,
331
- tensor_name_to_qsv: dict[str, Any],
332
- ) -> list[qtyping.TensorTransformationParams]:
333
- """Materialize tensors in tfl.select_v2."""
334
- return utils.materialize_standard_op(
335
- op_info,
336
- graph_info,
337
- tensor_name_to_qsv,
338
- constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
339
- inputs_to_ignore=[
340
- 0,
341
- ], # Condition tensor does not need to be quantized.
342
- )
343
-
344
-
345
- def materialize_sum(
346
- op_info: qtyping.OpInfo,
347
- graph_info: qtyping.GraphInfo,
348
- tensor_name_to_qsv: dict[str, Any],
349
- ) -> list[qtyping.TensorTransformationParams]:
350
- """Materialize tensors in tfl.sum."""
351
- return utils.materialize_standard_op(
352
- op_info,
353
- graph_info,
354
- tensor_name_to_qsv,
355
- constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
356
- inputs_to_ignore=[1], # Axis index does not need to be quantized.
357
- )
358
-
359
-
360
- def materialize_fc_conv(
361
- op_info: qtyping.OpInfo,
362
- graph_info: qtyping.GraphInfo,
363
- tensor_name_to_qsv: dict[str, Any],
364
- input_index: int = 0,
365
- weight_index: int = 1,
366
- bias_index: int = 2,
367
- ) -> list[qtyping.TensorTransformationParams]:
368
- """Materialize tensors in fully_connected, conv_2d and depthwise_conv_2d.
369
-
370
- Args:
371
- op_info: Aggregated information about the op (e.g., quantization config).
372
- graph_info: Graph information needed to perform quantization for the op.
373
- tensor_name_to_qsv: A map of tensor name to quantization parameters.
374
- input_index: Index for the input tensor in the op.
375
- weight_index: Index for the weight tensor in the op.
376
- bias_index: Index for the bias tensor in the op.
377
-
378
- Returns:
379
- Quantization configuration for the tensors associated with the op (e.g.,
380
- weights, bias).
381
- """
382
- ignored_inputs = [bias_index] # Bias tensor is quantized separately.
383
- if _are_weights_too_small(op_info, graph_info, weight_index):
384
- ignored_inputs.append(weight_index)
385
-
386
- op_tensor_params = utils.materialize_standard_op(
387
- op_info,
388
- graph_info,
389
- tensor_name_to_qsv,
390
- inputs_to_ignore=ignored_inputs,
391
- )
392
-
393
- _materialize_bias_for_conv_ops(
394
- op_info,
395
- graph_info,
396
- op_tensor_params,
397
- op_input_index=input_index,
398
- op_weight_index=weight_index,
399
- op_bias_index=bias_index,
400
- )
401
-
402
- return op_tensor_params
403
-
404
-
405
- def materialize_conv2d_transpose(
90
+ def get_tensor_quant_params(
406
91
  op_info: qtyping.OpInfo,
407
- graph_info: qtyping.GraphInfo,
408
- tensor_name_to_qsv: dict[str, Any],
409
- ) -> list[qtyping.TensorTransformationParams]:
410
- """Materialize tensors in tfl.conv2d_transpose.
92
+ tensor_quant_config: qtyping.TensorQuantizationConfig,
93
+ tensor_content: Optional[np.ndarray] = None,
94
+ tensor_qsv: Optional[dict[str, Any]] = None,
95
+ ) -> qtyping.UniformQuantParams:
96
+ """Get the quantization parameters for a tensor.
411
97
 
412
98
  Args:
413
99
  op_info: Aggregated information about the op (e.g., quantization config).
414
- graph_info: Graph information needed to perform quantization for the op.
415
- tensor_name_to_qsv: A map of tensor name to quantization parameters.
100
+ tensor_quant_config: The quantization config for the tensor.
101
+ tensor_content: The content of the tensor.
102
+ tensor_qsv: A dictionary containingthe min/max of the tensor.
416
103
 
417
104
  Returns:
418
- Quantization configuration for the tensors associated with the op (e.g.,
419
- weights, bias).
105
+ The quantization parameters for the tensor.
420
106
  """
421
- ignored_shape_index = 0
422
- weight_index = 1
423
- input_index = 2
424
- bias_index = 3
425
-
426
- ignored_inputs = [
427
- ignored_shape_index,
428
- bias_index, # Bias tensor is quantized separately.
429
- ]
430
- if _are_weights_too_small(op_info, graph_info, weight_index):
431
- ignored_inputs.append(weight_index)
107
+ # Get quant params.
108
+ if tensor_qsv is None:
109
+ if tensor_content is not None:
110
+ # We need min/max to calculate quantization parameters, which
111
+ # should be collected during the calibration process. However,
112
+ # weight-only and DRQ do not require calibration, thus it is
113
+ # possible that this information is missing here. In that case we
114
+ # collect min/max on the spot.
115
+ tensor_min_max = _init_tensor_min_max(
116
+ tensor_content,
117
+ op_info,
118
+ )
119
+ else:
120
+ raise ValueError(
121
+ f"{op_info.op_name}(index: {op_info.subgraph_op_index}) not found in"
122
+ " tensor_name_to_qsv. Check if the correct calibration results are"
123
+ " passed into the ParamsGenerator."
124
+ )
125
+ else:
126
+ tensor_min_max = tensor_qsv
432
127
 
433
- op_tensor_params = utils.materialize_standard_op(
434
- op_info,
435
- graph_info,
436
- tensor_name_to_qsv,
437
- inputs_to_ignore=ignored_inputs,
438
- )
439
- if len(op_tensor_params) < 2:
128
+ if "min" not in tensor_min_max or "max" not in tensor_min_max:
440
129
  raise ValueError(
441
- "Materialize standard op should return at least two tensors for"
442
- " conv2d_transpose."
443
- )
444
- _materialize_bias_for_conv_ops(
445
- op_info,
446
- graph_info,
447
- op_tensor_params,
448
- op_input_index=input_index,
449
- op_weight_index=weight_index,
450
- op_bias_index=bias_index,
451
- )
452
-
453
- return op_tensor_params
454
-
455
-
456
- def materialize_tanh(
457
- op_info: qtyping.OpInfo,
458
- graph_info: qtyping.GraphInfo,
459
- tensor_name_to_qsv: dict[str, Any],
460
- ) -> list[qtyping.TensorTransformationParams]:
461
- """Materialize tensors in tfl.tanh."""
462
- # Hard code scales and zero point values as they are hard coded in:
463
- # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/ir/tfl_ops.td#L3430
464
- output_activation_constraints = {}
465
- for num_bits in [8, 16]:
466
- output_activation_constraints[num_bits] = qtyping.UniformQuantParams(
467
- num_bits=num_bits,
468
- quantized_dimension=None,
469
- scale=np.array(1.0 / (1 << (num_bits - 1))),
470
- zero_point=np.array(0),
471
- # Activation is always asymmetric for 8 bit and symmetric for 16 bits.
472
- symmetric=num_bits == 16,
130
+ "min and max must be provided to produce tensor quantization"
131
+ " parameters. Check if the correct calibration results are passed into"
132
+ " the ParamsGenerator."
473
133
  )
474
- return utils.materialize_op_with_output_activation_constraint(
475
- op_info, graph_info, tensor_name_to_qsv, output_activation_constraints
476
- )
477
-
478
-
479
- def materialize_transpose(
480
- op_info: qtyping.OpInfo,
481
- graph_info: qtyping.GraphInfo,
482
- tensor_name_to_qsv: dict[str, Any],
483
- ) -> list[qtyping.TensorTransformationParams]:
484
- """Materialize tensors in tfl.transpose."""
485
- return utils.materialize_standard_op(
486
- op_info,
487
- graph_info,
488
- tensor_name_to_qsv,
489
- constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
490
- inputs_to_ignore=[1], # Permutation tensor does not need to be quantized.
491
- )
492
-
493
-
494
- def materialize_gelu(
495
- op_info: qtyping.OpInfo,
496
- graph_info: qtyping.GraphInfo,
497
- tensor_name_to_qsv: dict[str, Any],
498
- ) -> list[qtyping.TensorTransformationParams]:
499
- """Materialize tensors in tfl.gelu."""
500
- return utils.materialize_standard_op(
501
- op_info,
502
- graph_info,
503
- tensor_name_to_qsv,
504
- )
505
-
506
-
507
- def materialize_strided_slice(
508
- op_info: qtyping.OpInfo,
509
- graph_info: qtyping.GraphInfo,
510
- tensor_name_to_qsv: dict[str, Any],
511
- ) -> list[qtyping.TensorTransformationParams]:
512
- """Materialize tensors in tfl.strided_slice."""
513
- return utils.materialize_standard_op(
514
- op_info,
515
- graph_info,
516
- tensor_name_to_qsv,
517
- constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
518
- inputs_to_ignore=[1, 2, 3], # Ignore the begin, end, and strides tensors.
519
- )
520
-
521
-
522
- def materialize_mean(
523
- op_info: qtyping.OpInfo,
524
- graph_info: qtyping.GraphInfo,
525
- tensor_name_to_qsv: dict[str, Any],
526
- ) -> list[qtyping.TensorTransformationParams]:
527
- """Materialize tensors in tfl.mean."""
528
- return utils.materialize_standard_op(
529
- op_info,
530
- graph_info,
531
- tensor_name_to_qsv,
532
- inputs_to_ignore=[1], # Axis tensor does not need to be quantized.
134
+ zp, scale = uniform_quantize_tensor.tensor_zp_scale_from_min_max(
135
+ tensor_min_max["min"],
136
+ tensor_min_max["max"],
137
+ tensor_quant_config.num_bits,
138
+ tensor_quant_config.symmetric,
533
139
  )
534
-
535
-
536
- def materialize_rsqrt(
537
- op_info: qtyping.OpInfo,
538
- graph_info: qtyping.GraphInfo,
539
- tensor_name_to_qsv: dict[str, Any],
540
- ) -> list[qtyping.TensorTransformationParams]:
541
- """Materialize tensors in tfl.rsqrt."""
542
- return utils.materialize_standard_op(
543
- op_info,
544
- graph_info,
545
- tensor_name_to_qsv,
546
- )
547
-
548
-
549
- def materialize_concatenation(
550
- op_info: qtyping.OpInfo,
551
- graph_info: qtyping.GraphInfo,
552
- tensor_name_to_qsv: dict[str, Any],
553
- ) -> list[qtyping.TensorTransformationParams]:
554
- """Materialize tensors in tfl.concatenation."""
555
- return utils.materialize_standard_op(
556
- op_info,
557
- graph_info,
558
- tensor_name_to_qsv,
559
- constraint=_OpQuantConstraint.SAME_AS_OUTPUT_SCALE,
140
+ quantized_dim = None
141
+ if tensor_quant_config.granularity == qtyping.QuantGranularity.CHANNELWISE:
142
+ if op_info.op_name == _TFLOpName.BATCH_MATMUL:
143
+ quantized_dim = common_utils.get_bmm_weight_quantized_dim(
144
+ tensor_content, adj_y=op_info.op.builtinOptions.adjY
145
+ )
146
+ else:
147
+ quantized_dim = tfl_flatbuffer_utils.TFL_OP_TO_WEIGHT_QUANTIZED_DIM[
148
+ op_info.op_name
149
+ ]
150
+ quant_params = qtyping.UniformQuantParams(
151
+ scale=scale,
152
+ zero_point=zp,
153
+ num_bits=tensor_quant_config.num_bits,
154
+ symmetric=tensor_quant_config.symmetric,
155
+ quantized_dimension=quantized_dim,
560
156
  )
561
-
562
-
563
- def materialize_split(
564
- op_info: qtyping.OpInfo,
565
- graph_info: qtyping.GraphInfo,
566
- tensor_name_to_qsv: dict[str, Any],
567
- ) -> list[qtyping.TensorTransformationParams]:
568
- """Materialize tensors in tfl.split."""
569
- return utils.materialize_standard_op(
570
- op_info,
571
- graph_info,
572
- tensor_name_to_qsv,
573
- constraint=_OpQuantConstraint.SAME_AS_INPUT_SCALE,
574
- inputs_to_ignore=[0], # Split dimension does not need to be quantized.
157
+ if tensor_content is None:
158
+ return quant_params
159
+ if tensor_quant_config.granularity == qtyping.QuantGranularity.BLOCKWISE:
160
+ quantized_vars = (
161
+ uniform_quantize_tensor.uniform_quantize_for_emulated_subchannel(
162
+ tensor_content, quant_params, tensor_quant_config.block_size
163
+ )
164
+ )
165
+ else:
166
+ quantized_vars = uniform_quantize_tensor.uniform_quantize(
167
+ tensor_content, quant_params
168
+ )
169
+ # Update with quantized values.
170
+ return qtyping.UniformQuantParams(
171
+ scale=scale,
172
+ zero_point=zp,
173
+ num_bits=tensor_quant_config.num_bits,
174
+ symmetric=tensor_quant_config.symmetric,
175
+ quantized_dimension=quantized_dim,
176
+ quantized_data=quantized_vars,
575
177
  )
576
178
 
577
179
 
@@ -601,18 +203,22 @@ def init_qsvs(
601
203
  if tensor_idx != -1 and i not in inputs_to_ignore:
602
204
  tensor = graph_info.subgraph_tensors[tensor_idx]
603
205
  tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor)
604
- op_qsvs[tensor_name] = utils.init_tensor_min_max(
605
- tensor,
606
- graph_info,
206
+ tensor_data = tfl_flatbuffer_utils.get_tensor_data(
207
+ tensor, graph_info.buffers
208
+ )
209
+ op_qsvs[tensor_name] = _init_tensor_min_max(
210
+ tensor_data,
607
211
  op_info,
608
212
  )
609
213
  for i, tensor_idx in enumerate(op_info.op.outputs):
610
214
  if tensor_idx != -1 and i not in outputs_to_ignore:
611
215
  tensor = graph_info.subgraph_tensors[tensor_idx]
612
216
  tensor_name = tfl_flatbuffer_utils.get_tensor_name(tensor)
613
- op_qsvs[tensor_name] = utils.init_tensor_min_max(
614
- tensor,
615
- graph_info,
217
+ tensor_data = tfl_flatbuffer_utils.get_tensor_data(
218
+ tensor, graph_info.buffers
219
+ )
220
+ op_qsvs[tensor_name] = _init_tensor_min_max(
221
+ tensor_data,
616
222
  op_info,
617
223
  )
618
224
  return op_qsvs