emx-onnx-cgen 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +50 -23
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +1844 -1568
  6. emx_onnx_cgen/codegen/emitter.py +5 -0
  7. emx_onnx_cgen/compiler.py +30 -387
  8. emx_onnx_cgen/ir/context.py +87 -0
  9. emx_onnx_cgen/ir/op_base.py +193 -0
  10. emx_onnx_cgen/ir/op_context.py +65 -0
  11. emx_onnx_cgen/ir/ops/__init__.py +130 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +146 -0
  13. emx_onnx_cgen/ir/ops/misc.py +421 -0
  14. emx_onnx_cgen/ir/ops/nn.py +580 -0
  15. emx_onnx_cgen/ir/ops/reduce.py +95 -0
  16. emx_onnx_cgen/lowering/__init__.py +79 -1
  17. emx_onnx_cgen/lowering/adagrad.py +114 -0
  18. emx_onnx_cgen/lowering/arg_reduce.py +1 -1
  19. emx_onnx_cgen/lowering/attention.py +1 -1
  20. emx_onnx_cgen/lowering/average_pool.py +1 -1
  21. emx_onnx_cgen/lowering/batch_normalization.py +1 -1
  22. emx_onnx_cgen/lowering/cast.py +1 -1
  23. emx_onnx_cgen/lowering/common.py +36 -18
  24. emx_onnx_cgen/lowering/concat.py +1 -1
  25. emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
  26. emx_onnx_cgen/lowering/conv.py +1 -1
  27. emx_onnx_cgen/lowering/conv_transpose.py +1 -1
  28. emx_onnx_cgen/lowering/cumsum.py +1 -1
  29. emx_onnx_cgen/lowering/depth_space.py +1 -1
  30. emx_onnx_cgen/lowering/dropout.py +1 -1
  31. emx_onnx_cgen/lowering/einsum.py +1 -1
  32. emx_onnx_cgen/lowering/elementwise.py +152 -4
  33. emx_onnx_cgen/lowering/expand.py +1 -1
  34. emx_onnx_cgen/lowering/eye_like.py +1 -1
  35. emx_onnx_cgen/lowering/flatten.py +1 -1
  36. emx_onnx_cgen/lowering/gather.py +1 -1
  37. emx_onnx_cgen/lowering/gather_elements.py +1 -1
  38. emx_onnx_cgen/lowering/gather_nd.py +1 -1
  39. emx_onnx_cgen/lowering/gemm.py +1 -1
  40. emx_onnx_cgen/lowering/global_max_pool.py +1 -1
  41. emx_onnx_cgen/lowering/grid_sample.py +1 -1
  42. emx_onnx_cgen/lowering/group_normalization.py +1 -1
  43. emx_onnx_cgen/lowering/hardmax.py +1 -1
  44. emx_onnx_cgen/lowering/identity.py +1 -1
  45. emx_onnx_cgen/lowering/instance_normalization.py +1 -1
  46. emx_onnx_cgen/lowering/layer_normalization.py +1 -1
  47. emx_onnx_cgen/lowering/logsoftmax.py +1 -1
  48. emx_onnx_cgen/lowering/lp_normalization.py +1 -1
  49. emx_onnx_cgen/lowering/lp_pool.py +1 -1
  50. emx_onnx_cgen/lowering/lrn.py +1 -1
  51. emx_onnx_cgen/lowering/lstm.py +1 -1
  52. emx_onnx_cgen/lowering/matmul.py +1 -1
  53. emx_onnx_cgen/lowering/maxpool.py +1 -1
  54. emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
  55. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +1 -1
  56. emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
  57. emx_onnx_cgen/lowering/nonzero.py +1 -1
  58. emx_onnx_cgen/lowering/one_hot.py +1 -1
  59. emx_onnx_cgen/lowering/pad.py +1 -1
  60. emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
  61. emx_onnx_cgen/lowering/quantize_linear.py +1 -1
  62. emx_onnx_cgen/lowering/range.py +1 -1
  63. emx_onnx_cgen/lowering/reduce.py +1 -1
  64. emx_onnx_cgen/lowering/registry.py +24 -5
  65. emx_onnx_cgen/lowering/reshape.py +1 -1
  66. emx_onnx_cgen/lowering/resize.py +1 -1
  67. emx_onnx_cgen/lowering/rms_normalization.py +1 -1
  68. emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
  69. emx_onnx_cgen/lowering/scatter_nd.py +1 -1
  70. emx_onnx_cgen/lowering/shape.py +6 -25
  71. emx_onnx_cgen/lowering/size.py +1 -1
  72. emx_onnx_cgen/lowering/slice.py +1 -1
  73. emx_onnx_cgen/lowering/softmax.py +1 -1
  74. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
  75. emx_onnx_cgen/lowering/split.py +1 -1
  76. emx_onnx_cgen/lowering/squeeze.py +1 -1
  77. emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
  78. emx_onnx_cgen/lowering/tile.py +1 -1
  79. emx_onnx_cgen/lowering/topk.py +25 -7
  80. emx_onnx_cgen/lowering/transpose.py +1 -1
  81. emx_onnx_cgen/lowering/trilu.py +1 -1
  82. emx_onnx_cgen/lowering/unsqueeze.py +1 -1
  83. emx_onnx_cgen/lowering/variadic.py +1 -1
  84. emx_onnx_cgen/lowering/where.py +1 -1
  85. emx_onnx_cgen/runtime/evaluator.py +325 -1
  86. emx_onnx_cgen/verification.py +9 -39
  87. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/METADATA +8 -7
  88. emx_onnx_cgen-0.3.2.dist-info/RECORD +107 -0
  89. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/WHEEL +1 -1
  90. shared/scalar_functions.py +11 -0
  91. shared/ulp.py +17 -0
  92. emx_onnx_cgen-0.3.0.dist-info/RECORD +0 -93
  93. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/entry_points.txt +0 -0
  94. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/top_level.txt +0 -0
@@ -10,6 +10,7 @@ from ..errors import ShapeInferenceError, UnsupportedOpError
10
10
  from ..ir.model import Graph, Node
11
11
  from ..lowering.attention import resolve_attention_spec
12
12
  from ..lowering.average_pool import lower_average_pool, lower_global_average_pool
13
+ from ..lowering.adagrad import lower_adagrad
13
14
  from ..lowering.batch_normalization import lower_batch_normalization
14
15
  from ..lowering.concat import lower_concat
15
16
  from ..lowering.constant_of_shape import lower_constant_of_shape
@@ -28,6 +29,7 @@ from ..lowering.grid_sample import lower_grid_sample
28
29
  from ..lowering.instance_normalization import lower_instance_normalization
