emx-onnx-cgen 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (94) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +50 -23
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +1844 -1568
  6. emx_onnx_cgen/codegen/emitter.py +5 -0
  7. emx_onnx_cgen/compiler.py +30 -387
  8. emx_onnx_cgen/ir/context.py +87 -0
  9. emx_onnx_cgen/ir/op_base.py +193 -0
  10. emx_onnx_cgen/ir/op_context.py +65 -0
  11. emx_onnx_cgen/ir/ops/__init__.py +130 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +146 -0
  13. emx_onnx_cgen/ir/ops/misc.py +421 -0
  14. emx_onnx_cgen/ir/ops/nn.py +580 -0
  15. emx_onnx_cgen/ir/ops/reduce.py +95 -0
  16. emx_onnx_cgen/lowering/__init__.py +79 -1
  17. emx_onnx_cgen/lowering/adagrad.py +114 -0
  18. emx_onnx_cgen/lowering/arg_reduce.py +1 -1
  19. emx_onnx_cgen/lowering/attention.py +1 -1
  20. emx_onnx_cgen/lowering/average_pool.py +1 -1
  21. emx_onnx_cgen/lowering/batch_normalization.py +1 -1
  22. emx_onnx_cgen/lowering/cast.py +1 -1
  23. emx_onnx_cgen/lowering/common.py +36 -18
  24. emx_onnx_cgen/lowering/concat.py +1 -1
  25. emx_onnx_cgen/lowering/constant_of_shape.py +1 -1
  26. emx_onnx_cgen/lowering/conv.py +1 -1
  27. emx_onnx_cgen/lowering/conv_transpose.py +1 -1
  28. emx_onnx_cgen/lowering/cumsum.py +1 -1
  29. emx_onnx_cgen/lowering/depth_space.py +1 -1
  30. emx_onnx_cgen/lowering/dropout.py +1 -1
  31. emx_onnx_cgen/lowering/einsum.py +1 -1
  32. emx_onnx_cgen/lowering/elementwise.py +152 -4
  33. emx_onnx_cgen/lowering/expand.py +1 -1
  34. emx_onnx_cgen/lowering/eye_like.py +1 -1
  35. emx_onnx_cgen/lowering/flatten.py +1 -1
  36. emx_onnx_cgen/lowering/gather.py +1 -1
  37. emx_onnx_cgen/lowering/gather_elements.py +1 -1
  38. emx_onnx_cgen/lowering/gather_nd.py +1 -1
  39. emx_onnx_cgen/lowering/gemm.py +1 -1
  40. emx_onnx_cgen/lowering/global_max_pool.py +1 -1
  41. emx_onnx_cgen/lowering/grid_sample.py +1 -1
  42. emx_onnx_cgen/lowering/group_normalization.py +1 -1
  43. emx_onnx_cgen/lowering/hardmax.py +1 -1
  44. emx_onnx_cgen/lowering/identity.py +1 -1
  45. emx_onnx_cgen/lowering/instance_normalization.py +1 -1
  46. emx_onnx_cgen/lowering/layer_normalization.py +1 -1
  47. emx_onnx_cgen/lowering/logsoftmax.py +1 -1
  48. emx_onnx_cgen/lowering/lp_normalization.py +1 -1
  49. emx_onnx_cgen/lowering/lp_pool.py +1 -1
  50. emx_onnx_cgen/lowering/lrn.py +1 -1
  51. emx_onnx_cgen/lowering/lstm.py +1 -1
  52. emx_onnx_cgen/lowering/matmul.py +1 -1
  53. emx_onnx_cgen/lowering/maxpool.py +1 -1
  54. emx_onnx_cgen/lowering/mean_variance_normalization.py +1 -1
  55. emx_onnx_cgen/lowering/negative_log_likelihood_loss.py +1 -1
  56. emx_onnx_cgen/lowering/non_max_suppression.py +157 -0
  57. emx_onnx_cgen/lowering/nonzero.py +1 -1
  58. emx_onnx_cgen/lowering/one_hot.py +1 -1
  59. emx_onnx_cgen/lowering/pad.py +1 -1
  60. emx_onnx_cgen/lowering/qlinear_matmul.py +212 -0
  61. emx_onnx_cgen/lowering/quantize_linear.py +1 -1
  62. emx_onnx_cgen/lowering/range.py +1 -1
  63. emx_onnx_cgen/lowering/reduce.py +1 -1
  64. emx_onnx_cgen/lowering/registry.py +24 -5
  65. emx_onnx_cgen/lowering/reshape.py +1 -1
  66. emx_onnx_cgen/lowering/resize.py +1 -1
  67. emx_onnx_cgen/lowering/rms_normalization.py +1 -1
  68. emx_onnx_cgen/lowering/rotary_embedding.py +165 -0
  69. emx_onnx_cgen/lowering/scatter_nd.py +1 -1
  70. emx_onnx_cgen/lowering/shape.py +6 -25
  71. emx_onnx_cgen/lowering/size.py +1 -1
  72. emx_onnx_cgen/lowering/slice.py +1 -1
  73. emx_onnx_cgen/lowering/softmax.py +1 -1
  74. emx_onnx_cgen/lowering/softmax_cross_entropy_loss.py +1 -1
  75. emx_onnx_cgen/lowering/split.py +1 -1
  76. emx_onnx_cgen/lowering/squeeze.py +1 -1
  77. emx_onnx_cgen/lowering/tensor_scatter.py +110 -0
  78. emx_onnx_cgen/lowering/tile.py +1 -1
  79. emx_onnx_cgen/lowering/topk.py +25 -7
  80. emx_onnx_cgen/lowering/transpose.py +1 -1
  81. emx_onnx_cgen/lowering/trilu.py +1 -1
  82. emx_onnx_cgen/lowering/unsqueeze.py +1 -1
  83. emx_onnx_cgen/lowering/variadic.py +1 -1
  84. emx_onnx_cgen/lowering/where.py +1 -1
  85. emx_onnx_cgen/runtime/evaluator.py +325 -1
  86. emx_onnx_cgen/verification.py +9 -39
  87. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/METADATA +8 -7
  88. emx_onnx_cgen-0.3.1.dist-info/RECORD +107 -0
  89. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/WHEEL +1 -1
  90. shared/scalar_functions.py +11 -0
  91. shared/ulp.py +17 -0
  92. emx_onnx_cgen-0.3.0.dist-info/RECORD +0 -93
  93. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/entry_points.txt +0 -0
  94. {emx_onnx_cgen-0.3.0.dist-info → emx_onnx_cgen-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,580 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from enum import Enum
5
+
6
+ from shared.scalar_functions import ScalarFunction
7
+ from shared.scalar_types import ScalarType
8
+
9
+ from ...errors import ShapeInferenceError
10
+ from ..op_base import ConvLikeOpBase, GemmLikeOpBase, MatMulLikeOpBase, RenderableOpBase
11
+ from ..op_context import OpContext
12
+
13
+
14
+ class EinsumKind(str, Enum):
15
+ REDUCE_ALL = "reduce_all"
16
+ SUM_J = "sum_j"
17
+ TRANSPOSE = "transpose"
18
+ DOT = "dot"
19
+ BATCH_MATMUL = "batch_matmul"
20
+ BATCH_DIAGONAL = "batch_diagonal"
21
+
22
+
23
+ def _shape_product(shape: tuple[int, ...]) -> int:
24
+ product = 1
25
+ for dim in shape:
26
+ if dim < 0:
27
+ raise ShapeInferenceError("Dynamic dims are not supported")
28
+ product *= dim
29
+ return product
30
+
31
+
32
+ @dataclass(frozen=True)
33
+ class MatMulOp(MatMulLikeOpBase):
34
+ input0: str
35
+ input1: str
36
+ output: str
37
+ input0_shape: tuple[int, ...]
38
+ input1_shape: tuple[int, ...]
39
+ output_shape: tuple[int, ...]
40
+ batch_shape: tuple[int, ...]
41
+ input0_batch_shape: tuple[int, ...]
42
+ input1_batch_shape: tuple[int, ...]
43
+ m: int
44
+ n: int
45
+ k: int
46
+ left_vector: bool
47
+ right_vector: bool
48
+ dtype: ScalarType
49
+
50
+ @dataclass(frozen=True)
51
+ class QLinearMatMulOp(MatMulLikeOpBase):
52
+ input0: str
53
+ input0_scale: str
54
+ input0_zero_point: str
55
+ input1: str
56
+ input1_scale: str
57
+ input1_zero_point: str
58
+ output_scale: str
59
+ output_zero_point: str
60
+ output: str
61
+ input0_shape: tuple[int, ...]
62
+ input1_shape: tuple[int, ...]
63
+ output_shape: tuple[int, ...]
64
+ batch_shape: tuple[int, ...]
65
+ input0_batch_shape: tuple[int, ...]
66
+ input1_batch_shape: tuple[int, ...]
67
+ m: int
68
+ n: int
69
+ k: int
70
+ left_vector: bool
71
+ right_vector: bool
72
+ input0_dtype: ScalarType
73
+ input1_dtype: ScalarType
74
+ dtype: ScalarType
75
+ input0_scale_dtype: ScalarType
76
+ input1_scale_dtype: ScalarType
77
+ output_scale_dtype: ScalarType
78
+ input0_scale_shape: tuple[int, ...]
79
+ input1_scale_shape: tuple[int, ...]
80
+ output_scale_shape: tuple[int, ...]
81
+ input0_zero_shape: tuple[int, ...]
82
+ input1_zero_shape: tuple[int, ...]
83
+ output_zero_shape: tuple[int, ...]
84
+
85
+ @dataclass(frozen=True)
86
+ class EinsumOp(MatMulLikeOpBase):
87
+ inputs: tuple[str, ...]
88
+ output: str
89
+ kind: EinsumKind
90
+ input_shapes: tuple[tuple[int, ...], ...]
91
+ output_shape: tuple[int, ...]
92
+ dtype: ScalarType
93
+ input_dtype: ScalarType
94
+
95
+ @dataclass(frozen=True)
96
+ class GemmOp(GemmLikeOpBase):
97
+ input_a: str
98
+ input_b: str
99
+ input_c: str | None
100
+ output: str
101
+ m: int
102
+ n: int
103
+ k: int
104
+ trans_a: bool
105
+ trans_b: bool
106
+ alpha: float | int
107
+ beta: float | int
108
+ c_shape: tuple[int, ...] | None
109
+ dtype: ScalarType
110
+
111
+ @dataclass(frozen=True)
112
+ class AttentionOp(RenderableOpBase):
113
+ input_q: str
114
+ input_k: str
115
+ input_v: str
116
+ input_attn_mask: str | None
117
+ input_past_key: str | None
118
+ input_past_value: str | None
119
+ input_nonpad_kv_seqlen: str | None
120
+ output: str
121
+ output_present_key: str | None
122
+ output_present_value: str | None
123
+ output_qk_matmul: str | None
124
+ batch: int
125
+ q_heads: int
126
+ kv_heads: int
127
+ q_seq: int
128
+ kv_seq: int
129
+ total_seq: int
130
+ past_seq: int
131
+ qk_head_size: int
132
+ v_head_size: int
133
+ q_hidden_size: int | None
134
+ k_hidden_size: int | None
135
+ v_hidden_size: int | None
136
+ scale: float
137
+ is_causal: bool
138
+ softcap: float
139
+ qk_matmul_output_mode: int
140
+ q_rank: int
141
+ k_rank: int
142
+ v_rank: int
143
+ output_rank: int
144
+ mask_shape: tuple[int, ...] | None
145
+ mask_is_bool: bool
146
+ mask_rank: int | None
147
+ mask_broadcast_batch: bool
148
+ mask_broadcast_heads: bool
149
+ mask_broadcast_q_seq: bool
150
+ mask_q_seq: int | None
151
+ mask_kv_seq: int | None
152
+ head_group_size: int
153
+ dtype: ScalarType
154
+
155
+ @dataclass(frozen=True)
156
+ class RotaryEmbeddingOp(RenderableOpBase):
157
+ input0: str
158
+ cos_cache: str
159
+ sin_cache: str
160
+ position_ids: str | None
161
+ output: str
162
+ input_shape: tuple[int, ...]
163
+ cos_shape: tuple[int, ...]
164
+ sin_shape: tuple[int, ...]
165
+ position_ids_shape: tuple[int, ...] | None
166
+ dtype: ScalarType
167
+ position_ids_dtype: ScalarType | None
168
+ rotary_dim: int
169
+ rotary_dim_half: int
170
+ head_size: int
171
+ num_heads: int
172
+ seq_len: int
173
+ batch: int
174
+ input_rank: int
175
+ interleaved: bool
176
+
177
+ @dataclass(frozen=True)
178
+ class ConvOp(ConvLikeOpBase):
179
+ input0: str
180
+ weights: str
181
+ bias: str | None
182
+ output: str
183
+ batch: int
184
+ in_channels: int
185
+ out_channels: int
186
+ spatial_rank: int
187
+ in_spatial: tuple[int, ...]
188
+ out_spatial: tuple[int, ...]
189
+ kernel_shape: tuple[int, ...]
190
+ strides: tuple[int, ...]
191
+ pads: tuple[int, ...]
192
+ dilations: tuple[int, ...]
193
+ group: int
194
+ dtype: ScalarType
195
+
196
+ @property
197
+ def out_h(self) -> int:
198
+ if self.spatial_rank < 1:
199
+ raise ValueError("Conv output height is undefined for spatial_rank < 1")
200
+ return self.out_spatial[0]
201
+
202
+ @property
203
+ def out_w(self) -> int:
204
+ if self.spatial_rank < 2:
205
+ raise ValueError("Conv output width is undefined for spatial_rank < 2")
206
+ return self.out_spatial[1]
207
+
208
+ @dataclass(frozen=True)
209
+ class ConvTransposeOp(ConvLikeOpBase):
210
+ input0: str
211
+ weights: str
212
+ bias: str | None
213
+ output: str
214
+ batch: int
215
+ in_channels: int
216
+ out_channels: int
217
+ spatial_rank: int
218
+ in_spatial: tuple[int, ...]
219
+ out_spatial: tuple[int, ...]
220
+ kernel_shape: tuple[int, ...]
221
+ strides: tuple[int, ...]
222
+ pads: tuple[int, ...]
223
+ dilations: tuple[int, ...]
224
+ output_padding: tuple[int, ...]
225
+ group: int
226
+ dtype: ScalarType
227
+
228
+ @dataclass(frozen=True)
229
+ class AveragePoolOp(RenderableOpBase):
230
+ input0: str
231
+ output: str
232
+ batch: int
233
+ channels: int
234
+ in_h: int
235
+ in_w: int
236
+ out_h: int
237
+ out_w: int
238
+ kernel_h: int
239
+ kernel_w: int
240
+ stride_h: int
241
+ stride_w: int
242
+ pad_top: int
243
+ pad_left: int
244
+ pad_bottom: int
245
+ pad_right: int
246
+ count_include_pad: bool
247
+ dtype: ScalarType
248
+
249
+ @dataclass(frozen=True)
250
+ class LpPoolOp(RenderableOpBase):
251
+ input0: str
252
+ output: str
253
+ batch: int
254
+ channels: int
255
+ in_h: int
256
+ in_w: int
257
+ out_h: int
258
+ out_w: int
259
+ kernel_h: int
260
+ kernel_w: int
261
+ stride_h: int
262
+ stride_w: int
263
+ pad_top: int
264
+ pad_left: int
265
+ pad_bottom: int
266
+ pad_right: int
267
+ p: int
268
+ dtype: ScalarType
269
+
270
+ @dataclass(frozen=True)
271
+ class SoftmaxOp(RenderableOpBase):
272
+ input0: str
273
+ output: str
274
+ outer: int
275
+ axis_size: int
276
+ inner: int
277
+ axis: int
278
+ shape: tuple[int, ...]
279
+ dtype: ScalarType
280
+
281
+ def infer_shapes(self, ctx: OpContext) -> None:
282
+ input_shape = ctx.shape(self.input0)
283
+ axis = self.axis
284
+ if axis < 0:
285
+ axis += len(input_shape)
286
+ if axis < 0 or axis >= len(input_shape):
287
+ raise ShapeInferenceError(
288
+ f"Softmax axis {self.axis} is out of bounds for shape {input_shape}"
289
+ )
290
+ outer = _shape_product(input_shape[:axis]) if axis > 0 else 1
291
+ axis_size = input_shape[axis]
292
+ inner = (
293
+ _shape_product(input_shape[axis + 1 :])
294
+ if axis + 1 < len(input_shape)
295
+ else 1
296
+ )
297
+ ctx.set_shape(self.output, input_shape)
298
+ ctx.set_derived(self, "outer", outer)
299
+ ctx.set_derived(self, "axis_size", axis_size)
300
+ ctx.set_derived(self, "inner", inner)
301
+
302
+ @dataclass(frozen=True)
303
+ class LogSoftmaxOp(RenderableOpBase):
304
+ input0: str
305
+ output: str
306
+ outer: int
307
+ axis_size: int
308
+ inner: int
309
+ axis: int
310
+ shape: tuple[int, ...]
311
+ dtype: ScalarType
312
+
313
+ def infer_shapes(self, ctx: OpContext) -> None:
314
+ input_shape = ctx.shape(self.input0)
315
+ axis = self.axis
316
+ if axis < 0:
317
+ axis += len(input_shape)
318
+ if axis < 0 or axis >= len(input_shape):
319
+ raise ShapeInferenceError(
320
+ f"LogSoftmax axis {self.axis} is out of bounds for shape {input_shape}"
321
+ )
322
+ outer = _shape_product(input_shape[:axis]) if axis > 0 else 1
323
+ axis_size = input_shape[axis]
324
+ inner = (
325
+ _shape_product(input_shape[axis + 1 :])
326
+ if axis + 1 < len(input_shape)
327
+ else 1
328
+ )
329
+ ctx.set_shape(self.output, input_shape)
330
+ ctx.set_derived(self, "outer", outer)
331
+ ctx.set_derived(self, "axis_size", axis_size)
332
+ ctx.set_derived(self, "inner", inner)
333
+
334
+ @dataclass(frozen=True)
335
+ class HardmaxOp(RenderableOpBase):
336
+ input0: str
337
+ output: str
338
+ outer: int
339
+ axis_size: int
340
+ inner: int
341
+ axis: int
342
+ shape: tuple[int, ...]
343
+ dtype: ScalarType
344
+
345
+ def infer_shapes(self, ctx: OpContext) -> None:
346
+ input_shape = ctx.shape(self.input0)
347
+ axis = self.axis
348
+ if axis < 0:
349
+ axis += len(input_shape)
350
+ if axis < 0 or axis >= len(input_shape):
351
+ raise ShapeInferenceError(
352
+ f"Hardmax axis {self.axis} is out of bounds for shape {input_shape}"
353
+ )
354
+ outer = _shape_product(input_shape[:axis]) if axis > 0 else 1
355
+ axis_size = input_shape[axis]
356
+ inner = (
357
+ _shape_product(input_shape[axis + 1 :])
358
+ if axis + 1 < len(input_shape)
359
+ else 1
360
+ )
361
+ ctx.set_shape(self.output, input_shape)
362
+ ctx.set_derived(self, "outer", outer)
363
+ ctx.set_derived(self, "axis_size", axis_size)
364
+ ctx.set_derived(self, "inner", inner)
365
+
366
+ @dataclass(frozen=True)
367
+ class NegativeLogLikelihoodLossOp(RenderableOpBase):
368
+ input0: str
369
+ target: str
370
+ weight: str | None
371
+ output: str
372
+ input_shape: tuple[int, ...]
373
+ target_shape: tuple[int, ...]
374
+ output_shape: tuple[int, ...]
375
+ n: int
376
+ c: int
377
+ d: int
378
+ reduction: str
379
+ ignore_index: int
380
+ input_dtype: ScalarType
381
+ weight_dtype: ScalarType | None
382
+ weight_shape: tuple[int, ...] | None
383
+ dtype: ScalarType
384
+ target_dtype: ScalarType
385
+
386
+ @dataclass(frozen=True)
387
+ class SoftmaxCrossEntropyLossOp(RenderableOpBase):
388
+ input0: str
389
+ target: str
390
+ weight: str | None
391
+ output: str
392
+ log_prob: str | None
393
+ input_shape: tuple[int, ...]
394
+ target_shape: tuple[int, ...]
395
+ output_shape: tuple[int, ...]
396
+ log_prob_shape: tuple[int, ...] | None
397
+ n: int
398
+ c: int
399
+ d: int
400
+ reduction: str
401
+ ignore_index: int | None
402
+ input_dtype: ScalarType
403
+ weight_dtype: ScalarType | None
404
+ weight_shape: tuple[int, ...] | None
405
+ dtype: ScalarType
406
+ target_dtype: ScalarType
407
+
408
+ @dataclass(frozen=True)
409
+ class BatchNormOp(RenderableOpBase):
410
+ input0: str
411
+ scale: str
412
+ bias: str
413
+ mean: str
414
+ variance: str
415
+ output: str
416
+ shape: tuple[int, ...]
417
+ channels: int
418
+ epsilon: float
419
+ dtype: ScalarType
420
+
421
+ @dataclass(frozen=True)
422
+ class LpNormalizationOp(RenderableOpBase):
423
+ input0: str
424
+ output: str
425
+ shape: tuple[int, ...]
426
+ axis: int
427
+ p: int
428
+ outer: int
429
+ axis_size: int
430
+ inner: int
431
+ dtype: ScalarType
432
+
433
+ @dataclass(frozen=True)
434
+ class InstanceNormalizationOp(RenderableOpBase):
435
+ input0: str
436
+ scale: str
437
+ bias: str
438
+ output: str
439
+ shape: tuple[int, ...]
440
+ channels: int
441
+ spatial_size: int
442
+ epsilon: float
443
+ dtype: ScalarType
444
+
445
+ @dataclass(frozen=True)
446
+ class GroupNormalizationOp(RenderableOpBase):
447
+ input0: str
448
+ scale: str
449
+ bias: str
450
+ output: str
451
+ shape: tuple[int, ...]
452
+ channels: int
453
+ num_groups: int
454
+ group_size: int
455
+ spatial_size: int
456
+ epsilon: float
457
+ dtype: ScalarType
458
+
459
+ @dataclass(frozen=True)
460
+ class LayerNormalizationOp(RenderableOpBase):
461
+ input0: str
462
+ scale: str
463
+ bias: str | None
464
+ output: str
465
+ mean_output: str | None
466
+ invstd_output: str | None
467
+ shape: tuple[int, ...]
468
+ normalized_shape: tuple[int, ...]
469
+ scale_shape: tuple[int, ...]
470
+ bias_shape: tuple[int, ...] | None
471
+ outer: int
472
+ inner: int
473
+ axis: int
474
+ epsilon: float
475
+ dtype: ScalarType
476
+
477
+ @dataclass(frozen=True)
478
+ class MeanVarianceNormalizationOp(RenderableOpBase):
479
+ input0: str
480
+ output: str
481
+ shape: tuple[int, ...]
482
+ axes: tuple[int, ...]
483
+ non_axes: tuple[int, ...]
484
+ reduce_count: int
485
+ epsilon: float
486
+ dtype: ScalarType
487
+
488
+ @dataclass(frozen=True)
489
+ class RMSNormalizationOp(RenderableOpBase):
490
+ input0: str
491
+ scale: str
492
+ output: str
493
+ shape: tuple[int, ...]
494
+ normalized_shape: tuple[int, ...]
495
+ scale_shape: tuple[int, ...]
496
+ outer: int
497
+ inner: int
498
+ axis: int
499
+ epsilon: float
500
+ dtype: ScalarType
501
+
502
+ @dataclass(frozen=True)
503
+ class LrnOp(RenderableOpBase):
504
+ input0: str
505
+ output: str
506
+ shape: tuple[int, ...]
507
+ channels: int
508
+ size: int
509
+ half: int
510
+ alpha: float
511
+ beta: float
512
+ bias: float
513
+ dtype: ScalarType
514
+
515
+ @dataclass(frozen=True)
516
+ class LstmOp(RenderableOpBase):
517
+ input_x: str
518
+ input_w: str
519
+ input_r: str
520
+ input_b: str | None
521
+ input_sequence_lens: str | None
522
+ input_initial_h: str | None
523
+ input_initial_c: str | None
524
+ input_p: str | None
525
+ output_y: str | None
526
+ output_y_h: str | None
527
+ output_y_c: str | None
528
+ seq_length: int
529
+ batch_size: int
530
+ input_size: int
531
+ hidden_size: int
532
+ num_directions: int
533
+ direction: str
534
+ layout: int
535
+ input_forget: int
536
+ clip: float | None
537
+ activation_kinds: tuple[int, ...]
538
+ activation_alphas: tuple[float, ...]
539
+ activation_betas: tuple[float, ...]
540
+ dtype: ScalarType
541
+ sequence_lens_dtype: ScalarType | None
542
+
543
+ @dataclass(frozen=True)
544
+ class AdagradOp(RenderableOpBase):
545
+ rate: str
546
+ timestep: str
547
+ inputs: tuple[str, ...]
548
+ gradients: tuple[str, ...]
549
+ accumulators: tuple[str, ...]
550
+ outputs: tuple[str, ...]
551
+ accumulator_outputs: tuple[str, ...]
552
+ rate_shape: tuple[int, ...]
553
+ timestep_shape: tuple[int, ...]
554
+ tensor_shapes: tuple[tuple[int, ...], ...]
555
+ output_shapes: tuple[tuple[int, ...], ...]
556
+ dtype: ScalarType
557
+ rate_dtype: ScalarType
558
+ timestep_dtype: ScalarType
559
+ norm_coefficient: float
560
+ epsilon: float
561
+ decay_factor: float
562
+
563
+ @dataclass(frozen=True)
564
+ class MaxPoolOp(RenderableOpBase):
565
+ input0: str
566
+ output: str
567
+ indices: str | None
568
+ batch: int
569
+ channels: int
570
+ spatial_rank: int
571
+ in_spatial: tuple[int, ...]
572
+ out_spatial: tuple[int, ...]
573
+ kernel_shape: tuple[int, ...]
574
+ strides: tuple[int, ...]
575
+ pads: tuple[int, ...]
576
+ dilations: tuple[int, ...]
577
+ ceil_mode: bool
578
+ storage_order: int
579
+ dtype: ScalarType
580
+ indices_dtype: ScalarType | None
@@ -0,0 +1,95 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from shared.scalar_types import ScalarType
6
+
7
+ from ..op_base import ReduceOpBase
8
+ from ..op_context import OpContext
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class ReduceOp(ReduceOpBase):
13
+ input0: str
14
+ output: str
15
+ input_shape: tuple[int, ...]
16
+ output_shape: tuple[int, ...]
17
+ axes: tuple[int, ...]
18
+ axes_input: str | None
19
+ axes_input_shape: tuple[int, ...] | None
20
+ axes_input_dtype: ScalarType | None
21
+ keepdims: bool
22
+ noop_with_empty_axes: bool
23
+ reduce_kind: str
24
+ reduce_count: int | None
25
+ dtype: ScalarType
26
+
27
+ def infer_types(self, ctx: OpContext) -> None:
28
+ ctx.dtype(self.output)
29
+
30
+ def infer_shapes(self, ctx: OpContext) -> None:
31
+ input_shape = ctx.shape(self.input0)
32
+ if self.axes_input is None:
33
+ axes = self.normalize_axes(self.axes, len(input_shape))
34
+ output_shape = self.reduced_shape(
35
+ input_shape, axes, keepdims=self.keepdims
36
+ )
37
+ else:
38
+ axes = self.axes
39
+ output_shape = ctx.shape(self.output)
40
+ ctx.set_shape(self.output, output_shape)
41
+ ctx.set_derived(self, "axes", axes)
42
+
43
+
44
+ @dataclass(frozen=True)
45
+ class ArgReduceOp(ReduceOpBase):
46
+ input0: str
47
+ output: str
48
+ input_shape: tuple[int, ...]
49
+ output_shape: tuple[int, ...]
50
+ axis: int
51
+ keepdims: bool
52
+ select_last_index: bool
53
+ reduce_kind: str
54
+ input_dtype: ScalarType
55
+ output_dtype: ScalarType
56
+
57
+ def infer_types(self, ctx: OpContext) -> None:
58
+ ctx.dtype(self.input0)
59
+ ctx.dtype(self.output)
60
+
61
+ def infer_shapes(self, ctx: OpContext) -> None:
62
+ input_shape = ctx.shape(self.input0)
63
+ axes = self.normalize_axes((self.axis,), len(input_shape))
64
+ output_shape = self.reduced_shape(
65
+ input_shape, axes, keepdims=self.keepdims
66
+ )
67
+ ctx.set_shape(self.output, output_shape)
68
+ ctx.set_derived(self, "axis", axes[0])
69
+
70
+
71
+ @dataclass(frozen=True)
72
+ class TopKOp(ReduceOpBase):
73
+ input0: str
74
+ output_values: str
75
+ output_indices: str
76
+ input_shape: tuple[int, ...]
77
+ output_shape: tuple[int, ...]
78
+ axis: int
79
+ k: int
80
+ largest: bool
81
+ sorted: bool
82
+ input_dtype: ScalarType
83
+ output_values_dtype: ScalarType
84
+ output_indices_dtype: ScalarType
85
+
86
+ def infer_types(self, ctx: OpContext) -> None:
87
+ ctx.dtype(self.input0)
88
+ ctx.dtype(self.output_values)
89
+ ctx.dtype(self.output_indices)
90
+
91
+ def infer_shapes(self, ctx: OpContext) -> None:
92
+ input_shape = ctx.shape(self.input0)
93
+ output_shape = ctx.shape(self.output_values)
94
+ ctx.set_shape(self.output_values, output_shape)
95
+ ctx.set_shape(self.output_indices, ctx.shape(self.output_indices))