tico 0.1.0.dev250717__py3-none-any.whl → 0.1.0.dev250721__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tico/__init__.py CHANGED
@@ -21,7 +21,7 @@ from tico.config import CompileConfigV1, get_default_config
21
21
  from tico.utils.convert import convert, convert_from_exported_program, convert_from_pt2
22
22
 
23
23
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
24
- __version__ = "0.1.0.dev250717"
24
+ __version__ = "0.1.0.dev250721"
25
25
 
26
26
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
27
27
  SECURE_TORCH_VERSION = "2.6.0"
@@ -18,6 +18,9 @@ if TYPE_CHECKING:
18
18
  import torch.fx
19
19
  import copy
20
20
 
21
+ from collections import defaultdict
22
+ from typing import Any, Callable
23
+
21
24
  import torch
22
25
  from torch.export import ExportedProgram
23
26
 
@@ -92,6 +95,302 @@ def _u8_to_i16(qparam: QuantParam) -> QuantParam:
92
95
  return new_qparam
93
96
 
94
97
 
98
+ def _insert_quantize_op_before(node, inp):
99
+ graph = node.graph
100
+ qparam: QuantParam = node.meta[QPARAM_KEY]
101
+ assert qparam.scale is not None
102
+ assert qparam.zero_point is not None
103
+ scale = qparam.scale[0]
104
+ zerop = qparam.zero_point[0]
105
+ min_, max_ = quant_min_max(qparam.dtype)
106
+ dtype = getattr(torch, qparam.dtype)
107
+
108
+ with graph.inserting_before(node):
109
+ q_args = (inp, scale, zerop, min_, max_, dtype)
110
+ quantize = create_node(
111
+ graph,
112
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
113
+ args=q_args,
114
+ origin=node,
115
+ )
116
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
117
+ set_new_meta_val(quantize)
118
+
119
+ node.replace_input_with(inp, quantize)
120
+
121
+ return quantize
122
+
123
+
124
+ def _insert_quantize_op_after(node):
125
+ graph = node.graph
126
+ qparam: QuantParam = node.meta[QPARAM_KEY]
127
+ assert qparam.scale is not None
128
+ assert qparam.zero_point is not None
129
+ scale = qparam.scale[0]
130
+ zerop = qparam.zero_point[0]
131
+ min_, max_ = quant_min_max(qparam.dtype)
132
+ dtype = getattr(torch, qparam.dtype)
133
+ with graph.inserting_after(node):
134
+ q_args = (node, scale, zerop, min_, max_, dtype)
135
+ quantize = create_node(
136
+ graph,
137
+ torch.ops.quantized_decomposed.quantize_per_tensor.default,
138
+ args=q_args,
139
+ )
140
+
141
+ node.replace_all_uses_with(quantize, propagate_meta=True)
142
+ quantize.replace_input_with(quantize, node)
143
+
144
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
145
+
146
+ return quantize
147
+
148
+
149
+ def _linear_handler(node, logger):
150
+ lin_args = LinearArgs(*node.args, **node.kwargs)
151
+ inp = lin_args.input
152
+
153
+ if QPARAM_KEY not in inp.meta:
154
+ return
155
+
156
+ if QPARAM_KEY not in node.meta:
157
+ return
158
+
159
+ if qparam_dtype(inp) == qparam_dtype(node):
160
+ return
161
+
162
+ if qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
163
+ quantize = _insert_quantize_op_after(node)
164
+
165
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
166
+
167
+ # Update node's qparam from i16 to u8
168
+ # NOTE This would severely degrade accuracy. It is
169
+ # important to mitigate this accuracy drop in backend.
170
+ node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
171
+ logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
172
+ else:
173
+ raise NotYetSupportedError(
174
+ f"Unsupported dtype: From {qparam_dtype(inp)} to {qparam_dtype(node)}"
175
+ )
176
+
177
+
178
+ def _add_handler(node, logger):
179
+ add_args = AddTensorArgs(*node.args, **node.kwargs)
180
+ x = add_args.input
181
+ y = add_args.other
182
+
183
+ if not isinstance(x, torch.fx.Node):
184
+ return
185
+ if not isinstance(y, torch.fx.Node):
186
+ return
187
+
188
+ if QPARAM_KEY not in x.meta:
189
+ return
190
+ if QPARAM_KEY not in y.meta:
191
+ return
192
+ if QPARAM_KEY not in node.meta:
193
+ return
194
+
195
+ if qparam_dtype(x) == qparam_dtype(node):
196
+ return
197
+
198
+ if qparam_dtype(x) != qparam_dtype(y):
199
+ return
200
+
201
+ if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
202
+ quantize = _insert_quantize_op_after(node)
203
+
204
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
205
+ node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
206
+ logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
207
+ else:
208
+ raise NotYetSupportedError("Unsupported dtype")
209
+
210
+
211
+ def _mul_handler(node, logger):
212
+ mul_args = MulTensorArgs(*node.args, **node.kwargs)
213
+ x = mul_args.input
214
+ y = mul_args.other
215
+
216
+ if not isinstance(x, torch.fx.Node):
217
+ return
218
+ if not isinstance(y, torch.fx.Node):
219
+ return
220
+
221
+ if QPARAM_KEY not in x.meta:
222
+ return
223
+ if QPARAM_KEY not in y.meta:
224
+ return
225
+ if QPARAM_KEY not in node.meta:
226
+ return
227
+
228
+ if qparam_dtype(x) == qparam_dtype(node):
229
+ return
230
+
231
+ if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
232
+ quantize = _insert_quantize_op_after(node)
233
+
234
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
235
+ node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
236
+ logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
237
+ else:
238
+ raise NotYetSupportedError("Unsupported dtype")
239
+
240
+
241
+ def _cat_handler(node, logger):
242
+ cat_args = CatArgs(*node.args, **node.kwargs)
243
+ tensors = cat_args.tensors
244
+
245
+ if any(QPARAM_KEY not in x.meta for x in tensors):
246
+ return
247
+
248
+ if QPARAM_KEY not in node.meta:
249
+ return
250
+
251
+ assert len(tensors) > 0
252
+ in_dtype = qparam_dtype(tensors[0])
253
+ if in_dtype == qparam_dtype(node):
254
+ return
255
+
256
+ if any(qparam_dtype(x) != in_dtype for x in tensors):
257
+ return
258
+
259
+ if in_dtype == "int16" and qparam_dtype(node) == "uint8":
260
+ quantize = _insert_quantize_op_after(node)
261
+
262
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
263
+ node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
264
+ logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
265
+ else:
266
+ raise NotYetSupportedError("Unsupported dtype")
267
+
268
+
269
+ def _bmm_handler(node, logger):
270
+ bmm_args = BmmArgs(*node.args, **node.kwargs)
271
+ x = bmm_args.input
272
+ y = bmm_args.mat2
273
+
274
+ if QPARAM_KEY not in x.meta:
275
+ return
276
+ if QPARAM_KEY not in y.meta:
277
+ return
278
+ if QPARAM_KEY not in node.meta:
279
+ return
280
+
281
+ if qparam_dtype(x) == qparam_dtype(node):
282
+ return
283
+
284
+ if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
285
+ quantize = _insert_quantize_op_after(node)
286
+
287
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
288
+ node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
289
+ logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
290
+ elif qparam_dtype(x) == "uint8" and qparam_dtype(node) == "int16":
291
+ quantize = _insert_quantize_op_after(node)
292
+
293
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
294
+ node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
295
+ logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
296
+ else:
297
+ raise NotYetSupportedError("Unsupported dtype")
298
+
299
+
300
+ def _permute_handler(node, logger):
301
+ per_args = PermuteArgs(*node.args, **node.kwargs)
302
+ inp = per_args.input
303
+
304
+ if QPARAM_KEY not in inp.meta:
305
+ return
306
+
307
+ if QPARAM_KEY not in node.meta:
308
+ return
309
+
310
+ if qparam_dtype(inp) == qparam_dtype(node):
311
+ return
312
+
313
+ if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
314
+ # A new Quantize Op (s16 to u8) is inserted before (not after)
315
+ # permute Op to reduce tensor size ealier
316
+ quantize = _insert_quantize_op_before(node, inp)
317
+
318
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
319
+ logger.debug(f"quantize_per_tensor.default is inserted before {node.name}.")
320
+ elif qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
321
+ quantize = _insert_quantize_op_after(node)
322
+
323
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
324
+ node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
325
+ logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
326
+ else:
327
+ raise NotYetSupportedError("Unsupported dtype")
328
+
329
+
330
+ def _reshape_handler(node, logger):
331
+ reshape_args = ReshapeArgs(*node.args, **node.kwargs)
332
+ inp = reshape_args.input
333
+
334
+ if QPARAM_KEY not in inp.meta:
335
+ return
336
+
337
+ if QPARAM_KEY not in node.meta:
338
+ return
339
+
340
+ if qparam_dtype(inp) == qparam_dtype(node):
341
+ return
342
+
343
+ if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
344
+ # A new Quantize Op (s16 to u8) is inserted before (not after)
345
+ # reshape Op to reduce tensor size ealier
346
+ quantize = _insert_quantize_op_before(node, inp)
347
+
348
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
349
+ logger.debug(f"quantize_per_tensor.default is inserted before {node.name}.")
350
+ elif qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
351
+ quantize = _insert_quantize_op_after(node)
352
+
353
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
354
+ node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
355
+ logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
356
+ else:
357
+ raise NotYetSupportedError("Unsupported dtype")
358
+
359
+
360
+ def _relu_handler(node, logger):
361
+ relu_args = ReluArgs(*node.args, **node.kwargs)
362
+ inp = relu_args.input
363
+
364
+ if QPARAM_KEY not in inp.meta:
365
+ return
366
+
367
+ if QPARAM_KEY not in node.meta:
368
+ return
369
+
370
+ if qparam_dtype(inp) == qparam_dtype(node):
371
+ return
372
+
373
+ if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
374
+ quantize = _insert_quantize_op_after(node)
375
+
376
+ quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
377
+ node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
378
+ logger.debug(f"quantize_per_tensor.default is inserted after {node.name}.")
379
+ else:
380
+ raise NotYetSupportedError("Unsupported dtype")
381
+
382
+
383
+ _op_handler: defaultdict[Any, Any | None] = defaultdict(lambda: None)
384
+ _op_handler[torch.ops.aten.linear.default] = _linear_handler
385
+ _op_handler[torch.ops.aten.add.Tensor] = _add_handler
386
+ _op_handler[torch.ops.aten.mul.Tensor] = _mul_handler
387
+ _op_handler[torch.ops.aten.cat.default] = _cat_handler
388
+ _op_handler[torch.ops.aten.bmm.default] = _bmm_handler
389
+ _op_handler[torch.ops.aten.permute.default] = _permute_handler
390
+ _op_handler[torch.ops.aten.reshape.default] = _reshape_handler
391
+ _op_handler[torch.ops.aten.relu.default] = _relu_handler
392
+
393
+
95
394
  @trace_graph_diff_on_pass