29
30
  from ..lowering.group_normalization import lower_group_normalization
30
31
  from ..lowering.layer_normalization import lower_layer_normalization
32
+ from ..lowering.non_max_suppression import lower_non_max_suppression
31
33
  from ..lowering.mean_variance_normalization import (
32
34
  lower_mean_variance_normalization,
33
35
  )
@@ -49,6 +51,7 @@ from ..lowering.topk import lower_topk
49
51
  from ..lowering.lstm import ACTIVATION_KIND_BY_NAME, resolve_lstm_spec
50
52
  from ..lowering.lrn import resolve_lrn_spec
51
53
  from ..lowering.matmul import lower_matmul
54
+ from ..lowering.qlinear_matmul import lower_qlinear_matmul
52
55
  from ..lowering.maxpool import resolve_maxpool_spec
53
56
  from ..lowering.reduce import (
54
57
  REDUCE_KIND_BY_OP,
@@ -58,11 +61,13 @@ from ..lowering.reduce import (
58
61
  )
59
62
  from ..lowering.reshape import lower_reshape
60
63
  from ..lowering.scatter_nd import lower_scatternd
64
+ from ..lowering.tensor_scatter import lower_tensor_scatter
61
65
  from ..lowering.slice import _normalize_slices
62
66
  from ..lowering.shape import lower_shape
63
67
  from ..lowering.size import lower_size
64
68
  from ..lowering.softmax import lower_softmax
65
69
  from ..lowering.rms_normalization import lower_rms_normalization
70
+ from ..lowering.rotary_embedding import lower_rotary_embedding
66
71
  from ..lowering.squeeze import lower_squeeze
67
72
  from ..lowering.transpose import lower_transpose
68
73
  from ..lowering.unsqueeze import lower_unsqueeze
@@ -158,6 +163,37 @@ def _eval_einsum(evaluator: Evaluator, node: Node) -> None:
158
163
  evaluator.values[node.outputs[0]] = np.einsum(equation, *inputs)
159
164
 
160
165
 
166
+ @register_evaluator("Adagrad")
167
+ def _eval_adagrad(evaluator: Evaluator, node: Node) -> None:
168
+ op = lower_adagrad(evaluator.graph, node)
169
+ rate = evaluator.values[op.rate]
170
+ timestep = evaluator.values[op.timestep]
171
+ rate_value = (
172
+ np.array(rate, dtype=op.dtype.np_dtype).reshape(-1)[0].item()
173
+ )
174
+ timestep_value = (
175
+ np.array(timestep, dtype=np.int64).reshape(-1)[0].item()
176
+ )
177
+ r = op.dtype.np_dtype.type(
178
+ rate_value / (1.0 + float(timestep_value) * op.decay_factor)
179
+ )
180
+ for x_name, g_name, h_name, out_name, h_out_name in zip(
181
+ op.inputs,
182
+ op.gradients,
183
+ op.accumulators,
184
+ op.outputs,
185
+ op.accumulator_outputs,
186
+ ):
187
+ x = evaluator.values[x_name]
188
+ g = evaluator.values[g_name]
189
+ h = evaluator.values[h_name]
190
+ g_regularized = op.norm_coefficient * x + g
191
+ h_new = h + g_regularized * g_regularized
192
+ h_adaptive = np.sqrt(h_new) + op.epsilon
193
+ evaluator.values[out_name] = x - r * g_regularized / h_adaptive
194
+ evaluator.values[h_out_name] = h_new
195
+
196
+
161
197
  @register_evaluator("Clip")
162
198
  def _eval_clip(evaluator: Evaluator, node: Node) -> None:
163
199
  if not node.inputs or len(node.outputs) != 1:
@@ -188,6 +224,79 @@ def _eval_clip(evaluator: Evaluator, node: Node) -> None:
188
224
  evaluator.values[node.outputs[0]] = np.clip(x, min_val, max_val)
189
225
 
190
226
 
227
+ def _max_min(lhs: float, rhs: float) -> tuple[float, float]:
228
+ if lhs >= rhs:
229
+ return rhs, lhs
230
+ return lhs, rhs
231
+
232
+
233
+ def _suppress_by_iou(
234
+ boxes: np.ndarray,
235
+ box_index1: int,
236
+ box_index2: int,
237
+ *,
238
+ center_point_box: int,
239
+ iou_threshold: float,
240
+ ) -> bool:
241
+ box1 = boxes[box_index1]
242
+ box2 = boxes[box_index2]
243
+ if center_point_box == 0:
244
+ x1_min, x1_max = _max_min(float(box1[1]), float(box1[3]))
245
+ x2_min, x2_max = _max_min(float(box2[1]), float(box2[3]))
246
+ intersection_x_min = max(x1_min, x2_min)
247
+ intersection_x_max = min(x1_max, x2_max)
248
+ if intersection_x_max <= intersection_x_min:
249
+ return False
250
+
251
+ y1_min, y1_max = _max_min(float(box1[0]), float(box1[2]))
252
+ y2_min, y2_max = _max_min(float(box2[0]), float(box2[2]))
253
+ intersection_y_min = max(y1_min, y2_min)
254
+ intersection_y_max = min(y1_max, y2_max)
255
+ if intersection_y_max <= intersection_y_min:
256
+ return False
257
+ else:
258
+ box1_width_half = float(box1[2]) / 2.0
259
+ box1_height_half = float(box1[3]) / 2.0
260
+ box2_width_half = float(box2[2]) / 2.0
261
+ box2_height_half = float(box2[3]) / 2.0
262
+
263
+ x1_min = float(box1[0]) - box1_width_half
264
+ x1_max = float(box1[0]) + box1_width_half
265
+ x2_min = float(box2[0]) - box2_width_half
266
+ x2_max = float(box2[0]) + box2_width_half
267
+
268
+ y1_min = float(box1[1]) - box1_height_half
269
+ y1_max = float(box1[1]) + box1_height_half
270
+ y2_min = float(box2[1]) - box2_height_half
271
+ y2_max = float(box2[1]) + box2_height_half
272
+
273
+ intersection_x_min = max(x1_min, x2_min)
274
+ intersection_x_max = min(x1_max, x2_max)
275
+ if intersection_x_max <= intersection_x_min:
276
+ return False
277
+
278
+ intersection_y_min = max(y1_min, y2_min)
279
+ intersection_y_max = min(y1_max, y2_max)
280
+ if intersection_y_max <= intersection_y_min:
281
+ return False
282
+
283
+ intersection_area = (intersection_x_max - intersection_x_min) * (
284
+ intersection_y_max - intersection_y_min
285
+ )
286
+ if intersection_area <= 0:
287
+ return False
288
+
289
+ area1 = (x1_max - x1_min) * (y1_max - y1_min)
290
+ area2 = (x2_max - x2_min) * (y2_max - y2_min)
291
+ union_area = area1 + area2 - intersection_area
292
+
293
+ if area1 <= 0 or area2 <= 0 or union_area <= 0:
294
+ return False
295
+
296
+ intersection_over_union = intersection_area / union_area
297
+ return intersection_over_union > iou_threshold
298
+
299
+
191
300
  def _exclusive_cumsum(data: np.ndarray, axis: int) -> np.ndarray:
192
301
  result = np.zeros_like(data)
193
302
  if data.shape[axis] == 0:
@@ -222,6 +331,100 @@ def _eval_cumsum(evaluator: Evaluator, node: Node) -> None:
222
331
  evaluator.values[op.output] = result
223
332
 
224
333
 
334
+ @register_evaluator("NonMaxSuppression")
335
+ def _eval_nonmax_suppression(evaluator: Evaluator, node: Node) -> None:
336
+ op = lower_non_max_suppression(evaluator.graph, node)
337
+ boxes = evaluator.values[op.boxes]
338
+ scores = evaluator.values[op.scores]
339
+
340
+ max_output_boxes_per_class = 0
341
+ if op.max_output_boxes_per_class is not None:
342
+ max_output_values = evaluator.values[
343
+ op.max_output_boxes_per_class
344
+ ].astype(np.int64, copy=False)
345
+ max_output_values = max_output_values.reshape(-1)
346
+ if max_output_values.size != 1:
347
+ raise UnsupportedOpError(
348
+ "NonMaxSuppression max_output_boxes_per_class must be scalar"
349
+ )
350
+ max_output_boxes_per_class = max(int(max_output_values[0]), 0)
351
+
352
+ iou_threshold = 0.0
353
+ if op.iou_threshold is not None:
354
+ iou_values = evaluator.values[op.iou_threshold].reshape(-1)
355
+ if iou_values.size != 1:
356
+ raise UnsupportedOpError(
357
+ "NonMaxSuppression iou_threshold must be scalar"
358
+ )
359
+ iou_threshold = float(iou_values[0])
360
+
361
+ score_threshold = 0.0
362
+ score_threshold_enabled = op.score_threshold is not None
363
+ if op.score_threshold is not None:
364
+ score_values = evaluator.values[op.score_threshold].reshape(-1)
365
+ if score_values.size != 1:
366
+ raise UnsupportedOpError(
367
+ "NonMaxSuppression score_threshold must be scalar"
368
+ )
369
+ score_threshold = float(score_values[0])
370
+
371
+ if max_output_boxes_per_class == 0:
372
+ evaluator.values[op.output] = np.empty((0, 3), dtype=np.int64)
373
+ return
374
+
375
+ num_batches = boxes.shape[0]
376
+ num_boxes = boxes.shape[1]
377
+ num_classes = scores.shape[1]
378
+
379
+ selected_indices: list[tuple[int, int, int]] = []
380
+ for batch_index in range(num_batches):
381
+ batch_boxes = boxes[batch_index]
382
+ for class_index in range(num_classes):
383
+ class_scores = scores[batch_index, class_index]
384
+ candidates: list[tuple[float, int]] = []
385
+ if score_threshold_enabled:
386
+ for box_index in range(num_boxes):
387
+ score = float(class_scores[box_index])
388
+ if score > score_threshold:
389
+ candidates.append((score, box_index))
390
+ else:
391
+ for box_index in range(num_boxes):
392
+ candidates.append(
393
+ (float(class_scores[box_index]), box_index)
394
+ )
395
+ candidates.sort(key=lambda item: (item[0], -item[1]))
396
+ selected_boxes: list[int] = []
397
+ while (
398
+ candidates
399
+ and len(selected_boxes) < max_output_boxes_per_class
400
+ ):
401
+ _, box_index = candidates.pop()
402
+ if any(
403
+ _suppress_by_iou(
404
+ batch_boxes,
405
+ box_index,
406
+ selected_index,
407
+ center_point_box=op.center_point_box,
408
+ iou_threshold=iou_threshold,
409
+ )
410
+ for selected_index in selected_boxes
411
+ ):
412
+ continue
413
+ selected_boxes.append(box_index)
414
+ selected_indices.append(
415
+ (batch_index, class_index, box_index)
416
+ )
417
+
418
+ result = np.empty((len(selected_indices), 3), dtype=np.int64)
419
+ for idx, (batch_index, class_index, box_index) in enumerate(
420
+ selected_indices
421
+ ):
422
+ result[idx, 0] = batch_index
423
+ result[idx, 1] = class_index
424
+ result[idx, 2] = box_index
425
+ evaluator.values[op.output] = result
426
+
427
+
225
428
  @register_evaluator("Pad")
226
429
  def _eval_pad(evaluator: Evaluator, node: Node) -> None:
227
430
  op = lower_pad(evaluator.graph, node)
@@ -343,6 +546,35 @@ def _eval_scatternd(evaluator: Evaluator, node: Node) -> None:
343
546
  evaluator.values[op.output] = output
344
547
 
345
548
 
549
+ @register_evaluator("TensorScatter")
550
+ def _eval_tensor_scatter(evaluator: Evaluator, node: Node) -> None:
551
+ op = lower_tensor_scatter(evaluator.graph, node)
552
+ past_cache = evaluator.values[op.past_cache]
553
+ update = evaluator.values[op.update]
554
+ if op.write_indices is None:
555
+ write_indices = np.zeros((past_cache.shape[0],), dtype=np.int64)
556
+ else:
557
+ write_indices = evaluator.values[op.write_indices].astype(
558
+ np.int64, copy=False
559
+ )
560
+ axis = op.axis
561
+ max_sequence_length = past_cache.shape[axis]
562
+ sequence_length = update.shape[axis]
563
+ output = np.array(past_cache, copy=True)
564
+ for prefix_idx in np.ndindex(past_cache.shape[:axis]):
565
+ batch_idx = prefix_idx[0]
566
+ base_index = int(write_indices[batch_idx])
567
+ for sequence_idx in range(sequence_length):
568
+ cache_idx = (*prefix_idx, base_index + sequence_idx)
569
+ if op.mode == "circular":
570
+ cache_idx = tuple(
571
+ np.mod(np.asarray(cache_idx), max_sequence_length)
572
+ )
573
+ update_idx = (*prefix_idx, sequence_idx)
574
+ output[cache_idx] = update[update_idx]
575
+ evaluator.values[op.output] = output
576
+
577
+
346
578
  @register_evaluator("Celu")
347
579
  def _eval_celu(evaluator: Evaluator, node: Node) -> None:
348
580
  if len(node.inputs) != 1 or len(node.outputs) != 1:
@@ -749,8 +981,22 @@ def _eval_isinf(evaluator: Evaluator, node: Node) -> None:
749
981
  output_dtype = value_dtype(evaluator.graph, node.outputs[0], node)
750
982
  if output_dtype != ScalarType.BOOL:
751
983
  raise UnsupportedOpError("IsInf output must be bool")
984
+ detect_negative = int(node.attrs.get("detect_negative", 1))
985
+ detect_positive = int(node.attrs.get("detect_positive", 1))
986
+ if detect_negative not in {0, 1} or detect_positive not in {0, 1}:
987
+ raise UnsupportedOpError(
988
+ "IsInf detect_negative and detect_positive must be 0 or 1"
989
+ )
752
990
  x = evaluator.values[node.inputs[0]]
753
- evaluator.values[node.outputs[0]] = np.isinf(x)
991
+ if detect_negative and detect_positive:
992
+ result = np.isinf(x)
993
+ elif detect_negative:
994
+ result = np.isneginf(x)
995
+ elif detect_positive:
996
+ result = np.isposinf(x)
997
+ else:
998
+ result = np.zeros(x.shape, dtype=bool)
999
+ evaluator.values[node.outputs[0]] = result
754
1000
 
755
1001
 
756
1002
  @register_evaluator("IsNaN")
@@ -1184,6 +1430,49 @@ def _eval_attention(evaluator: Evaluator, node: Node) -> None:
1184
1430
  evaluator.values[qk_matmul_output_name] = qk_output
1185
1431
 
1186
1432
 
1433
+ @register_evaluator("RotaryEmbedding")
1434
+ def _eval_rotary_embedding(evaluator: Evaluator, node: Node) -> None:
1435
+ op = lower_rotary_embedding(evaluator.graph, node)
1436
+ x = evaluator.values[op.input0]
1437
+ cos_cache = evaluator.values[op.cos_cache]
1438
+ sin_cache = evaluator.values[op.sin_cache]
1439
+ position_ids = (
1440
+ evaluator.values[op.position_ids] if op.position_ids else None
1441
+ )
1442
+ original_shape = x.shape
1443
+ if op.input_rank == 4:
1444
+ x = np.transpose(x, (0, 2, 1, 3))
1445
+ else:
1446
+ x = x.reshape(op.batch, op.seq_len, op.num_heads, op.head_size)
1447
+ x_rotate = x[..., : op.rotary_dim]
1448
+ x_not_rotate = x[..., op.rotary_dim :]
1449
+ if position_ids is not None:
1450
+ cos_cache = cos_cache[position_ids]
1451
+ sin_cache = sin_cache[position_ids]
1452
+ cos_cache = np.expand_dims(cos_cache, axis=2)
1453
+ sin_cache = np.expand_dims(sin_cache, axis=2)
1454
+ if op.interleaved:
1455
+ x1 = x_rotate[..., 0::2]
1456
+ x2 = x_rotate[..., 1::2]
1457
+ else:
1458
+ x1, x2 = np.split(x_rotate, 2, axis=-1)
1459
+ real = (cos_cache * x1) - (sin_cache * x2)
1460
+ imag = (sin_cache * x1) + (cos_cache * x2)
1461
+ if op.interleaved:
1462
+ real = np.expand_dims(real, axis=-1)
1463
+ imag = np.expand_dims(imag, axis=-1)
1464
+ x_rotate_concat = np.concatenate((real, imag), axis=-1)
1465
+ x_rotate = np.reshape(x_rotate_concat, x_rotate.shape)
1466
+ else:
1467
+ x_rotate = np.concatenate((real, imag), axis=-1)
1468
+ output = np.concatenate((x_rotate, x_not_rotate), axis=-1)
1469
+ if op.input_rank == 4:
1470
+ output = np.transpose(output, (0, 2, 1, 3))
1471
+ else:
1472
+ output = output.reshape(original_shape)
1473
+ evaluator.values[node.outputs[0]] = output
1474
+
1475
+
1187
1476
  def _apply_lstm_activation(
1188
1477
  kind: int, value: np.ndarray, alpha: float, beta: float
1189
1478
  ) -> np.ndarray:
@@ -1382,6 +1671,41 @@ def _eval_quantize_linear(evaluator: Evaluator, node: Node) -> None:
1382
1671
  spec.output_dtype.np_dtype, copy=False
1383
1672
  )
