emx-onnx-cgen 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +50 -23
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +1844 -1568
  6. emx_onnx_cgen/codegen/emitter.py +5 -0
  7. emx_onnx_cgen/compiler.py +30 -387
  8. emx_onnx_cgen/ir/context.py +87 -0
  9. emx_onnx_cgen/ir/op_base.py +193 -0
  10. emx_onnx_cgen/ir/op_context.py +65 -0
  11. emx_onnx_cgen/ir/ops/__init__.py +130 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +146 -0
  13. emx_onnx_cgen/ir/ops/misc.py +421 -0
  14. emx_onnx_cgen/ir/ops/nn.py +580 -0
  15. emx_onnx_cgen/ir/ops/reduce.py +95 -0
  16. emx_onnx_cgen/lowering/__init__.py +79 -1
  17. emx_onnx_cgen/lowering/adagrad.py +114 -0
  18. emx_onnx_cgen/lowering/arg_reduce.py +1 -1
  19. emx_onnx_cgen/lowering/attention.py +1 -1
  20. emx_onnx_cgen/lowering/average_pool.py +1 -1
  21. emx_onnx_cgen/lowering/batch_normalization.py +1 -1
  22. emx_onnx_cgen/lowering/cast.py +1 -1
  23. emx_onnx_cgen/lowering/common.py +36 -18
  24. emx_onnx_cgen/lowering/concat.py +1 -1
  25. emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
  26. emx_onnx_cgen/lowering/conv.py +1 -1
  27. emx_onnx_cgen/lowering/conv_transpose.py +1 -1
  28. emx_onnx_cgen/lowering/cumsum.py +1 -1
  29. emx_onnx_cgen/lowering/depth_space.py +1 -1
  30. emx_onnx_cgen/lowering/dropout.py +1 -1
  31. emx_onnx_cgen/lowering/einsum.py +1 -1
  32. emx_onnx_cgen/lowering/elementwise.py +152 -4
  33. emx_onnx_cgen/lowering/expand.py +1 -1
  34. emx_onnx_cgen/lowering/eye_like.py +1 -1
  35. emx_onnx_cgen/lowering/flatten.py +1 -1
  36. emx_onnx_cgen/lowering/gather.py +1 -1
  37. emx_onnx_cgen/lowering/gather_elements.py +1 -1
  38. emx_onnx_cgen/lowering/gather_nd.py +1 -1
  39. emx_onnx_cgen/lowering/gemm.py +1 -1
  40. emx_onnx_cgen/lowering/global_max_pool.py +1 -1
  41. emx_onnx_cgen/lowering/grid_sample.py +1 -1
  42. emx_onnx_cgen/lowering/group_normalization.py +1 -1
  43. emx_onnx_cgen/lowering/hardmax.py +1 -1
  44. emx_onnx_cgen/lowering/identity.py +1 -1
  45. emx_onnx_cgen/lowering/instance_normalization.py +1 -1
  46. emx_onnx_cgen/lowering/layer_normalization.py +1 -1
  47. emx_onnx_cgen/lowering/logsoftmax.py +1 -1
  48. emx_onnx_cgen/lowering/lp_normalization.py +1 -1
  49. emx_onnx_cgen/lowering/lp_pool.py +1 -1
  50. emx_onnx_cgen/lowering/lrn.py +1 -1
  51. emx_onnx_cgen/lowering/lstm.py +1 -1
  52. emx_onnx_cgen/lowering/matmul.py +1 -1
  53. emx_onnx_cgen/lowering/maxpool.py +1 -1
  54. emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
  55. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +1 -1
  56. emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
  57. emx_onnx_cgen/lowering/nonzero.py +1 -1
  58. emx_onnx_cgen/lowering/one_hot.py +1 -1
  59. emx_onnx_cgen/lowering/pad.py +1 -1
  60. emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
  61. emx_onnx_cgen/lowering/quantize_linear.py +1 -1
  62. emx_onnx_cgen/lowering/range.py +1 -1
  63. emx_onnx_cgen/lowering/reduce.py +1 -1
  64. emx_onnx_cgen/lowering/registry.py +24 -5
  65. emx_onnx_cgen/lowering/reshape.py +1 -1
  66. emx_onnx_cgen/lowering/resize.py +1 -1
  67. emx_onnx_cgen/lowering/rms_normalization.py +1 -1
  68. emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
  69. emx_onnx_cgen/lowering/scatter_nd.py +1 -1
  70. emx_onnx_cgen/lowering/shape.py +6 -25
  71. emx_onnx_cgen/lowering/size.py +1 -1
  72. emx_onnx_cgen/lowering/slice.py +1 -1
  73. emx_onnx_cgen/lowering/softmax.py +1 -1
  74. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
  75. emx_onnx_cgen/lowering/split.py +1 -1
  76. emx_onnx_cgen/lowering/squeeze.py +1 -1
  77. emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
  78. emx_onnx_cgen/lowering/tile.py +1 -1
  79. emx_onnx_cgen/lowering/topk.py +25 -7
  80. emx_onnx_cgen/lowering/transpose.py +1 -1
  81. emx_onnx_cgen/lowering/trilu.py +1 -1
  82. emx_onnx_cgen/lowering/unsqueeze.py +1 -1
  83. emx_onnx_cgen/lowering/variadic.py +1 -1
  84. emx_onnx_cgen/lowering/where.py +1 -1
  85. emx_onnx_cgen/runtime/evaluator.py +325 -1
  86. emx_onnx_cgen/verification.py +9 -39
  87. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/METADATA +8 -7
  88. emx_onnx_cgen-0.3.2.dist-info/RECORD +107 -0
  89. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/WHEEL +1 -1
  90. shared/scalar_functions.py +11 -0
  91. shared/ulp.py +17 -0
  92. emx_onnx_cgen-0.3.0.dist-info/RECORD +0 -93
  93. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/entry_points.txt +0 -0
  94. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,5 @@
1
+ from __future__ import annotations
2
+
3
+ from .c_emitter import CEmitter
4
+
5
+ __all__ = ["CEmitter"]
emx_onnx_cgen/compiler.py CHANGED
@@ -12,161 +12,22 @@ from shared.scalar_types import ScalarType
12
12
 
13
13
  from .onnxruntime_utils import make_deterministic_session_options
14
14
  from .codegen.c_emitter import (
15
- AttentionOp,
16
- AveragePoolOp,
17
- BatchNormOp,
18
- LpNormalizationOp,
19
- InstanceNormalizationOp,
20
- GroupNormalizationOp,
21
- LayerNormalizationOp,
22
- MeanVarianceNormalizationOp,
23
- RMSNormalizationOp,
24
- BinaryOp,
25
- MultiInputBinaryOp,
26
- CastOp,
27
- ClipOp,
28
15
  CEmitter,
29
16
  ConstTensor,
30
- ConvOp,
31
- ConvTransposeOp,
32
- ConcatOp,
33
- ConstantOfShapeOp,
34
- CumSumOp,
35
- GemmOp,
36
- GatherOp,
37
- GatherElementsOp,
38
- GatherNDOp,
39
- ScatterNDOp,
40
- ExpandOp,
41
- RangeOp,
42
- OneHotOp,
43
- LpPoolOp,
44
- QuantizeLinearOp,
45
- LrnOp,
46
- LstmOp,
47
- LogSoftmaxOp,
48
- HardmaxOp,
49
- NegativeLogLikelihoodLossOp,
50
- NonZeroOp,
51
- NodeInfo,
52
- PadOp,
53
- SplitOp,
54
- SoftmaxCrossEntropyLossOp,
55
17
  LoweredModel,
56
18
  ModelHeader,
57
- MatMulOp,
58
- MaxPoolOp,
59
- ReduceOp,
60
- ArgReduceOp,
61
- ReshapeOp,
62
- ResizeOp,
63
- GridSampleOp,
64
- HardmaxOp,
65
- SoftmaxOp,
66
- ShapeOp,
67
- SliceOp,
68
- TransposeOp,
69
- UnaryOp,
70
- WhereOp,
19
+ NodeInfo,
71
20
  )
72
21
  from .dtypes import dtype_info
73
22
  from .errors import CodegenError, ShapeInferenceError, UnsupportedOpError
23
+ from .ir.context import GraphContext
74
24
  from .ir.model import Graph, TensorType, Value
