emx-onnx-cgen 0.3.8__py3-none-any.whl → 0.4.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (137) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +1025 -162
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +2081 -458
  6. emx_onnx_cgen/compiler.py +157 -75
  7. emx_onnx_cgen/determinism.py +39 -0
  8. emx_onnx_cgen/ir/context.py +25 -15
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +32 -7
  11. emx_onnx_cgen/ir/ops/__init__.py +20 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +138 -22
  13. emx_onnx_cgen/ir/ops/misc.py +95 -0
  14. emx_onnx_cgen/ir/ops/nn.py +361 -38
  15. emx_onnx_cgen/ir/ops/reduce.py +1 -16
  16. emx_onnx_cgen/lowering/__init__.py +9 -0
  17. emx_onnx_cgen/lowering/arg_reduce.py +0 -4
  18. emx_onnx_cgen/lowering/average_pool.py +157 -27
  19. emx_onnx_cgen/lowering/bernoulli.py +73 -0
  20. emx_onnx_cgen/lowering/common.py +48 -0
  21. emx_onnx_cgen/lowering/concat.py +41 -7
  22. emx_onnx_cgen/lowering/conv.py +19 -8
  23. emx_onnx_cgen/lowering/conv_integer.py +103 -0
  24. emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
  25. emx_onnx_cgen/lowering/elementwise.py +140 -43
  26. emx_onnx_cgen/lowering/gather.py +11 -2
  27. emx_onnx_cgen/lowering/gemm.py +7 -124
  28. emx_onnx_cgen/lowering/global_max_pool.py +0 -5
  29. emx_onnx_cgen/lowering/gru.py +323 -0
  30. emx_onnx_cgen/lowering/hamming_window.py +104 -0
  31. emx_onnx_cgen/lowering/hardmax.py +1 -37
  32. emx_onnx_cgen/lowering/identity.py +7 -6
  33. emx_onnx_cgen/lowering/logsoftmax.py +1 -35
  34. emx_onnx_cgen/lowering/lp_pool.py +15 -4
  35. emx_onnx_cgen/lowering/matmul.py +3 -105
  36. emx_onnx_cgen/lowering/optional_has_element.py +28 -0
  37. emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
  38. emx_onnx_cgen/lowering/reduce.py +0 -5
  39. emx_onnx_cgen/lowering/reshape.py +7 -16
  40. emx_onnx_cgen/lowering/shape.py +14 -8
  41. emx_onnx_cgen/lowering/slice.py +14 -4
  42. emx_onnx_cgen/lowering/softmax.py +1 -35
  43. emx_onnx_cgen/lowering/split.py +37 -3
  44. emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
  45. emx_onnx_cgen/lowering/tile.py +38 -1
  46. emx_onnx_cgen/lowering/topk.py +1 -5
  47. emx_onnx_cgen/lowering/transpose.py +9 -3
  48. emx_onnx_cgen/lowering/unsqueeze.py +11 -16
  49. emx_onnx_cgen/lowering/upsample.py +151 -0
  50. emx_onnx_cgen/lowering/variadic.py +1 -1
  51. emx_onnx_cgen/lowering/where.py +0 -5
  52. emx_onnx_cgen/onnx_import.py +578 -14
  53. emx_onnx_cgen/ops.py +3 -0
  54. emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
  55. emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
  56. emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
  57. emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
  58. emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
  59. emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
  60. emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
  61. emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
  62. emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
  63. emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
  64. emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
  65. emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
  66. emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
  67. emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
  68. emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
  69. emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
  70. emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
  71. emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
  72. emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
  73. emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
  74. emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
  75. emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
  76. emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
  77. emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
  78. emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
  79. emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
  80. emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
  81. emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
  82. emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
  83. emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
  84. emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
  85. emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
  86. emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
  87. emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
  88. emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
  89. emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
  90. emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
  91. emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
  92. emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
  93. emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
  94. emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
  95. emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
  96. emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
  97. emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
  98. emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
  99. emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
  100. emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
  101. emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
  102. emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
  103. emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
  104. emx_onnx_cgen/templates/range_op.c.j2 +8 -0
  105. emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
  106. emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
  107. emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
  108. emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
  109. emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
  110. emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
  111. emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
  112. emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
  113. emx_onnx_cgen/templates/size_op.c.j2 +4 -0
  114. emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
  115. emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
  116. emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
  117. emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
  118. emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
  119. emx_onnx_cgen/templates/split_op.c.j2 +18 -0
  120. emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
  121. emx_onnx_cgen/templates/testbench.c.j2 +161 -0
  122. emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
  123. emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
  124. emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
  125. emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
  126. emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
  127. emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
  128. emx_onnx_cgen/templates/where_op.c.j2 +9 -0
  129. emx_onnx_cgen/verification.py +45 -5
  130. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/METADATA +33 -15
  131. emx_onnx_cgen-0.4.2.dev0.dist-info/RECORD +190 -0
  132. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/WHEEL +1 -1
  133. emx_onnx_cgen/runtime/__init__.py +0 -1
  134. emx_onnx_cgen/runtime/evaluator.py +0 -2955
  135. emx_onnx_cgen-0.3.8.dist-info/RECORD +0 -107
  136. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/entry_points.txt +0 -0
  137. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,13 @@ from shared.scalar_functions import ScalarFunction