1384
1673
 
1674
+
1675
+ @register_evaluator("QLinearMatMul")
1676
+ def _eval_qlinear_matmul(evaluator: Evaluator, node: Node) -> None:
1677
+ op = lower_qlinear_matmul(evaluator.graph, node)
1678
+ input0 = evaluator.values[op.input0]
1679
+ input1 = evaluator.values[op.input1]
1680
+ input0_scale = evaluator.values[op.input0_scale]
1681
+ input1_scale = evaluator.values[op.input1_scale]
1682
+ output_scale = evaluator.values[op.output_scale]
1683
+ input0_zero_point = evaluator.values[op.input0_zero_point]
1684
+ input1_zero_point = evaluator.values[op.input1_zero_point]
1685
+ output_zero_point = evaluator.values[op.output_zero_point]
1686
+
1687
+ def _scalar_value(array: np.ndarray) -> float:
1688
+ return float(np.asarray(array).reshape(-1)[0])
1689
+
1690
+ def _scalar_int(array: np.ndarray) -> int:
1691
+ return int(np.asarray(array).reshape(-1)[0])
1692
+
1693
+ input0_zero = _scalar_int(input0_zero_point)
1694
+ input1_zero = _scalar_int(input1_zero_point)
1695
+ output_zero = _scalar_int(output_zero_point)
1696
+ scale = _scalar_value(input0_scale) * _scalar_value(
1697
+ input1_scale
1698
+ ) / _scalar_value(output_scale)
1699
+ acc = _apply_matmul(
1700
+ input0.astype(np.int32) - input0_zero,
1701
+ input1.astype(np.int32) - input1_zero,
1702
+ )
1703
+ scaled = acc.astype(np.float64) * scale + output_zero
1704
+ rounded = np.rint(scaled)
1705
+ info = np.iinfo(op.dtype.np_dtype)
1706
+ clipped = np.clip(rounded, info.min, info.max)
1707
+ evaluator.values[op.output] = clipped.astype(op.dtype.np_dtype)
1708
+
1385
1709
  @register_evaluator("InstanceNormalization")