75
- from .lowering.attention import AttentionSpec, resolve_attention_spec
76
- from .lowering.average_pool import (
77
- lower_average_pool,
78
- lower_global_average_pool,
79
- )
80
- from .lowering import global_max_pool as _global_max_pool # noqa: F401
81
- from .lowering.batch_normalization import lower_batch_normalization
82
- from .lowering.cast import lower_cast
83
- from .lowering.concat import lower_concat
84
- from .lowering.common import (
85
- ensure_supported_dtype,
86
- node_dtype,
87
- shape_product,
88
- value_dtype,
89
- value_shape,
90
- )
91
- from .lowering.conv import ConvSpec, resolve_conv_spec
92
- from .lowering import conv_transpose as _conv_transpose # noqa: F401
93
- from .lowering.constant_of_shape import lower_constant_of_shape
94
- from .lowering.dropout import lower_dropout
95
- from .lowering import cumsum as _cumsum # noqa: F401
96
- from .lowering import einsum as _einsum # noqa: F401
97
- from .lowering.flatten import lower_flatten
98
- from .lowering.gather import lower_gather
99
- from .lowering.gather_elements import lower_gather_elements
100
- from .lowering.gather_nd import lower_gather_nd
101
- from .lowering import scatter_nd as _scatter_nd # noqa: F401
102
- from .lowering.gemm import resolve_gemm_spec, validate_gemm_bias_shape
103
- from .lowering.lrn import LrnSpec, resolve_lrn_spec
104
- from .lowering.logsoftmax import lower_logsoftmax
105
- from .lowering import hardmax as _hardmax # noqa: F401
106
- from .lowering import group_normalization as _group_normalization # noqa: F401
107
- from .lowering import instance_normalization as _instance_normalization # noqa: F401
108
- from .lowering import layer_normalization as _layer_normalization # noqa: F401
109
- from .lowering import lp_normalization as _lp_normalization # noqa: F401
110
- from .lowering import lp_pool as _lp_pool # noqa: F401
111
- from .lowering import mean_variance_normalization as _mean_variance_normalization # noqa: F401
112
- from .lowering.negative_log_likelihood_loss import (
113
- lower_negative_log_likelihood_loss,
114
- )
115
- from .lowering import nonzero as _nonzero # noqa: F401
116
- from .lowering.expand import lower_expand
117
- from .lowering.range import lower_range
118
- from .lowering import one_hot as _one_hot # noqa: F401
119
- from .lowering.split import lower_split
120
- from .lowering.softmax_cross_entropy_loss import (
121
- lower_softmax_cross_entropy_loss,
122
- )
123
- from .lowering.matmul import lower_matmul
124
- from .lowering.maxpool import MaxPoolSpec, resolve_maxpool_spec
125
- from .lowering import pad as _pad # noqa: F401
126
- from .lowering.reduce import (
127
- REDUCE_KIND_BY_OP,
128
- REDUCE_OUTPUTS_FLOAT_ONLY,
129
- )
130
- from .lowering import arg_reduce as _arg_reduce # noqa: F401
131
- from .lowering import topk as _topk # noqa: F401
132
- from .lowering.reshape import lower_reshape
133
- from .lowering.resize import lower_resize
134
- from .lowering.grid_sample import lower_grid_sample
135
- from .lowering import quantize_linear as _quantize_linear # noqa: F401
136
- from .lowering.slice import lower_slice
137
- from .lowering.squeeze import lower_squeeze
138
- from .lowering import depth_space as _depth_space # noqa: F401
139
- from .lowering import eye_like as _eye_like # noqa: F401
140
- from .lowering import identity as _identity # noqa: F401
141
- from .lowering import tile as _tile # noqa: F401
142
- from .lowering import trilu as _trilu # noqa: F401
143
- from .lowering.shape import lower_shape
144
- from .lowering.size import lower_size
145
- from .lowering.softmax import lower_softmax
146
- from .lowering.transpose import lower_transpose
147
- from .lowering.unsqueeze import lower_unsqueeze
148
- from .lowering.where import lower_where
149
- from .lowering.elementwise import (
150
- lower_celu,
151
- lower_clip,
152
- lower_isinf,
153
- lower_isnan,
154
- lower_shrink,
155
- lower_swish,
156
- )
157
- from .lowering import variadic as _variadic # noqa: F401
158
- from .lowering import rms_normalization as _rms_normalization # noqa: F401
159
- from .lowering.registry import get_lowering_registry, resolve_dispatch
25
+ from .ir.op_base import OpBase
26
+ from .ir.op_context import OpContext
27
+ from .lowering import load_lowering_registry
28
+ from .lowering.common import ensure_supported_dtype, shape_product, value_dtype
29
+ from .lowering.registry import get_lowering_registry
160
30
  from .onnx_import import import_onnx
