emx-onnx-cgen 0.3.8__py3-none-any.whl → 0.4.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (137) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +1025 -162
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +2081 -458
  6. emx_onnx_cgen/compiler.py +157 -75
  7. emx_onnx_cgen/determinism.py +39 -0
  8. emx_onnx_cgen/ir/context.py +25 -15
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +32 -7
  11. emx_onnx_cgen/ir/ops/__init__.py +20 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +138 -22
  13. emx_onnx_cgen/ir/ops/misc.py +95 -0
  14. emx_onnx_cgen/ir/ops/nn.py +361 -38
  15. emx_onnx_cgen/ir/ops/reduce.py +1 -16
  16. emx_onnx_cgen/lowering/__init__.py +9 -0
  17. emx_onnx_cgen/lowering/arg_reduce.py +0 -4
  18. emx_onnx_cgen/lowering/average_pool.py +157 -27
  19. emx_onnx_cgen/lowering/bernoulli.py +73 -0
  20. emx_onnx_cgen/lowering/common.py +48 -0
  21. emx_onnx_cgen/lowering/concat.py +41 -7
  22. emx_onnx_cgen/lowering/conv.py +19 -8
  23. emx_onnx_cgen/lowering/conv_integer.py +103 -0
  24. emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
  25. emx_onnx_cgen/lowering/elementwise.py +140 -43
  26. emx_onnx_cgen/lowering/gather.py +11 -2
  27. emx_onnx_cgen/lowering/gemm.py +7 -124
  28. emx_onnx_cgen/lowering/global_max_pool.py +0 -5
  29. emx_onnx_cgen/lowering/gru.py +323 -0
  30. emx_onnx_cgen/lowering/hamming_window.py +104 -0
  31. emx_onnx_cgen/lowering/hardmax.py +1 -37
  32. emx_onnx_cgen/lowering/identity.py +7 -6
  33. emx_onnx_cgen/lowering/logsoftmax.py +1 -35
  34. emx_onnx_cgen/lowering/lp_pool.py +15 -4
  35. emx_onnx_cgen/lowering/matmul.py +3 -105
  36. emx_onnx_cgen/lowering/optional_has_element.py +28 -0
  37. emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
  38. emx_onnx_cgen/lowering/reduce.py +0 -5
  39. emx_onnx_cgen/lowering/reshape.py +7 -16
  40. emx_onnx_cgen/lowering/shape.py +14 -8
  41. emx_onnx_cgen/lowering/slice.py +14 -4
  42. emx_onnx_cgen/lowering/softmax.py +1 -35
  43. emx_onnx_cgen/lowering/split.py +37 -3
  44. emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
  45. emx_onnx_cgen/lowering/tile.py +38 -1
  46. emx_onnx_cgen/lowering/topk.py +1 -5
  47. emx_onnx_cgen/lowering/transpose.py +9 -3
  48. emx_onnx_cgen/lowering/unsqueeze.py +11 -16
  49. emx_onnx_cgen/lowering/upsample.py +151 -0
  50. emx_onnx_cgen/lowering/variadic.py +1 -1
  51. emx_onnx_cgen/lowering/where.py +0 -5
  52. emx_onnx_cgen/onnx_import.py +578 -14
  53. emx_onnx_cgen/ops.py +3 -0
  54. emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
  55. emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
  56. emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
  57. emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
  58. emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
  59. emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
  60. emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
  61. emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
  62. emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
  63. emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
  64. emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
  65. emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
  66. emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
  67. emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
  68. emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
  69. emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
  70. emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
  71. emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
  72. emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
  73. emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
  74. emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
  75. emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
  76. emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
  77. emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
  78. emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
  79. emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
  80. emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
  81. emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
  82. emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
  83. emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
  84. emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
  85. emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
  86. emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
  87. emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
  88. emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
  89. emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
  90. emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
  91. emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
  92. emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
  93. emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
  94. emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
  95. emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
  96. emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
  97. emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
  98. emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
  99. emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
  100. emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
  101. emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
  102. emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
  103. emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
  104. emx_onnx_cgen/templates/range_op.c.j2 +8 -0
  105. emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
  106. emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
  107. emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
  108. emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
  109. emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
  110. emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
  111. emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
  112. emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
  113. emx_onnx_cgen/templates/size_op.c.j2 +4 -0
  114. emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
  115. emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
  116. emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
  117. emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
  118. emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
  119. emx_onnx_cgen/templates/split_op.c.j2 +18 -0
  120. emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
  121. emx_onnx_cgen/templates/testbench.c.j2 +161 -0
  122. emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
  123. emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
  124. emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
  125. emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
  126. emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
  127. emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
  128. emx_onnx_cgen/templates/where_op.c.j2 +9 -0
  129. emx_onnx_cgen/verification.py +45 -5
  130. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/METADATA +33 -15
  131. emx_onnx_cgen-0.4.2.dev0.dist-info/RECORD +190 -0
  132. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/WHEEL +1 -1
  133. emx_onnx_cgen/runtime/__init__.py +0 -1
  134. emx_onnx_cgen/runtime/evaluator.py +0 -2955
  135. emx_onnx_cgen-0.3.8.dist-info/RECORD +0 -107
  136. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/entry_points.txt +0 -0
  137. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/top_level.txt +0 -0
