emx-onnx-cgen 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +34 -0
- emx_onnx_cgen/cli.py +372 -64
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +3932 -1398
- emx_onnx_cgen/codegen/emitter.py +5 -0
- emx_onnx_cgen/compiler.py +169 -343
- emx_onnx_cgen/ir/context.py +87 -0
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +193 -0
- emx_onnx_cgen/ir/op_context.py +65 -0
- emx_onnx_cgen/ir/ops/__init__.py +130 -0
- emx_onnx_cgen/ir/ops/elementwise.py +146 -0
- emx_onnx_cgen/ir/ops/misc.py +421 -0
- emx_onnx_cgen/ir/ops/nn.py +580 -0
- emx_onnx_cgen/ir/ops/reduce.py +95 -0
- emx_onnx_cgen/lowering/__init__.py +79 -1
- emx_onnx_cgen/lowering/adagrad.py +114 -0
- emx_onnx_cgen/lowering/arg_reduce.py +1 -1
- emx_onnx_cgen/lowering/attention.py +1 -1
- emx_onnx_cgen/lowering/average_pool.py +1 -1
- emx_onnx_cgen/lowering/batch_normalization.py +1 -1
- emx_onnx_cgen/lowering/cast.py +1 -1
- emx_onnx_cgen/lowering/common.py +406 -11
- emx_onnx_cgen/lowering/concat.py +1 -1
- emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
- emx_onnx_cgen/lowering/conv.py +1 -1
- emx_onnx_cgen/lowering/conv_transpose.py +301 -0
- emx_onnx_cgen/lowering/cumsum.py +1 -1
- emx_onnx_cgen/lowering/depth_space.py +1 -1
- emx_onnx_cgen/lowering/dropout.py +1 -1
- emx_onnx_cgen/lowering/einsum.py +153 -0
- emx_onnx_cgen/lowering/elementwise.py +152 -4
- emx_onnx_cgen/lowering/expand.py +1 -1
- emx_onnx_cgen/lowering/eye_like.py +1 -1
- emx_onnx_cgen/lowering/flatten.py +1 -1
- emx_onnx_cgen/lowering/gather.py +1 -1
- emx_onnx_cgen/lowering/gather_elements.py +2 -4
- emx_onnx_cgen/lowering/gather_nd.py +79 -0
- emx_onnx_cgen/lowering/gemm.py +1 -1
- emx_onnx_cgen/lowering/global_max_pool.py +59 -0
- emx_onnx_cgen/lowering/grid_sample.py +1 -1
- emx_onnx_cgen/lowering/group_normalization.py +1 -1
- emx_onnx_cgen/lowering/hardmax.py +53 -0
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/instance_normalization.py +1 -1
- emx_onnx_cgen/lowering/layer_normalization.py +1 -1
- emx_onnx_cgen/lowering/logsoftmax.py +6 -2
- emx_onnx_cgen/lowering/lp_normalization.py +1 -1
- emx_onnx_cgen/lowering/lp_pool.py +141 -0
- emx_onnx_cgen/lowering/lrn.py +1 -1
- emx_onnx_cgen/lowering/lstm.py +1 -1
- emx_onnx_cgen/lowering/matmul.py +7 -8
- emx_onnx_cgen/lowering/maxpool.py +1 -1
- emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
- emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +13 -13
- emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
- emx_onnx_cgen/lowering/nonzero.py +42 -0
- emx_onnx_cgen/lowering/one_hot.py +120 -0
- emx_onnx_cgen/lowering/pad.py +1 -1
- emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
- emx_onnx_cgen/lowering/quantize_linear.py +126 -0
- emx_onnx_cgen/lowering/range.py +1 -1
- emx_onnx_cgen/lowering/reduce.py +6 -7
- emx_onnx_cgen/lowering/registry.py +24 -5
- emx_onnx_cgen/lowering/reshape.py +224 -52
- emx_onnx_cgen/lowering/resize.py +1 -1
- emx_onnx_cgen/lowering/rms_normalization.py +1 -1
- emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
- emx_onnx_cgen/lowering/scatter_nd.py +82 -0
- emx_onnx_cgen/lowering/shape.py +6 -25
- emx_onnx_cgen/lowering/size.py +1 -1
- emx_onnx_cgen/lowering/slice.py +1 -1
- emx_onnx_cgen/lowering/softmax.py +6 -2
- emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
- emx_onnx_cgen/lowering/split.py +1 -1
- emx_onnx_cgen/lowering/squeeze.py +6 -6
- emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
- emx_onnx_cgen/lowering/tile.py +1 -1
- emx_onnx_cgen/lowering/topk.py +134 -0
- emx_onnx_cgen/lowering/transpose.py +1 -1
- emx_onnx_cgen/lowering/trilu.py +89 -0
- emx_onnx_cgen/lowering/unsqueeze.py +6 -6
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +1 -1
- emx_onnx_cgen/onnx_import.py +4 -0
- emx_onnx_cgen/onnxruntime_utils.py +11 -0
- emx_onnx_cgen/ops.py +4 -0
- emx_onnx_cgen/runtime/evaluator.py +785 -43
- emx_onnx_cgen/testbench.py +23 -0
- emx_onnx_cgen/verification.py +31 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/METADATA +33 -6
- emx_onnx_cgen-0.3.1.dist-info/RECORD +107 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/WHEEL +1 -1
- shared/scalar_functions.py +60 -17
- shared/ulp.py +65 -0
- emx_onnx_cgen-0.2.0.dist-info/RECORD +0 -76
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.2.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/top_level.txt +0 -0
emx_onnx_cgen/compiler.py
CHANGED
|
@@ -10,141 +10,24 @@ import onnx
|
|
|
10
10
|
|
|
11
11
|
from shared.scalar_types import ScalarType
|
|
12
12
|
|
|
13
|
+
from .onnxruntime_utils import make_deterministic_session_options
|
|
13
14
|
from .codegen.c_emitter import (
|
|
14
|
-
AttentionOp,
|
|
15
|
-
AveragePoolOp,
|
|
16
|
-
BatchNormOp,
|
|
17
|
-
LpNormalizationOp,
|
|
18
|
-
InstanceNormalizationOp,
|
|
19
|
-
GroupNormalizationOp,
|
|
20
|
-
LayerNormalizationOp,
|
|
21
|
-
MeanVarianceNormalizationOp,
|
|
22
|
-
RMSNormalizationOp,
|
|
23
|
-
BinaryOp,
|
|
24
|
-
MultiInputBinaryOp,
|
|
25
|
-
CastOp,
|
|
26
|
-
ClipOp,
|
|
27
15
|
CEmitter,
|
|
28
16
|
ConstTensor,
|
|
29
|
-
ConvOp,
|
|
30
|
-
ConcatOp,
|
|
31
|
-
ConstantOfShapeOp,
|
|
32
|
-
CumSumOp,
|
|
33
|
-
GemmOp,
|
|
34
|
-
GatherOp,
|
|
35
|
-
GatherElementsOp,
|
|
36
|
-
ExpandOp,
|
|
37
|
-
RangeOp,
|
|
38
|
-
LrnOp,
|
|
39
|
-
LstmOp,
|
|
40
|
-
LogSoftmaxOp,
|
|
41
|
-
NegativeLogLikelihoodLossOp,
|
|
42
|
-
NodeInfo,
|
|
43
|
-
PadOp,
|
|
44
|
-
SplitOp,
|
|
45
|
-
SoftmaxCrossEntropyLossOp,
|
|
46
17
|
LoweredModel,
|
|
47
18
|
ModelHeader,
|
|
48
|
-
|
|
49
|
-
MaxPoolOp,
|
|
50
|
-
ReduceOp,
|
|
51
|
-
ArgReduceOp,
|
|
52
|
-
ReshapeOp,
|
|
53
|
-
ResizeOp,
|
|
54
|
-
GridSampleOp,
|
|
55
|
-
SoftmaxOp,
|
|
56
|
-
ShapeOp,
|
|
57
|
-
SliceOp,
|
|
58
|
-
TransposeOp,
|
|
59
|
-
UnaryOp,
|
|
60
|
-
WhereOp,
|
|
19
|
+
NodeInfo,
|
|
61
20
|
)
|
|
62
21
|
from .dtypes import dtype_info
|
|
63
22
|
from .errors import CodegenError, ShapeInferenceError, UnsupportedOpError
|
|
64
|
-
from .ir.
|
|
65
|
-
from .
|
|
66
|
-
from .
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
from .lowering.
|
|
71
|
-
from .lowering.cast import lower_cast
|
|
72
|
-
from .lowering.concat import lower_concat
|
|
73
|
-
from .lowering.common import (
|
|
74
|
-
ensure_supported_dtype,
|
|
75
|
-
node_dtype,
|
|
76
|
-
shape_product,
|
|
77
|
-
value_dtype,
|
|
78
|
-
value_shape,
|
|
79
|
-
)
|
|
80
|
-
from .lowering.conv import ConvSpec, resolve_conv_spec
|
|
81
|
-
from .lowering.constant_of_shape import lower_constant_of_shape
|
|
82
|
-
from .lowering.dropout import lower_dropout
|
|
83
|
-
from .lowering import cumsum as _cumsum # noqa: F401
|
|
84
|
-
from .lowering.flatten import lower_flatten
|
|
85
|
-
from .lowering.gather import lower_gather
|
|
86
|
-
from .lowering.gather_elements import lower_gather_elements
|
|
87
|
-
from .lowering.gemm import resolve_gemm_spec, validate_gemm_bias_shape
|
|
88
|
-
from .lowering.lrn import LrnSpec, resolve_lrn_spec
|
|
89
|
-
from .lowering.logsoftmax import lower_logsoftmax
|
|
90
|
-
from .lowering import group_normalization as _group_normalization # noqa: F401
|
|
91
|
-
from .lowering import instance_normalization as _instance_normalization # noqa: F401
|
|
92
|
-
from .lowering import layer_normalization as _layer_normalization # noqa: F401
|
|
93
|
-
from .lowering import lp_normalization as _lp_normalization # noqa: F401
|
|
94
|
-
from .lowering import mean_variance_normalization as _mean_variance_normalization # noqa: F401
|
|
95
|
-
from .lowering.negative_log_likelihood_loss import (
|
|
96
|
-
lower_negative_log_likelihood_loss,
|
|
97
|
-
)
|
|
98
|
-
from .lowering.expand import lower_expand
|
|
99
|
-
from .lowering.range import lower_range
|
|
100
|
-
from .lowering.split import lower_split
|
|
101
|
-
from .lowering.softmax_cross_entropy_loss import (
|
|
102
|
-
lower_softmax_cross_entropy_loss,
|
|
103
|
-
)
|
|
104
|
-
from .lowering.matmul import lower_matmul
|
|
105
|
-
from .lowering.maxpool import MaxPoolSpec, resolve_maxpool_spec
|
|
106
|
-
from .lowering import pad as _pad # noqa: F401
|
|
107
|
-
from .lowering.reduce import (
|
|
108
|
-
REDUCE_KIND_BY_OP,
|
|
109
|
-
REDUCE_OUTPUTS_FLOAT_ONLY,
|
|
110
|
-
)
|
|
111
|
-
from .lowering import arg_reduce as _arg_reduce # noqa: F401
|
|
112
|
-
from .lowering.reshape import lower_reshape
|
|
113
|
-
from .lowering.resize import lower_resize
|
|
114
|
-
from .lowering.grid_sample import lower_grid_sample
|
|
115
|
-
from .lowering.slice import lower_slice
|
|
116
|
-
from .lowering.squeeze import lower_squeeze
|
|
117
|
-
from .lowering import depth_space as _depth_space # noqa: F401
|
|
118
|
-
from .lowering import eye_like as _eye_like # noqa: F401
|
|
119
|
-
from .lowering import identity as _identity # noqa: F401
|
|
120
|
-
from .lowering import tile as _tile # noqa: F401
|
|
121
|
-
from .lowering.shape import lower_shape
|
|
122
|
-
from .lowering.size import lower_size
|
|
123
|
-
from .lowering.softmax import lower_softmax
|
|
124
|
-
from .lowering.transpose import lower_transpose
|
|
125
|
-
from .lowering.unsqueeze import lower_unsqueeze
|
|
126
|
-
from .lowering.where import lower_where
|
|
127
|
-
from .lowering.elementwise import (
|
|
128
|
-
lower_celu,
|
|
129
|
-
lower_clip,
|
|
130
|
-
lower_isinf,
|
|
131
|
-
lower_isnan,
|
|
132
|
-
lower_shrink,
|
|
133
|
-
lower_swish,
|
|
134
|
-
)
|
|
135
|
-
from .lowering import variadic as _variadic # noqa: F401
|
|
136
|
-
from .lowering import rms_normalization as _rms_normalization # noqa: F401
|
|
137
|
-
from .lowering.registry import get_lowering_registry, resolve_dispatch
|
|
23
|
+
from .ir.context import GraphContext
|
|
24
|
+
from .ir.model import Graph, TensorType, Value
|
|
25
|
+
from .ir.op_base import OpBase
|
|
26
|
+
from .ir.op_context import OpContext
|
|
27
|
+
from .lowering import load_lowering_registry
|
|
28
|
+
from .lowering.common import ensure_supported_dtype, shape_product, value_dtype
|
|
29
|
+
from .lowering.registry import get_lowering_registry
|
|
138
30
|
from .onnx_import import import_onnx
|
|
139
|
-
from .ops import (
|
|
140
|
-
BINARY_OP_TYPES,
|
|
141
|
-
COMPARE_FUNCTIONS,
|
|
142
|
-
UNARY_OP_TYPES,
|
|
143
|
-
binary_op_symbol,
|
|
144
|
-
unary_op_symbol,
|
|
145
|
-
validate_unary_attrs,
|
|
146
|
-
)
|
|
147
|
-
from shared.scalar_functions import ScalarFunction, ScalarFunctionError
|
|
148
31
|
from .runtime.evaluator import Evaluator
|
|
149
32
|
|
|
150
33
|
|
|
@@ -157,6 +40,16 @@ class CompilerOptions:
|
|
|
157
40
|
model_checksum: str | None = None
|
|
158
41
|
restrict_arrays: bool = True
|
|
159
42
|
testbench_inputs: Mapping[str, np.ndarray] | None = None
|
|
43
|
+
truncate_weights_after: int | None = None
|
|
44
|
+
large_temp_threshold_bytes: int = 1024
|
|
45
|
+
large_weight_threshold: int = 1024 * 1024
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _onnx_elem_type(dtype: np.dtype) -> int:
|
|
49
|
+
for elem_type, info in onnx._mapping.TENSOR_TYPE_MAP.items():
|
|
50
|
+
if info.np_dtype == dtype:
|
|
51
|
+
return elem_type
|
|
52
|
+
raise UnsupportedOpError(f"Unsupported dtype {dtype} for ONNX output")
|
|
160
53
|
|
|
161
54
|
|
|
162
55
|
class Compiler:
|
|
@@ -165,11 +58,17 @@ class Compiler:
|
|
|
165
58
|
options = CompilerOptions(template_dir=Path("templates"))
|
|
166
59
|
self._options = options
|
|
167
60
|
self._emitter = CEmitter(
|
|
168
|
-
options.template_dir,
|
|
61
|
+
options.template_dir,
|
|
62
|
+
restrict_arrays=options.restrict_arrays,
|
|
63
|
+
truncate_weights_after=options.truncate_weights_after,
|
|
64
|
+
large_temp_threshold_bytes=options.large_temp_threshold_bytes,
|
|
65
|
+
large_weight_threshold=options.large_weight_threshold,
|
|
169
66
|
)
|
|
67
|
+
load_lowering_registry()
|
|
170
68
|
|
|
171
69
|
def compile(self, model: onnx.ModelProto) -> str:
|
|
172
70
|
graph = import_onnx(model)
|
|
71
|
+
graph = self._concretize_graph_shapes(model, graph)
|
|
173
72
|
testbench_inputs = self._resolve_testbench_inputs(graph)
|
|
174
73
|
variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
|
|
175
74
|
graph
|
|
@@ -185,6 +84,7 @@ class Compiler:
|
|
|
185
84
|
|
|
186
85
|
def compile_with_data_file(self, model: onnx.ModelProto) -> tuple[str, str]:
|
|
187
86
|
graph = import_onnx(model)
|
|
87
|
+
graph = self._concretize_graph_shapes(model, graph)
|
|
188
88
|
testbench_inputs = self._resolve_testbench_inputs(graph)
|
|
189
89
|
variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
|
|
190
90
|
graph
|
|
@@ -198,6 +98,46 @@ class Compiler:
|
|
|
198
98
|
variable_dim_outputs=variable_dim_outputs,
|
|
199
99
|
)
|
|
200
100
|
|
|
101
|
+
def compile_with_weight_data(
|
|
102
|
+
self, model: onnx.ModelProto
|
|
103
|
+
) -> tuple[str, bytes | None]:
|
|
104
|
+
graph = import_onnx(model)
|
|
105
|
+
graph = self._concretize_graph_shapes(model, graph)
|
|
106
|
+
testbench_inputs = self._resolve_testbench_inputs(graph)
|
|
107
|
+
variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
|
|
108
|
+
graph
|
|
109
|
+
)
|
|
110
|
+
lowered = self._lower_model(model, graph)
|
|
111
|
+
generated = self._emitter.emit_model(
|
|
112
|
+
lowered,
|
|
113
|
+
emit_testbench=self._options.emit_testbench,
|
|
114
|
+
testbench_inputs=testbench_inputs,
|
|
115
|
+
variable_dim_inputs=variable_dim_inputs,
|
|
116
|
+
variable_dim_outputs=variable_dim_outputs,
|
|
117
|
+
)
|
|
118
|
+
weight_data = self._emitter.collect_weight_data(lowered.constants)
|
|
119
|
+
return generated, weight_data
|
|
120
|
+
|
|
121
|
+
def compile_with_data_file_and_weight_data(
|
|
122
|
+
self, model: onnx.ModelProto
|
|
123
|
+
) -> tuple[str, str, bytes | None]:
|
|
124
|
+
graph = import_onnx(model)
|
|
125
|
+
graph = self._concretize_graph_shapes(model, graph)
|
|
126
|
+
testbench_inputs = self._resolve_testbench_inputs(graph)
|
|
127
|
+
variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
|
|
128
|
+
graph
|
|
129
|
+
)
|
|
130
|
+
lowered = self._lower_model(model, graph)
|
|
131
|
+
generated, data_source = self._emitter.emit_model_with_data_file(
|
|
132
|
+
lowered,
|
|
133
|
+
emit_testbench=self._options.emit_testbench,
|
|
134
|
+
testbench_inputs=testbench_inputs,
|
|
135
|
+
variable_dim_inputs=variable_dim_inputs,
|
|
136
|
+
variable_dim_outputs=variable_dim_outputs,
|
|
137
|
+
)
|
|
138
|
+
weight_data = self._emitter.collect_weight_data(lowered.constants)
|
|
139
|
+
return generated, data_source, weight_data
|
|
140
|
+
|
|
201
141
|
@staticmethod
|
|
202
142
|
def _collect_variable_dims(
|
|
203
143
|
graph: Graph,
|
|
@@ -219,7 +159,8 @@ class Compiler:
|
|
|
219
159
|
return collect(graph.inputs), collect(graph.outputs)
|
|
220
160
|
|
|
221
161
|
def _lower_model(self, model: onnx.ModelProto, graph: Graph) -> LoweredModel:
|
|
222
|
-
|
|
162
|
+
ctx = GraphContext(graph)
|
|
163
|
+
constants = _lowered_constants(ctx)
|
|
223
164
|
self._validate_graph(graph)
|
|
224
165
|
(
|
|
225
166
|
input_names,
|
|
@@ -229,7 +170,14 @@ class Compiler:
|
|
|
229
170
|
output_shapes,
|
|
230
171
|
output_dtypes,
|
|
231
172
|
) = self._collect_io_specs(graph)
|
|
232
|
-
ops, node_infos = self._lower_nodes(
|
|
173
|
+
ops, node_infos = self._lower_nodes(ctx)
|
|
174
|
+
op_ctx = OpContext(ctx)
|
|
175
|
+
for op in ops:
|
|
176
|
+
op.validate(op_ctx)
|
|
177
|
+
for op in ops:
|
|
178
|
+
op.infer_types(op_ctx)
|
|
179
|
+
for op in ops:
|
|
180
|
+
op.infer_shapes(op_ctx)
|
|
233
181
|
header = self._build_header(model, graph)
|
|
234
182
|
return LoweredModel(
|
|
235
183
|
name=self._options.model_name,
|
|
@@ -243,6 +191,7 @@ class Compiler:
|
|
|
243
191
|
ops=tuple(ops),
|
|
244
192
|
node_infos=tuple(node_infos),
|
|
245
193
|
header=header,
|
|
194
|
+
op_context=op_ctx,
|
|
246
195
|
)
|
|
247
196
|
|
|
248
197
|
def _resolve_testbench_inputs(
|
|
@@ -282,15 +231,93 @@ class Compiler:
|
|
|
282
231
|
resolved[name] = tuple(array.ravel().tolist())
|
|
283
232
|
return resolved
|
|
284
233
|
|
|
234
|
+
def _concretize_graph_shapes(
|
|
235
|
+
self, model: onnx.ModelProto, graph: Graph
|
|
236
|
+
) -> Graph:
|
|
237
|
+
if not self._options.testbench_inputs:
|
|
238
|
+
return graph
|
|
239
|
+
if not any(value.type.dim_params for value in graph.values):
|
|
240
|
+
if not any(value.type.dim_params for value in graph.inputs):
|
|
241
|
+
if not any(value.type.dim_params for value in graph.outputs):
|
|
242
|
+
return graph
|
|
243
|
+
try:
|
|
244
|
+
import onnxruntime as ort
|
|
245
|
+
except Exception:
|
|
246
|
+
return graph
|
|
247
|
+
try:
|
|
248
|
+
model_with_outputs = onnx.ModelProto()
|
|
249
|
+
model_with_outputs.CopyFrom(model)
|
|
250
|
+
existing_outputs = {
|
|
251
|
+
output.name for output in model_with_outputs.graph.output
|
|
252
|
+
}
|
|
253
|
+
value_info_by_name = {
|
|
254
|
+
value_info.name: value_info
|
|
255
|
+
for value_info in model_with_outputs.graph.value_info
|
|
256
|
+
}
|
|
257
|
+
for value in graph.values:
|
|
258
|
+
if value.name in existing_outputs:
|
|
259
|
+
continue
|
|
260
|
+
value_info = value_info_by_name.get(value.name)
|
|
261
|
+
if value_info is None:
|
|
262
|
+
dims: list[int | str | None] = []
|
|
263
|
+
for index, dim in enumerate(value.type.shape):
|
|
264
|
+
dim_param = None
|
|
265
|
+
if index < len(value.type.dim_params):
|
|
266
|
+
dim_param = value.type.dim_params[index]
|
|
267
|
+
dims.append(dim_param if dim_param else None)
|
|
268
|
+
elem_type = _onnx_elem_type(value.type.dtype.np_dtype)
|
|
269
|
+
value_info = onnx.helper.make_tensor_value_info(
|
|
270
|
+
value.name, elem_type, dims
|
|
271
|
+
)
|
|
272
|
+
model_with_outputs.graph.output.append(value_info)
|
|
273
|
+
existing_outputs.add(value.name)
|
|
274
|
+
output_names = [output.name for output in model_with_outputs.graph.output]
|
|
275
|
+
sess_options = make_deterministic_session_options(ort)
|
|
276
|
+
sess = ort.InferenceSession(
|
|
277
|
+
model_with_outputs.SerializeToString(),
|
|
278
|
+
sess_options=sess_options,
|
|
279
|
+
providers=["CPUExecutionProvider"],
|
|
280
|
+
)
|
|
281
|
+
output_arrays = sess.run(None, self._options.testbench_inputs)
|
|
282
|
+
except Exception:
|
|
283
|
+
return graph
|
|
284
|
+
|
|
285
|
+
shapes_by_name: dict[str, tuple[int, ...]] = {
|
|
286
|
+
name: tuple(int(dim) for dim in array.shape)
|
|
287
|
+
for name, array in zip(output_names, output_arrays)
|
|
288
|
+
}
|
|
289
|
+
for name, array in self._options.testbench_inputs.items():
|
|
290
|
+
shapes_by_name[name] = tuple(int(dim) for dim in array.shape)
|
|
291
|
+
|
|
292
|
+
def concretize_value(value: Value) -> Value:
|
|
293
|
+
shape = shapes_by_name.get(value.name)
|
|
294
|
+
if shape is None:
|
|
295
|
+
return value
|
|
296
|
+
return Value(
|
|
297
|
+
name=value.name,
|
|
298
|
+
type=TensorType(
|
|
299
|
+
dtype=value.type.dtype,
|
|
300
|
+
shape=shape,
|
|
301
|
+
dim_params=(None,) * len(shape),
|
|
302
|
+
),
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
return Graph(
|
|
306
|
+
inputs=tuple(concretize_value(value) for value in graph.inputs),
|
|
307
|
+
outputs=tuple(concretize_value(value) for value in graph.outputs),
|
|
308
|
+
nodes=graph.nodes,
|
|
309
|
+
initializers=graph.initializers,
|
|
310
|
+
values=tuple(concretize_value(value) for value in graph.values),
|
|
311
|
+
opset_imports=graph.opset_imports,
|
|
312
|
+
)
|
|
313
|
+
|
|
285
314
|
def _validate_graph(self, graph: Graph) -> None:
|
|
286
315
|
if not graph.outputs:
|
|
287
316
|
raise UnsupportedOpError("Graph must have at least one output")
|
|
288
317
|
if not graph.nodes:
|
|
289
318
|
raise UnsupportedOpError("Graph must contain at least one node")
|
|
290
319
|
for value in graph.outputs:
|
|
291
|
-
|
|
292
|
-
if element_count <= 0:
|
|
293
|
-
raise ShapeInferenceError("Output shape must be fully defined")
|
|
320
|
+
shape_product(value.type.shape)
|
|
294
321
|
|
|
295
322
|
def _collect_io_specs(
|
|
296
323
|
self, graph: Graph
|
|
@@ -322,107 +349,16 @@ class Compiler:
|
|
|
322
349
|
)
|
|
323
350
|
|
|
324
351
|
def _lower_nodes(
|
|
325
|
-
self,
|
|
326
|
-
) -> tuple[
|
|
327
|
-
list[
|
|
328
|
-
BinaryOp
|
|
329
|
-
| MultiInputBinaryOp
|
|
330
|
-
| UnaryOp
|
|
331
|
-
| ClipOp
|
|
332
|
-
| CastOp
|
|
333
|
-
| MatMulOp
|
|
334
|
-
| GemmOp
|
|
335
|
-
| AttentionOp
|
|
336
|
-
| ConvOp
|
|
337
|
-
| AveragePoolOp
|
|
338
|
-
| BatchNormOp
|
|
339
|
-
| LpNormalizationOp
|
|
340
|
-
| InstanceNormalizationOp
|
|
341
|
-
| GroupNormalizationOp
|
|
342
|
-
| LayerNormalizationOp
|
|
343
|
-
| MeanVarianceNormalizationOp
|
|
344
|
-
| RMSNormalizationOp
|
|
345
|
-
| LrnOp
|
|
346
|
-
| LstmOp
|
|
347
|
-
| SoftmaxOp
|
|
348
|
-
| LogSoftmaxOp
|
|
349
|
-
| NegativeLogLikelihoodLossOp
|
|
350
|
-
| SoftmaxCrossEntropyLossOp
|
|
351
|
-
| MaxPoolOp
|
|
352
|
-
| ConcatOp
|
|
353
|
-
| GatherElementsOp
|
|
354
|
-
| GatherOp
|
|
355
|
-
| TransposeOp
|
|
356
|
-
| ConstantOfShapeOp
|
|
357
|
-
| ReshapeOp
|
|
358
|
-
| SliceOp
|
|
359
|
-
| ResizeOp
|
|
360
|
-
| GridSampleOp
|
|
361
|
-
| ReduceOp
|
|
362
|
-
| ArgReduceOp
|
|
363
|
-
| ShapeOp
|
|
364
|
-
| PadOp
|
|
365
|
-
| ExpandOp
|
|
366
|
-
| CumSumOp
|
|
367
|
-
| RangeOp
|
|
368
|
-
| SplitOp
|
|
369
|
-
],
|
|
370
|
-
list[NodeInfo],
|
|
371
|
-
]:
|
|
372
|
-
ops: list[
|
|
373
|
-
BinaryOp
|
|
374
|
-
| MultiInputBinaryOp
|
|
375
|
-
| UnaryOp
|
|
376
|
-
| ClipOp
|
|
377
|
-
| CastOp
|
|
378
|
-
| MatMulOp
|
|
379
|
-
| GemmOp
|
|
380
|
-
| AttentionOp
|
|
381
|
-
| ConvOp
|
|
382
|
-
| AveragePoolOp
|
|
383
|
-
| BatchNormOp
|
|
384
|
-
| LpNormalizationOp
|
|
385
|
-
| InstanceNormalizationOp
|
|
386
|
-
| GroupNormalizationOp
|
|
387
|
-
| LayerNormalizationOp
|
|
388
|
-
| MeanVarianceNormalizationOp
|
|
389
|
-
| RMSNormalizationOp
|
|
390
|
-
| LrnOp
|
|
391
|
-
| LstmOp
|
|
392
|
-
| SoftmaxOp
|
|
393
|
-
| LogSoftmaxOp
|
|
394
|
-
| NegativeLogLikelihoodLossOp
|
|
395
|
-
| SoftmaxCrossEntropyLossOp
|
|
396
|
-
| MaxPoolOp
|
|
397
|
-
| ConcatOp
|
|
398
|
-
| GatherElementsOp
|
|
399
|
-
| GatherOp
|
|
400
|
-
| TransposeOp
|
|
401
|
-
| ConstantOfShapeOp
|
|
402
|
-
| ReshapeOp
|
|
403
|
-
| SliceOp
|
|
404
|
-
| ResizeOp
|
|
405
|
-
| ReduceOp
|
|
406
|
-
| ArgReduceOp
|
|
407
|
-
| ShapeOp
|
|
408
|
-
| PadOp
|
|
409
|
-
| ExpandOp
|
|
410
|
-
| CumSumOp
|
|
411
|
-
| RangeOp
|
|
412
|
-
| SplitOp
|
|
413
|
-
| WhereOp
|
|
414
|
-
] = []
|
|
352
|
+
self, ctx: GraphContext
|
|
353
|
+
) -> tuple[list[OpBase], list[NodeInfo]]:
|
|
354
|
+
ops: list[OpBase] = []
|
|
415
355
|
node_infos: list[NodeInfo] = []
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
binary_fallback=lambda: _lower_binary_unary,
|
|
423
|
-
unary_fallback=lambda: _lower_binary_unary,
|
|
424
|
-
)
|
|
425
|
-
ops.append(lowering(graph, node))
|
|
356
|
+
registry = get_lowering_registry()
|
|
357
|
+
for node in ctx.nodes:
|
|
358
|
+
lowering = registry.get(node.op_type)
|
|
359
|
+
if lowering is None:
|
|
360
|
+
raise UnsupportedOpError(f"Unsupported op {node.op_type}")
|
|
361
|
+
ops.append(lowering(ctx, node))
|
|
426
362
|
node_infos.append(
|
|
427
363
|
NodeInfo(
|
|
428
364
|
op_type=node.op_type,
|
|
@@ -473,7 +409,7 @@ class Compiler:
|
|
|
473
409
|
return evaluator.run(feeds)
|
|
474
410
|
|
|
475
411
|
|
|
476
|
-
def _lowered_constants(graph: Graph) -> tuple[ConstTensor, ...]:
|
|
412
|
+
def _lowered_constants(graph: Graph | GraphContext) -> tuple[ConstTensor, ...]:
|
|
477
413
|
constants: list[ConstTensor] = []
|
|
478
414
|
for initializer in graph.initializers:
|
|
479
415
|
dtype = ensure_supported_dtype(initializer.type.dtype)
|
|
@@ -489,113 +425,3 @@ def _lowered_constants(graph: Graph) -> tuple[ConstTensor, ...]:
|
|
|
489
425
|
)
|
|
490
426
|
)
|
|
491
427
|
return tuple(constants)
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
def _lower_binary_unary(graph: Graph, node: Node) -> BinaryOp | UnaryOp:
|
|
495
|
-
if node.op_type == "BitShift":
|
|
496
|
-
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
497
|
-
raise UnsupportedOpError("BitShift must have 2 inputs and 1 output")
|
|
498
|
-
direction_attr = node.attrs.get("direction", "LEFT")
|
|
499
|
-
if isinstance(direction_attr, bytes):
|
|
500
|
-
direction = direction_attr.decode()
|
|
501
|
-
else:
|
|
502
|
-
direction = str(direction_attr)
|
|
503
|
-
if direction not in {"LEFT", "RIGHT"}:
|
|
504
|
-
raise UnsupportedOpError(
|
|
505
|
-
"BitShift direction must be LEFT or RIGHT"
|
|
506
|
-
)
|
|
507
|
-
op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
508
|
-
if not op_dtype.is_integer:
|
|
509
|
-
raise UnsupportedOpError("BitShift expects integer inputs")
|
|
510
|
-
function = (
|
|
511
|
-
ScalarFunction.BITWISE_LEFT_SHIFT
|
|
512
|
-
if direction == "LEFT"
|
|
513
|
-
else ScalarFunction.BITWISE_RIGHT_SHIFT
|
|
514
|
-
)
|
|
515
|
-
op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
|
|
516
|
-
if op_spec is None:
|
|
517
|
-
raise UnsupportedOpError("Unsupported op BitShift")
|
|
518
|
-
output_shape = value_shape(graph, node.outputs[0], node)
|
|
519
|
-
return BinaryOp(
|
|
520
|
-
input0=node.inputs[0],
|
|
521
|
-
input1=node.inputs[1],
|
|
522
|
-
output=node.outputs[0],
|
|
523
|
-
function=function,
|
|
524
|
-
operator_kind=op_spec.kind,
|
|
525
|
-
shape=output_shape,
|
|
526
|
-
dtype=op_dtype,
|
|
527
|
-
input_dtype=op_dtype,
|
|
528
|
-
)
|
|
529
|
-
if node.op_type == "Mod":
|
|
530
|
-
fmod = int(node.attrs.get("fmod", 0))
|
|
531
|
-
if fmod not in {0, 1}:
|
|
532
|
-
raise UnsupportedOpError("Mod only supports fmod=0 or fmod=1")
|
|
533
|
-
function = (
|
|
534
|
-
ScalarFunction.FMOD if fmod == 1 else ScalarFunction.REMAINDER
|
|
535
|
-
)
|
|
536
|
-
else:
|
|
537
|
-
try:
|
|
538
|
-
function = ScalarFunction.from_onnx_op(node.op_type)
|
|
539
|
-
except ScalarFunctionError as exc:
|
|
540
|
-
raise UnsupportedOpError(
|
|
541
|
-
f"Unsupported op {node.op_type}"
|
|
542
|
-
) from exc
|
|
543
|
-
validate_unary_attrs(node.op_type, node.attrs)
|
|
544
|
-
if function in COMPARE_FUNCTIONS:
|
|
545
|
-
input_dtype = node_dtype(graph, node, *node.inputs)
|
|
546
|
-
output_dtype = value_dtype(graph, node.outputs[0], node)
|
|
547
|
-
op_spec = binary_op_symbol(function, node.attrs, dtype=input_dtype)
|
|
548
|
-
if op_spec is None:
|
|
549
|
-
raise UnsupportedOpError(f"Unsupported op {node.op_type}")
|
|
550
|
-
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
551
|
-
raise UnsupportedOpError(
|
|
552
|
-
f"{node.op_type} must have 2 inputs and 1 output"
|
|
553
|
-
)
|
|
554
|
-
if output_dtype != ScalarType.BOOL:
|
|
555
|
-
raise UnsupportedOpError(
|
|
556
|
-
f"{node.op_type} expects bool output, got {output_dtype.onnx_name}"
|
|
557
|
-
)
|
|
558
|
-
output_shape = value_shape(graph, node.outputs[0], node)
|
|
559
|
-
return BinaryOp(
|
|
560
|
-
input0=node.inputs[0],
|
|
561
|
-
input1=node.inputs[1],
|
|
562
|
-
output=node.outputs[0],
|
|
563
|
-
function=function,
|
|
564
|
-
operator_kind=op_spec.kind,
|
|
565
|
-
shape=output_shape,
|
|
566
|
-
dtype=output_dtype,
|
|
567
|
-
input_dtype=input_dtype,
|
|
568
|
-
)
|
|
569
|
-
op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
|
|
570
|
-
op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
|
|
571
|
-
unary_symbol = unary_op_symbol(function, dtype=op_dtype)
|
|
572
|
-
if op_spec is None and unary_symbol is None:
|
|
573
|
-
raise UnsupportedOpError(f"Unsupported op {node.op_type}")
|
|
574
|
-
if op_spec is not None:
|
|
575
|
-
if len(node.inputs) != 2 or len(node.outputs) != 1:
|
|
576
|
-
raise UnsupportedOpError(
|
|
577
|
-
f"{node.op_type} must have 2 inputs and 1 output"
|
|
578
|
-
)
|
|
579
|
-
output_shape = value_shape(graph, node.outputs[0], node)
|
|
580
|
-
return BinaryOp(
|
|
581
|
-
input0=node.inputs[0],
|
|
582
|
-
input1=node.inputs[1],
|
|
583
|
-
output=node.outputs[0],
|
|
584
|
-
function=function,
|
|
585
|
-
operator_kind=op_spec.kind,
|
|
586
|
-
shape=output_shape,
|
|
587
|
-
dtype=op_dtype,
|
|
588
|
-
input_dtype=op_dtype,
|
|
589
|
-
)
|
|
590
|
-
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
591
|
-
raise UnsupportedOpError(f"{node.op_type} must have 1 input and 1 output")
|
|
592
|
-
output_shape = value_shape(graph, node.outputs[0], node)
|
|
593
|
-
return UnaryOp(
|
|
594
|
-
input0=node.inputs[0],
|
|
595
|
-
output=node.outputs[0],
|
|
596
|
-
function=function,
|
|
597
|
-
shape=output_shape,
|
|
598
|
-
dtype=op_dtype,
|
|
599
|
-
input_dtype=op_dtype,
|
|
600
|
-
params=(),
|
|
601
|
-
)
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
|
|
5
|
+
from ..errors import ShapeInferenceError, UnsupportedOpError
|
|
6
|
+
from .model import Graph, Initializer, Node, Value
|
|
7
|
+
from shared.scalar_types import ScalarType
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class GraphContext:
|
|
12
|
+
graph: Graph
|
|
13
|
+
_dtype_cache: dict[str, ScalarType] = field(default_factory=dict)
|
|
14
|
+
_shape_cache: dict[str, tuple[int, ...]] = field(default_factory=dict)
|
|
15
|
+
_initializer_cache: dict[str, Initializer] = field(default_factory=dict)
|
|
16
|
+
_producer_cache: dict[str, Node] = field(default_factory=dict)
|
|
17
|
+
|
|
18
|
+
def find_value(self, name: str) -> Value:
|
|
19
|
+
return self.graph.find_value(name)
|
|
20
|
+
|
|
21
|
+
def dtype(self, name: str, node: Node | None = None) -> ScalarType:
|
|
22
|
+
if name in self._dtype_cache:
|
|
23
|
+
return self._dtype_cache[name]
|
|
24
|
+
try:
|
|
25
|
+
value = self.graph.find_value(name)
|
|
26
|
+
except KeyError as exc:
|
|
27
|
+
op_type = node.op_type if node is not None else "unknown"
|
|
28
|
+
raise ShapeInferenceError(
|
|
29
|
+
f"Missing dtype for value '{name}' in op {op_type}. "
|
|
30
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
31
|
+
) from exc
|
|
32
|
+
dtype = value.type.dtype
|
|
33
|
+
if not isinstance(dtype, ScalarType):
|
|
34
|
+
raise UnsupportedOpError(f"Unsupported dtype {dtype}")
|
|
35
|
+
self._dtype_cache[name] = dtype
|
|
36
|
+
return dtype
|
|
37
|
+
|
|
38
|
+
def set_dtype(self, name: str, dtype: ScalarType) -> None:
|
|
39
|
+
self._dtype_cache[name] = dtype
|
|
40
|
+
|
|
41
|
+
def shape(self, name: str, node: Node | None = None) -> tuple[int, ...]:
|
|
42
|
+
if name in self._shape_cache:
|
|
43
|
+
return self._shape_cache[name]
|
|
44
|
+
try:
|
|
45
|
+
value = self.graph.find_value(name)
|
|
46
|
+
except KeyError as exc:
|
|
47
|
+
op_type = node.op_type if node is not None else "unknown"
|
|
48
|
+
raise ShapeInferenceError(
|
|
49
|
+
f"Missing shape for value '{name}' in op {op_type}. "
|
|
50
|
+
"Hint: run ONNX shape inference or export with static shapes."
|
|
51
|
+
) from exc
|
|
52
|
+
self._shape_cache[name] = value.type.shape
|
|
53
|
+
return value.type.shape
|
|
54
|
+
|
|
55
|
+
def set_shape(self, name: str, shape: tuple[int, ...]) -> None:
|
|
56
|
+
self._shape_cache[name] = shape
|
|
57
|
+
|
|
58
|
+
def initializer(self, name: str) -> Initializer | None:
|
|
59
|
+
if name in self._initializer_cache:
|
|
60
|
+
return self._initializer_cache[name]
|
|
61
|
+
for initializer in self.graph.initializers:
|
|
62
|
+
if initializer.name == name:
|
|
63
|
+
self._initializer_cache[name] = initializer
|
|
64
|
+
return initializer
|
|
65
|
+
return None
|
|
66
|
+
|
|
67
|
+
def producer(self, output_name: str) -> Node | None:
|
|
68
|
+
if output_name in self._producer_cache:
|
|
69
|
+
return self._producer_cache[output_name]
|
|
70
|
+
for node in self.graph.nodes:
|
|
71
|
+
if output_name in node.outputs:
|
|
72
|
+
self._producer_cache[output_name] = node
|
|
73
|
+
return node
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
def opset_version(self, domain: str = "") -> int | None:
|
|
77
|
+
if domain in {"", "ai.onnx"}:
|
|
78
|
+
domains = {"", "ai.onnx"}
|
|
79
|
+
else:
|
|
80
|
+
domains = {domain}
|
|
81
|
+
for opset_domain, version in self.graph.opset_imports:
|
|
82
|
+
if opset_domain in domains:
|
|
83
|
+
return int(version)
|
|
84
|
+
return None
|
|
85
|
+
|
|
86
|
+
def __getattr__(self, name: str):
|
|
87
|
+
return getattr(self.graph, name)
|