161
- from .ops import (
162
- BINARY_OP_TYPES,
163
- COMPARE_FUNCTIONS,
164
- UNARY_OP_TYPES,
165
- binary_op_symbol,
166
- unary_op_symbol,
167
- validate_unary_attrs,
168
- )
169
- from shared.scalar_functions import ScalarFunction, ScalarFunctionError
170
31
  from .runtime.evaluator import Evaluator
171
32
 
172
33
 
@@ -181,7 +42,7 @@ class CompilerOptions:
181
42
  testbench_inputs: Mapping[str, np.ndarray] | None = None
182
43
  truncate_weights_after: int | None = None
183
44
  large_temp_threshold_bytes: int = 1024
184
- large_weight_threshold: int = 1024
45
+ large_weight_threshold: int = 1024 * 1024
185
46
 
186
47
 
187
48
  def _onnx_elem_type(dtype: np.dtype) -> int:
@@ -203,6 +64,7 @@ class Compiler:
203
64
  large_temp_threshold_bytes=options.large_temp_threshold_bytes,
204
65
  large_weight_threshold=options.large_weight_threshold,
205
66
  )
67
+ load_lowering_registry()
206
68
 
207
69
  def compile(self, model: onnx.ModelProto) -> str:
208
70
  graph = import_onnx(model)
@@ -297,7 +159,8 @@ class Compiler:
297
159
  return collect(graph.inputs), collect(graph.outputs)
298
160
 
299
161
  def _lower_model(self, model: onnx.ModelProto, graph: Graph) -> LoweredModel:
300
- constants = _lowered_constants(graph)
162
+ ctx = GraphContext(graph)
163
+ constants = _lowered_constants(ctx)
301
164
  self._validate_graph(graph)
302
165
  (
303
166
  input_names,
@@ -307,7 +170,14 @@ class Compiler:
307
170
  output_shapes,
308
171
  output_dtypes,
309
172
  ) = self._collect_io_specs(graph)
310
- ops, node_infos = self._lower_nodes(graph)
173
+ ops, node_infos = self._lower_nodes(ctx)
174
+ op_ctx = OpContext(ctx)
175
+ for op in ops:
176
+ op.validate(op_ctx)
177
+ for op in ops:
178
+ op.infer_types(op_ctx)
179
+ for op in ops:
180
+ op.infer_shapes(op_ctx)
311
181
  header = self._build_header(model, graph)
312
182
  return LoweredModel(
313
183
  name=self._options.model_name,
@@ -321,6 +191,7 @@ class Compiler:
321
191
  ops=tuple(ops),
322
192
  node_infos=tuple(node_infos),
323
193
  header=header,
194
+ op_context=op_ctx,
324
195
  )
325
196
 
326
197
  def _resolve_testbench_inputs(
@@ -478,122 +349,16 @@ class Compiler:
478
349
  )
479
350
 