6
6
  from shared.scalar_types import ScalarType
7
7
 
8
8
  from ...ops import COMPARE_FUNCTIONS, OperatorKind, binary_op_symbol
9
- from ..op_base import ElementwiseOpBase, VariadicLikeOpBase
9
+ from ...errors import ShapeInferenceError, UnsupportedOpError
10
+ from ..op_base import (
11
+ BroadcastingOpBase,
12
+ ElementwiseOpBase,
13
+ RenderableOpBase,
14
+ VariadicLikeOpBase,
15
+ )
10
16
  from ..op_context import OpContext
11
17
 
12
18
 
@@ -17,11 +23,6 @@ class BinaryOp(ElementwiseOpBase):
17
23
  output: str
18
24
  function: ScalarFunction
19
25
  operator_kind: OperatorKind
20
- input0_shape: tuple[int, ...]
21
- input1_shape: tuple[int, ...]
22
- shape: tuple[int, ...]
23
- dtype: ScalarType
24
- input_dtype: ScalarType
25
26
 
26
27
  def _elementwise_inputs(self) -> tuple[str, ...]:
27
28
  return (self.input0, self.input1)
@@ -32,6 +33,107 @@ class BinaryOp(ElementwiseOpBase):
32
33
  def _elementwise_compare(self) -> bool:
33
34
  return self.function in COMPARE_FUNCTIONS
34
35
 