emx_onnx_cgen/compiler.py CHANGED
@@ -3,7 +3,8 @@ from __future__ import annotations
3
3
  from dataclasses import dataclass, fields
4
4
  import hashlib
5
5
  from pathlib import Path
6
- from typing import Mapping
6
+ import time
7
+ from typing import Callable, Mapping, TypeVar
7
8
 
8
9
  import numpy as np
9
10
  import onnx
@@ -29,21 +30,27 @@ from .lowering import load_lowering_registry
29
30
  from .lowering.common import ensure_supported_dtype, shape_product, value_dtype
30
31
  from .lowering.registry import get_lowering_registry
31
32
  from .onnx_import import import_onnx
32
- from .runtime.evaluator import Evaluator
33
33
 
34
34
 
35
35
  @dataclass(frozen=True)
36
36
  class CompilerOptions:
37
- template_dir: Path
37
+ template_dir: Path | None = None
38
38
  model_name: str = "model"
39
39
  emit_testbench: bool = False
40
40
  command_line: str | None = None
41
41
  model_checksum: str | None = None
42
42
  restrict_arrays: bool = True
43
+ fp32_accumulation_strategy: str = "fp64"
44
+ fp16_accumulation_strategy: str = "fp32"
43
45
  testbench_inputs: Mapping[str, np.ndarray] | None = None
46
+ testbench_optional_inputs: Mapping[str, bool] | None = None
44
47
  truncate_weights_after: int | None = None
45
48
  large_temp_threshold_bytes: int = 1024
46
- large_weight_threshold: int = 1024 * 1024
49
+ large_weight_threshold: int = 100 * 1024
50
+ timings: dict[str, float] | None = None
51
+
52
+
53
+ _T = TypeVar("_T")
47
54
 
48
55
 
49
56
  def _onnx_elem_type(dtype: np.dtype) -> int:
@@ -53,90 +60,155 @@ def _onnx_elem_type(dtype: np.dtype) -> int:
53
60
  raise UnsupportedOpError(f"Unsupported dtype {dtype} for ONNX output")
54
61
 
55
62
 
63
+ def _optional_flag_name(name: str) -> str:
64
+ return f"{name}_present"
65
+
66
+
56
67
  class Compiler:
57
68
  def __init__(self, options: CompilerOptions | None = None) -> None:
58
69
  if options is None:
59
- options = CompilerOptions(template_dir=Path("templates"))
70
+ options = CompilerOptions()
60
71
  self._options = options
61
72
  self._emitter = CEmitter(
62
73
  options.template_dir,
63
74
  restrict_arrays=options.restrict_arrays,
75
+ fp32_accumulation_strategy=options.fp32_accumulation_strategy,
76
+ fp16_accumulation_strategy=options.fp16_accumulation_strategy,
64
77
  truncate_weights_after=options.truncate_weights_after,
65
78
  large_temp_threshold_bytes=options.large_temp_threshold_bytes,
66
79
  large_weight_threshold=options.large_weight_threshold,
67
80
  )