480
351
  def _lower_nodes(
481
- self, graph: Graph
482
- ) -> tuple[
483
- list[
484
- BinaryOp
485
- | MultiInputBinaryOp
486
- | UnaryOp
487
- | ClipOp
488
- | CastOp
489
- | QuantizeLinearOp
490
- | MatMulOp
491
- | GemmOp
492
- | AttentionOp
493
- | ConvOp
494
- | ConvTransposeOp
495
- | AveragePoolOp
496
- | LpPoolOp
497
- | BatchNormOp
498
- | LpNormalizationOp
499
- | InstanceNormalizationOp
500
- | GroupNormalizationOp
501
- | LayerNormalizationOp
502
- | MeanVarianceNormalizationOp
503
- | RMSNormalizationOp
504
- | LrnOp
505
- | LstmOp
506
- | SoftmaxOp
507
- | LogSoftmaxOp
508
- | HardmaxOp
509
- | NegativeLogLikelihoodLossOp
510
- | SoftmaxCrossEntropyLossOp
511
- | MaxPoolOp
512
- | ConcatOp
513
- | GatherElementsOp
514
- | GatherOp
515
- | GatherNDOp
516
- | ScatterNDOp
517
- | TransposeOp
518
- | ConstantOfShapeOp
519
- | ReshapeOp
520
- | SliceOp
521
- | ResizeOp
522
- | GridSampleOp
523
- | ReduceOp
524
- | ArgReduceOp
525
- | ShapeOp
526
- | PadOp
527
- | NonZeroOp
528
- | ExpandOp
529
- | CumSumOp
530
- | RangeOp
531
- | OneHotOp
532
- | SplitOp
533
- ],
534
- list[NodeInfo],
535
- ]:
536
- ops: list[
537
- BinaryOp
538
- | MultiInputBinaryOp
539
- | UnaryOp
540
- | ClipOp
541
- | CastOp
542
- | QuantizeLinearOp
543
- | MatMulOp
544
- | GemmOp
545
- | AttentionOp
546
- | ConvOp
547
- | ConvTransposeOp
548
- | AveragePoolOp
549
- | LpPoolOp
550
- | BatchNormOp
551
- | LpNormalizationOp
552
- | InstanceNormalizationOp
553
- | GroupNormalizationOp
554
- | LayerNormalizationOp
555
- | MeanVarianceNormalizationOp
556
- | RMSNormalizationOp
557
- | LrnOp
558
- | LstmOp
559
- | SoftmaxOp
560
- | LogSoftmaxOp
561
- | HardmaxOp
562
- | NegativeLogLikelihoodLossOp
563
- | SoftmaxCrossEntropyLossOp
564
- | MaxPoolOp
565
- | ConcatOp
566
- | GatherElementsOp
567
- | GatherOp
568
- | GatherNDOp
569
- | TransposeOp
570
- | ConstantOfShapeOp
571
- | ReshapeOp
572
- | SliceOp
573
- | ResizeOp
574
- | ReduceOp
575
- | ArgReduceOp
576
- | ShapeOp
577
- | PadOp
578
- | NonZeroOp
579
- | ExpandOp
580
- | CumSumOp
581
- | RangeOp
582
- | OneHotOp
583
- | SplitOp
584
- | WhereOp
585
- ] = []
352
+ self, ctx: GraphContext
353
+ ) -> tuple[list[OpBase], list[NodeInfo]]:
354
+ ops: list[OpBase] = []
586
355
  node_infos: list[NodeInfo] = []
587
- for node in graph.nodes:
588
- lowering = resolve_dispatch(
589
- node.op_type,
590
- get_lowering_registry(),
591
- binary_types=BINARY_OP_TYPES,
592
- unary_types=UNARY_OP_TYPES,
593
- binary_fallback=lambda: _lower_binary_unary,
594
- unary_fallback=lambda: _lower_binary_unary,
595
- )
596
- ops.append(lowering(graph, node))
356
+ registry = get_lowering_registry()
357
+ for node in ctx.nodes:
358
+ lowering = registry.get(node.op_type)
359
+ if lowering is None:
360
+ raise UnsupportedOpError(f"Unsupported op {node.op_type}")
361
+ ops.append(lowering(ctx, node))
597
362
  node_infos.append(
598
363
  NodeInfo(
599
364
  op_type=node.op_type,
@@ -644,7 +409,7 @@ class Compiler:
644
409
  return evaluator.run(feeds)
645
410
 
646
411
 
647
- def _lowered_constants(graph: Graph) -> tuple[ConstTensor, ...]:
412
+ def _lowered_constants(graph: Graph | GraphContext) -> tuple[ConstTensor, ...]:
648
413
  constants: list[ConstTensor] = []
649
414
  for initializer in graph.initializers:
650
415
  dtype = ensure_supported_dtype(initializer.type.dtype)
@@ -660,125 +425,3 @@ def _lowered_constants(graph: Graph) -> tuple[ConstTensor, ...]:
660
425
  )
661
426
  )
662
427
  return tuple(constants)