36
+ def infer_shapes(self, ctx: OpContext) -> None:
37
+ if self.function != ScalarFunction.PRELU:
38
+ return super().infer_shapes(ctx)
39
+ input_shape = ctx.shape(self.input0)
40
+ slope_shape = ctx.shape(self.input1)
41
+ output_name = self.output
42
+ if BroadcastingOpBase.unidirectional_broadcastable(
43
+ slope_shape, input_shape
44
+ ):
45
+ ctx.set_shape(output_name, input_shape)
46
+ return None
47
+ channel_axis = BroadcastingOpBase.prelu_channel_axis(
48
+ input_shape, slope_shape
49
+ )
50
+ if channel_axis is not None:
51
+ ctx.set_shape(output_name, input_shape)
52
+ ctx.set_derived(self, "prelu_slope_axis", channel_axis)
53
+ return None
54
+ raise ShapeInferenceError(
55
+ "Broadcasting mismatch for shapes: "
56
+ + ", ".join(str(shape) for shape in (input_shape, slope_shape))
57
+ )
58
+
59
+
60
+ _POW_BASE_DTYPES = {
61
+ ScalarType.F16,
62
+ ScalarType.F32,
63
+ ScalarType.F64,
64
+ ScalarType.I32,
65
+ ScalarType.I64,
66
+ }
67
+ _POW_EXPONENT_DTYPES = {
68
+ ScalarType.F16,
69
+ ScalarType.F32,
70
+ ScalarType.F64,
71
+ ScalarType.I8,
72
+ ScalarType.I16,
73
+ ScalarType.I32,
74
+ ScalarType.I64,
75
+ ScalarType.U8,
76
+ ScalarType.U16,
77
+ ScalarType.U32,
78
+ ScalarType.U64,
79
+ }
80
+
81
+
82
+ @dataclass(frozen=True)
83
+ class PowOp(BinaryOp):
84
+ def validate(self, ctx: OpContext) -> None:
85
+ base_dtype = ctx.dtype(self.input0)
86
+ exponent_dtype = ctx.dtype(self.input1)
87
+ if base_dtype not in _POW_BASE_DTYPES:
88
+ raise UnsupportedOpError(
89
+ "Pow base dtype must be one of "
90
+ f"{', '.join(dtype.onnx_name for dtype in sorted(_POW_BASE_DTYPES, key=str))}, "
91
+ f"got {base_dtype.onnx_name}"
92
+ )
93
+ if exponent_dtype not in _POW_EXPONENT_DTYPES:
94
+ raise UnsupportedOpError(
95
+ "Pow exponent dtype must be one of "
96
+ f"{', '.join(dtype.onnx_name for dtype in sorted(_POW_EXPONENT_DTYPES, key=str))}, "
97
+ f"got {exponent_dtype.onnx_name}"
98
+ )
99
+ try:
100
+ output_dtype = ctx.dtype(self.output)
101
+ except ShapeInferenceError:
102
+ return None
103
+ if output_dtype != base_dtype:
104
+ raise UnsupportedOpError(
105
+ "Pow expects output dtype "
106
+ f"{base_dtype.onnx_name}, got {output_dtype.onnx_name}"
107
+ )
108
+ return None
109
+
110
+ def infer_types(self, ctx: OpContext) -> None:
111
+ base_dtype = ctx.dtype(self.input0)
112
+ exponent_dtype = ctx.dtype(self.input1)
113
+ if base_dtype not in _POW_BASE_DTYPES:
114
+ raise UnsupportedOpError(
115
+ "Pow base dtype must be one of "
116
+ f"{', '.join(dtype.onnx_name for dtype in sorted(_POW_BASE_DTYPES, key=str))}, "
117
+ f"got {base_dtype.onnx_name}"
118
+ )
119
+ if exponent_dtype not in _POW_EXPONENT_DTYPES:
120
+ raise UnsupportedOpError(
121
+ "Pow exponent dtype must be one of "
122
+ f"{', '.join(dtype.onnx_name for dtype in sorted(_POW_EXPONENT_DTYPES, key=str))}, "
123
+ f"got {exponent_dtype.onnx_name}"
124
+ )
125
+ try:
126
+ output_dtype = ctx.dtype(self.output)
127
+ except ShapeInferenceError:
128
+ ctx.set_dtype(self.output, base_dtype)
129
+ return None
130
+ if output_dtype != base_dtype:
131
+ raise UnsupportedOpError(
132
+ "Pow expects output dtype "
133
+ f"{base_dtype.onnx_name}, got {output_dtype.onnx_name}"
134
+ )
135
+ return None
136
+
35
137
 
36
138
  @dataclass(frozen=True)
37
139
  class VariadicOp(VariadicLikeOpBase):
@@ -80,11 +182,6 @@ class WhereOp(ElementwiseOpBase):
80
182
  input_x: str
81
183
  input_y: str
82
184
  output: str
83
- condition_shape: tuple[int, ...]
84
- x_shape: tuple[int, ...]
85
- y_shape: tuple[int, ...]
86
- output_shape: tuple[int, ...]
87
- dtype: ScalarType
88
185
 
89
186
  def _elementwise_inputs(self) -> tuple[str, ...]:
90
187
  return (self.condition, self.input_x, self.input_y)