1386
1710
  def _eval_instance_normalization(evaluator: Evaluator, node: Node) -> None:
1387
1711
  op = lower_instance_normalization(evaluator.graph, node)
@@ -2,24 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import numpy as np
4
4
 
5
-
6
- def _float_uint_dtype(values: np.ndarray) -> type[np.unsignedinteger]:
7
- if values.dtype == np.float16:
8
- return np.uint16
9
- if values.dtype == np.float32:
10
- return np.uint32
11
- if values.dtype == np.float64:
12
- return np.uint64
13
- raise ValueError(f"Unsupported floating dtype for ULP calculation: {values.dtype}")
14
-
15
-
16
- def _float_to_ordered_int(values: np.ndarray) -> np.ndarray:
17
- uint_dtype = _float_uint_dtype(values)
18
- bits = np.dtype(uint_dtype).itemsize * 8
19
- sign_mask = np.array(1 << (bits - 1), dtype=uint_dtype)
20
- as_uint = values.view(uint_dtype)
21
- ordered = np.where(as_uint & sign_mask, ~as_uint, as_uint | sign_mask)
22
- return ordered.astype(np.uint64, copy=False)
5
+ from shared.ulp import ulp_intdiff_float
23
6
 
24
7
 
25
8
  def max_ulp_diff(actual: np.ndarray, expected: np.ndarray) -> int:
@@ -34,27 +17,14 @@ def max_ulp_diff(actual: np.ndarray, expected: np.ndarray) -> int:
34
17
  raise ValueError(f"Unsupported floating dtype for ULP calculation: {dtype}")
35
18
  actual_cast = actual.astype(dtype, copy=False)
36
19
  expected_cast = expected.astype(dtype, copy=False)
37
- nan_mask = np.isnan(actual_cast) | np.isnan(expected_cast)
38
- if nan_mask.any():
39
- both_nan = np.isnan(actual_cast) & np.isnan(expected_cast)
40
- if not np.all(both_nan):
41
- uint_dtype = _float_uint_dtype(expected_cast)
42
- return int(np.iinfo(uint_dtype).max)
43
- actual_cast = actual_cast[~nan_mask]
44
- expected_cast = expected_cast[~nan_mask]
45
- if actual_cast.size == 0:
46
- return 0
47
- eps = np.finfo(dtype).eps
48
- near_zero = (np.abs(actual_cast) < eps) & (np.abs(expected_cast) < eps)
49
- if np.any(near_zero):
50
- actual_cast = actual_cast.copy()
51
- expected_cast = expected_cast.copy()
52
- actual_cast[near_zero] = 0
53
- expected_cast[near_zero] = 0
54
- ordered_actual = _float_to_ordered_int(actual_cast)
55
- ordered_expected = _float_to_ordered_int(expected_cast)
56
- deltas = ordered_actual.astype(np.int64) - ordered_expected.astype(np.int64)
57
- return int(np.max(np.abs(deltas)))
20
+ max_diff = 0
21
+ for actual_value, expected_value in np.nditer(
22
+ [actual_cast, expected_cast], flags=["refs_ok"]
23
+ ):
24
+ diff = ulp_intdiff_float(actual_value[()], expected_value[()])
25
+ if diff > max_diff:
26
+ max_diff = diff
27
+ return max_diff
58
28
 
59
29
 
60
30
  def format_success_message(max_ulp: int) -> str:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: emx-onnx-cgen
3
- Version: 0.3.0
3
+ Version: 0.3.2
4
4
  Summary: emmtrix ONNX-to-C Code Generator
5
5
  Requires-Python: >=3.10
6
6
  Description-Content-Type: text/markdown
@@ -71,7 +71,7 @@ Compile an ONNX model into a C source file:
71
71
  emx-onnx-cgen compile path/to/model.onnx build/model.c
72
72
  ```
73
73
 
74
- Verify an ONNX model end-to-end against ONNX Runtime:
74
+ Verify an ONNX model end-to-end against ONNX Runtime (default):
75
75
 
76
76
  ```bash
77
77
  emx-onnx-cgen verify path/to/model.onnx
@@ -93,7 +93,7 @@ Options:
93
93
  - `--model-name`: Override the generated model name (default: output file stem).
94
94
  - `--emit-testbench`: Emit a JSON-producing `main()` testbench for validation.
95
95
  - `--emit-data-file`: Emit constant data arrays into a companion `_data` C file.
96
- - `--large-weight-threshold`: Store weights larger than this element count in a binary file (default: `1024`).
96
+ - `--large-weight-threshold`: Store weights larger than this element count in a binary file (default: `1048576`; set to `0` to disable).
97
97
  - `--large-temp-threshold-bytes`: Mark temporary buffers larger than this threshold as static (default: `1024`).
98
98
  - `--no-restrict-arrays`: Disable `restrict` qualifiers on generated array parameters.
99
99
 
@@ -111,6 +111,7 @@ Options:
111
111
  - `--large-weight-threshold`: Store weights larger than this element count in a binary file (default: `1024`).
112
112
  - `--large-temp-threshold-bytes`: Mark temporary buffers larger than this threshold as static (default: `1024`).
113
113
  - `--max-ulp`: Maximum allowed ULP distance for floating outputs (default: `100`).
114
+ - `--runtime`: Runtime backend for verification (`onnxruntime` or `onnx-reference`, default: `onnx-reference`).
114
115
 
115
116
  How verification works:
116
117
 
@@ -119,14 +120,14 @@ How verification works:
119
120
  2. **Build and execute**: the testbench is compiled with the selected C compiler
120
121
  (`--cc`, `CC`, or a detected `cc/gcc/clang`) and executed in a temporary
121
122
  directory.
122
- 3. **Run ONNX Runtime**: the JSON inputs from the testbench are fed to ORT using
123
- the same model.
123
+ 3. **Run runtime backend**: the JSON inputs from the testbench are fed to the
124
+ selected runtime (`onnxruntime` or `onnx-reference`) using the same model.
124
125
  4. **Compare outputs**: floating outputs are compared by maximum ULP distance
125
126
  (see https://www.emmtrix.com/wiki/ULP_Difference_of_Float_Numbers for the
126
127
  ULP definition and algorithm); non-floating outputs must match exactly.
127
128
  Missing outputs or mismatches are treated as failures.
128
- 5. **ORT unsupported models**: if ORT reports `NOT_IMPLEMENTED`, verification is
129
- skipped with a warning (exit code 0).
129
+ 5. **ORT unsupported models**: when using `onnxruntime`, if ORT reports
130
+ `NOT_IMPLEMENTED`, verification is skipped with a warning (exit code 0).
130
131
 
131
132
  ## Output
132
133
 