663
-
664
-
665
- def _lower_binary_unary(graph: Graph, node: Node) -> BinaryOp | UnaryOp:
666
- if node.op_type == "BitShift":
667
- if len(node.inputs) != 2 or len(node.outputs) != 1:
668
- raise UnsupportedOpError("BitShift must have 2 inputs and 1 output")
669
- direction_attr = node.attrs.get("direction", "LEFT")
670
- if isinstance(direction_attr, bytes):
671
- direction = direction_attr.decode()
672
- else:
673
- direction = str(direction_attr)
674
- if direction not in {"LEFT", "RIGHT"}:
675
- raise UnsupportedOpError(
676
- "BitShift direction must be LEFT or RIGHT"
677
- )
678
- op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
679
- if not op_dtype.is_integer:
680
- raise UnsupportedOpError("BitShift expects integer inputs")
681
- function = (
682
- ScalarFunction.BITWISE_LEFT_SHIFT
683
- if direction == "LEFT"
684
- else ScalarFunction.BITWISE_RIGHT_SHIFT
685
- )
686
- op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
687
- if op_spec is None:
688
- raise UnsupportedOpError("Unsupported op BitShift")
689
- input0_shape = value_shape(graph, node.inputs[0], node)
690
- input1_shape = value_shape(graph, node.inputs[1], node)
691
- output_shape = value_shape(graph, node.outputs[0], node)
692
- return BinaryOp(
693
- input0=node.inputs[0],
694
- input1=node.inputs[1],
695
- output=node.outputs[0],
696
- function=function,
697
- operator_kind=op_spec.kind,
698
- input0_shape=input0_shape,
699
- input1_shape=input1_shape,
700
- shape=output_shape,
701
- dtype=op_dtype,
702
- input_dtype=op_dtype,
703
- )
704
- if node.op_type == "Mod":
705
- fmod = int(node.attrs.get("fmod", 0))
706
- if fmod not in {0, 1}:
707
- raise UnsupportedOpError("Mod only supports fmod=0 or fmod=1")
708
- function = (
709
- ScalarFunction.FMOD if fmod == 1 else ScalarFunction.REMAINDER
710
- )
711
- else:
712
- try:
713
- function = ScalarFunction.from_onnx_op(node.op_type)
714
- except ScalarFunctionError as exc:
715
- raise UnsupportedOpError(
716
- f"Unsupported op {node.op_type}"
717
- ) from exc
718
- validate_unary_attrs(node.op_type, node.attrs)
719
- if function in COMPARE_FUNCTIONS:
720
- input_dtype = node_dtype(graph, node, *node.inputs)
721
- output_dtype = value_dtype(graph, node.outputs[0], node)
722
- op_spec = binary_op_symbol(function, node.attrs, dtype=input_dtype)
723
- if op_spec is None:
724
- raise UnsupportedOpError(f"Unsupported op {node.op_type}")
725
- if len(node.inputs) != 2 or len(node.outputs) != 1:
726
- raise UnsupportedOpError(
727
- f"{node.op_type} must have 2 inputs and 1 output"
728
- )
729
- if output_dtype != ScalarType.BOOL:
730
- raise UnsupportedOpError(
731
- f"{node.op_type} expects bool output, got {output_dtype.onnx_name}"
732
- )
733
- input0_shape = value_shape(graph, node.inputs[0], node)
734
- input1_shape = value_shape(graph, node.inputs[1], node)
735
- output_shape = value_shape(graph, node.outputs[0], node)
736
- return BinaryOp(
737
- input0=node.inputs[0],
738
- input1=node.inputs[1],
739
- output=node.outputs[0],
740
- function=function,
741
- operator_kind=op_spec.kind,
742
- input0_shape=input0_shape,
743
- input1_shape=input1_shape,
744
- shape=output_shape,
745
- dtype=output_dtype,
746
- input_dtype=input_dtype,
747
- )
748
- op_dtype = node_dtype(graph, node, *node.inputs, *node.outputs)
749
- op_spec = binary_op_symbol(function, node.attrs, dtype=op_dtype)
750
- unary_symbol = unary_op_symbol(function, dtype=op_dtype)
751
- if op_spec is None and unary_symbol is None:
752
- raise UnsupportedOpError(f"Unsupported op {node.op_type}")
753
- if op_spec is not None:
754
- if len(node.inputs) != 2 or len(node.outputs) != 1:
755
- raise UnsupportedOpError(
756
- f"{node.op_type} must have 2 inputs and 1 output"
757
- )
758
- input0_shape = value_shape(graph, node.inputs[0], node)
759
- input1_shape = value_shape(graph, node.inputs[1], node)
760
- output_shape = value_shape(graph, node.outputs[0], node)
761
- return BinaryOp(
762
- input0=node.inputs[0],
763
- input1=node.inputs[1],
764
- output=node.outputs[0],
765
- function=function,
766
- operator_kind=op_spec.kind,
767
- input0_shape=input0_shape,
768
- input1_shape=input1_shape,
769
- shape=output_shape,
770
- dtype=op_dtype,
771
- input_dtype=op_dtype,
772
- )
773
- if len(node.inputs) != 1 or len(node.outputs) != 1:
774
- raise UnsupportedOpError(f"{node.op_type} must have 1 input and 1 output")
775
- output_shape = value_shape(graph, node.outputs[0], node)
776
- return UnaryOp(
777
- input0=node.inputs[0],
778
- output=node.outputs[0],
779
- function=function,
780
- shape=output_shape,
781
- dtype=op_dtype,
782
- input_dtype=op_dtype,
783
- params=(),
784
- )
@@ -0,0 +1,87 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+
5
+ from ..errors import ShapeInferenceError, UnsupportedOpError
6
+ from .model import Graph, Initializer, Node, Value
7
+ from shared.scalar_types import ScalarType
8
+
9
+
10
+ @dataclass
11
+ class GraphContext:
12
+ graph: Graph
13
+ _dtype_cache: dict[str, ScalarType] = field(default_factory=dict)
14
+ _shape_cache: dict[str, tuple[int, ...]] = field(default_factory=dict)
15
+ _initializer_cache: dict[str, Initializer] = field(default_factory=dict)
16
+ _producer_cache: dict[str, Node] = field(default_factory=dict)
17
+
18
+ def find_value(self, name: str) -> Value:
19
+ return self.graph.find_value(name)
20
+
21
+ def dtype(self, name: str, node: Node | None = None) -> ScalarType:
22
+ if name in self._dtype_cache:
23
+ return self._dtype_cache[name]
24
+ try:
25
+ value = self.graph.find_value(name)
26
+ except KeyError as exc:
27
+ op_type = node.op_type if node is not None else "unknown"
28
+ raise ShapeInferenceError(
29
+ f"Missing dtype for value '{name}' in op {op_type}. "
30
+ "Hint: run ONNX shape inference or export with static shapes."
31
+ ) from exc
32
+ dtype = value.type.dtype
33
+ if not isinstance(dtype, ScalarType):
34
+ raise UnsupportedOpError(f"Unsupported dtype {dtype}")
35
+ self._dtype_cache[name] = dtype
36
+ return dtype
37
+
38
+ def set_dtype(self, name: str, dtype: ScalarType) -> None:
39
+ self._dtype_cache[name] = dtype
40
+
41
+ def shape(self, name: str, node: Node | None = None) -> tuple[int, ...]:
42
+ if name in self._shape_cache:
43
+ return self._shape_cache[name]
44
+ try:
45
+ value = self.graph.find_value(name)
46
+ except KeyError as exc:
47
+ op_type = node.op_type if node is not None else "unknown"
48
+ raise ShapeInferenceError(
49
+ f"Missing shape for value '{name}' in op {op_type}. "
50
+ "Hint: run ONNX shape inference or export with static shapes."
51
+ ) from exc
52
+ self._shape_cache[name] = value.type.shape
53
+ return value.type.shape
54
+
55
+ def set_shape(self, name: str, shape: tuple[int, ...]) -> None:
56
+ self._shape_cache[name] = shape
57
+
58
+ def initializer(self, name: str) -> Initializer | None:
59
+ if name in self._initializer_cache:
60
+ return self._initializer_cache[name]
61
+ for initializer in self.graph.initializers:
62
+ if initializer.name == name:
63
+ self._initializer_cache[name] = initializer
64
+ return initializer
65
+ return None
66
+
67
+ def producer(self, output_name: str) -> Node | None:
68
+ if output_name in self._producer_cache:
69
+ return self._producer_cache[output_name]
70
+ for node in self.graph.nodes:
71
+ if output_name in node.outputs:
72
+ self._producer_cache[output_name] = node
73
+ return node
74
+ return None
75
+
76
+ def opset_version(self, domain: str = "") -> int | None:
77
+ if domain in {"", "ai.onnx"}:
78
+ domains = {"", "ai.onnx"}
79
+ else:
80
+ domains = {domain}
81
+ for opset_domain, version in self.graph.opset_imports:
82
+ if opset_domain in domains:
83
+ return int(version)
84
+ return None
85
+
86
+ def __getattr__(self, name: str):
87
+ return getattr(self.graph, name)