68
81
  load_lowering_registry()
69
82
 
83
+ def _time_step(self, label: str, func: Callable[[], _T]) -> _T:
84
+ timings = self._options.timings
85
+ if timings is None:
86
+ return func()
87
+ started = time.perf_counter()
88
+ result = func()
89
+ timings[label] = time.perf_counter() - started
90
+ return result
91
+
70
92
  def compile(self, model: onnx.ModelProto) -> str:
71
- graph = import_onnx(model)
72
- graph = self._concretize_graph_shapes(model, graph)
73
- testbench_inputs = self._resolve_testbench_inputs(graph)
74
- variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
75
- graph
93
+ graph = self._time_step("import_onnx", lambda: import_onnx(model))
94
+ graph = self._time_step(
95
+ "concretize_shapes",
96
+ lambda: self._concretize_graph_shapes(model, graph),
97
+ )
98
+ testbench_inputs = self._time_step(
99
+ "resolve_testbench_inputs", lambda: self._resolve_testbench_inputs(graph)
100
+ )
101
+ variable_dim_inputs, variable_dim_outputs = self._time_step(
102
+ "collect_variable_dims", lambda: self._collect_variable_dims(graph)
76
103
  )
77
- lowered = self._lower_model(model, graph)
78
- return self._emitter.emit_model(
79
- lowered,
80
- emit_testbench=self._options.emit_testbench,
81
- testbench_inputs=testbench_inputs,
82
- variable_dim_inputs=variable_dim_inputs,
83
- variable_dim_outputs=variable_dim_outputs,
104
+ lowered = self._time_step(
105
+ "lower_model", lambda: self._lower_model(model, graph)
106
+ )
107
+ return self._time_step(
108
+ "emit_model",
109
+ lambda: self._emitter.emit_model(
110
+ lowered,
111
+ emit_testbench=self._options.emit_testbench,
112
+ testbench_inputs=testbench_inputs,
113
+ testbench_optional_inputs=self._options.testbench_optional_inputs,
114
+ variable_dim_inputs=variable_dim_inputs,
115
+ variable_dim_outputs=variable_dim_outputs,
116
+ ),
84
117
  )
85
118
 
86
119
  def compile_with_data_file(self, model: onnx.ModelProto) -> tuple[str, str]:
87
- graph = import_onnx(model)
88
- graph = self._concretize_graph_shapes(model, graph)
89
- testbench_inputs = self._resolve_testbench_inputs(graph)
90
- variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
91
- graph
120
+ graph = self._time_step("import_onnx", lambda: import_onnx(model))
121
+ graph = self._time_step(
122
+ "concretize_shapes",
123
+ lambda: self._concretize_graph_shapes(model, graph),
124
+ )
125
+ testbench_inputs = self._time_step(
126
+ "resolve_testbench_inputs", lambda: self._resolve_testbench_inputs(graph)
127
+ )
128
+ variable_dim_inputs, variable_dim_outputs = self._time_step(
129
+ "collect_variable_dims", lambda: self._collect_variable_dims(graph)
130
+ )
131
+ lowered = self._time_step(
132
+ "lower_model", lambda: self._lower_model(model, graph)
92
133
  )
93
- lowered = self._lower_model(model, graph)
94
- return self._emitter.emit_model_with_data_file(
95
- lowered,
96
- emit_testbench=self._options.emit_testbench,
97
- testbench_inputs=testbench_inputs,
98
- variable_dim_inputs=variable_dim_inputs,
99
- variable_dim_outputs=variable_dim_outputs,
134
+ return self._time_step(
135
+ "emit_model_with_data_file",
136
+ lambda: self._emitter.emit_model_with_data_file(
137
+ lowered,
138
+ emit_testbench=self._options.emit_testbench,
139
+ testbench_inputs=testbench_inputs,
140
+ testbench_optional_inputs=self._options.testbench_optional_inputs,
141
+ variable_dim_inputs=variable_dim_inputs,
142
+ variable_dim_outputs=variable_dim_outputs,
143
+ ),
100
144
  )
101
145
 
102
146
  def compile_with_weight_data(
103
147
  self, model: onnx.ModelProto
104
148
  ) -> tuple[str, bytes | None]:
105
- graph = import_onnx(model)
106
- graph = self._concretize_graph_shapes(model, graph)
107
- testbench_inputs = self._resolve_testbench_inputs(graph)
108
- variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
109
- graph
110
- )
111
- lowered = self._lower_model(model, graph)
112
- generated = self._emitter.emit_model(
113
- lowered,
114
- emit_testbench=self._options.emit_testbench,
115
- testbench_inputs=testbench_inputs,
116
- variable_dim_inputs=variable_dim_inputs,
117
- variable_dim_outputs=variable_dim_outputs,
118
- )
119
- weight_data = self._emitter.collect_weight_data(lowered.constants)
149
+ graph = self._time_step("import_onnx", lambda: import_onnx(model))
150
+ graph = self._time_step(
151
+ "concretize_shapes",
152
+ lambda: self._concretize_graph_shapes(model, graph),
153
+ )
154
+ testbench_inputs = self._time_step(
155
+ "resolve_testbench_inputs", lambda: self._resolve_testbench_inputs(graph)
156
+ )
157
+ variable_dim_inputs, variable_dim_outputs = self._time_step(
158
+ "collect_variable_dims", lambda: self._collect_variable_dims(graph)
159
+ )
160
+ lowered = self._time_step(
161
+ "lower_model", lambda: self._lower_model(model, graph)
162
+ )
163
+ generated = self._time_step(
164
+ "emit_model",
165
+ lambda: self._emitter.emit_model(
166
+ lowered,
167
+ emit_testbench=self._options.emit_testbench,
168
+ testbench_inputs=testbench_inputs,
169
+ testbench_optional_inputs=self._options.testbench_optional_inputs,
170
+ variable_dim_inputs=variable_dim_inputs,
171
+ variable_dim_outputs=variable_dim_outputs,
172
+ ),
173
+ )
174
+ weight_data = self._time_step(
175
+ "collect_weight_data",
176
+ lambda: self._emitter.collect_weight_data(lowered.constants),
177
+ )
120
178
  return generated, weight_data
121
179
 
122
180
  def compile_with_data_file_and_weight_data(
123
181
  self, model: onnx.ModelProto
124
182
  ) -> tuple[str, str, bytes | None]:
125
- graph = import_onnx(model)
126
- graph = self._concretize_graph_shapes(model, graph)
127
- testbench_inputs = self._resolve_testbench_inputs(graph)
128
- variable_dim_inputs, variable_dim_outputs = self._collect_variable_dims(
129
- graph
130
- )
131
- lowered = self._lower_model(model, graph)
132
- generated, data_source = self._emitter.emit_model_with_data_file(
133
- lowered,
134
- emit_testbench=self._options.emit_testbench,
135
- testbench_inputs=testbench_inputs,
136
- variable_dim_inputs=variable_dim_inputs,
137
- variable_dim_outputs=variable_dim_outputs,
138
- )
139
- weight_data = self._emitter.collect_weight_data(lowered.constants)
183
+ graph = self._time_step("import_onnx", lambda: import_onnx(model))
184
+ graph = self._time_step(
185
+ "concretize_shapes",
186
+ lambda: self._concretize_graph_shapes(model, graph),
187
+ )
188
+ testbench_inputs = self._time_step(
189
+ "resolve_testbench_inputs", lambda: self._resolve_testbench_inputs(graph)
190
+ )
191
+ variable_dim_inputs, variable_dim_outputs = self._time_step(
192
+ "collect_variable_dims", lambda: self._collect_variable_dims(graph)
193
+ )
194
+ lowered = self._time_step(
195
+ "lower_model", lambda: self._lower_model(model, graph)
196
+ )
197
+ generated, data_source = self._time_step(
198
+ "emit_model_with_data_file",
199
+ lambda: self._emitter.emit_model_with_data_file(
200
+ lowered,
201
+ emit_testbench=self._options.emit_testbench,
202
+ testbench_inputs=testbench_inputs,
203
+ testbench_optional_inputs=self._options.testbench_optional_inputs,
204
+ variable_dim_inputs=variable_dim_inputs,
205
+ variable_dim_outputs=variable_dim_outputs,
206
+ ),
207
+ )
208
+ weight_data = self._time_step(
209
+ "collect_weight_data",
210
+ lambda: self._emitter.collect_weight_data(lowered.constants),
211
+ )
140
212
  return generated, data_source, weight_data
141
213
 
142
214
  @staticmethod
@@ -165,9 +237,11 @@ class Compiler:
165
237
  self._validate_graph(graph)
166
238
  (
167
239
  input_names,
240
+ input_optional_names,
168
241
  input_shapes,
169
242
  input_dtypes,
170
243
  output_names,
244
+ output_optional_names,
171
245
  output_shapes,
172
246
  output_dtypes,
173
247
  ) = self._collect_io_specs(graph)
@@ -220,9 +294,11 @@ class Compiler:
220
294
  return LoweredModel(
221
295
  name=self._options.model_name,
222
296
  input_names=input_names,
297
+ input_optional_names=input_optional_names,
223
298
  input_shapes=input_shapes,
224
299
  input_dtypes=input_dtypes,
225
300
  output_names=output_names,
301
+ output_optional_names=output_optional_names,
226
302
  output_shapes=output_shapes,
227
303
  output_dtypes=output_dtypes,
228
304
  constants=constants,
@@ -248,7 +324,6 @@ class Compiler:
248
324
  "Testbench inputs include unknown inputs: "
249
325
  + ", ".join(unknown_inputs)
250
326
  )
251
- resolved: dict[str, tuple[float | int | bool, ...]] = {}
252
327
  for name, values in self._options.testbench_inputs.items():
253
328
  if not isinstance(values, np.ndarray):
254
329
  raise CodegenError(
@@ -265,9 +340,7 @@ class Compiler:
265
340
  "Testbench input "
266
341
  f"{name} has {array.size} elements, expected {expected_count}"
267
342
  )
268
- array = array.reshape(expected_shape)
269
- resolved[name] = tuple(array.ravel().tolist())
270
- return resolved
343
+ return None
271
344
 
272
345
  def _concretize_graph_shapes(
273
346
  self, model: onnx.ModelProto, graph: Graph
@@ -337,6 +410,7 @@ class Compiler:
337
410
  dtype=value.type.dtype,
338
411
  shape=shape,
339
412
  dim_params=(None,) * len(shape),
413
+ is_optional=value.type.is_optional,
340
414
  ),
341
415
  )