96
395
  class InsertQuantizeOnDtypeMismatch(PassBase):
97
396
  """
@@ -138,296 +437,13 @@ class InsertQuantizeOnDtypeMismatch(PassBase):
138
437
  graph_module = exported_program.graph_module
139
438
  graph: torch.fx.Graph = graph_module.graph
140
439
 
141
- def _insert_quantize_op_before(node, inp):
142
- qparam: QuantParam = node.meta[QPARAM_KEY]
143
- assert qparam.scale is not None
144
- assert qparam.zero_point is not None
145
- scale = qparam.scale[0]
146
- zerop = qparam.zero_point[0]
147
- min_, max_ = quant_min_max(qparam.dtype)
148
- dtype = getattr(torch, qparam.dtype)
149
-
150
- with graph.inserting_before(node):
151
- q_args = (inp, scale, zerop, min_, max_, dtype)
152
- quantize = create_node(
153
- graph,
154
- torch.ops.quantized_decomposed.quantize_per_tensor.default,
155
- args=q_args,
156
- origin=node,
157
- )
158
- quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
159
- set_new_meta_val(quantize)
160
-
161
- node.replace_input_with(inp, quantize)
162
-
163
- return quantize
164
-
165
- def _insert_quantize_op_after(node):
166
- qparam: QuantParam = node.meta[QPARAM_KEY]
167
- assert qparam.scale is not None
168
- assert qparam.zero_point is not None
169
- scale = qparam.scale[0]
170
- zerop = qparam.zero_point[0]
171
- min_, max_ = quant_min_max(qparam.dtype)
172
- dtype = getattr(torch, qparam.dtype)
173
- with graph.inserting_after(node):
174
- q_args = (node, scale, zerop, min_, max_, dtype)
175
- quantize = create_node(
176
- graph,
177
- torch.ops.quantized_decomposed.quantize_per_tensor.default,
178
- args=q_args,
179
- )
180
-
181
- node.replace_all_uses_with(quantize, propagate_meta=True)
182
- quantize.replace_input_with(quantize, node)
183
-
184
- quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
185
-
186
- return quantize
187
-
188
440
  for node in graph.nodes:
189
441
  if node.op != "call_function":
190
442
  continue
191
- if node.target == torch.ops.aten.linear.default:
192
- lin_args = LinearArgs(*node.args, **node.kwargs)
193
- inp = lin_args.input
194
-
195
- if QPARAM_KEY not in inp.meta:
196
- continue
197
-
198
- if QPARAM_KEY not in node.meta:
199
- continue
200
-
201
- if qparam_dtype(inp) == qparam_dtype(node):
202
- continue
203
-
204
- if qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
205
- quantize = _insert_quantize_op_after(node)
206
-
207
- quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
208
-
209
- # Update node's qparam from i16 to u8
210
- # NOTE This would severely degrade accuracy. It is
211
- # important to mitigate this accuracy drop in backend.
212
- node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
213
- logger.debug(
214
- f"quantize_per_tensor.default is inserted after {node.name}."
215
- )
216
- else:
217
- raise NotYetSupportedError(
218
- f"Unsupported dtype: From {qparam_dtype(inp)} to {qparam_dtype(node)}"
219
- )
220
-
221
- elif node.target == torch.ops.aten.add.Tensor:
222
- add_args = AddTensorArgs(*node.args, **node.kwargs)
223
- x = add_args.input
224
- y = add_args.other
225
-
226
- if not isinstance(x, torch.fx.Node):
227
- continue
228
- if not isinstance(y, torch.fx.Node):
229
- continue
230
-
231
- if QPARAM_KEY not in x.meta:
232
- continue
233
- if QPARAM_KEY not in y.meta:
234
- continue
235
- if QPARAM_KEY not in node.meta:
236
- continue
237
-
238
- if qparam_dtype(x) == qparam_dtype(node):
239
- continue
240
-
241
- if qparam_dtype(x) != qparam_dtype(y):
242
- continue
243
-
244
- if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
245
- quantize = _insert_quantize_op_after(node)
246
-
247
- quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
248
- node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
249
- logger.debug(
250
- f"quantize_per_tensor.default is inserted after {node.name}."
251
- )
252
- else:
253
- raise NotYetSupportedError("Unsupported dtype")
254
-
255
- elif node.target == torch.ops.aten.mul.Tensor:
256
- mul_args = MulTensorArgs(*node.args, **node.kwargs)
257
- x = mul_args.input
258
- y = mul_args.other
259
-
260
- if not isinstance(x, torch.fx.Node):
261
- continue
262
- if not isinstance(y, torch.fx.Node):
263
- continue
264
-
265
- if QPARAM_KEY not in x.meta:
266
- continue
267
- if QPARAM_KEY not in y.meta:
268
- continue
269
- if QPARAM_KEY not in node.meta:
270
- continue
271
-
272
- if qparam_dtype(x) == qparam_dtype(node):
273
- continue
274
-
275
- if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
276
- quantize = _insert_quantize_op_after(node)
277
-
278
- quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
279
- node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
280
- logger.debug(
281
- f"quantize_per_tensor.default is inserted after {node.name}."
282
- )
283
- else:
284
- raise NotYetSupportedError("Unsupported dtype")
285
-
286
- elif node.target == torch.ops.aten.cat.default:
287
- cat_args = CatArgs(*node.args, **node.kwargs)
288
- tensors = cat_args.tensors
289
-
290
- if any(QPARAM_KEY not in x.meta for x in tensors):
291
- continue
292
-
293
- if QPARAM_KEY not in node.meta:
294
- continue
295
-
296
- assert len(tensors) > 0
297
- in_dtype = qparam_dtype(tensors[0])
298
- if in_dtype == qparam_dtype(node):
299
- continue
300
-
301
- if any(qparam_dtype(x) != in_dtype for x in tensors):
302
- continue
303
-
304
- if in_dtype == "int16" and qparam_dtype(node) == "uint8":
305
- quantize = _insert_quantize_op_after(node)
306
-
307
- quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
308
- node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
309
- logger.debug(
310
- f"quantize_per_tensor.default is inserted after {node.name}."
311
- )
312
- else:
313
- raise NotYetSupportedError("Unsupported dtype")
314
-
315
- elif node.target == torch.ops.aten.bmm.default:
316
- bmm_args = BmmArgs(*node.args, **node.kwargs)
317
- x = bmm_args.input
318
- y = bmm_args.mat2
319
-
320
- if QPARAM_KEY not in x.meta:
321
- continue
322
- if QPARAM_KEY not in y.meta:
323
- continue
324
- if QPARAM_KEY not in node.meta:
325
- continue
326
-
327
- if qparam_dtype(x) == qparam_dtype(node):
328
- continue
329
-
330
- if qparam_dtype(x) == "int16" and qparam_dtype(node) == "uint8":
331
- quantize = _insert_quantize_op_after(node)
332
-
333
- quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
334
- node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
335
- logger.debug(
336
- f"quantize_per_tensor.default is inserted after {node.name}."
337
- )
338
- else:
339
- raise NotYetSupportedError("Unsupported dtype")
340
-
341
- elif node.target == torch.ops.aten.permute.default:
342
- per_args = PermuteArgs(*node.args, **node.kwargs)
343
- inp = per_args.input
344
-
345
- if QPARAM_KEY not in inp.meta:
346
- continue
347
-
348
- if QPARAM_KEY not in node.meta:
349
- continue
350
-
351
- if qparam_dtype(inp) == qparam_dtype(node):
352
- continue
353
-
354
- if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
355
- # A new Quantize Op (s16 to u8) is inserted before (not after)
356
- # permute Op to reduce tensor size ealier
357
- quantize = _insert_quantize_op_before(node, inp)
358
-
359
- quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
360
- logger.debug(
361
- f"quantize_per_tensor.default is inserted before {node.name}."
362
- )
363
- elif qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
364
- quantize = _insert_quantize_op_after(node)
365
-
366
- quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
367
- node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
368
- logger.debug(
369
- f"quantize_per_tensor.default is inserted after {node.name}."
370
- )
371
- else:
372
- raise NotYetSupportedError("Unsupported dtype")
373
- elif node.target == torch.ops.aten.reshape.default:
374
- reshape_args = ReshapeArgs(*node.args, **node.kwargs)
375
- inp = reshape_args.input
376
-
377
- if QPARAM_KEY not in inp.meta:
378
- continue
379
-
380
- if QPARAM_KEY not in node.meta:
381
- continue
382
-
383
- if qparam_dtype(inp) == qparam_dtype(node):
384
- continue
385
-
386
- if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
387
- # A new Quantize Op (s16 to u8) is inserted before (not after)
388
- # reshape Op to reduce tensor size ealier
389
- quantize = _insert_quantize_op_before(node, inp)
390
-
391
- quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
392
- logger.debug(
393
- f"quantize_per_tensor.default is inserted before {node.name}."
394
- )
395
- elif qparam_dtype(inp) == "uint8" and qparam_dtype(node) == "int16":
396
- quantize = _insert_quantize_op_after(node)
397
-
398
- quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
399
- node.meta[QPARAM_KEY] = _i16_to_u8(node.meta[QPARAM_KEY])
400
- logger.debug(
401
- f"quantize_per_tensor.default is inserted after {node.name}."
402
- )
403
- else:
404
- raise NotYetSupportedError("Unsupported dtype")
405
-
406
- elif node.target == torch.ops.aten.relu.default:
407
- relu_args = ReluArgs(*node.args, **node.kwargs)
408
- inp = relu_args.input
409
-
410
- if QPARAM_KEY not in inp.meta:
411
- continue
412
-
413
- if QPARAM_KEY not in node.meta:
414
- continue
415
-
416
- if qparam_dtype(inp) == qparam_dtype(node):
417
- continue
418
-
419
- if qparam_dtype(inp) == "int16" and qparam_dtype(node) == "uint8":
420
- quantize = _insert_quantize_op_after(node)
421
-
422
- quantize.meta[QPARAM_KEY] = copy.deepcopy(node.meta[QPARAM_KEY])
423
- node.meta[QPARAM_KEY] = _u8_to_i16(node.meta[QPARAM_KEY])
424
- logger.debug(
425
- f"quantize_per_tensor.default is inserted after {node.name}."
426
- )
427
- else:
428
- raise NotYetSupportedError("Unsupported dtype")
429
-
430
- # TODO Support more ops.
443
+
444
+ handler = _op_handler[node.target]
445
+ if handler is not None:
446
+ handler(node, logger)
431
447
 
432
448
  graph.eliminate_dead_code()
433
449
  graph.lint()
@@ -152,6 +152,15 @@ class CircleSubgraph(circle.SubGraph.SubGraphT):
152
152
  assert node.meta.get("val") is not None
153
153
  tensor.type = extract_circle_dtype(node)
154
154
  tensor.shape = list(extract_shape(node))
155
+
156
+ # Handle dynamic shape
157
+ if any(isinstance(s, torch.SymInt) for s in tensor.shape):
158
+ tensor.shapeSignature = tensor.shape.copy()
159
+ for idx, s in enumerate(tensor.shape):
160
+ if isinstance(s, torch.SymInt):
161
+ tensor.shape[idx] = 1
162
+ tensor.shapeSignature[idx] = -1
163
+
155
164
  if QPARAM_KEY in node.meta:
156
165
  tensor.quantization = to_circle_qparam(node.meta[QPARAM_KEY])
157
166
  tensor.type = str_to_circle_dtype(node.meta[QPARAM_KEY].dtype)
@@ -241,6 +250,15 @@ class CircleSubgraph(circle.SubGraph.SubGraphT):
241
250
  if source_node is not None:
242
251
  self.name_to_node[tensor.name] = source_node
243
252
  tensor.shape = shape
253
+
254
+ # Handle dynamic shape
255
+ if any(isinstance(s, torch.SymInt) for s in tensor.shape):
256
+ tensor.shapeSignature = tensor.shape.copy()
257
+ for idx, s in enumerate(tensor.shape):
258
+ if isinstance(s, torch.SymInt):
259
+ tensor.shape[idx] = 1
260
+ tensor.shapeSignature[idx] = -1
261
+
244
262
  if qparam is not None:
245
263
  tensor.quantization = to_circle_qparam(qparam)
246
264
  tensor.type = str_to_circle_dtype(qparam.dtype)
@@ -24,25 +24,18 @@ from tico.serialize.operators.hashable_opcode import OpCode
24
24
  from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
25
25
  from tico.serialize.operators.utils import create_builtin_operator, get_op_index
26
26
  from tico.utils.errors import NotYetSupportedError
27
- from tico.utils.utils import HAS_TORCH_OVER_25
28
27
  from tico.utils.validate_args_kwargs import SafeSoftmaxArgs, SoftmaxArgs
29
28
 
30
29
 
31
30
  @register_node_visitor
32
31
  class SoftMaxVisitor(NodeVisitor):
33
- target: List[torch._ops.OpOverload] = (
34
- [
35
- torch.ops.aten._softmax.default,
36
- # NOTE: Let's treat _safe_softmax as normal _softmax as its usage is for training.
37
- # In order for optimization during inference, it can be replaced to softmax.
38
- # ref: https://github.com/pytorch/pytorch/pull/133882
39
- torch.ops.aten._safe_softmax.default,
40
- ]
41
- if HAS_TORCH_OVER_25
42
- else [
43
- torch.ops.aten._softmax.default,
44
- ]
45
- )
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.aten._softmax.default,
34
+ # NOTE: Let's treat _safe_softmax as normal _softmax as its usage is for training.
35
+ # In order for optimization during inference, it can be replaced to softmax.
36
+ # ref: https://github.com/pytorch/pytorch/pull/133882
37
+ torch.ops.aten._safe_softmax.default,
38
+ ]
46
39
 
47
40
  def __init__(self, op_codes: Dict[OpCode, int], graph):
48
41
  super().__init__(op_codes, graph)
tico/utils/convert.py CHANGED
@@ -106,6 +106,7 @@ def traced_run_decompositions(exported_program: ExportedProgram):
106
106
  torch.ops.aten._safe_softmax.default,
107
107
  torch.ops.aten.relu6.default, # Do not decompose to hardtanh
108
108
  torch.ops.aten.linear.default,
109
+ torch.ops.aten.upsample_nearest2d.vec,
109
110
  )
110
111
  ep = ep.run_decompositions(_preserve_ops=_preserve_ops)
111
112
 
@@ -124,6 +125,7 @@ def traced_run_decompositions(exported_program: ExportedProgram):
124
125
  torch.ops.aten.relu6.default, # Do not decompose to hardtanh
125
126
  torch.ops.aten.prelu.default,
126
127
  torch.ops.aten.linear.default,
128
+ torch.ops.aten.upsample_nearest2d.vec,
127
129
  )
128
130
  for op in _preserve_ops:
129
131
  if op in _decomp_table:
@@ -138,6 +140,7 @@ def traced_run_decompositions(exported_program: ExportedProgram):
138
140
  torch.__version__.startswith("2.6")
139
141
  or torch.__version__.startswith("2.7")
140
142
  or torch.__version__.startswith("2.8")
143
+ or torch.__version__.startswith("2.9")
141
144
  ):
142
145
  return run_decompositions(exported_program)
143
146
  else:
@@ -293,6 +296,7 @@ def convert(
293
296
  mod: torch.nn.Module,
294
297
  args: Tuple[Any, ...],
295
298
  kwargs: Optional[Dict[str, Any]] = None,
299
+ dynamic_shapes: Optional[dict] = None,
296
300
  strict: bool = True,
297
301
  config: CompileConfigBase = get_default_config(),
298
302
  ) -> CircleModel:
@@ -303,7 +307,9 @@ def convert(
303
307
  )
304
308
 
305
309
  with torch.no_grad():
306
- exported_program = export(mod, args, kwargs, strict=strict)
310
+ exported_program = export(
311
+ mod, args, kwargs, dynamic_shapes=dynamic_shapes, strict=strict
312
+ )
307
313
 
308
314
  circle_binary = convert_exported_module_to_circle(exported_program, config=config)
309
315
 
tico/utils/dtype.py ADDED
@@ -0,0 +1,20 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ NUMPY_TO_TORCH_DTYPE_DICT = {
5
+ np.dtype("float32"): torch.float32,
6
+ np.dtype("float64"): torch.float64,
7
+ np.dtype("float16"): torch.float16,
8
+ np.dtype("complex64"): torch.complex64,
9
+ np.dtype("complex128"): torch.complex128,
10
+ np.dtype("int64"): torch.int64,
11
+ np.dtype("int32"): torch.int32,
12
+ np.dtype("int16"): torch.int16,
13
+ np.dtype("int8"): torch.int8,
14
+ np.dtype("uint8"): torch.uint8,
15
+ np.dtype("bool"): torch.bool,
16
+ }
17
+
18
+
19
+ def numpy_dtype_to_torch_dtype(np_dtype: np.dtype) -> torch.dtype:
20
+ return NUMPY_TO_TORCH_DTYPE_DICT[np_dtype]
@@ -0,0 +1,52 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Runtime **capability-detection helpers** for the `torch.export` stack.
17
+
18
+ Instead of sprinkling version checks like `torch.__version__ >= "2.9"` throughout
19
+ the codebase, import these helpers once and branch on the feature you need.
20
+
21
+ Each probe executes only **once per process** thanks to `functools.lru_cache`,
22
+ so the overhead is negligible.
23
+ """
24
+
25
+ import functools
26
+
27
+ import torch
28
+
29
+
30
+ @functools.lru_cache(maxsize=None)
31
+ def export_produces_slice() -> bool:
32
+ """
33
+ Compile a minimal model with `torch.export.export` and inspect its FX graph
34
+ to see whether an `aten.slice.Tensor` node appears.
35
+
36
+ Returns
37
+ -------
38
+ bool
39
+ * ``True`` — downstream passes should expect redundant **slice** nodes.
40
+ * ``False`` — downstream passes should expect only a **select** node.
41
+ """
42
+
43
+ class _Probe(torch.nn.Module):
44
+ def forward(self, x): # simple slice: keep all dims except 3rd
45
+ return x[:, :, 1]
46
+
47
+ def get_example_inputs(self):
48
+ return (torch.randn(1, 4, 4),)
49
+
50
+ m = _Probe()
51
+ ep = torch.export.export(m, m.get_example_inputs())
52
+ return any(n.target == torch.ops.aten.slice.Tensor for n in ep.graph.nodes)
tico/utils/utils.py CHANGED
@@ -29,10 +29,6 @@ from torch.utils import _pytree as pytree
29
29
  from tico.serialize.quant_param import QuantParam
