tico 0.1.0.dev250722__py3-none-any.whl → 0.1.0.dev250724__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +9 -1
- tico/config/base.py +1 -1
- tico/experimental/quantization/__init__.py +5 -0
- tico/experimental/quantization/algorithm/pt2e/annotation/annotator.py +1 -6
- tico/experimental/quantization/algorithm/pt2e/annotation/op/linear.py +1 -1
- tico/experimental/quantization/algorithm/pt2e/utils.py +0 -1
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +1 -1
- tico/experimental/quantization/evaluation/evaluate.py +1 -1
- tico/experimental/quantization/passes/fold_quant_ops.py +0 -1
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +1 -1
- tico/experimental/quantization/passes/quantize_bias.py +0 -1
- tico/experimental/quantization/passes/remove_weight_dequant_op.py +1 -1
- tico/passes/cast_aten_where_arg_type.py +1 -1
- tico/passes/cast_mixed_type_args.py +2 -2
- tico/passes/const_prop_pass.py +1 -1
- tico/passes/convert_conv1d_to_conv2d.py +1 -1
- tico/passes/decompose_addmm.py +0 -3
- tico/passes/decompose_batch_norm.py +2 -2
- tico/passes/decompose_fake_quantize.py +0 -3
- tico/passes/decompose_fake_quantize_tensor_qparams.py +0 -2
- tico/passes/decompose_group_norm.py +0 -3
- tico/passes/legalize_predefined_layout_operators.py +2 -11
- tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
- tico/passes/lower_to_slice.py +1 -1
- tico/passes/merge_consecutive_cat.py +1 -1
- tico/passes/remove_redundant_expand.py +0 -5
- tico/passes/remove_redundant_reshape.py +5 -5
- tico/passes/segment_index_select.py +1 -1
- tico/serialize/circle_graph.py +1 -1
- tico/serialize/circle_serializer.py +234 -141
- tico/serialize/operators/op_any.py +0 -3
- tico/serialize/operators/op_clamp.py +2 -5
- tico/serialize/operators/op_full_like.py +0 -2
- tico/serialize/operators/op_instance_norm.py +0 -6
- tico/serialize/operators/op_mul.py +2 -8
- tico/serialize/operators/op_transpose_conv.py +0 -2
- tico/serialize/quant_param.py +5 -5
- tico/utils/convert.py +1 -1
- tico/utils/graph.py +1 -1
- tico/utils/padding.py +0 -2
- tico/utils/serialize.py +0 -3
- tico/utils/utils.py +1 -2
- tico/utils/validate_args_kwargs.py +1 -3
- {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/METADATA +1 -1
- {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/RECORD +49 -49
- {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250722.dist-info → tico-0.1.0.dev250724.dist-info}/top_level.txt +0 -0
@@ -18,12 +18,7 @@ from typing import Dict
|
|
18
18
|
import flatbuffers
|
19
19
|
import torch
|
20
20
|
from circle_schema import circle
|
21
|
-
from torch.export.exported_program import
|
22
|
-
ConstantArgument,
|
23
|
-
ExportedProgram,
|
24
|
-
InputKind,
|
25
|
-
TensorArgument,
|
26
|
-
)
|
21
|
+
from torch.export.exported_program import ConstantArgument, ExportedProgram, InputKind
|
27
22
|
|
28
23
|
from tico.serialize.circle_mapping import to_circle_dtype
|
29
24
|
from tico.serialize.operators import *
|
@@ -39,147 +34,38 @@ multiple_output_ops = [
|
|
39
34
|
torch.ops.aten.max.dim,
|
40
35
|
]
|
41
36
|
|
42
|
-
# Build circle model from ExportedProgram
|
43
|
-
# Return raw bytes of circle model
|
44
|
-
def build_circle(edge_program: ExportedProgram) -> bytes:
|
45
|
-
logger = logging.getLogger(__name__)
|
46
37
|
|
47
|
-
|
38
|
+
def _initialize_model() -> tuple[CircleModel, CircleSubgraph]:
|
39
|
+
"""Initialize a new Circle model and subgraph.
|
48
40
|
|
49
|
-
|
41
|
+
Returns:
|
42
|
+
Tuple containing the model and subgraph
|
43
|
+
"""
|
50
44
|
model = CircleModel()
|
51
|
-
|
52
|
-
# Add empty buffer at the front (convention)
|
53
|
-
model.add_buffer(circle.Buffer.BufferT())
|
54
|
-
|
55
|
-
# Create an empty subgraph (assume a single subgraph)
|
45
|
+
model.add_buffer(circle.Buffer.BufferT()) # Add empty buffer at the front
|
56
46
|
graph = CircleSubgraph(model)
|
47
|
+
return model, graph
|
57
48
|
|
58
|
-
# Export tensors
|
59
|
-
logger.debug("---------------Export tensors--------------")
|
60
|
-
buf_name_to_data = {name: buf for name, buf in edge_program.named_buffers()}
|
61
|
-
for node in edge_program.graph.nodes:
|
62
|
-
if node.op == "call_function":
|
63
|
-
if node.target in multiple_output_ops:
|
64
|
-
continue
|
65
|
-
node_val = node.meta["val"]
|
66
|
-
if node_val.layout != torch.strided:
|
67
|
-
raise RuntimeError(
|
68
|
-
f"Only support dense tensors (node layout: {node_val.layout})"
|
69
|
-
)
|
70
|
-
graph.add_tensor_from_node(node)
|
71
|
-
logger.debug(f"call_function: {node.name} tensor exported.")
|
72
|
-
|
73
|
-
# placeholder: function input (including parameters, buffers, constant tensors)
|
74
|
-
elif node.op == "placeholder":
|
75
|
-
# placeholder invariants
|
76
|
-
assert node.args is None or len(node.args) == 0 # Not support default param
|
77
|
-
|
78
|
-
# parameters
|
79
|
-
if node.name in edge_program.graph_signature.inputs_to_parameters:
|
80
|
-
param_name = edge_program.graph_signature.inputs_to_parameters[
|
81
|
-
node.name
|
82
|
-
]
|
83
|
-
param_data = edge_program.state_dict[param_name]
|
84
|
-
|
85
|
-
assert isinstance(
|
86
|
-
param_data, torch.Tensor
|
87
|
-
), "Expect parameters to be a tensor"
|
88
|
-
param_value = param_data.cpu().detach().numpy()
|
89
|
-
|
90
|
-
graph.add_tensor_from_node(node, param_value)
|
91
|
-
logger.debug(f"placeholder(param): {node.name} tensor exported.")
|
92
|
-
elif node.name in edge_program.graph_signature.inputs_to_buffers:
|
93
|
-
buffer_name = edge_program.graph_signature.inputs_to_buffers[node.name]
|
94
|
-
assert buffer_name in buf_name_to_data
|
95
|
-
buffer_data = buf_name_to_data[buffer_name]
|
96
|
-
assert isinstance(
|
97
|
-
buffer_data, torch.Tensor
|
98
|
-
), "Expect buffers to be a tensor"
|
99
|
-
buffer_value = buffer_data.cpu().detach().numpy()
|
100
|
-
|
101
|
-
graph.add_tensor_from_node(node, buffer_value)
|
102
|
-
logger.debug(f"placeholder(buffer): {node.name} tensor exported.")
|
103
|
-
elif (
|
104
|
-
node.name
|
105
|
-
in edge_program.graph_signature.inputs_to_lifted_tensor_constants
|
106
|
-
):
|
107
|
-
ctensor_name = (
|
108
|
-
edge_program.graph_signature.inputs_to_lifted_tensor_constants[
|
109
|
-
node.name
|
110
|
-
]
|
111
|
-
)
|
112
|
-
ctensor_data = edge_program.constants[ctensor_name]
|
113
|
-
|
114
|
-
assert isinstance(
|
115
|
-
ctensor_data, torch.Tensor
|
116
|
-
), "Expect constant tensor to be a tensor"
|
117
|
-
ctensor_value = ctensor_data.cpu().detach().numpy()
|
118
|
-
|
119
|
-
graph.add_tensor_from_node(node, ctensor_value)
|
120
|
-
logger.debug(
|
121
|
-
f"placeholder(constant tensor): {node.name} tensor exported."
|
122
|
-
)
|
123
|
-
else:
|
124
|
-
user_inputs = [
|
125
|
-
specs
|
126
|
-
for specs in edge_program.graph_signature.input_specs
|
127
|
-
if specs.kind == InputKind.USER_INPUT
|
128
|
-
]
|
129
|
-
constant_inputs = [
|
130
|
-
specs
|
131
|
-
for specs in user_inputs
|
132
|
-
if isinstance(specs.arg, ConstantArgument)
|
133
|
-
]
|
134
|
-
name_to_value = {
|
135
|
-
specs.arg.name: specs.arg.value for specs in constant_inputs
|
136
|
-
}
|
137
|
-
# NoneType ConstantArgument is ignored.
|
138
|
-
if node.name in name_to_value and name_to_value[node.name] == None:
|
139
|
-
continue
|
140
|
-
graph.add_tensor_from_node(node)
|
141
|
-
logger.debug(f"placeholder: {node.name} tensor exported.")
|
142
|
-
|
143
|
-
# get_attr: retrieve parameter
|
144
|
-
elif node.op == "get_attr":
|
145
|
-
# node.name: Place where fetched attribute is saved
|
146
|
-
# node.target: Attribute in the module
|
147
|
-
attr_tensor = getattr(node.graph.owning_module, node.target)
|
148
|
-
assert isinstance(attr_tensor, torch.Tensor)
|
149
49
|
|
150
|
-
|
151
|
-
|
152
|
-
shape=list(attr_tensor.shape),
|
153
|
-
dtype=to_circle_dtype(attr_tensor.dtype),
|
154
|
-
source_node=node,
|
155
|
-
)
|
50
|
+
def build_circle(ep: ExportedProgram) -> bytes:
|
51
|
+
"""Convert ExportedProgram to Circle format.
|
156
52
|
|
157
|
-
|
53
|
+
Args:
|
54
|
+
ep: The exported PyTorch program to convert
|
158
55
|
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
assert graph.has_tensor(output.name)
|
166
|
-
continue
|
167
|
-
|
168
|
-
# call_method: call method
|
169
|
-
elif node.op == "call_method":
|
170
|
-
raise AssertionError("Not yet implemented")
|
171
|
-
|
172
|
-
# call_module: call 'forward' of module
|
173
|
-
elif node.op == "call_module":
|
174
|
-
raise AssertionError("Not yet implemented")
|
56
|
+
Returns:
|
57
|
+
bytes: Raw bytes of the Circle model
|
58
|
+
"""
|
59
|
+
logger = logging.getLogger(__name__)
|
60
|
+
builder = flatbuffers.Builder()
|
61
|
+
model, graph = _initialize_model()
|
175
62
|
|
176
|
-
|
177
|
-
|
178
|
-
raise AssertionError(f"Unknown fx.Node op {node.op}")
|
63
|
+
# Export tensors
|
64
|
+
_export_tensors(graph, ep)
|
179
65
|
|
180
66
|
# Register inputs
|
181
67
|
logger.debug("---------------Register inputs--------------")
|
182
|
-
for in_spec in
|
68
|
+
for in_spec in ep.graph_signature.input_specs:
|
183
69
|
if in_spec.kind != InputKind.USER_INPUT:
|
184
70
|
continue
|
185
71
|
# NoneType ConstantArgument is ignored.
|
@@ -191,9 +77,9 @@ def build_circle(edge_program: ExportedProgram) -> bytes:
|
|
191
77
|
|
192
78
|
# Register outputs
|
193
79
|
logger.debug("---------------Register outputs--------------")
|
194
|
-
for user_output in
|
80
|
+
for user_output in ep.graph_signature.user_outputs:
|
195
81
|
if user_output == None:
|
196
|
-
logger.debug(
|
82
|
+
logger.debug("Ignore 'None' output")
|
197
83
|
continue
|
198
84
|
|
199
85
|
graph.add_output(user_output)
|
@@ -203,7 +89,7 @@ def build_circle(edge_program: ExportedProgram) -> bytes:
|
|
203
89
|
logger.debug("---------------Export operators--------------")
|
204
90
|
op_codes: Dict[OpCode, int] = {}
|
205
91
|
visitors = get_node_visitors(op_codes, graph)
|
206
|
-
for node in
|
92
|
+
for node in ep.graph.nodes:
|
207
93
|
if node.op != "call_function":
|
208
94
|
continue
|
209
95
|
|
@@ -227,10 +113,8 @@ def build_circle(edge_program: ExportedProgram) -> bytes:
|
|
227
113
|
code for code, _ in sorted(op_codes.items(), key=lambda x: x[1])
|
228
114
|
]
|
229
115
|
|
230
|
-
#
|
116
|
+
# Final model settings
|
231
117
|
model.description = "circle"
|
232
|
-
|
233
|
-
# Set version
|
234
118
|
model.version = 0
|
235
119
|
|
236
120
|
# Finish model
|
@@ -238,3 +122,212 @@ def build_circle(edge_program: ExportedProgram) -> bytes:
|
|
238
122
|
buf = builder.Output()
|
239
123
|
|
240
124
|
return bytes(buf)
|
125
|
+
|
126
|
+
|
127
|
+
def _export_tensors(graph: CircleSubgraph, ep: ExportedProgram) -> None:
|
128
|
+
"""Export all tensors from the exported program to the circle graph.
|
129
|
+
|
130
|
+
Args:
|
131
|
+
graph: The CircleSubgraph to add tensors to
|
132
|
+
ep: The exported PyTorch program
|
133
|
+
"""
|
134
|
+
logger = logging.getLogger(__name__)
|
135
|
+
logger.debug("---------------Export tensors--------------")
|
136
|
+
buf_name_to_data = {name: buf for name, buf in ep.named_buffers()}
|
137
|
+
|
138
|
+
for node in ep.graph.nodes:
|
139
|
+
if node.op == "call_function":
|
140
|
+
if node.target in multiple_output_ops:
|
141
|
+
continue
|
142
|
+
node_val = node.meta["val"]
|
143
|
+
if node_val.layout != torch.strided:
|
144
|
+
raise RuntimeError(
|
145
|
+
f"Only support dense tensors (node layout: {node_val.layout})"
|
146
|
+
)
|
147
|
+
graph.add_tensor_from_node(node)
|
148
|
+
logger.debug(f"call_function: {node.name} tensor exported.")
|
149
|
+
|
150
|
+
elif node.op == "placeholder":
|
151
|
+
_handle_placeholder_node(graph, node, ep, buf_name_to_data)
|
152
|
+
|
153
|
+
elif node.op == "get_attr":
|
154
|
+
_handle_get_attr_node(graph, node)
|
155
|
+
|
156
|
+
elif node.op == "output":
|
157
|
+
for output in node.args[0]:
|
158
|
+
if isinstance(output, torch.fx.Node):
|
159
|
+
assert graph.has_tensor(output.name)
|
160
|
+
continue
|
161
|
+
|
162
|
+
elif node.op == "call_method":
|
163
|
+
raise AssertionError("Not yet implemented")
|
164
|
+
|
165
|
+
elif node.op == "call_module":
|
166
|
+
raise AssertionError("Not yet implemented")
|
167
|
+
|
168
|
+
else:
|
169
|
+
raise AssertionError(f"Unknown fx.Node op {node.op}")
|
170
|
+
|
171
|
+
|
172
|
+
def _handle_placeholder_node(
|
173
|
+
graph: CircleSubgraph,
|
174
|
+
node: torch.fx.Node,
|
175
|
+
ep: ExportedProgram,
|
176
|
+
buf_name_to_data: dict,
|
177
|
+
) -> None:
|
178
|
+
"""Handle a placeholder node during tensor export."""
|
179
|
+
# placeholder invariants
|
180
|
+
assert node.args is None or len(node.args) == 0 # Not support default param
|
181
|
+
|
182
|
+
if node.name in ep.graph_signature.inputs_to_parameters:
|
183
|
+
_handle_parameter_node(graph, node, ep)
|
184
|
+
elif node.name in ep.graph_signature.inputs_to_buffers:
|
185
|
+
_handle_buffer_node(graph, node, ep, buf_name_to_data)
|
186
|
+
elif node.name in ep.graph_signature.inputs_to_lifted_tensor_constants:
|
187
|
+
_handle_constant_tensor_node(graph, node, ep)
|
188
|
+
else:
|
189
|
+
_handle_user_input_node(graph, node, ep)
|
190
|
+
|
191
|
+
|
192
|
+
def _handle_parameter_node(
|
193
|
+
graph: CircleSubgraph,
|
194
|
+
node: torch.fx.Node,
|
195
|
+
ep: ExportedProgram,
|
196
|
+
) -> None:
|
197
|
+
"""Handle a parameter placeholder node by exporting its tensor data.
|
198
|
+
|
199
|
+
Args:
|
200
|
+
graph: CircleSubgraph to add tensor to
|
201
|
+
node: The parameter node to process
|
202
|
+
ep: ExportedProgram containing parameter data
|
203
|
+
"""
|
204
|
+
param_name = ep.graph_signature.inputs_to_parameters[node.name]
|
205
|
+
param_data = ep.state_dict[param_name]
|
206
|
+
|
207
|
+
if not isinstance(param_data, torch.Tensor):
|
208
|
+
raise ValueError(f"Parameter {param_name} is not a tensor")
|
209
|
+
|
210
|
+
tensor_value = param_data.cpu().detach().numpy()
|
211
|
+
graph.add_tensor_from_node(node, tensor_value)
|
212
|
+
|
213
|
+
logger = logging.getLogger(__name__)
|
214
|
+
logger.debug(f"Exported parameter tensor: {node.name}")
|
215
|
+
|
216
|
+
|
217
|
+
def _handle_buffer_node(
|
218
|
+
graph: CircleSubgraph,
|
219
|
+
node: torch.fx.Node,
|
220
|
+
ep: ExportedProgram,
|
221
|
+
buf_name_to_data: dict,
|
222
|
+
) -> None:
|
223
|
+
"""Handle a buffer placeholder node by exporting its tensor data.
|
224
|
+
|
225
|
+
Args:
|
226
|
+
graph: CircleSubgraph to add tensor to
|
227
|
+
node: The buffer node to process
|
228
|
+
ep: ExportedProgram containing buffer info
|
229
|
+
buf_name_to_data: Mapping of buffer names to data
|
230
|
+
"""
|
231
|
+
buffer_name = ep.graph_signature.inputs_to_buffers[node.name]
|
232
|
+
|
233
|
+
if buffer_name not in buf_name_to_data:
|
234
|
+
raise ValueError(f"Buffer {buffer_name} not found in buffer data")
|
235
|
+
|
236
|
+
buffer_data = buf_name_to_data[buffer_name]
|
237
|
+
|
238
|
+
if not isinstance(buffer_data, torch.Tensor):
|
239
|
+
raise ValueError(f"Buffer {buffer_name} is not a tensor")
|
240
|
+
|
241
|
+
tensor_value = buffer_data.cpu().detach().numpy()
|
242
|
+
graph.add_tensor_from_node(node, tensor_value)
|
243
|
+
|
244
|
+
logger = logging.getLogger(__name__)
|
245
|
+
logger.debug(f"Exported buffer tensor: {node.name}")
|
246
|
+
|
247
|
+
|
248
|
+
def _handle_constant_tensor_node(
|
249
|
+
graph: CircleSubgraph,
|
250
|
+
node: torch.fx.Node,
|
251
|
+
ep: ExportedProgram,
|
252
|
+
) -> None:
|
253
|
+
"""Handle a constant tensor placeholder node by exporting its tensor data.
|
254
|
+
|
255
|
+
Args:
|
256
|
+
graph: CircleSubgraph to add tensor to
|
257
|
+
node: The constant tensor node to process
|
258
|
+
ep: ExportedProgram containing constant data
|
259
|
+
"""
|
260
|
+
ctensor_name = ep.graph_signature.inputs_to_lifted_tensor_constants[node.name]
|
261
|
+
|
262
|
+
if ctensor_name not in ep.constants:
|
263
|
+
raise ValueError(f"Constant tensor {ctensor_name} not found")
|
264
|
+
|
265
|
+
ctensor_data = ep.constants[ctensor_name]
|
266
|
+
|
267
|
+
if not isinstance(ctensor_data, torch.Tensor):
|
268
|
+
raise ValueError(f"Constant tensor {ctensor_name} is not a tensor")
|
269
|
+
|
270
|
+
tensor_value = ctensor_data.cpu().detach().numpy()
|
271
|
+
graph.add_tensor_from_node(node, tensor_value)
|
272
|
+
|
273
|
+
logger = logging.getLogger(__name__)
|
274
|
+
logger.debug(f"Exported constant tensor: {node.name}")
|
275
|
+
|
276
|
+
|
277
|
+
def _handle_user_input_node(
|
278
|
+
graph: CircleSubgraph,
|
279
|
+
node: torch.fx.Node,
|
280
|
+
ep: ExportedProgram,
|
281
|
+
) -> None:
|
282
|
+
"""Handle a user input placeholder node by exporting its tensor data.
|
283
|
+
|
284
|
+
Args:
|
285
|
+
graph: CircleSubgraph to add tensor to
|
286
|
+
node: The user input node to process
|
287
|
+
ep: ExportedProgram containing input specs
|
288
|
+
"""
|
289
|
+
user_inputs = [
|
290
|
+
specs
|
291
|
+
for specs in ep.graph_signature.input_specs
|
292
|
+
if specs.kind == InputKind.USER_INPUT
|
293
|
+
]
|
294
|
+
constant_inputs = [
|
295
|
+
specs for specs in user_inputs if isinstance(specs.arg, ConstantArgument)
|
296
|
+
]
|
297
|
+
name_to_value = {specs.arg.name: specs.arg.value for specs in constant_inputs}
|
298
|
+
|
299
|
+
# Skip NoneType ConstantArgument
|
300
|
+
if node.name in name_to_value and name_to_value[node.name] is None:
|
301
|
+
return
|
302
|
+
|
303
|
+
graph.add_tensor_from_node(node)
|
304
|
+
|
305
|
+
logger = logging.getLogger(__name__)
|
306
|
+
logger.debug(f"Exported user input tensor: {node.name}")
|
307
|
+
|
308
|
+
|
309
|
+
def _handle_get_attr_node(
|
310
|
+
graph: CircleSubgraph,
|
311
|
+
node: torch.fx.Node,
|
312
|
+
) -> None:
|
313
|
+
"""Handle a get_attr node by exporting its tensor data.
|
314
|
+
|
315
|
+
Args:
|
316
|
+
graph: CircleSubgraph to add tensor to
|
317
|
+
node: The get_attr node to process
|
318
|
+
"""
|
319
|
+
assert isinstance(node.target, str)
|
320
|
+
attr_tensor = getattr(node.graph.owning_module, node.target)
|
321
|
+
|
322
|
+
if not isinstance(attr_tensor, torch.Tensor):
|
323
|
+
raise ValueError(f"Attribute {node.target} is not a tensor")
|
324
|
+
|
325
|
+
graph.add_tensor_from_scratch(
|
326
|
+
prefix=node.name,
|
327
|
+
shape=list(attr_tensor.shape),
|
328
|
+
dtype=to_circle_dtype(attr_tensor.dtype),
|
329
|
+
source_node=node,
|
330
|
+
)
|
331
|
+
|
332
|
+
logger = logging.getLogger(__name__)
|
333
|
+
logger.debug(f"Exported attribute tensor: {node.name}")
|
@@ -22,7 +22,6 @@ from circle_schema import circle
|
|
22
22
|
from tico.serialize.circle_graph import CircleSubgraph
|
23
23
|
from tico.serialize.circle_mapping import (
|
24
24
|
circle_legalize_dtype_to,
|
25
|
-
extract_circle_dtype,
|
26
25
|
extract_shape,
|
27
26
|
extract_torch_dtype,
|
28
27
|
)
|
@@ -100,8 +99,6 @@ class AnyVisitor(NodeVisitor):
|
|
100
99
|
keepdim = args.keepdim
|
101
100
|
|
102
101
|
input_shape = list(extract_shape(input))
|
103
|
-
output_shape = list(extract_shape(node))
|
104
|
-
|
105
102
|
dim_i32 = None
|
106
103
|
if dim is None:
|
107
104
|
dims = tuple(i for i in range(0, len(input_shape)))
|
@@ -21,12 +21,9 @@ import torch
|
|
21
21
|
from circle_schema import circle
|
22
22
|
|
23
23
|
from tico.passes import ops
|
24
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
24
25
|
|
25
|
-
from tico.serialize.
|
26
|
-
CircleSubgraph,
|
27
|
-
extract_circle_dtype,
|
28
|
-
extract_shape,
|
29
|
-
)
|
26
|
+
from tico.serialize.circle_mapping import extract_circle_dtype, extract_shape
|
30
27
|
from tico.serialize.operators.hashable_opcode import OpCode
|
31
28
|
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
32
29
|
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
@@ -21,10 +21,8 @@ import torch
|
|
21
21
|
from circle_schema import circle
|
22
22
|
|
23
23
|
from tico.serialize.circle_graph import CircleSubgraph
|
24
|
-
from tico.serialize.circle_mapping import to_circle_dtype
|
25
24
|
from tico.serialize.operators.hashable_opcode import OpCode
|
26
25
|
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
27
|
-
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
28
26
|
from tico.utils.validate_args_kwargs import FullLikeArgs
|
29
27
|
|
30
28
|
|
@@ -73,12 +73,6 @@ class InstanceNormVisitor(NodeVisitor):
|
|
73
73
|
eps = args.eps
|
74
74
|
|
75
75
|
# Ignore training-related args
|
76
|
-
running_mean = args.running_mean
|
77
|
-
running_var = args.running_var
|
78
|
-
use_input_stats = args.use_input_stats
|
79
|
-
momentum = args.momentum
|
80
|
-
cudnn_enabled = args.cudnn_enabled
|
81
|
-
|
82
76
|
input_shape = list(extract_shape(input))
|
83
77
|
assert len(input_shape) == 4, len(input_shape)
|
84
78
|
|
@@ -66,10 +66,7 @@ class MulTensorVisitor(BaseMulVisitor):
|
|
66
66
|
self,
|
67
67
|
node: torch.fx.Node,
|
68
68
|
) -> circle.Operator.OperatorT:
|
69
|
-
|
70
|
-
input = args.input
|
71
|
-
other = args.other
|
72
|
-
|
69
|
+
_ = MulTensorArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
73
70
|
operator = super().define_node(
|
74
71
|
node,
|
75
72
|
)
|
@@ -88,10 +85,7 @@ class MulScalarVisitor(BaseMulVisitor):
|
|
88
85
|
self,
|
89
86
|
node: torch.fx.Node,
|
90
87
|
) -> circle.Operator.OperatorT:
|
91
|
-
|
92
|
-
input = args.input
|
93
|
-
other = args.other
|
94
|
-
|
88
|
+
_ = MulScalarArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
95
89
|
operator = super().define_node(
|
96
90
|
node,
|
97
91
|
)
|
tico/serialize/quant_param.py
CHANGED
@@ -12,6 +12,11 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
from dataclasses import dataclass
|
16
|
+
from typing import List, Optional
|
17
|
+
|
18
|
+
import torch
|
19
|
+
|
15
20
|
"""
|
16
21
|
This is a key for torch.fx.Node's meta dict to save QuantParam
|
17
22
|
|
@@ -19,11 +24,6 @@ QuantParam can be retrieved as node.meta[QPARAM_KEY]
|
|
19
24
|
"""
|
20
25
|
QPARAM_KEY = "_quantization_parameters_"
|
21
26
|
|
22
|
-
from dataclasses import dataclass
|
23
|
-
from typing import List, Optional
|
24
|
-
|
25
|
-
import torch
|
26
|
-
|
27
27
|
|
28
28
|
@dataclass
|
29
29
|
class QuantParam:
|
tico/utils/convert.py
CHANGED
@@ -157,7 +157,7 @@ def check_unsupported_target(exported_program: ExportedProgram):
|
|
157
157
|
for n in exported_program.graph.nodes:
|
158
158
|
if n.op != "call_function":
|
159
159
|
continue
|
160
|
-
if
|
160
|
+
if n.target not in supported_target:
|
161
161
|
unsupported.append(n)
|
162
162
|
|
163
163
|
if unsupported:
|
tico/utils/graph.py
CHANGED
@@ -24,7 +24,7 @@ import torch
|
|
24
24
|
from torch.export import ExportedProgram
|
25
25
|
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
|
26
26
|
|
27
|
-
from tico.utils.utils import get_fake_mode
|
27
|
+
from tico.utils.utils import get_fake_mode
|
28
28
|
|
29
29
|
|
30
30
|
def is_torch_param(node: torch.fx.Node, ep: ExportedProgram):
|
tico/utils/padding.py
CHANGED
tico/utils/serialize.py
CHANGED
@@ -12,9 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Optional
|
16
|
-
|
17
|
-
import torch
|
18
15
|
|
19
16
|
from tico.serialize.circle_graph import CircleSubgraph
|
20
17
|
from tico.utils.graph import get_module_name_chain
|
tico/utils/utils.py
CHANGED
@@ -21,7 +21,6 @@ from typing import List
|
|
21
21
|
|
22
22
|
import torch
|
23
23
|
from circle_schema import circle
|
24
|
-
from packaging.version import Version
|
25
24
|
from torch._guards import detect_fake_mode
|
26
25
|
from torch.export import ExportedProgram
|
27
26
|
from torch.utils import _pytree as pytree
|
@@ -131,7 +130,7 @@ def enforce_type(callable):
|
|
131
130
|
|
132
131
|
return True
|
133
132
|
|
134
|
-
if typing.get_origin(type_hint)
|
133
|
+
if typing.get_origin(type_hint) is dict:
|
135
134
|
if not isinstance(value, typing.get_origin(type_hint)):
|
136
135
|
return False
|
137
136
|
|
@@ -16,10 +16,8 @@ from dataclasses import dataclass, field
|
|
16
16
|
from typing import List, Optional, TYPE_CHECKING, Union
|
17
17
|
|
18
18
|
if TYPE_CHECKING:
|
19
|
-
import torch._ops
|
20
19
|
import torch.fx
|
21
20
|
import torch
|
22
|
-
import torch.fx.node
|
23
21
|
|
24
22
|
from tico.utils.utils import enforce_type
|
25
23
|
|
@@ -137,7 +135,7 @@ class AvgPool2dArgs:
|
|
137
135
|
assert len(self.padding) == 2, len(self.padding)
|
138
136
|
if self.divisor_override is not None:
|
139
137
|
assert isinstance(self.divisor_override, int), type(self.divisor_override)
|
140
|
-
assert self.divisor_override != 0,
|
138
|
+
assert self.divisor_override != 0, "Divisor must be not zero."
|
141
139
|
|
142
140
|
|
143
141
|
@enforce_type
|