342
416
 
@@ -361,27 +435,39 @@ class Compiler:
361
435
  self, graph: Graph
362
436
  ) -> tuple[
363
437
  tuple[str, ...],
438
+ tuple[str | None, ...],
364
439
  tuple[tuple[int, ...], ...],
365
440
  tuple[ScalarType, ...],
366
441
  tuple[str, ...],
442
+ tuple[str | None, ...],
367
443
  tuple[tuple[int, ...], ...],
368
444
  tuple[ScalarType, ...],
369
445
  ]:
370
446
  input_names = tuple(value.name for value in graph.inputs)
447
+ input_optional_names = tuple(
448
+ _optional_flag_name(value.name) if value.type.is_optional else None
449
+ for value in graph.inputs
450
+ )
371
451
  input_shapes = tuple(value.type.shape for value in graph.inputs)
372
452
  input_dtypes = tuple(
373
453
  value_dtype(graph, value.name) for value in graph.inputs
374
454
  )
375
455
  output_names = tuple(value.name for value in graph.outputs)
456
+ output_optional_names = tuple(
457
+ _optional_flag_name(value.name) if value.type.is_optional else None
458
+ for value in graph.outputs
459
+ )
376
460
  output_shapes = tuple(value.type.shape for value in graph.outputs)
377
461
  output_dtypes = tuple(
378
462
  value_dtype(graph, value.name) for value in graph.outputs
379
463
  )