@@ -101,9 +198,6 @@ class UnaryOp(ElementwiseOpBase):
101
198
  input0: str
102
199
  output: str
103
200
  function: ScalarFunction
104
- shape: tuple[int, ...]
105
- dtype: ScalarType
106
- input_dtype: ScalarType
107
201
  params: tuple[float, ...] = ()
108
202
 
109
203
  def _elementwise_inputs(self) -> tuple[str, ...]:
@@ -126,11 +220,8 @@ class ClipOp(ElementwiseOpBase):
126
220
  input_min: str | None
127
221
  input_max: str | None
128
222
  output: str
129
- input_shape: tuple[int, ...]
130
- min_shape: tuple[int, ...] | None
131
- max_shape: tuple[int, ...] | None
132
- output_shape: tuple[int, ...]
133
- dtype: ScalarType
223
+ min_value: float | None = None
224
+ max_value: float | None = None
134
225
 
135
226
  def _elementwise_inputs(self) -> tuple[str, ...]:
136
227
  inputs = [self.input0]
@@ -152,9 +243,6 @@ class ClipOp(ElementwiseOpBase):
152
243
  class IdentityOp(ElementwiseOpBase):
153
244
  input0: str
154
245
  output: str
155
- shape: tuple[int, ...]
156
- dtype: ScalarType
157
- input_dtype: ScalarType
158
246
 
159
247
  def _elementwise_inputs(self) -> tuple[str, ...]:
160
248
  return (self.input0,)
@@ -165,3 +253,31 @@ class IdentityOp(ElementwiseOpBase):
165
253
  def validate(self, ctx: OpContext) -> None:
166
254
  super().validate(ctx)
167
255
  return None
256
+
257
+
258
+ @dataclass(frozen=True)
259
+ class QLinearMulOp(RenderableOpBase):
260
+ input0: str
261
+ input0_scale: str
262
+ input0_zero_point: str
263
+ input1: str
264
+ input1_scale: str
265
+ input1_zero_point: str
266
+ output_scale: str
267
+ output_zero_point: str
268
+ output: str
269
+ input0_shape: tuple[int, ...]
270
+ input1_shape: tuple[int, ...]
271
+ output_shape: tuple[int, ...]
272
+ input0_dtype: ScalarType
273
+ input1_dtype: ScalarType
274
+ dtype: ScalarType
275
+ input0_scale_dtype: ScalarType
276
+ input1_scale_dtype: ScalarType
277
+ output_scale_dtype: ScalarType
278
+ input0_scale_shape: tuple[int, ...]
279
+ input1_scale_shape: tuple[int, ...]
280
+ output_scale_shape: tuple[int, ...]
281
+ input0_zero_shape: tuple[int, ...]
282
+ input1_zero_shape: tuple[int, ...]
283
+ output_zero_shape: tuple[int, ...]
@@ -53,6 +53,20 @@ class QuantizeLinearOp(RenderableOpBase):
53
53
  input_dtype: ScalarType
54
54
  scale_dtype: ScalarType
55
55
 
56
+
57
+ @dataclass(frozen=True)
58
+ class DequantizeLinearOp(RenderableOpBase):
59
+ input0: str
60
+ scale: str
61
+ zero_point: str | None
62
+ output: str
63
+ input_shape: tuple[int, ...]
64
+ axis: int | None
65
+ block_size: int | None
66
+ dtype: ScalarType
67
+ input_dtype: ScalarType
68
+ scale_dtype: ScalarType
69
+
56
70
  @dataclass(frozen=True)
57
71
  class ConcatOp(RenderableOpBase):
58
72
  inputs: tuple[str, ...]
@@ -184,6 +198,16 @@ class EyeLikeOp(RenderableOpBase):
184
198
  dtype: ScalarType
185
199
  input_dtype: ScalarType
186
200
 