30
30
 
31
31
 
32
- HAS_TORCH_OVER_25 = Version(torch.__version__) >= Version("2.5.0")
33
- HAS_TORCH_OVER_28_DEV = Version(torch.__version__) >= Version("2.8.0.dev")
34
-
35
-
36
32
  def get_fake_mode(exported_program: ExportedProgram):
37
33
  fake_mode = detect_fake_mode(
38
34
  tuple(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tico
3
- Version: 0.1.0.dev250717
3
+ Version: 0.1.0.dev250721
4
4
  Summary: Convert exported Torch module to circle
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- tico/__init__.py,sha256=8WsnAhznDCSGOK_vrdZdi2apsz1wqJDKSfCjR4LTG8c,1743
1
+ tico/__init__.py,sha256=QY54qph93oHLtUUYMg8T_e_Bpn4Zmxfif9POfAGssHI,1743
2
2
  tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
3
3
  tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
4
4
  tico/config/base.py,sha256=anwOiJFkUxUi7Cef573JgQcjk6S-FSi6O_TLjYASW-g,1244
@@ -51,7 +51,7 @@ tico/experimental/quantization/evaluation/executor/circle_executor.py,sha256=eCC
51
51
  tico/experimental/quantization/evaluation/executor/triv24_executor.py,sha256=sUoXl6oOO2arAKaNjOBg7HiQja145_Jv6qgY7XtR7A8,5159
52
52
  tico/experimental/quantization/passes/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
53
53
  tico/experimental/quantization/passes/fold_quant_ops.py,sha256=iaBMyO49CwVkhebMz3rjkHWfWE2LhwH6fORe7n4S6XQ,7040
54
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py,sha256=FfqTlGANcG1V64zw0MFcIxL9WafuxPINuzWohGdsYCg,16617
54
+ tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py,sha256=t3bnNY9Abm8CZfSWzsbvx59luylXVxqmUvCKPBVPAIE,14731
55
55
  tico/experimental/quantization/passes/propagate_qparam_backward.py,sha256=TGtyW0Z2qOTgVIasBdGRgbwH31YYd6ek7OvLTmCV614,3118
56
56
  tico/experimental/quantization/passes/propagate_qparam_forward.py,sha256=RhUHGCR2RpBO5KYkQ7Z8U5u7HEwDq2wdKHLKAJCi-5c,5138
57
57
  tico/experimental/quantization/passes/quantize_bias.py,sha256=ZQ3rETYStpW28JUbODRixbq5sDEOiIOB_qWA-Jzuu-Y,4337
@@ -96,7 +96,7 @@ tico/passes/remove_redundant_to_copy.py,sha256=tKy4XKkO2l33fMxVPQ_iFkUeFvP15kbPv
96
96
  tico/passes/restore_linear.py,sha256=xGJdNb-1CrkOKS9BnLbcblkZc6P2vVjKIi-7lRcs7Bk,4111
97
97
  tico/passes/segment_index_select.py,sha256=jn0M2sdUcDyjrvxfvM40wt5644iPQMY_ud0uvptXN84,5187
98
98
  tico/serialize/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
99
- tico/serialize/circle_graph.py,sha256=_u0vFDhPdOhEkucmaEhqILo13NKbjyVemPYFfC5YCZg,11619
99
+ tico/serialize/circle_graph.py,sha256=3t78g5eKzhHKvIBJqQ-CcwbqoV-2QwAdd_8wm4W1yXw,12317
100
100
  tico/serialize/circle_mapping.py,sha256=C9C3ORACQOdvBdnt5KRzlT8zao_TvzQklIxH794OhP0,5719
101
101
  tico/serialize/circle_serializer.py,sha256=KRx_Azx2Je9XNYe-pZuuiSMvbXEddd8M8qDATIt7XXk,8981
102
102
  tico/serialize/pack.py,sha256=5HZ9kX3x6C6CyT_FWS6FRmvx_P7Dx21orjUNQxJ2xlo,1297
@@ -169,7 +169,7 @@ tico/serialize/operators/op_select_copy.py,sha256=GPLN7QZmwSlA4WRbjfU6pLer3KVWzg
169
169
  tico/serialize/operators/op_sigmoid.py,sha256=ZubbGG1yU5uvNkEmOmbjj3eq7d9mwEaJdChRgL0OjDU,2045
170
170
  tico/serialize/operators/op_sin.py,sha256=MbttmHTVKhwKK6gH9Vbcbn5aAaxnQ71NdpmQAlTcojU,1827
171
171
  tico/serialize/operators/op_slice.py,sha256=g0r8lj5CIxpT6ixOKqUzwKiNhoiuIFwWjbpaiCoOg6w,5259
172
- tico/serialize/operators/op_softmax.py,sha256=8AwmsAVdSoIMKdfejrw9cy44TbOvvXsA0w3WQDVpI3A,3855
172
+ tico/serialize/operators/op_softmax.py,sha256=qwYke5zfhnSL89DZbzdr5Fc9SsJf0vI-LDZjT_NFpbc,3669
173
173
  tico/serialize/operators/op_split_with_sizes.py,sha256=TgYg1cu-3BSz9SsXfAhoJbo4q5ZzFaoFArkH_obsYlU,4274
174
174
  tico/serialize/operators/op_sqrt.py,sha256=9Q5jkuEPrim11XfSQHGDGVTMYk1TQhOfVqMVYD_eIrI,1871
175
175
  tico/serialize/operators/op_squeeze.py,sha256=QnNwfAdTC1xBm04C9DkVs8VB5YRN-4fCsIWn189QaPg,2416
@@ -183,9 +183,10 @@ tico/serialize/operators/op_view.py,sha256=5EMww-ve17Vm9XPuV03Tn7vJsjpU2J8U4d_FO
183
183
  tico/serialize/operators/op_where.py,sha256=doE81GSwygrPBm3JIfN9w7kKXxeIYKxgk0eoY22QIcg,2845
184
184
  tico/serialize/operators/utils.py,sha256=lXGpEJW1h8U_-gfc6EWjvvSiq3yJ9P-v1v3EMRT_pSk,2954
185
185
  tico/utils/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
186
- tico/utils/convert.py,sha256=w4l7fnqbiVACOU5-OXr8Ebyl4EMeeBz6vwUSuOS_CtI,12977
186
+ tico/utils/convert.py,sha256=lNbbNswbKyCTK5E5i5CXkJBykWWfF5HDChhM3DLscWo,13222
187
187
  tico/utils/define.py,sha256=Ypgp7YffM4pgPl4Zh6TmogSn1OxGBMRw_e09qYGflZk,1467
188
188
  tico/utils/diff_graph.py,sha256=_eDGGPDPYQD4b--MXX0DLoVgSt_wLfNPt47UlolLLR4,5272
189
+ tico/utils/dtype.py,sha256=4-k1iUaHivFFXAQuDs7up6fXt5y4FqldGNokAPa3kic,603
189
190
  tico/utils/errors.py,sha256=f3csJjgbXG9W1aHhqEcou008Aor19W57X8oT5Hx8w1M,954
190
191
  tico/utils/graph.py,sha256=Y6aODsnc_-9l61oanknb7K1jqJ8B35iPypOKkM0Qkk0,9149
191
192
  tico/utils/installed_packages.py,sha256=J0FTwnkCGs0MxRWoCMYAqiwH7Z0GWFDLV--x-IndSp4,1017
@@ -196,16 +197,17 @@ tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
196
197
  tico/utils/pytree_utils.py,sha256=jrk3N6X6LiUnBCX_gM1K9nywbVAJBVnszlTAgeIeDUc,5219
197
198
  tico/utils/register_custom_op.py,sha256=3-Yl6iYmx1qQA2igNHt4hYhQhQMkdPb7gF50LIY8yvc,27350
198
199
  tico/utils/serialize.py,sha256=AQXMBOLu-Kg2Rn-qbqsAtHndjZAZIavlKA0QFgJREHM,1420
200
+ tico/utils/torch_compat.py,sha256=oc6PztVsXdHcQ3iaVR90wLLxrGaj6zFHWZ8K9rRS6q8,1795
199
201
  tico/utils/trace_decorators.py,sha256=ddLIiKQfSaQrxgF1kNpwjFTQnXENzeSfcr1kuAW4jGI,3221
200
- tico/utils/utils.py,sha256=fnbZ2RLH6-J-wqb32O4qsR1ce4BJU0wYNrk84QXa6_E,13158
202
+ tico/utils/utils.py,sha256=kg4in3E0eH2nURpnt8XRPUS2Iam3HjJRadDZdyhUy0w,13014
201
203
  tico/utils/validate_args_kwargs.py,sha256=3dXkNll9E9eZq-p0HjYaV4YltQESqdEHBU34k-tIg1k,26733
202
204
  tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
203
205
  tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
204
206
  tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
205
207
  tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
206
- tico-0.1.0.dev250717.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
207
- tico-0.1.0.dev250717.dist-info/METADATA,sha256=AxsK-qqfRS2Cd0fJ8ChPI-pVZHvX5Kt1XeB7SMkdyKc,8430
208
- tico-0.1.0.dev250717.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
209
- tico-0.1.0.dev250717.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
210
- tico-0.1.0.dev250717.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
211
- tico-0.1.0.dev250717.dist-info/RECORD,,
208
+ tico-0.1.0.dev250721.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
209
+ tico-0.1.0.dev250721.dist-info/METADATA,sha256=OwOUs1qxfuulEmUn99ocFFSGB9L0XmUmCqNhZugmsCU,8430
210
+ tico-0.1.0.dev250721.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
211
+ tico-0.1.0.dev250721.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
212
+ tico-0.1.0.dev250721.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
213
+ tico-0.1.0.dev250721.dist-info/RECORD,,