380
464
  return (
381
465
  input_names,
466
+ input_optional_names,
382
467
  input_shapes,
383
468
  input_dtypes,
384
469
  output_names,
470
+ output_optional_names,
385
471
  output_shapes,
386
472
  output_dtypes,
387
473
  )
@@ -439,26 +525,22 @@ class Compiler:
439
525
  initializer_count=len(graph.initializers),
440
526
  )
441
527
 
442
- def run(
443
- self, model: onnx.ModelProto, feeds: Mapping[str, np.ndarray]
444
- ) -> dict[str, np.ndarray]:
445
- graph = import_onnx(model)
446
- evaluator = Evaluator(graph)
447
- return evaluator.run(feeds)
448
-
449
-
450
528
  def _lowered_constants(graph: Graph | GraphContext) -> tuple[ConstTensor, ...]:
529
+ used_initializers = {value.name for value in graph.outputs}
530
+ for node in graph.nodes:
531
+ used_initializers.update(node.inputs)
451
532
  constants: list[ConstTensor] = []
452
533
  for initializer in graph.initializers:
534
+ if initializer.name not in used_initializers:
535
+ continue
453
536
  dtype = ensure_supported_dtype(initializer.type.dtype)
537
+ data_array = initializer.data.astype(dtype.np_dtype, copy=False)
538
+ data_tuple = tuple(data_array.ravel().tolist())
454
539
  constants.append(
455
540
  ConstTensor(
456
541
  name=initializer.name,
457
542
  shape=initializer.type.shape,
458
- data=tuple(
459
- dtype.np_dtype.type(value)
460
- for value in initializer.data.ravel()
461
- ),
543
+ data=data_tuple,
462
544
  dtype=dtype,
463
545
  )
464
546
  )
@@ -0,0 +1,39 @@
1
+ from __future__ import annotations
2
+
3
+ from contextlib import contextmanager
4
+ import os
5
+ from typing import Iterator
6
+
7
+ THREAD_ENV_VARS = (
8
+ "OMP_NUM_THREADS",
9
+ "OPENBLAS_NUM_THREADS",
10
+ "MKL_NUM_THREADS",
11
+ "VECLIB_MAXIMUM_THREADS",
12
+ "NUMEXPR_NUM_THREADS",
13
+ "BLIS_NUM_THREADS",
14
+ )
15
+
16
+
17
+ @contextmanager
18
+ def deterministic_reference_runtime() -> Iterator[None]:
19
+ previous = {name: os.environ.get(name) for name in THREAD_ENV_VARS}
20
+ for name in THREAD_ENV_VARS:
21
+ os.environ[name] = "1"
22
+ limits_context = None
23
+ try:
24
+ try:
25
+ from threadpoolctl import threadpool_limits
26
+ except Exception:
27
+ threadpool_limits = None
28
+ if threadpool_limits is not None:
29
+ limits_context = threadpool_limits(limits=1)
30
+ limits_context.__enter__()
31
+ yield
32
+ finally:
33
+ if limits_context is not None:
34
+ limits_context.__exit__(None, None, None)
35
+ for name, value in previous.items():
36
+ if value is None:
37
+ os.environ.pop(name, None)
38
+ else:
39
+ os.environ[name] = value
@@ -14,9 +14,28 @@ class GraphContext:
14
14
  _shape_cache: dict[str, tuple[int, ...]] = field(default_factory=dict)
15
15
  _initializer_cache: dict[str, Initializer] = field(default_factory=dict)
16
16
  _producer_cache: dict[str, Node] = field(default_factory=dict)
17
+ _value_cache: dict[str, Value] = field(default_factory=dict)
18
+
19
+ def __post_init__(self) -> None:
20
+ for value in self.graph.inputs + self.graph.outputs + self.graph.values:
21
+ self._value_cache[value.name] = value
22
+ for initializer in self.graph.initializers:
23
+ if initializer.name not in self._value_cache:
24
+ self._value_cache[initializer.name] = Value(
25
+ name=initializer.name,
26
+ type=initializer.type,
27
+ )
28
+ self._initializer_cache[initializer.name] = initializer
29
+ for node in self.graph.nodes:
30
+ for output in node.outputs:
31
+ if output and output not in self._producer_cache:
32
+ self._producer_cache[output] = node
17
33
 
18
34
  def find_value(self, name: str) -> Value:
19
- return self.graph.find_value(name)
35
+ value = self._value_cache.get(name)
36
+ if value is None:
37
+ raise KeyError(name)
38
+ return value
20
39
 
21
40
  def dtype(self, name: str, node: Node | None = None) -> ScalarType:
22
41
  if name in self._dtype_cache:
@@ -55,23 +74,14 @@ class GraphContext:
55
74
  def set_shape(self, name: str, shape: tuple[int, ...]) -> None:
56
75
  self._shape_cache[name] = shape
57
76
 
77
+ def has_shape(self, name: str) -> bool:
78
+ return name in self._shape_cache
79
+
58
80
  def initializer(self, name: str) -> Initializer | None:
59
- if name in self._initializer_cache:
60
- return self._initializer_cache[name]
61
- for initializer in self.graph.initializers:
62
- if initializer.name == name:
63
- self._initializer_cache[name] = initializer
64
- return initializer
65
- return None
81
+ return self._initializer_cache.get(name)
66
82
 
67
83
  def producer(self, output_name: str) -> Node | None:
68
- if output_name in self._producer_cache:
69
- return self._producer_cache[output_name]
70
- for node in self.graph.nodes:
71
- if output_name in node.outputs:
72
- self._producer_cache[output_name] = node
73
- return node
74
- return None
84
+ return self._producer_cache.get(output_name)
75
85
 
76
86
  def opset_version(self, domain: str = "") -> int | None:
77
87
  if domain in {"", "ai.onnx"}:
emx_onnx_cgen/ir/model.py CHANGED
@@ -13,6 +13,7 @@ class TensorType:
13
13
  dtype: ScalarType
14
14
  shape: tuple[int, ...]
15
15
  dim_params: tuple[str | None, ...]
16
+ is_optional: bool = False
16
17
 
17
18
 
18
19
  @dataclass(frozen=True)
@@ -414,19 +414,20 @@ class VariadicLikeOpBase(RenderableOpBase):
414
414
 
415
415
  def infer_shapes(self, ctx: OpContext) -> None:
416
416
  input_shapes = tuple(ctx.shape(name) for name in self._variadic_inputs())
417
- output_shape = BroadcastingOpBase.broadcast_shapes(*input_shapes)
418
- for shape in input_shapes:
419
- if shape != output_shape:
420
- raise UnsupportedOpError(
421
- f"{self._variadic_kind()} expects identical input/output shapes"
422
- )
417
+ try:
418
+ output_shape = BroadcastingOpBase.broadcast_shapes(*input_shapes)
419
+ except ShapeInferenceError as exc:
420
+ raise UnsupportedOpError(
421
+ f"{self._variadic_kind()} expects broadcastable input shapes"
422
+ ) from exc
423
423
  try:
424
424
  expected = ctx.shape(self._variadic_output())
425
425
  except ShapeInferenceError:
426
426
  expected = None
427
427
  if expected is not None and expected != output_shape:
428
428
  raise UnsupportedOpError(
429
- f"{self._variadic_kind()} expects identical input/output shapes"
429
+ f"{self._variadic_kind()} output shape must be {output_shape}, "
430
+ f"got {expected}"
430
431
  )
431
432
  ctx.set_shape(self._variadic_output(), output_shape)
432
433
 
@@ -469,6 +470,30 @@ class ReduceOpBase(RenderableOpBase):
469
470
 
470
471
 
471
472
  class BroadcastingOpBase(RenderableOpBase):
473
+ @staticmethod
474
+ def unidirectional_broadcastable(
475
+ source: tuple[int, ...],
476
+ target: tuple[int, ...],
477
+ ) -> bool:
478
+ if len(source) > len(target):
479
+ return False
480
+ padded = (1,) * (len(target) - len(source)) + source
481
+ for source_dim, target_dim in zip(padded, target):
482
+ if source_dim not in {1, target_dim}:
483
+ return False
484
+ return True
485
+
486
+ @staticmethod
487
+ def prelu_channel_axis(
488
+ input_shape: tuple[int, ...],
489
+ slope_shape: tuple[int, ...],
490
+ ) -> int | None:
491
+ if len(input_shape) < 2 or len(slope_shape) != 1:
492
+ return None
493
+ if slope_shape[0] != input_shape[1]:
494
+ return None
495
+ return 1
496
+
472
497
  @staticmethod