@@ -0,0 +1,107 @@
1
+ emx_onnx_cgen/__init__.py,sha256=jUSbu1kJ0krzVTYEcph3jCprBhD7tWNtiSdL6r29KrM,221
2
+ emx_onnx_cgen/__main__.py,sha256=iC1lLVtR6-TmpL6OxXcy3oIntExUtajn9-q627R1XyI,140
3
+ emx_onnx_cgen/_build_info.py,sha256=A7nFhoSa2YUEYMuQb71Vehv5hhDs7VSInIe852k4khc,112
4
+ emx_onnx_cgen/_version.py,sha256=e8NqPtZ8fggRgk3GPrqZ_U_BDV8aSULw1u_Gn9NNbnk,704
5
+ emx_onnx_cgen/cli.py,sha256=7Y9JW-t1PLg25zOizuqyMqwsXbbG9ok99DsYeFSiOFQ,21685
6
+ emx_onnx_cgen/compiler.py,sha256=qXKUQedaQY6A2jX-twte4qVA263T3UtCDlPjvoM5vYU,16513
7
+ emx_onnx_cgen/dtypes.py,sha256=jRx3BBvk0qFW14bngoL1B7L_IRasyNJ4jqhpM5YhcOM,1335
8
+ emx_onnx_cgen/errors.py,sha256=HpOv95mTgr9ZX2gYe1RtwVMbPskh7zkqjU_FgAD-uIM,363
9
+ emx_onnx_cgen/onnx_import.py,sha256=IF7KZGfEP9H4H1fHYjobGbB_381fqD_67KtqZYs9AZ4,9168
10
+ emx_onnx_cgen/onnxruntime_utils.py,sha256=mEsC1x00M1jyBgVBKqnKoqx6H1tdgsFFUy7rbITs3bs,308
11
+ emx_onnx_cgen/ops.py,sha256=qpPOaqsYprlJrhCNLVBZ3XnREBRDdmkXbd1zaAkywOI,16732
12
+ emx_onnx_cgen/testbench.py,sha256=-NbqD1aC7OXvFMLiLzd2IPObenQdHFH85cNxNSB1GeY,640
13
+ emx_onnx_cgen/validation.py,sha256=KFdUdGjQbzTj1szCJcjxnTi8f5l6ywNgCB9abbBpTbM,2360
14
+ emx_onnx_cgen/verification.py,sha256=IrhIMm29R2vEkW1Q8gtoQtscMGxfJRavNRSMJHBAJ5g,1041
15
+ emx_onnx_cgen/codegen/__init__.py,sha256=H_kBdc_w_W-3qdUZJHwKBDns1AeP_Un3-46LW20yLV0,406
16
+ emx_onnx_cgen/codegen/c_emitter.py,sha256=dS-vjjuWT0GHETbV3ipoYedvuvcJB0yGwMZgoQuJe-g,452931
17
+ emx_onnx_cgen/codegen/emitter.py,sha256=udcsqJNr46TFHiyVv5I4wdVH8ll6Bi4VqcR1VvofbnY,92
18
+ emx_onnx_cgen/ir/__init__.py,sha256=fD2D8qxlGoCFJb0m9v6u3XTgzSxDOhB4cfLBiCLovzg,102
19
+ emx_onnx_cgen/ir/context.py,sha256=cM3V6G3zs6VCsABP6TnZ8vvQ7VGwOF1iKtb1hq0WO3g,3356
20
+ emx_onnx_cgen/ir/model.py,sha256=SZ3K8t4dKUqWuXWe5ozApofXx4bdcf4p0WYCdeU-mFA,1265
21
+ emx_onnx_cgen/ir/op_base.py,sha256=mHvp0VD55JIrwQI2MFEmSILi22kuurBX085aamcjQ0g,6160
22
+ emx_onnx_cgen/ir/op_context.py,sha256=9CZCUNJLsV4cJsYmJqWbaDrwQd4sr-9Ot1PmPSqGAto,2103
23
+ emx_onnx_cgen/ir/ops/__init__.py,sha256=IcllGXB4T3TCrpBq9cy3jR_edS_IJ_qXac37K_rIZcA,2440
24
+ emx_onnx_cgen/ir/ops/elementwise.py,sha256=sZ1S6X_fagNDevN6dXHBy75g_z-WP_dHFAVmPGnmeaU,3721
25
+ emx_onnx_cgen/ir/ops/misc.py,sha256=1ekAgV5j6Stc1Yw8e-0EPD5t8mI1YJxmyIkAn9Zr4h8,10920
26
+ emx_onnx_cgen/ir/ops/nn.py,sha256=-4ZqDkcu7zgci3YVfMzCDzokqpZHgOYZaq_C1GclBZQ,14365
27
+ emx_onnx_cgen/ir/ops/reduce.py,sha256=-aA4bwOMppd9pnWQwhl6hOxryh0G2xRaHqeNwQ97AdY,2756
28
+ emx_onnx_cgen/lowering/__init__.py,sha256=AxnUfmpf5Teos1ms3zE6r0EBxxPYznGSOICDEFWH_pk,1535
29
+ emx_onnx_cgen/lowering/adagrad.py,sha256=DuW3MeNNJjhXz1k7XI9JDwfgWr-TyD5Q-B9eAZrNecM,4797
30
+ emx_onnx_cgen/lowering/arg_reduce.py,sha256=7dvlOItEp_Mtxj-lohI_mNRqHFZZnGCsdfx8ON0i2F0,3377
31
+ emx_onnx_cgen/lowering/attention.py,sha256=-Il_8AQMuwQtq-2-RkVyVfnvtRJuO61Cv1PlMIypxEc,16477
32
+ emx_onnx_cgen/lowering/average_pool.py,sha256=kcaOBPNaVMITY7gprbJSIMRrwhgIbeI3OEVxzO1xRM0,8074
33
+ emx_onnx_cgen/lowering/batch_normalization.py,sha256=_i-vwlhuAQYqxJIezHaxeqcmISV66Y_5o929_FTtMZg,3976
34
+ emx_onnx_cgen/lowering/cast.py,sha256=J2Tf7MprIcZjsgVLGsaccpbyvftfXfm57o--Il-8GlQ,2841
35
+ emx_onnx_cgen/lowering/common.py,sha256=lQVBapOlo3w0ats2R2kPwftuTYMc8aAsQuLWrcQQ_pM,16783
36
+ emx_onnx_cgen/lowering/concat.py,sha256=aY1QjCBzmyxDmfybyzRRSEPiL3hR1JwtCCXvHA7vFDE,1086
37
+ emx_onnx_cgen/lowering/constant_of_shape.py,sha256=N01UvbVroDk08FTbBMndrLYIzI0G6M0UQuCr4oxpP40,3197
38
+ emx_onnx_cgen/lowering/conv.py,sha256=9VFdsChsJ_AL25mhe2H482Aa2-89-S1dSJpiu7ixgQg,7298
39
+ emx_onnx_cgen/lowering/conv_transpose.py,sha256=10K7nhQ60p0PAB3qxmeazm2tbsSS1GDeINBk7VzsH1U,11153
40
+ emx_onnx_cgen/lowering/cumsum.py,sha256=9E0C5NtvPt6g5T4QLdIOeDkXaZNzyDklus2-qu2B7eA,4114
41
+ emx_onnx_cgen/lowering/depth_space.py,sha256=i7INioNkofBxFlZW9y0W_qA6mp67_FAXouhKCiB9RKc,4206
42
+ emx_onnx_cgen/lowering/dropout.py,sha256=MZ4YrB-jvUFXpIKE5kOLyrEF5uy5dh0yjJH6Rj8KlMs,1764
43
+ emx_onnx_cgen/lowering/einsum.py,sha256=MWAgWVOzP38RSOxJABwvYU6ykD9odmhrmddXinmFs7s,6117
44
+ emx_onnx_cgen/lowering/elementwise.py,sha256=q9X3qTll7gLp39NTTdzuLs9RBsONssw50l1hWo8wby0,12229
45
+ emx_onnx_cgen/lowering/expand.py,sha256=GmYJZWXXcBV42hMGUgbKKbLjeCxpbcMSoG9OU1ZkFFY,5518
46
+ emx_onnx_cgen/lowering/eye_like.py,sha256=QBiHWYZbgK4uiUYWuS7WHCMBGMSG0paNZM84OYmGb7c,1723
47
+ emx_onnx_cgen/lowering/flatten.py,sha256=6h-TQNy9iq5hfXR9h2clUrc2eHmZP9gAb9KbCSJdV20,2131
48
+ emx_onnx_cgen/lowering/gather.py,sha256=PCER36AjmpxzAM4wuL7En3XR1RKZCdSzjxualDCUHAI,1803
49
+ emx_onnx_cgen/lowering/gather_elements.py,sha256=cCp2UFOjktgEfS9s9npMS_BXklBkpMpD7UhIIMhQ-_Y,2318
50
+ emx_onnx_cgen/lowering/gather_nd.py,sha256=rmr_ijeSeCrZ_R_QPwdoHPQUCe8nE0YRSv2NjUiiFjY,3090
51
+ emx_onnx_cgen/lowering/gemm.py,sha256=qBaZ-6FZAAMEaZ4uifo58tJI8SoBsJvkZTCg7jvq288,4579
52
+ emx_onnx_cgen/lowering/global_max_pool.py,sha256=RMjaspdwThsHFGq_CJ2lUo5MOZc4NtmG-W5zshhc85A,2212
53
+ emx_onnx_cgen/lowering/grid_sample.py,sha256=FFbK-jrjqFLwSUu7BfSZC9So7MeCZprGKG5N4XQUxR4,5217
54
+ emx_onnx_cgen/lowering/group_normalization.py,sha256=Ep7toUW9sHvMHb2EwNpgayygTW-TN62ooVLdaF0z9_c,2653
55
+ emx_onnx_cgen/lowering/hardmax.py,sha256=PKY7w_4N6qzJq_l1O3le8J-uspPPK3Ujpl6Kdmt4tOU,1950
56
+ emx_onnx_cgen/lowering/identity.py,sha256=zzmmSz1NTiRAPIZqU81qnNQFuuSJq6EvqbUOt1Hc3gA,1848
57
+ emx_onnx_cgen/lowering/instance_normalization.py,sha256=XrDOAo8Af7yDObtAAJ006dVCN175cWPb5Wvh61PE7xs,1939
58
+ emx_onnx_cgen/lowering/layer_normalization.py,sha256=RjRn1sPFupB8n3RsA8O9p5vDmfmj2Q6hjMVhSFzfLkU,4518
59
+ emx_onnx_cgen/lowering/logsoftmax.py,sha256=giFEKQKN7xxlQqV64HNvO1QQobjM-IgavWJi7DT5pJk,1884
60
+ emx_onnx_cgen/lowering/lp_normalization.py,sha256=il1fBWan8DwZ3dlRVSJWVhMpzHDYtwjh1YJaNm6palY,1701
61
+ emx_onnx_cgen/lowering/lp_pool.py,sha256=aG-J6xwhprMJIXTNXwA781XfbBnUD0oh9_POwwEEAe4,4862
62
+ emx_onnx_cgen/lowering/lrn.py,sha256=rJ_7ISllYphbHKmlMv3c5IwqPl-oZrEKWux7QCdjqIQ,3359
63
+ emx_onnx_cgen/lowering/lstm.py,sha256=RVe0qGesoK-FfWeV0vCKCkoWD32Fv_C22LnQLFLr4Tc,12294
64
+ emx_onnx_cgen/lowering/matmul.py,sha256=CpxloKLXX7u5SofOTYUTt8vU9IkD7h25VByQbLwkGiw,4248
65
+ emx_onnx_cgen/lowering/maxpool.py,sha256=0XoazajqrB5So-hEnR73LOSsdF7ZnguVNAc9NSjK6Q4,7483
66
+ emx_onnx_cgen/lowering/mean_variance_normalization.py,sha256=tFeDgrocZO5Q5hNBaFl4cTFpKTPNVmRH9-FZircEffA,1864
67
+ emx_onnx_cgen/lowering/negative_log_likelihood_loss.py,sha256=J5VfAQN2bIrt8D4_6KIGxRBk4Q9ykJvlqJftCrqy-jc,9333
68
+ emx_onnx_cgen/lowering/non_max_suppression.py,sha256=9EeHm2aF7QBmP-s23r43VDgRvGyFWcNcI1s_jYPqln0,5749
69
+ emx_onnx_cgen/lowering/nonzero.py,sha256=qjDlI_0s37hK-arOD-Bm_Ft9N_gTVt0X3OEqxuP1sR0,1626
70
+ emx_onnx_cgen/lowering/one_hot.py,sha256=JGJsA35Q5hyX7nutNVJMGgTgcFxlAlolH9k4igVc2s0,4341
71
+ emx_onnx_cgen/lowering/pad.py,sha256=Z8361NQCwypKfTnS8-0rylX6P-S8xLU6QLbahVzxrzw,10405
72
+ emx_onnx_cgen/lowering/qlinear_matmul.py,sha256=gsV8CAB9_PhPuCGBYEvqfhby3uHQ6-4lyfDI2Xgvw0c,7899
73
+ emx_onnx_cgen/lowering/quantize_linear.py,sha256=yJOvZbGxI8HcZ_Zl9VO49qJVfZ5FwNoDq5TjTiGzKmg,4760
74
+ emx_onnx_cgen/lowering/range.py,sha256=yaRvLHLlWNvvg-IO590jSVPv2dWrJjPWXyysSNOj0IY,3452
75
+ emx_onnx_cgen/lowering/reduce.py,sha256=W_wa1ev2tD8gqSTTQX5K0brwvB_x1kqf9vo8R5HHN6k,18402
76
+ emx_onnx_cgen/lowering/registry.py,sha256=tNmnP6ZhIrKv83Q6VdfkTLSsw6P8cqch-nqSWpURYX8,2002
77
+ emx_onnx_cgen/lowering/reshape.py,sha256=L5h-u7DbrRzuPucDFwXw-oCX8bikD9R2RBkz9lTwEBM,13441
78
+ emx_onnx_cgen/lowering/resize.py,sha256=XCTUppSDj9-GyztBORIuK1MJMxelA3DU_NZzfsVIlgQ,14633
79
+ emx_onnx_cgen/lowering/rms_normalization.py,sha256=pWu5u0TqHZaL3rh07MtA6eOP0zLzNCoQ84f1V0un2Iw,2525
80
+ emx_onnx_cgen/lowering/rotary_embedding.py,sha256=IfDxuUCJqFIK8SCviYXZfdJcrgg8tjT2ofYFUP2uv8c,6068
81
+ emx_onnx_cgen/lowering/scatter_nd.py,sha256=WuNxsMQmCTXgqen5rygpAbZIsfca537lvvFPakn0rJU,3210
82
+ emx_onnx_cgen/lowering/shape.py,sha256=r68BQSK2ldY6ct3iPfkpJsGySmaViOGqn3Mi3qoPTTM,2224
83
+ emx_onnx_cgen/lowering/size.py,sha256=Mfj2x0zvDrhMAcmhXI5F63dzd3w3ZT2IxfI0jMbTSuQ,1250
84
+ emx_onnx_cgen/lowering/slice.py,sha256=rMzmQ5nwaM8oJVmChOeVHPJn0qMXYpKZCPzO_eoEj_M,14805
85
+ emx_onnx_cgen/lowering/softmax.py,sha256=mImrc0oeFpMywsx94PdDS1aQVj7DUgqoFhjsMjLLDdk,1863
86
+ emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py,sha256=B6h23sGBZLdpKcbtoQUhVwfLrdSJwNcbCoPoDc3rTc0,5219
87
+ emx_onnx_cgen/lowering/split.py,sha256=w4OPi4X-xoJgmTJuCTIfp0Dm7wd2NLZZ6AJM-jUROFg,5883
88
+ emx_onnx_cgen/lowering/squeeze.py,sha256=p9bER1Jkc8_6BGjsD3b7zhuak11eywoQhVFIvJ9Vzj0,6084
89
+ emx_onnx_cgen/lowering/tensor_scatter.py,sha256=1Wqb9XsNNj1CEKnH3Vx45xh3QQbxHF9L90ycVbcsy44,4485
90
+ emx_onnx_cgen/lowering/tile.py,sha256=PJva752IM55f8evZzOz12Y7PUGsQ0kC_YR86SqowWqM,3030
91
+ emx_onnx_cgen/lowering/topk.py,sha256=Dqx7qMr4HbXhVGN-wJf_D4dPTvYMVT6S82A2M3f9Dwo,4819
92
+ emx_onnx_cgen/lowering/transpose.py,sha256=oNFRjkH63KqnO2Q4oJengEAUEYC1M3PW12AauWwebzI,1751
93
+ emx_onnx_cgen/lowering/trilu.py,sha256=OjJjyo2ZRcfo9UGH8Zfq4o0PR6YDeoHSj8DzMu0w318,3266
94
+ emx_onnx_cgen/lowering/unsqueeze.py,sha256=9y-OM-oY6ln1-R6duRRemeRrwBIpX2TZs_nRtlYQMYE,5985
95
+ emx_onnx_cgen/lowering/variadic.py,sha256=etIWA7jVqWrWH3NkNvpF5opVGgvb0ZS4iLo4L3euWDs,3287
96
+ emx_onnx_cgen/lowering/where.py,sha256=K2RUDvLg0uTvi6Z_uTOXM5jgc3PXRj0cTZ4u58GEGko,2644
97
+ emx_onnx_cgen/runtime/__init__.py,sha256=88xGpAs1IEBlzlWL_e9tnKUlaSRdc7pQUeVCu5LC4DY,50
98
+ emx_onnx_cgen/runtime/evaluator.py,sha256=yqsBpAIlBky-rby7J5z7i1SvDaK6PjObxH-wQSdZ2G0,114732
99
+ shared/__init__.py,sha256=bmP79AVZdY_1aNULJap9pm76Q41Rabrza6X-0A8lDzw,45
100
+ shared/scalar_functions.py,sha256=CErro1Du2Ri3uqX6Dgd18DzNbxduckAvsmLJ6oHGx9A,91123
101
+ shared/scalar_types.py,sha256=kEpsl5T-NVFxCcTzXqPJbtpvDiCgKHfz91dphLLZxZA,4912
102
+ shared/ulp.py,sha256=DpeovCFijmP8_M7zyTZWsNyfOtJ1AjNSdxf5jGsdfJo,1856
103
+ emx_onnx_cgen-0.3.2.dist-info/METADATA,sha256=K_7vi0Tqx4-r94xZ2WliM4PXIfApgaI8C5a1_UgIjZE,6266
104
+ emx_onnx_cgen-0.3.2.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
105
+ emx_onnx_cgen-0.3.2.dist-info/entry_points.txt,sha256=b7Rvmz_Bi9kWyn7QayQC_FEXiRpt4cS1RnluKh49yoo,57
106
+ emx_onnx_cgen-0.3.2.dist-info/top_level.txt,sha256=g39fo-blEbgiVcC_GRqAnBzN234w3LXbcVdLUoItSLk,21
107
+ emx_onnx_cgen-0.3.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -2059,6 +2059,17 @@ def _bool_from_ops(name: str) -> _GeneratedScalar:
2059
2059
 
2060
2060
 
2061
2061
  _SCALAR_TYPES: Dict[ScalarType, _ScalarTypeInfo] = {
2062
+ ScalarType.F16: _ScalarTypeInfo(
2063
+ scalar_type=ScalarType.F16,
2064
+ c_type="_Float16",
2065
+ prefix="ref_scalar_f16_",
2066
+ suffix="f16",
2067
+ is_float=True,
2068
+ is_bool=False,
2069
+ is_signed=True,
2070
+ is_small_int=False,
2071
+ bits=None,
2072
+ ),
2062
2073
  ScalarType.F32: _ScalarTypeInfo(
2063
2074
  scalar_type=ScalarType.F32,
2064
2075
  c_type="float",