201
+ @dataclass(frozen=True)
202
+ class BernoulliOp(RenderableOpBase):
203
+ input0: str
204
+ output: str
205
+ input_shape: tuple[int, ...]
206
+ output_shape: tuple[int, ...]
207
+ input_dtype: ScalarType
208
+ dtype: ScalarType
209
+ seed: int | None
210
+
187
211
  @dataclass(frozen=True)
188
212
  class TriluOp(RenderableOpBase):
189
213
  input0: str
@@ -355,6 +379,51 @@ class SizeOp(RenderableOpBase):
355
379
  dtype: ScalarType
356
380
  input_dtype: ScalarType
357
381
 
382
+
383
+ @dataclass(frozen=True)
384
+ class OptionalHasElementOp(RenderableOpBase):
385
+ input0: str
386
+ output: str
387
+
388
+ def validate(self, ctx: OpContext) -> None:
389
+ value = ctx.graph.find_value(self.input0)
390
+ if not value.type.is_optional:
391
+ raise UnsupportedOpError(
392
+ f"{self.kind} expects optional input, got non-optional tensor."
393
+ )
394
+ try:
395
+ output_dtype = ctx.dtype(self.output)
396
+ except ShapeInferenceError:
397
+ return None
398
+ if output_dtype != ScalarType.BOOL:
399
+ raise UnsupportedOpError(
400
+ f"{self.kind} expects bool output, got {output_dtype.onnx_name}"
401
+ )
402
+ return None
403
+
404
+ def infer_types(self, ctx: OpContext) -> None:
405
+ ctx.dtype(self.input0)
406
+ try:
407
+ output_dtype = ctx.dtype(self.output)
408
+ except ShapeInferenceError:
409
+ ctx.set_dtype(self.output, ScalarType.BOOL)
410
+ return None
411
+ if output_dtype != ScalarType.BOOL:
412
+ raise UnsupportedOpError(
413
+ f"{self.kind} expects bool output, got {output_dtype.onnx_name}"
414
+ )
415
+
416
+ def infer_shapes(self, ctx: OpContext) -> None:
417
+ try:
418
+ output_shape = ctx.shape(self.output)
419
+ except ShapeInferenceError:
420
+ ctx.set_shape(self.output, ())
421
+ return None
422
+ if output_shape not in {(), (1,)}:
423
+ raise UnsupportedOpError(
424
+ f"{self.kind} expects scalar output, got shape {output_shape}"
425
+ )
426
+
358
427
  @dataclass(frozen=True)
359
428
  class NonZeroOp(RenderableOpBase):
360
429
  input0: str
@@ -477,6 +546,15 @@ class RangeOp(RenderableOpBase):
477
546
  dtype: ScalarType
478
547
  input_dtype: ScalarType
479
548
 
549
+ @dataclass(frozen=True)
550
+ class HammingWindowOp(RenderableOpBase):
551
+ size: str
552
+ output: str
553
+ output_shape: tuple[int, ...]
554
+ periodic: bool
555
+ dtype: ScalarType
556
+ input_dtype: ScalarType
557
+
480
558
  @dataclass(frozen=True)
481
559
  class OneHotOp(RenderableOpBase):
482
560
  indices: str
@@ -492,6 +570,23 @@ class OneHotOp(RenderableOpBase):
492
570
  indices_dtype: ScalarType
493
571
  depth_dtype: ScalarType
494
572
 
573
+ @dataclass(frozen=True)
574
+ class TfIdfVectorizerOp(RenderableOpBase):
575
+ input0: str
576
+ output: str
577
+ input_shape: tuple[int, ...]
578
+ output_shape: tuple[int, ...]
579
+ input_dtype: ScalarType
580
+ output_dtype: ScalarType
581
+ min_gram_length: int
582
+ max_gram_length: int
583
+ max_skip_count: int
584
+ mode: str
585
+ ngram_counts: tuple[int, ...]
586
+ ngram_indexes: tuple[int, ...]
587
+ pool_int64s: tuple[int, ...]
588
+ weights: tuple[float, ...] | None
589
+
495
590
  @dataclass(frozen=True)
496
591
  class SplitOp(RenderableOpBase):
497
592
  input0: str