473
498
  def broadcast_shapes(
474
499
  *shapes: tuple[int, ...],
@@ -3,28 +3,35 @@ from .elementwise import (
3
3
  ClipOp,
4
4
  IdentityOp,
5
5
  MultiInputBinaryOp,
6
+ PowOp,
7
+ QLinearMulOp,
6
8
  UnaryOp,
7
9
  VariadicOp,
8
10
  WhereOp,
9
11
  )
10
12
  from .misc import (
13
+ BernoulliOp,
11
14
  CastOp,
12
15
  ConcatOp,
13
16
  ConstantOfShapeOp,
14
17
  CumSumOp,
15
18
  DepthToSpaceOp,
19
+ DequantizeLinearOp,
16
20
  ExpandOp,
17
21
  EyeLikeOp,
18
22
  GatherElementsOp,
19
23
  GatherNDOp,
20
24
  GatherOp,
21
25
  GridSampleOp,
26
+ HammingWindowOp,
22
27
  NonMaxSuppressionOp,
23
28
  NonZeroOp,
24
29
  OneHotOp,
30
+ OptionalHasElementOp,
25
31
  PadOp,
26
32
  QuantizeLinearOp,
27
33
  RangeOp,
34
+ HammingWindowOp,
28
35
  ReshapeOp,
29
36
  ResizeOp,
30
37
  ScatterNDOp,
@@ -34,6 +41,7 @@ from .misc import (
34
41
  SpaceToDepthOp,
35
42
  SplitOp,
36
43
  TensorScatterOp,
44
+ TfIdfVectorizerOp,
37
45
  TileOp,
38
46
  TransposeOp,
39
47
  TriluOp,
@@ -44,10 +52,12 @@ from .nn import (
44
52
  AveragePoolOp,
45
53
  BatchNormOp,
46
54
  ConvOp,
55
+ ConvIntegerOp,
47
56
  ConvTransposeOp,
48
57
  EinsumKind,
49
58
  EinsumOp,
50
59
  GemmOp,
60
+ GruOp,
51
61
  GroupNormalizationOp,
52
62
  HardmaxOp,
53
63
  InstanceNormalizationOp,
@@ -75,15 +85,18 @@ __all__ = [
75
85
  "AttentionOp",
76
86
  "AveragePoolOp",
77
87
  "BatchNormOp",
88
+ "BernoulliOp",
78
89
  "BinaryOp",
79
90
  "CastOp",
80
91
  "ClipOp",
81
92
  "ConcatOp",
82
93
  "ConstantOfShapeOp",
83
94
  "ConvOp",
95
+ "ConvIntegerOp",
84
96
  "ConvTransposeOp",
85
97
  "CumSumOp",
86
98
  "DepthToSpaceOp",
99
+ "DequantizeLinearOp",
87
100
  "EinsumKind",
88
101
  "EinsumOp",
89
102
  "ExpandOp",
@@ -93,6 +106,8 @@ __all__ = [
93
106
  "GatherOp",
94
107
  "GemmOp",
95
108
  "GridSampleOp",
109
+ "GruOp",
110
+ "HammingWindowOp",
96
111
  "GroupNormalizationOp",
97
112
  "HardmaxOp",
98
113
  "IdentityOp",
@@ -111,10 +126,14 @@ __all__ = [
111
126
  "NonMaxSuppressionOp",
112
127
  "NonZeroOp",
113
128
  "OneHotOp",
129
+ "OptionalHasElementOp",
114
130
  "PadOp",
131
+ "PowOp",
115
132
  "QuantizeLinearOp",
133
+ "QLinearMulOp",
116
134
  "QLinearMatMulOp",
117
135
  "RangeOp",
136
+ "HammingWindowOp",
118
137
  "ReduceOp",
119
138
  "ReshapeOp",
120
139
  "ResizeOp",
@@ -129,6 +148,7 @@ __all__ = [
129
148
  "SpaceToDepthOp",
130
149
  "SplitOp",
131
150
  "TensorScatterOp",
151
+ "TfIdfVectorizerOp",
132
152
  "TileOp",
133
153
  "TopKOp",
134
154
  "TransposeOp",