tico 0.1.0.dev250722__py3-none-any.whl → 0.1.0.dev250724__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. tico/__init__.py +9 -1
  2. tico/config/base.py +1 -1
  3. tico/experimental/quantization/__init__.py +5 -0
  4. tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +1 -6
  5. tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +1 -1
  6. tico/experimental/quantization/algorithm/pt2e/utils.py +0 -1
  7. tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +1 -1
  8. tico/experimental/quantization/evaluation/evaluate.py +1 -1
  9. tico/experimental/quantization/passes/fold_quant_ops.py +0 -1
  10. tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +1 -1
  11. tico/experimental/quantization/passes/quantize_bias.py +0 -1
  12. tico/experimental/quantization/passes/remove_weight_dequant_op.py +1 -1
  13. tico/passes/cast_aten_where_arg_type.py +1 -1
  14. tico/passes/cast_mixed_type_args.py +2 -2
  15. tico/passes/const_prop_pass.py +1 -1
  16. tico/passes/convert_conv1d_to_conv2d.py +1 -1
  17. tico/passes/decompose_addmm.py +0 -3
  18. tico/passes/decompose_batch_norm.py +2 -2
  19. tico/passes/decompose_fake_quantize.py +0 -3
  20. tico/passes/decompose_fake_quantize_tensor_qparams.py +0 -2
  21. tico/passes/decompose_group_norm.py +0 -3
  22. tico/passes/legalize_predefined_layout_operators.py +2 -11
  23. tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
  24. tico/passes/lower_to_slice.py +1 -1
  25. tico/passes/merge_consecutive_cat.py +1 -1
  26. tico/passes/remove_redundant_expand.py +0 -5
  27. tico/passes/remove_redundant_reshape.py +5 -5
  28. tico/passes/segment_index_select.py +1 -1
  29. tico/serialize/circle_graph.py +1 -1
  30. tico/serialize/circle_serializer.py +234 -141
  31. tico/serialize/operators/op_any.py +0 -3
  32. tico/serialize/operators/op_clamp.py +2 -5
  33. tico/serialize/operators/op_full_like.py +0 -2
  34. tico/serialize/operators/op_instance_norm.py +0 -6
  35. tico/serialize/operators/op_mul.py +2 -8
  36. tico/serialize/operators/op_transpose_conv.py +0 -2
  37. tico/serialize/quant_param.py +5 -5
  38. tico/utils/convert.py +1 -1
  39. tico/utils/graph.py +1 -1
  40. tico/utils/padding.py +0 -2
  41. tico/utils/serialize.py +0 -3
  42. tico/utils/utils.py +1 -2
  43. tico/utils/validate_args_kwargs.py +1 -3
  44. {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/METADATA +1 -1
  45. {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/RECORD +49 -49
  46. {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/LICENSE +0 -0
  47. {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/WHEEL +0 -0
  48. {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/entry_points.txt +0 -0
  49. {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/top_level.txt +0 -0
@@ -18,12 +18,7 @@ from typing import Dict
18
18
  import flatbuffers
19
19
  import torch
20
20
  from circle_schema import circle
21
- from torch.export.exported_program import (
22
- ConstantArgument,
23
- ExportedProgram,
24
- InputKind,
25
- TensorArgument,
26
- )
21
+ from torch.export.exported_program import ConstantArgument, ExportedProgram, InputKind
27
22
 
28
23
  from tico.serialize.circle_mapping import to_circle_dtype
29
24
  from tico.serialize.operators import *
@@ -39,147 +34,38 @@ multiple_output_ops = [
39
34
  torch.ops.aten.max.dim,
40
35
  ]
41
36
 
42
- # Build circle model from ExportedProgram
43
- # Return raw bytes of circle model
44
- def build_circle(edge_program: ExportedProgram) -> bytes:
45
- logger = logging.getLogger(__name__)
46
37
 
47
- builder = flatbuffers.Builder()
38
+ def _initialize_model() -> tuple[CircleModel, CircleSubgraph]:
39
+ """Initialize a new Circle model and subgraph.
48
40
 
49
- # Init Model
41
+ Returns:
42
+ Tuple containing the model and subgraph
43
+ """
50
44
  model = CircleModel()
51
-
52
- # Add empty buffer at the front (convention)
53
- model.add_buffer(circle.Buffer.BufferT())
54
-
55
- # Create an empty subgraph (assume a single subgraph)
45
+ model.add_buffer(circle.Buffer.BufferT()) # Add empty buffer at the front
56
46
  graph = CircleSubgraph(model)
47
+ return model, graph
57
48
 
58
- # Export tensors
59
- logger.debug("---------------Export tensors--------------")
60
- buf_name_to_data = {name: buf for name, buf in edge_program.named_buffers()}
61
- for node in edge_program.graph.nodes:
62
- if node.op == "call_function":
63
- if node.target in multiple_output_ops:
64
- continue
65
- node_val = node.meta["val"]
66
- if node_val.layout != torch.strided:
67
- raise RuntimeError(
68
- f"Only support dense tensors (node layout: {node_val.layout})"
69
- )
70
- graph.add_tensor_from_node(node)
71
- logger.debug(f"call_function: {node.name} tensor exported.")
72
-
73
- # placeholder: function input (including parameters, buffers, constant tensors)
74
- elif node.op == "placeholder":
75
- # placeholder invariants
76
- assert node.args is None or len(node.args) == 0 # Not support default param
77
-
78
- # parameters
79
- if node.name in edge_program.graph_signature.inputs_to_parameters:
80
- param_name = edge_program.graph_signature.inputs_to_parameters[
81
- node.name
82
- ]
83
- param_data = edge_program.state_dict[param_name]
84
-
85
- assert isinstance(
86
- param_data, torch.Tensor
87
- ), "Expect parameters to be a tensor"
88
- param_value = param_data.cpu().detach().numpy()
89
-
90
- graph.add_tensor_from_node(node, param_value)
91
- logger.debug(f"placeholder(param): {node.name} tensor exported.")
92
- elif node.name in edge_program.graph_signature.inputs_to_buffers:
93
- buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
94
- assert buffer_name in buf_name_to_data
95
- buffer_data = buf_name_to_data[buffer_name]
96
- assert isinstance(
97
- buffer_data, torch.Tensor
98
- ), "Expect buffers to be a tensor"
99
- buffer_value = buffer_data.cpu().detach().numpy()
100
-
101
- graph.add_tensor_from_node(node, buffer_value)
102
- logger.debug(f"placeholder(buffer): {node.name} tensor exported.")
103
- elif (
104
- node.name
105
- in edge_program.graph_signature.inputs_to_lifted_tensor_constants
106
- ):
107
- ctensor_name = (
108
- edge_program.graph_signature.inputs_to_lifted_tensor_constants[
109
- node.name
110
- ]
111
- )
112
- ctensor_data = edge_program.constants[ctensor_name]
113
-
114
- assert isinstance(
115
- ctensor_data, torch.Tensor
116
- ), "Expect constant tensor to be a tensor"
117
- ctensor_value = ctensor_data.cpu().detach().numpy()
118
-
119
- graph.add_tensor_from_node(node, ctensor_value)
120
- logger.debug(
121
- f"placeholder(constant tensor): {node.name} tensor exported."
122
- )
123
- else:
124
- user_inputs = [
125
- specs
126
- for specs in edge_program.graph_signature.input_specs
127
- if specs.kind == InputKind.USER_INPUT
128
- ]
129
- constant_inputs = [
130
- specs
131
- for specs in user_inputs
132
- if isinstance(specs.arg, ConstantArgument)
133
- ]
134
- name_to_value = {
135
- specs.arg.name: specs.arg.value for specs in constant_inputs
136
- }
137
- # NoneType ConstantArgument is ignored.
138
- if node.name in name_to_value and name_to_value[node.name] == None:
139
- continue
140
- graph.add_tensor_from_node(node)
141
- logger.debug(f"placeholder: {node.name} tensor exported.")
142
-
143
- # get_attr: retrieve parameter
144
- elif node.op == "get_attr":
145
- # node.name: Place where fetched attribute is saved
146
- # node.target: Attribute in the module
147
- attr_tensor = getattr(node.graph.owning_module, node.target)
148
- assert isinstance(attr_tensor, torch.Tensor)
149
49
 
150
- graph.add_tensor_from_scratch(
151
- prefix=node.name,
152
- shape=list(attr_tensor.shape),
153
- dtype=to_circle_dtype(attr_tensor.dtype),
154
- source_node=node,
155
- )
50
+ def build_circle(ep: ExportedProgram) -> bytes:
51
+ """Convert ExportedProgram to Circle format.
156
52
 
157
- logger.debug(f"get_attr: {node.name} tensor exported.")
53
+ Args:
54
+ ep: The exported PyTorch program to convert
158
55
 
159
- # output: function output
160
- elif node.op == "output":
161
- # output node itself does not need a buffer
162
- # argument of output node is assumed to be exported beforehand
163
- for output in node.args[0]:
164
- if isinstance(output, torch.fx.Node):
165
- assert graph.has_tensor(output.name)
166
- continue
167
-
168
- # call_method: call method
169
- elif node.op == "call_method":
170
- raise AssertionError("Not yet implemented")
171
-
172
- # call_module: call 'forward' of module
173
- elif node.op == "call_module":
174
- raise AssertionError("Not yet implemented")
56
+ Returns:
57
+ bytes: Raw bytes of the Circle model
58
+ """
59
+ logger = logging.getLogger(__name__)
60
+ builder = flatbuffers.Builder()
61
+ model, graph = _initialize_model()
175
62
 
176
- else:
177
- # Add more if fx.Node is extended
178
- raise AssertionError(f"Unknown fx.Node op {node.op}")
63
+ # Export tensors
64
+ _export_tensors(graph, ep)
179
65
 
180
66
  # Register inputs
181
67
  logger.debug("---------------Register inputs--------------")
182
- for in_spec in edge_program.graph_signature.input_specs:
68
+ for in_spec in ep.graph_signature.input_specs:
183
69
  if in_spec.kind != InputKind.USER_INPUT:
184
70
  continue
185
71
  # NoneType ConstantArgument is ignored.
@@ -191,9 +77,9 @@ def build_circle(edge_program: ExportedProgram) -> bytes:
191
77
 
192
78
  # Register outputs
193
79
  logger.debug("---------------Register outputs--------------")
194
- for user_output in edge_program.graph_signature.user_outputs:
80
+ for user_output in ep.graph_signature.user_outputs:
195
81
  if user_output == None:
196
- logger.debug(f"Ignore 'None' output")
82
+ logger.debug("Ignore 'None' output")
197
83
  continue
198
84
 
199
85
  graph.add_output(user_output)
@@ -203,7 +89,7 @@ def build_circle(edge_program: ExportedProgram) -> bytes:
203
89
  logger.debug("---------------Export operators--------------")
204
90
  op_codes: Dict[OpCode, int] = {}
205
91
  visitors = get_node_visitors(op_codes, graph)
206
- for node in edge_program.graph.nodes:
92
+ for node in ep.graph.nodes:
207
93
  if node.op != "call_function":
208
94
  continue
209
95
 
@@ -227,10 +113,8 @@ def build_circle(edge_program: ExportedProgram) -> bytes:
227
113
  code for code, _ in sorted(op_codes.items(), key=lambda x: x[1])
228
114
  ]
229
115
 
230
- # Description
116
+ # Final model settings
231
117
  model.description = "circle"
232
-
233
- # Set version
234
118
  model.version = 0
235
119
 
236
120
  # Finish model
@@ -238,3 +122,212 @@ def build_circle(edge_program: ExportedProgram) -> bytes:
238
122
  buf = builder.Output()
239
123
 
240
124
  return bytes(buf)
125
+
126
+
127
+ def _export_tensors(graph: CircleSubgraph, ep: ExportedProgram) -> None:
128
+ """Export all tensors from the exported program to the circle graph.
129
+
130
+ Args:
131
+ graph: The CircleSubgraph to add tensors to
132
+ ep: The exported PyTorch program
133
+ """
134
+ logger = logging.getLogger(__name__)
135
+ logger.debug("---------------Export tensors--------------")
136
+ buf_name_to_data = {name: buf for name, buf in ep.named_buffers()}
137
+
138
+ for node in ep.graph.nodes:
139
+ if node.op == "call_function":
140
+ if node.target in multiple_output_ops:
141
+ continue
142
+ node_val = node.meta["val"]
143
+ if node_val.layout != torch.strided:
144
+ raise RuntimeError(
145
+ f"Only support dense tensors (node layout: {node_val.layout})"
146
+ )
147
+ graph.add_tensor_from_node(node)
148
+ logger.debug(f"call_function: {node.name} tensor exported.")
149
+
150
+ elif node.op == "placeholder":
151
+ _handle_placeholder_node(graph, node, ep, buf_name_to_data)
152
+
153
+ elif node.op == "get_attr":
154
+ _handle_get_attr_node(graph, node)
155
+
156
+ elif node.op == "output":
157
+ for output in node.args[0]:
158
+ if isinstance(output, torch.fx.Node):
159
+ assert graph.has_tensor(output.name)
160
+ continue
161
+
162
+ elif node.op == "call_method":
163
+ raise AssertionError("Not yet implemented")
164
+
165
+ elif node.op == "call_module":
166
+ raise AssertionError("Not yet implemented")
167
+
168
+ else:
169
+ raise AssertionError(f"Unknown fx.Node op {node.op}")
170
+
171
+
172
+ def _handle_placeholder_node(
173
+ graph: CircleSubgraph,
174
+ node: torch.fx.Node,
175
+ ep: ExportedProgram,
176
+ buf_name_to_data: dict,
177
+ ) -> None:
178
+ """Handle a placeholder node during tensor export."""
179
+ # placeholder invariants
180
+ assert node.args is None or len(node.args) == 0 # Not support default param
181
+
182
+ if node.name in ep.graph_signature.inputs_to_parameters:
183
+ _handle_parameter_node(graph, node, ep)
184
+ elif node.name in ep.graph_signature.inputs_to_buffers:
185
+ _handle_buffer_node(graph, node, ep, buf_name_to_data)
186
+ elif node.name in ep.graph_signature.inputs_to_lifted_tensor_constants:
187
+ _handle_constant_tensor_node(graph, node, ep)
188
+ else:
189
+ _handle_user_input_node(graph, node, ep)
190
+
191
+
192
+ def _handle_parameter_node(
193
+ graph: CircleSubgraph,
194
+ node: torch.fx.Node,
195
+ ep: ExportedProgram,
196
+ ) -> None:
197
+ """Handle a parameter placeholder node by exporting its tensor data.
198
+
199
+ Args:
200
+ graph: CircleSubgraph to add tensor to
201
+ node: The parameter node to process
202
+ ep: ExportedProgram containing parameter data
203
+ """
204
+ param_name = ep.graph_signature.inputs_to_parameters[node.name]
205
+ param_data = ep.state_dict[param_name]
206
+
207
+ if not isinstance(param_data, torch.Tensor):
208
+ raise ValueError(f"Parameter {param_name} is not a tensor")
209
+
210
+ tensor_value = param_data.cpu().detach().numpy()
211
+ graph.add_tensor_from_node(node, tensor_value)
212
+
213
+ logger = logging.getLogger(__name__)
214
+ logger.debug(f"Exported parameter tensor: {node.name}")
215
+
216
+
217
+ def _handle_buffer_node(
218
+ graph: CircleSubgraph,
219
+ node: torch.fx.Node,
220
+ ep: ExportedProgram,
221
+ buf_name_to_data: dict,
222
+ ) -> None:
223
+ """Handle a buffer placeholder node by exporting its tensor data.
224
+
225
+ Args:
226
+ graph: CircleSubgraph to add tensor to
227
+ node: The buffer node to process
228
+ ep: ExportedProgram containing buffer info
229
+ buf_name_to_data: Mapping of buffer names to data
230
+ """
231
+ buffer_name = ep.graph_signature.inputs_to_buffers[node.name]
232
+
233
+ if buffer_name not in buf_name_to_data:
234
+ raise ValueError(f"Buffer {buffer_name} not found in buffer data")
235
+
236
+ buffer_data = buf_name_to_data[buffer_name]
237
+
238
+ if not isinstance(buffer_data, torch.Tensor):
239
+ raise ValueError(f"Buffer {buffer_name} is not a tensor")
240
+
241
+ tensor_value = buffer_data.cpu().detach().numpy()
242
+ graph.add_tensor_from_node(node, tensor_value)
243
+
244
+ logger = logging.getLogger(__name__)
245
+ logger.debug(f"Exported buffer tensor: {node.name}")
246
+
247
+
248
+ def _handle_constant_tensor_node(
249
+ graph: CircleSubgraph,
250
+ node: torch.fx.Node,
251
+ ep: ExportedProgram,
252
+ ) -> None:
253
+ """Handle a constant tensor placeholder node by exporting its tensor data.
254
+
255
+ Args:
256
+ graph: CircleSubgraph to add tensor to
257
+ node: The constant tensor node to process
258
+ ep: ExportedProgram containing constant data
259
+ """
260
+ ctensor_name = ep.graph_signature.inputs_to_lifted_tensor_constants[node.name]
261
+
262
+ if ctensor_name not in ep.constants:
263
+ raise ValueError(f"Constant tensor {ctensor_name} not found")
264
+
265
+ ctensor_data = ep.constants[ctensor_name]
266
+
267
+ if not isinstance(ctensor_data, torch.Tensor):
268
+ raise ValueError(f"Constant tensor {ctensor_name} is not a tensor")
269
+
270
+ tensor_value = ctensor_data.cpu().detach().numpy()
271
+ graph.add_tensor_from_node(node, tensor_value)
272
+
273
+ logger = logging.getLogger(__name__)
274
+ logger.debug(f"Exported constant tensor: {node.name}")
275
+
276
+
277
+ def _handle_user_input_node(
278
+ graph: CircleSubgraph,
279
+ node: torch.fx.Node,
280
+ ep: ExportedProgram,
281
+ ) -> None:
282
+ """Handle a user input placeholder node by exporting its tensor data.
283
+
284
+ Args:
285
+ graph: CircleSubgraph to add tensor to
286
+ node: The user input node to process
287
+ ep: ExportedProgram containing input specs
288
+ """
289
+ user_inputs = [
290
+ specs
291
+ for specs in ep.graph_signature.input_specs
292
+ if specs.kind == InputKind.USER_INPUT
293
+ ]
294
+ constant_inputs = [
295
+ specs for specs in user_inputs if isinstance(specs.arg, ConstantArgument)
296
+ ]
297
+ name_to_value = {specs.arg.name: specs.arg.value for specs in constant_inputs}
298
+
299
+ # Skip NoneType ConstantArgument
300
+ if node.name in name_to_value and name_to_value[node.name] is None:
301
+ return
302
+
303
+ graph.add_tensor_from_node(node)
304
+
305
+ logger = logging.getLogger(__name__)
306
+ logger.debug(f"Exported user input tensor: {node.name}")
307
+
308
+
309
+ def _handle_get_attr_node(
310
+ graph: CircleSubgraph,
311
+ node: torch.fx.Node,
312
+ ) -> None:
313
+ """Handle a get_attr node by exporting its tensor data.
314
+
315
+ Args:
316
+ graph: CircleSubgraph to add tensor to
317
+ node: The get_attr node to process
318
+ """
319
+ assert isinstance(node.target, str)
320
+ attr_tensor = getattr(node.graph.owning_module, node.target)
321
+
322
+ if not isinstance(attr_tensor, torch.Tensor):
323
+ raise ValueError(f"Attribute {node.target} is not a tensor")
324
+
325
+ graph.add_tensor_from_scratch(
326
+ prefix=node.name,
327
+ shape=list(attr_tensor.shape),
328
+ dtype=to_circle_dtype(attr_tensor.dtype),
329
+ source_node=node,
330
+ )
331
+
332
+ logger = logging.getLogger(__name__)
333
+ logger.debug(f"Exported attribute tensor: {node.name}")
@@ -22,7 +22,6 @@ from circle_schema import circle
22
22
  from tico.serialize.circle_graph import CircleSubgraph
23
23
  from tico.serialize.circle_mapping import (
24
24
  circle_legalize_dtype_to,
25
- extract_circle_dtype,
26
25
  extract_shape,
27
26
  extract_torch_dtype,
28
27
  )
@@ -100,8 +99,6 @@ class AnyVisitor(NodeVisitor):
100
99
  keepdim = args.keepdim
101
100
 
102
101
  input_shape = list(extract_shape(input))
103
- output_shape = list(extract_shape(node))
104
-
105
102
  dim_i32 = None
106
103
  if dim is None:
107
104
  dims = tuple(i for i in range(0, len(input_shape)))
@@ -21,12 +21,9 @@ import torch
21
21
  from circle_schema import circle
22
22
 
23
23
  from tico.passes import ops
24
+ from tico.serialize.circle_graph import CircleSubgraph
24
25
 
25
- from tico.serialize.circle_graph import (
26
- CircleSubgraph,
27
- extract_circle_dtype,
28
- extract_shape,
29
- )
26
+ from tico.serialize.circle_mapping import extract_circle_dtype, extract_shape
30
27
  from tico.serialize.operators.hashable_opcode import OpCode
31
28
  from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
32
29
  from tico.serialize.operators.utils import create_builtin_operator, get_op_index
@@ -21,10 +21,8 @@ import torch
21
21
  from circle_schema import circle
22
22
 
23
23
  from tico.serialize.circle_graph import CircleSubgraph
24
- from tico.serialize.circle_mapping import to_circle_dtype
25
24
  from tico.serialize.operators.hashable_opcode import OpCode
26
25
  from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
27
- from tico.serialize.operators.utils import create_builtin_operator, get_op_index
28
26
  from tico.utils.validate_args_kwargs import FullLikeArgs
29
27
 
30
28
 
@@ -73,12 +73,6 @@ class InstanceNormVisitor(NodeVisitor):
73
73
  eps = args.eps
74
74
 
75
75
  # Ignore training-related args
76
- running_mean = args.running_mean
77
- running_var = args.running_var
78
- use_input_stats = args.use_input_stats
79
- momentum = args.momentum
80
- cudnn_enabled = args.cudnn_enabled
81
-
82
76
  input_shape = list(extract_shape(input))
83
77
  assert len(input_shape) == 4, len(input_shape)
84
78
 
@@ -66,10 +66,7 @@ class MulTensorVisitor(BaseMulVisitor):
66
66
  self,
67
67
  node: torch.fx.Node,
68
68
  ) -> circle.Operator.OperatorT:
69
- args = MulTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
70
- input = args.input
71
- other = args.other
72
-
69
+ _ = MulTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
73
70
  operator = super().define_node(
74
71
  node,
75
72
  )
@@ -88,10 +85,7 @@ class MulScalarVisitor(BaseMulVisitor):
88
85
  self,
89
86
  node: torch.fx.Node,
90
87
  ) -> circle.Operator.OperatorT:
91
- args = MulScalarArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
92
- input = args.input
93
- other = args.other
94
-
88
+ _ = MulScalarArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
95
89
  operator = super().define_node(
96
90
  node,
97
91
  )
@@ -76,9 +76,7 @@ class TransposeConvVisitor(NodeVisitor):
76
76
  bias = args.bias
77
77
  stride = args.stride
78
78
  padding = args.padding
79
- output_padding = args.output_padding
80
79
  groups = args.groups
81
- dilation = args.dilation
82
80
 
83
81
  assert groups == 1, "Only support group 1"
84
82
 
@@ -12,6 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from dataclasses import dataclass
16
+ from typing import List, Optional
17
+
18
+ import torch
19
+
15
20
  """
16
21
  This is a key for torch.fx.Node's meta dict to save QuantParam
17
22
 
@@ -19,11 +24,6 @@ QuantParam can be retrieved as node.meta[QPARAM_KEY]
19
24
  """
20
25
  QPARAM_KEY = "_quantization_parameters_"
21
26
 
22
- from dataclasses import dataclass
23
- from typing import List, Optional
24
-
25
- import torch
26
-
27
27
 
28
28
  @dataclass
29
29
  class QuantParam:
tico/utils/convert.py CHANGED
@@ -157,7 +157,7 @@ def check_unsupported_target(exported_program: ExportedProgram):
157
157
  for n in exported_program.graph.nodes:
158
158
  if n.op != "call_function":
159
159
  continue
160
- if not n.target in supported_target:
160
+ if n.target not in supported_target:
161
161
  unsupported.append(n)
162
162
 
163
163
  if unsupported:
tico/utils/graph.py CHANGED
@@ -24,7 +24,7 @@ import torch
24
24
  from torch.export import ExportedProgram
25
25
  from torch.export.exported_program import InputKind, InputSpec, TensorArgument
26
26
 
27
- from tico.utils.utils import get_fake_mode, set_new_meta_val
27
+ from tico.utils.utils import get_fake_mode
28
28
 
29
29
 
30
30
  def is_torch_param(node: torch.fx.Node, ep: ExportedProgram):
tico/utils/padding.py CHANGED
@@ -15,8 +15,6 @@
15
15
  from enum import IntEnum
16
16
  from typing import NamedTuple, Optional, Sequence, Tuple, Union
17
17
 
18
- import torch
19
-
20
18
  from tico.utils.errors import InvalidArgumentError
21
19
 
22
20
 
tico/utils/serialize.py CHANGED
@@ -12,9 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional
16
-
17
- import torch
18
15
 
19
16
  from tico.serialize.circle_graph import CircleSubgraph
20
17
  from tico.utils.graph import get_module_name_chain
tico/utils/utils.py CHANGED
@@ -21,7 +21,6 @@ from typing import List
21
21
 
22
22
  import torch
23
23
  from circle_schema import circle
24
- from packaging.version import Version
25
24
  from torch._guards import detect_fake_mode
26
25
  from torch.export import ExportedProgram
27
26
  from torch.utils import _pytree as pytree
@@ -131,7 +130,7 @@ def enforce_type(callable):
131
130
 
132
131
  return True
133
132
 
134
- if typing.get_origin(type_hint) == dict:
133
+ if typing.get_origin(type_hint) is dict:
135
134
  if not isinstance(value, typing.get_origin(type_hint)):
136
135
  return False
137
136
 
@@ -16,10 +16,8 @@ from dataclasses import dataclass, field
16
16
  from typing import List, Optional, TYPE_CHECKING, Union
17
17
 
18
18
  if TYPE_CHECKING:
19
- import torch._ops
20
19
  import torch.fx
21
20
  import torch
22
- import torch.fx.node
23
21
 
24
22
  from tico.utils.utils import enforce_type
25
23
 
@@ -137,7 +135,7 @@ class AvgPool2dArgs:
137
135
  assert len(self.padding) == 2, len(self.padding)
138
136
  if self.divisor_override is not None:
139
137
  assert isinstance(self.divisor_override, int), type(self.divisor_override)
140
- assert self.divisor_override != 0, f"Divisor must be not zero."
138
+ assert self.divisor_override != 0, "Divisor must be not zero."
141
139
 
142
140
 
143
141
  @enforce_type
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tico
3
- Version: 0.1.0.dev250722
3
+ Version: 0.1.0.dev250724
4
4
  Summary: Convert exported Torch module to circle
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN