emx-onnx-cgen 0.3.8__py3-none-any.whl → 0.4.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +1025 -162
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +2081 -458
  6. emx_onnx_cgen/compiler.py +157 -75
  7. emx_onnx_cgen/determinism.py +39 -0
  8. emx_onnx_cgen/ir/context.py +25 -15
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +32 -7
  11. emx_onnx_cgen/ir/ops/__init__.py +20 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +138 -22
  13. emx_onnx_cgen/ir/ops/misc.py +95 -0
  14. emx_onnx_cgen/ir/ops/nn.py +361 -38
  15. emx_onnx_cgen/ir/ops/reduce.py +1 -16
  16. emx_onnx_cgen/lowering/__init__.py +9 -0
  17. emx_onnx_cgen/lowering/arg_reduce.py +0 -4
  18. emx_onnx_cgen/lowering/average_pool.py +157 -27
  19. emx_onnx_cgen/lowering/bernoulli.py +73 -0
  20. emx_onnx_cgen/lowering/common.py +48 -0
  21. emx_onnx_cgen/lowering/concat.py +41 -7
  22. emx_onnx_cgen/lowering/conv.py +19 -8
  23. emx_onnx_cgen/lowering/conv_integer.py +103 -0
  24. emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
  25. emx_onnx_cgen/lowering/elementwise.py +140 -43
  26. emx_onnx_cgen/lowering/gather.py +11 -2
  27. emx_onnx_cgen/lowering/gemm.py +7 -124
  28. emx_onnx_cgen/lowering/global_max_pool.py +0 -5
  29. emx_onnx_cgen/lowering/gru.py +323 -0
  30. emx_onnx_cgen/lowering/hamming_window.py +104 -0
  31. emx_onnx_cgen/lowering/hardmax.py +1 -37
  32. emx_onnx_cgen/lowering/identity.py +7 -6
  33. emx_onnx_cgen/lowering/logsoftmax.py +1 -35
  34. emx_onnx_cgen/lowering/lp_pool.py +15 -4
  35. emx_onnx_cgen/lowering/matmul.py +3 -105
  36. emx_onnx_cgen/lowering/optional_has_element.py +28 -0
  37. emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
  38. emx_onnx_cgen/lowering/reduce.py +0 -5
  39. emx_onnx_cgen/lowering/reshape.py +7 -16
  40. emx_onnx_cgen/lowering/shape.py +14 -8
  41. emx_onnx_cgen/lowering/slice.py +14 -4
  42. emx_onnx_cgen/lowering/softmax.py +1 -35
  43. emx_onnx_cgen/lowering/split.py +37 -3
  44. emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
  45. emx_onnx_cgen/lowering/tile.py +38 -1
  46. emx_onnx_cgen/lowering/topk.py +1 -5
  47. emx_onnx_cgen/lowering/transpose.py +9 -3
  48. emx_onnx_cgen/lowering/unsqueeze.py +11 -16
  49. emx_onnx_cgen/lowering/upsample.py +151 -0
  50. emx_onnx_cgen/lowering/variadic.py +1 -1
  51. emx_onnx_cgen/lowering/where.py +0 -5
  52. emx_onnx_cgen/onnx_import.py +578 -14
  53. emx_onnx_cgen/ops.py +3 -0
  54. emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
  55. emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
  56. emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
  57. emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
  58. emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
  59. emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
  60. emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
  61. emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
  62. emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
  63. emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
  64. emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
  65. emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
  66. emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
  67. emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
  68. emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
  69. emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
  70. emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
  71. emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
  72. emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
  73. emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
  74. emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
  75. emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
  76. emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
  77. emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
  78. emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
  79. emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
  80. emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
  81. emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
  82. emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
  83. emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
  84. emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
  85. emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
  86. emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
  87. emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
  88. emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
  89. emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
  90. emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
  91. emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
  92. emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
  93. emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
  94. emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
  95. emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
  96. emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
  97. emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
  98. emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
  99. emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
  100. emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
  101. emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
  102. emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
  103. emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
  104. emx_onnx_cgen/templates/range_op.c.j2 +8 -0
  105. emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
  106. emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
  107. emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
  108. emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
  109. emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
  110. emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
  111. emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
  112. emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
  113. emx_onnx_cgen/templates/size_op.c.j2 +4 -0
  114. emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
  115. emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
  116. emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
  117. emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
  118. emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
  119. emx_onnx_cgen/templates/split_op.c.j2 +18 -0
  120. emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
  121. emx_onnx_cgen/templates/testbench.c.j2 +161 -0
  122. emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
  123. emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
  124. emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
  125. emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
  126. emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
  127. emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
  128. emx_onnx_cgen/templates/where_op.c.j2 +9 -0
  129. emx_onnx_cgen/verification.py +45 -5
  130. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/METADATA +33 -15
  131. emx_onnx_cgen-0.4.1.dev0.dist-info/RECORD +190 -0
  132. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/WHEEL +1 -1
  133. emx_onnx_cgen/runtime/__init__.py +0 -1
  134. emx_onnx_cgen/runtime/evaluator.py +0 -2955
  135. emx_onnx_cgen-0.3.8.dist-info/RECORD +0 -107
  136. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/entry_points.txt +0 -0
  137. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,277 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ const int64_t input_shape[{{ rank }}] = { {% for dim in input_shape %}{{ dim }}{% if not loop.last %}, {% endif %}{% endfor %} };
3
+ const int64_t output_shape[{{ rank }}] EMX_UNUSED = { {% for dim in output_shape %}{{ dim }}{% if not loop.last %}, {% endif %}{% endfor %} };
4
+ double scales[{{ rank }}] EMX_UNUSED;
5
+ {% if roi_input %}
6
+ double roi[{{ rank * 2 }}] EMX_UNUSED;
7
+ {% if roi_axes %}
8
+ for (idx_t r = 0; r < {{ rank }}; ++r) {
9
+ roi[r] = 0.0;
10
+ roi[r + {{ rank }}] = 1.0;
11
+ }
12
+ {% for axis in axes %}
13
+ roi[{{ axis }}] = (double){{ roi_input }}[{{ roi_axis_map[loop.index0] }}];
14
+ roi[{{ axis + rank }}] = (double){{ roi_input }}[{{ roi_axis_map[loop.index0] + (axes | length) }}];
15
+ {% endfor %}
16
+ {% else %}
17
+ for (idx_t r = 0; r < {{ rank * 2 }}; ++r) {
18
+ roi[r] = (double){{ roi_input }}[r];
19
+ }
20
+ {% endif %}
21
+ {% else %}
22
+ const double roi[{{ rank * 2 }}] EMX_UNUSED = {
23
+ {% for _ in range(rank) %} 0.0{% if not loop.last %},{% endif %}
24
+ {% endfor %}{% if rank > 0 %},{% endif %}
25
+ {% for _ in range(rank) %} 1.0{% if not loop.last %},{% endif %}
26
+ {% endfor %}
27
+ };
28
+ {% endif %}
29
+ {% if scales_input %}
30
+ {% if scales_axes %}
31
+ for (idx_t r = 0; r < {{ rank }}; ++r) {
32
+ scales[r] = 1.0;
33
+ }
34
+ {% for axis in axes %}
35
+ scales[{{ axis }}] = (double){{ scales_input }}[{{ scales_axis_map[loop.index0] }}];
36
+ {% endfor %}
37
+ {% else %}
38
+ {% for axis in range(rank) %}
39
+ scales[{{ axis }}] = (double){{ scales_input }}[{{ axis }}];
40
+ {% endfor %}
41
+ {% endif %}
42
+ {% elif sizes_input %}
43
+ {% if keep_aspect_ratio_policy != "stretch" %}
44
+ {% if keep_aspect_ratio_policy == "not_larger" %}
45
+ double scale_value = INFINITY;
46
+ {% else %}
47
+ double scale_value = 0.0;
48
+ {% endif %}
49
+ {% for axis in axes %}
50
+ {
51
+ const double size_axis = (double){{ sizes_input }}[{{ sizes_axis_map[loop.index0] }}];
52
+ const double axis_scale = size_axis / (double)input_shape[{{ axis }}];
53
+ {% if keep_aspect_ratio_policy == "not_larger" %}
54
+ if (axis_scale < scale_value) {
55
+ scale_value = axis_scale;
56
+ }
57
+ {% else %}
58
+ if (axis_scale > scale_value) {
59
+ scale_value = axis_scale;
60
+ }
61
+ {% endif %}
62
+ }
63
+ {% endfor %}
64
+ for (idx_t r = 0; r < {{ rank }}; ++r) {
65
+ scales[r] = 1.0;
66
+ }
67
+ {% for axis in axes %}
68
+ scales[{{ axis }}] = scale_value;
69
+ {% endfor %}
70
+ {% else %}
71
+ for (idx_t r = 0; r < {{ rank }}; ++r) {
72
+ scales[r] = 1.0;
73
+ }
74
+ {% for axis in axes %}
75
+ scales[{{ axis }}] = (double){{ sizes_input }}[{{ sizes_axis_map[loop.index0] }}] / (double)input_shape[{{ axis }}];
76
+ {% endfor %}
77
+ {% endif %}
78
+ {% else %}
79
+ {% for axis in range(rank) %}
80
+ scales[{{ axis }}] = {{ scales[axis] }};
81
+ {% endfor %}
82
+ {% endif %}
83
+ {% for dim in output_shape %}
84
+ for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
85
+ {% endfor %}
86
+ int use_extrapolation = 0;
87
+ {% for dim in range(rank) %}
88
+ double x_orig{{ dim }};
89
+ {% if coordinate_transformation_mode == "align_corners" %}
90
+ if (output_shape[{{ dim }}] == 1) {
91
+ x_orig{{ dim }} = 0.0;
92
+ } else {
93
+ x_orig{{ dim }} = (double){{ loop_vars[dim] }} * (input_shape[{{ dim }}] - 1) / (double)(output_shape[{{ dim }}] - 1);
94
+ }
95
+ {% elif coordinate_transformation_mode == "asymmetric" %}
96
+ x_orig{{ dim }} = (double){{ loop_vars[dim] }} / scales[{{ dim }}];
97
+ {% elif coordinate_transformation_mode == "tf_crop_and_resize" %}
98
+ {
99
+ const double roi_start = roi[{{ dim }}];
100
+ const double roi_end = roi[{{ dim + rank }}];
101
+ if (output_shape[{{ dim }}] == 1) {
102
+ x_orig{{ dim }} = (roi_end - roi_start) * (input_shape[{{ dim }}] - 1) / 2.0;
103
+ } else {
104
+ x_orig{{ dim }} = (double){{ loop_vars[dim] }} * (roi_end - roi_start) * (input_shape[{{ dim }}] - 1)
105
+ / (double)(output_shape[{{ dim }}] - 1);
106
+ }
107
+ x_orig{{ dim }} += roi_start * (input_shape[{{ dim }}] - 1);
108
+ if (x_orig{{ dim }} < 0.0 || x_orig{{ dim }} > (double)(input_shape[{{ dim }}] - 1)) {
109
+ use_extrapolation = 1;
110
+ }
111
+ }
112
+ {% elif coordinate_transformation_mode == "pytorch_half_pixel" %}
113
+ if (output_shape[{{ dim }}] == 1) {
114
+ x_orig{{ dim }} = -0.5;
115
+ } else {
116
+ x_orig{{ dim }} = ((double){{ loop_vars[dim] }} + 0.5) / scales[{{ dim }}] - 0.5;
117
+ }
118
+ {% elif coordinate_transformation_mode == "half_pixel_symmetric" %}
119
+ {
120
+ const double output_width = scales[{{ dim }}] * (double)input_shape[{{ dim }}];
121
+ const double adjustment = (double)output_shape[{{ dim }}] / output_width;
122
+ const double center = (double)input_shape[{{ dim }}] / 2.0;
123
+ const double offset = center * (1.0 - adjustment);
124
+ x_orig{{ dim }} = offset + ((double){{ loop_vars[dim] }} + 0.5) / scales[{{ dim }}] - 0.5;
125
+ }
126
+ {% else %}
127
+ x_orig{{ dim }} = ((double){{ loop_vars[dim] }} + 0.5) / scales[{{ dim }}] - 0.5;
128
+ {% endif %}
129
+ {% endfor %}
130
+ if (use_extrapolation) {
131
+ {{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = ({{ c_type }}){{ extrapolation_value }};
132
+ } else {
133
+ {% if mode == "nearest" %}
134
+ {% for dim in range(rank) %}
135
+ const double x_val{{ dim }} = x_orig{{ dim }};
136
+ const double x_floor{{ dim }} EMX_UNUSED = floor(x_val{{ dim }});
137
+ const double x_ceil{{ dim }} EMX_UNUSED = ceil(x_val{{ dim }});
138
+ int idx{{ dim }};
139
+ {% if nearest_mode == "round_prefer_floor" %}
140
+ idx{{ dim }} = (x_val{{ dim }} - x_floor{{ dim }} <= 0.5) ? (int)x_floor{{ dim }} : (int)x_ceil{{ dim }};
141
+ {% elif nearest_mode == "round_prefer_ceil" %}
142
+ idx{{ dim }} = (x_val{{ dim }} - x_floor{{ dim }} < 0.5) ? (int)x_floor{{ dim }} : (int)x_ceil{{ dim }};
143
+ {% elif nearest_mode == "floor" %}
144
+ idx{{ dim }} = (int)x_floor{{ dim }};
145
+ {% else %}
146
+ idx{{ dim }} = (int)x_ceil{{ dim }};
147
+ {% endif %}
148
+ if (idx{{ dim }} < 0) {
149
+ idx{{ dim }} = 0;
150
+ } else if (idx{{ dim }} >= input_shape[{{ dim }}]) {
151
+ idx{{ dim }} = (int)input_shape[{{ dim }}] - 1;
152
+ }
153
+ {% endfor %}
154
+ {{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = {{ input0 }}{% for dim in range(rank) %}[idx{{ dim }}]{% endfor %};
155
+ {% else %}
156
+ double acc = 0.0;
157
+ {% for dim in range(rank) %}
158
+ double x_floor{{ dim }} = floor(x_orig{{ dim }});
159
+ double ratio{{ dim }} = x_orig{{ dim }} - x_floor{{ dim }};
160
+ const int is_integer{{ dim }} = ratio{{ dim }} == 0.0;
161
+ if (is_integer{{ dim }}) {
162
+ ratio{{ dim }} = 1.0;
163
+ }
164
+ int count{{ dim }};
165
+ {% if antialias %}
166
+ double scale_clamped{{ dim }} = scales[{{ dim }}] < 1.0 ? scales[{{ dim }}] : 1.0;
167
+ {% if mode == "linear" %}
168
+ int coeff_start{{ dim }} = (int)floor(-1.0 / scale_clamped{{ dim }}) + 1;
169
+ count{{ dim }} = 2 - 2 * coeff_start{{ dim }};
170
+ {% else %}
171
+ int coeff_start{{ dim }} = (int)floor(-2.0 / scale_clamped{{ dim }}) + 1;
172
+ int coeff_end{{ dim }} = 2 - coeff_start{{ dim }};
173
+ count{{ dim }} = coeff_end{{ dim }} - coeff_start{{ dim }};
174
+ {% endif %}
175
+ {% else %}
176
+ {% if mode == "linear" %}
177
+ count{{ dim }} = 2;
178
+ {% else %}
179
+ count{{ dim }} = 4;
180
+ {% endif %}
181
+ {% endif %}
182
+ int start_index{{ dim }} = (int)x_floor{{ dim }} - (count{{ dim }} / 2);
183
+ if (!is_integer{{ dim }}) {
184
+ start_index{{ dim }} += 1;
185
+ }
186
+ int idx{{ dim }}_values[count{{ dim }}];
187
+ double coeff{{ dim }}[count{{ dim }}];
188
+ {% if antialias %}
189
+ double coeff_sum{{ dim }} = 0.0;
190
+ for (int c = 0; c < count{{ dim }}; ++c) {
191
+ idx{{ dim }}_values[c] = start_index{{ dim }} + c;
192
+ {% if mode == "linear" %}
193
+ double arg = (coeff_start{{ dim }} + c - ratio{{ dim }}) * scale_clamped{{ dim }};
194
+ double coeff = 1.0 - fabs(arg);
195
+ if (coeff < 0.0) {
196
+ coeff = 0.0;
197
+ }
198
+ {% else %}
199
+ double x = scale_clamped{{ dim }} * (coeff_start{{ dim }} + c - ratio{{ dim }});
200
+ double ax = fabs(x);
201
+ double coeff;
202
+ if (ax <= 1.0) {
203
+ coeff = ({{ cubic_coeff_a }} + 2.0) * ax * ax * ax - ({{ cubic_coeff_a }} + 3.0) * ax * ax + 1.0;
204
+ } else if (ax < 2.0) {
205
+ coeff = {{ cubic_coeff_a }} * ax * ax * ax - 5.0 * {{ cubic_coeff_a }} * ax * ax + 8.0 * {{ cubic_coeff_a }} * ax - 4.0 * {{ cubic_coeff_a }};
206
+ } else {
207
+ coeff = 0.0;
208
+ }
209
+ {% endif %}
210
+ coeff{{ dim }}[c] = coeff;
211
+ coeff_sum{{ dim }} += coeff;
212
+ }
213
+ if (coeff_sum{{ dim }} > 0.0) {
214
+ for (int c = 0; c < count{{ dim }}; ++c) {
215
+ coeff{{ dim }}[c] /= coeff_sum{{ dim }};
216
+ }
217
+ }
218
+ {% else %}
219
+ for (int c = 0; c < count{{ dim }}; ++c) {
220
+ idx{{ dim }}_values[c] = start_index{{ dim }} + c;
221
+ }
222
+ {% if mode == "linear" %}
223
+ coeff{{ dim }}[0] = 1.0 - ratio{{ dim }};
224
+ coeff{{ dim }}[1] = ratio{{ dim }};
225
+ {% else %}
226
+ {
227
+ const double t = 1.0 - ratio{{ dim }};
228
+ coeff{{ dim }}[0] = (({{ cubic_coeff_a }} * (ratio{{ dim }} + 1.0) - 5.0 * {{ cubic_coeff_a }}) * (ratio{{ dim }} + 1.0) + 8.0 * {{ cubic_coeff_a }}) * (ratio{{ dim }} + 1.0) - 4.0 * {{ cubic_coeff_a }};
229
+ coeff{{ dim }}[1] = (({{ cubic_coeff_a }} + 2.0) * ratio{{ dim }} - ({{ cubic_coeff_a }} + 3.0)) * ratio{{ dim }} * ratio{{ dim }} + 1.0;
230
+ coeff{{ dim }}[2] = (({{ cubic_coeff_a }} + 2.0) * t - ({{ cubic_coeff_a }} + 3.0)) * t * t + 1.0;
231
+ coeff{{ dim }}[3] = (({{ cubic_coeff_a }} * (t + 1.0) - 5.0 * {{ cubic_coeff_a }}) * (t + 1.0) + 8.0 * {{ cubic_coeff_a }}) * (t + 1.0) - 4.0 * {{ cubic_coeff_a }};
232
+ }
233
+ {% endif %}
234
+ {% endif %}
235
+ {% if exclude_outside %}
236
+ {
237
+ double coeff_sum{{ dim }} = 0.0;
238
+ for (int c = 0; c < count{{ dim }}; ++c) {
239
+ if (idx{{ dim }}_values[c] < 0 || idx{{ dim }}_values[c] >= input_shape[{{ dim }}]) {
240
+ coeff{{ dim }}[c] = 0.0;
241
+ }
242
+ coeff_sum{{ dim }} += coeff{{ dim }}[c];
243
+ }
244
+ if (coeff_sum{{ dim }} > 0.0) {
245
+ for (int c = 0; c < count{{ dim }}; ++c) {
246
+ coeff{{ dim }}[c] /= coeff_sum{{ dim }};
247
+ }
248
+ }
249
+ }
250
+ {% endif %}
251
+ {% endfor %}
252
+ {% for dim in range(rank) %}
253
+ for (int n{{ dim }} = 0; n{{ dim }} < count{{ dim }}; ++n{{ dim }}) {
254
+ {% endfor %}
255
+ double weight = 1.0;
256
+ {% for dim in range(rank) %}
257
+ weight *= coeff{{ dim }}[n{{ dim }}];
258
+ {% endfor %}
259
+ {% for dim in range(rank) %}
260
+ int idx{{ dim }} = idx{{ dim }}_values[n{{ dim }}];
261
+ if (idx{{ dim }} < 0) {
262
+ idx{{ dim }} = 0;
263
+ } else if (idx{{ dim }} >= input_shape[{{ dim }}]) {
264
+ idx{{ dim }} = (int)input_shape[{{ dim }}] - 1;
265
+ }
266
+ {% endfor %}
267
+ acc += weight * (double){{ input0 }}{% for dim in range(rank) %}[idx{{ dim }}]{% endfor %};
268
+ {% for dim in range(rank) %}
269
+ }
270
+ {% endfor %}
271
+ {{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = ({{ c_type }})acc;
272
+ {% endif %}
273
+ }
274
+ {% for _ in output_shape %}
275
+ }
276
+ {% endfor %}
277
+ }
@@ -0,0 +1,28 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for dim in prefix_shape %}
3
+ for (idx_t {{ prefix_loop_vars[loop.index0] }} = 0; {{ prefix_loop_vars[loop.index0] }} < {{ dim }}; ++{{ prefix_loop_vars[loop.index0] }}) {
4
+ {% endfor %}
5
+ {{ c_type }} sum = {{ zero_literal }};
6
+ {% for dim in norm_shape %}
7
+ for (idx_t {{ norm_loop_vars[loop.index0] }} = 0; {{ norm_loop_vars[loop.index0] }} < {{ dim }}; ++{{ norm_loop_vars[loop.index0] }}) {
8
+ {% endfor %}
9
+ {{ c_type }} value = {{ input0 }}{% for var in prefix_loop_vars %}[{{ var }}]{% endfor %}{% for var in norm_loop_vars %}[{{ var }}]{% endfor %};
10
+ sum += value * value;
11
+ {% for _ in norm_shape %}
12
+ }
13
+ {% endfor %}
14
+ {{ c_type }} mean_square = sum / {{ inner }};
15
+ {{ c_type }} denom = {{ sqrt_fn }}(mean_square + {{ epsilon_literal }});
16
+ {% for dim in norm_shape %}
17
+ for (idx_t {{ norm_loop_vars[loop.index0] }} = 0; {{ norm_loop_vars[loop.index0] }} < {{ dim }}; ++{{ norm_loop_vars[loop.index0] }}) {
18
+ {% endfor %}
19
+ {{ c_type }} value = {{ input0 }}{% for var in prefix_loop_vars %}[{{ var }}]{% endfor %}{% for var in norm_loop_vars %}[{{ var }}]{% endfor %} / denom;
20
+ value = value * {{ scale }}{% for var in scale_index_vars %}[{{ var }}]{% endfor %};
21
+ {{ output }}{% for var in prefix_loop_vars %}[{{ var }}]{% endfor %}{% for var in norm_loop_vars %}[{{ var }}]{% endfor %} = value;
22
+ {% for _ in norm_shape %}
23
+ }
24
+ {% endfor %}
25
+ {% for _ in prefix_shape %}
26
+ }
27
+ {% endfor %}
28
+ }
@@ -0,0 +1,66 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% if input_rank == 3 %}
3
+ for (idx_t b = 0; b < {{ batch }}; ++b) {
4
+ for (idx_t s = 0; s < {{ seq_len }}; ++s) {
5
+ for (idx_t h = 0; h < {{ num_heads }}; ++h) {
6
+ idx_t head_offset = h * {{ head_size }};
7
+ for (idx_t d = 0; d < {{ rotary_dim_half }}; ++d) {
8
+ {% if interleaved %}
9
+ idx_t idx1 = 2 * d;
10
+ idx_t idx2 = 2 * d + 1;
11
+ {% else %}
12
+ idx_t idx1 = d;
13
+ idx_t idx2 = d + {{ rotary_dim_half }};
14
+ {% endif %}
15
+ {% if has_position_ids %}
16
+ idx_t pos = (idx_t){{ position_ids }}[b][s];
17
+ {{ c_type }} cos_val = {{ cos_cache }}[pos][d];
18
+ {{ c_type }} sin_val = {{ sin_cache }}[pos][d];
19
+ {% else %}
20
+ {{ c_type }} cos_val = {{ cos_cache }}[b][s][d];
21
+ {{ c_type }} sin_val = {{ sin_cache }}[b][s][d];
22
+ {% endif %}
23
+ {{ c_type }} x1 = {{ input0 }}[b][s][head_offset + idx1];
24
+ {{ c_type }} x2 = {{ input0 }}[b][s][head_offset + idx2];
25
+ {{ output }}[b][s][head_offset + idx1] = (cos_val * x1) - (sin_val * x2);
26
+ {{ output }}[b][s][head_offset + idx2] = (sin_val * x1) + (cos_val * x2);
27
+ }
28
+ for (idx_t d = {{ rotary_dim }}; d < {{ head_size }}; ++d) {
29
+ {{ output }}[b][s][head_offset + d] = {{ input0 }}[b][s][head_offset + d];
30
+ }
31
+ }
32
+ }
33
+ }
34
+ {% else %}
35
+ for (idx_t b = 0; b < {{ batch }}; ++b) {
36
+ for (idx_t h = 0; h < {{ num_heads }}; ++h) {
37
+ for (idx_t s = 0; s < {{ seq_len }}; ++s) {
38
+ for (idx_t d = 0; d < {{ rotary_dim_half }}; ++d) {
39
+ {% if interleaved %}
40
+ idx_t idx1 = 2 * d;
41
+ idx_t idx2 = 2 * d + 1;
42
+ {% else %}
43
+ idx_t idx1 = d;
44
+ idx_t idx2 = d + {{ rotary_dim_half }};
45
+ {% endif %}
46
+ {% if has_position_ids %}
47
+ idx_t pos = (idx_t){{ position_ids }}[b][s];
48
+ {{ c_type }} cos_val = {{ cos_cache }}[pos][d];
49
+ {{ c_type }} sin_val = {{ sin_cache }}[pos][d];
50
+ {% else %}
51
+ {{ c_type }} cos_val = {{ cos_cache }}[b][s][d];
52
+ {{ c_type }} sin_val = {{ sin_cache }}[b][s][d];
53
+ {% endif %}
54
+ {{ c_type }} x1 = {{ input0 }}[b][h][s][idx1];
55
+ {{ c_type }} x2 = {{ input0 }}[b][h][s][idx2];
56
+ {{ output }}[b][h][s][idx1] = (cos_val * x1) - (sin_val * x2);
57
+ {{ output }}[b][h][s][idx2] = (sin_val * x1) + (cos_val * x2);
58
+ }
59
+ for (idx_t d = {{ rotary_dim }}; d < {{ head_size }}; ++d) {
60
+ {{ output }}[b][h][s][d] = {{ input0 }}[b][h][s][d];
61
+ }
62
+ }
63
+ }
64
+ }
65
+ {% endif %}
66
+ }
@@ -0,0 +1,52 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for dim in output_shape %}
3
+ for (idx_t {{ output_loop_vars[loop.index0] }} = 0; {{ output_loop_vars[loop.index0] }} < {{ dim }}; ++{{ output_loop_vars[loop.index0] }}) {
4
+ {% endfor %}
5
+ {{ output }}{% for var in output_loop_vars %}[{{ var }}]{% endfor %} = {{ data }}{% for var in output_loop_vars %}[{{ var }}]{% endfor %};
6
+ {% for _ in output_shape %}
7
+ }
8
+ {% endfor %}
9
+
10
+ {% if indices_prefix_shape %}
11
+ {% for dim in indices_prefix_shape %}
12
+ for (idx_t {{ indices_prefix_loop_vars[loop.index0] }} = 0; {{ indices_prefix_loop_vars[loop.index0] }} < {{ dim }}; ++{{ indices_prefix_loop_vars[loop.index0] }}) {
13
+ {% endfor %}
14
+ {% endif %}
15
+ {% for idx in range(index_depth) %}
16
+ idx_t index{{ idx }} = {{ indices }}{% for var in indices_prefix_loop_vars %}[{{ var }}]{% endfor %}[{{ idx }}];
17
+ if (index{{ idx }} < 0) {
18
+ index{{ idx }} += {{ data_shape[idx] }};
19
+ }
20
+ {% endfor %}
21
+ {% if tail_shape %}
22
+ {% for dim in tail_shape %}
23
+ for (idx_t {{ tail_loop_vars[loop.index0] }} = 0; {{ tail_loop_vars[loop.index0] }} < {{ dim }}; ++{{ tail_loop_vars[loop.index0] }}) {
24
+ {% endfor %}
25
+ {% endif %}
26
+ {{ c_type }} update_val = {{ updates_index_expr }};
27
+ {% if reduction == "none" %}
28
+ {{ output_index_expr }} = update_val;
29
+ {% elif reduction == "add" %}
30
+ {{ output_index_expr }} += update_val;
31
+ {% elif reduction == "mul" %}
32
+ {{ output_index_expr }} *= update_val;
33
+ {% elif reduction == "max" %}
34
+ if (update_val > {{ output_index_expr }}) {
35
+ {{ output_index_expr }} = update_val;
36
+ }
37
+ {% elif reduction == "min" %}
38
+ if (update_val < {{ output_index_expr }}) {
39
+ {{ output_index_expr }} = update_val;
40
+ }
41
+ {% endif %}
42
+ {% if tail_shape %}
43
+ {% for _ in tail_shape %}
44
+ }
45
+ {% endfor %}
46
+ {% endif %}
47
+ {% if indices_prefix_shape %}
48
+ {% for _ in indices_prefix_shape %}
49
+ }
50
+ {% endfor %}
51
+ {% endif %}
52
+ }
@@ -0,0 +1,6 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ (void){{ input0 }};
3
+ {% for value in values %}
4
+ {{ output }}[{{ loop.index0 }}] = {{ value }};
5
+ {% endfor %}
6
+ }
@@ -0,0 +1,4 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ (void){{ input0 }};
3
+ {{ output }}[0] = {{ value }};
4
+ }
@@ -0,0 +1,9 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for dim in output_shape %}
3
+ for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
4
+ {% endfor %}
5
+ {{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = {{ input0 }}{% for idx in input_indices %}[{{ idx }}]{% endfor %};
6
+ {% for _ in output_shape %}
7
+ }
8
+ {% endfor %}
9
+ }
@@ -0,0 +1,70 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ const idx_t input_rank = {{ input_rank }};
3
+ const idx_t input_dims[] = { {% for dim in input_shape %}{{ dim }}{% if not loop.last %}, {% endif %}{% endfor %} };
4
+ idx_t start_indices[{{ input_rank }}];
5
+ idx_t step_values[{{ input_rank }}];
6
+ for (idx_t idx = 0; idx < input_rank; ++idx) {
7
+ start_indices[idx] = 0;
8
+ step_values[idx] = 1;
9
+ }
10
+ for (idx_t i = 0; i < {{ starts_len }}; ++i) {
11
+ {% if axes_input %}
12
+ idx_t axis_value = {{ axes_input }}[i];
13
+ {% else %}
14
+ idx_t axis_value = i;
15
+ {% endif %}
16
+ if (axis_value < 0) {
17
+ axis_value += input_rank;
18
+ }
19
+ idx_t axis = axis_value;
20
+ idx_t dim = input_dims[axis];
21
+ idx_t start_value = {{ starts_input }}[i];
22
+ idx_t step_value = 1;
23
+ {% if steps_input %}
24
+ step_value = {{ steps_input }}[i];
25
+ {% endif %}
26
+ idx_t end_value = {{ ends_input }}[i];
27
+ if (start_value < 0) {
28
+ start_value += dim;
29
+ }
30
+ if (end_value < 0) {
31
+ end_value += dim;
32
+ }
33
+ if (step_value > 0) {
34
+ if (start_value < 0) {
35
+ start_value = 0;
36
+ }
37
+ if (start_value > dim) {
38
+ start_value = dim;
39
+ }
40
+ if (end_value < 0) {
41
+ end_value = 0;
42
+ }
43
+ if (end_value > dim) {
44
+ end_value = dim;
45
+ }
46
+ } else {
47
+ if (start_value < 0) {
48
+ start_value = -1;
49
+ }
50
+ if (start_value >= dim) {
51
+ start_value = dim - 1;
52
+ }
53
+ if (end_value < -1) {
54
+ end_value = -1;
55
+ }
56
+ if (end_value >= dim) {
57
+ end_value = dim - 1;
58
+ }
59
+ }
60
+ start_indices[axis] = start_value;
61
+ step_values[axis] = step_value;
62
+ }
63
+ {% for dim in output_shape %}
64
+ for (idx_t {{ output_loop_vars[loop.index0] }} = 0; {{ output_loop_vars[loop.index0] }} < {{ dim }}; ++{{ output_loop_vars[loop.index0] }}) {
65
+ {% endfor %}
66
+ {{ output }}{% for var in output_loop_vars %}[{{ var }}]{% endfor %} = {{ input0 }}{% for var in output_loop_vars %}[start_indices[{{ loop.index0 }}] + ({{ var }} * step_values[{{ loop.index0 }}])]{% endfor %};
67
+ {% for _ in output_shape %}
68
+ }
69
+ {% endfor %}
70
+ }
@@ -0,0 +1,105 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ const {{ c_type }} *input_flat = (const {{ c_type }} *){{ input0 }};
3
+ const {{ target_c_type }} *target_flat = (const {{ target_c_type }} *){{ target }};
4
+ {{ c_type }} *output_flat = ({{ c_type }} *){{ output }};
5
+ {% if log_prob %}
6
+ {{ c_type }} *log_prob_flat = ({{ c_type }} *){{ log_prob }};
7
+ {% endif %}
8
+ const idx_t n = {{ n }};
9
+ const idx_t c = {{ c }};
10
+ const idx_t d = {{ d }};
11
+ {% if reduction != "none" %}
12
+ {{ acc_type }} loss_sum = {{ acc_zero_literal }};
13
+ {% if reduction == "mean" and (weight or use_ignore_index) %}
14
+ {{ acc_type }} weight_sum = {{ acc_zero_literal }};
15
+ {% endif %}
16
+ {% endif %}
17
+ for (idx_t n_idx = 0; n_idx < n; ++n_idx) {
18
+ for (idx_t d_idx = 0; d_idx < d; ++d_idx) {
19
+ idx_t target_index = n_idx * d + d_idx;
20
+ {{ target_c_type }} target_value = target_flat[target_index];
21
+ {% if use_ignore_index %}
22
+ bool ignored = ((int64_t)target_value == {{ ignore_index }});
23
+ {% endif %}
24
+ idx_t class_index = (idx_t)target_value;
25
+ idx_t base = (n_idx * c * d) + d_idx;
26
+ {{ acc_type }} max_value = ({{ acc_type }})input_flat[base];
27
+ for (idx_t c_idx = 1; c_idx < c; ++c_idx) {
28
+ {{ acc_type }} value = ({{ acc_type }})input_flat[base + c_idx * d];
29
+ max_value = {{ max_fn }}(max_value, value);
30
+ }
31
+ {{ acc_type }} sum = {{ acc_zero_literal }};
32
+ for (idx_t c_idx = 0; c_idx < c; ++c_idx) {
33
+ {{ acc_type }} value = ({{ acc_type }})input_flat[base + c_idx * d] - max_value;
34
+ sum += {{ acc_exp_fn }}(value);
35
+ }
36
+ {{ acc_type }} logsum = {{ acc_log_fn }}(sum);
37
+ {% if log_prob %}
38
+ {{ acc_type }} loss_value = {{ acc_zero_literal }};
39
+ for (idx_t c_idx = 0; c_idx < c; ++c_idx) {
40
+ {{ acc_type }} log_prob_value = ({{ acc_type }})input_flat[base + c_idx * d] - max_value - logsum;
41
+ log_prob_flat[base + c_idx * d] = ({{ c_type }})log_prob_value;
42
+ if (c_idx == class_index) {
43
+ loss_value = -log_prob_value;
44
+ }
45
+ }
46
+ {% if use_ignore_index %}
47
+ if (ignored) {
48
+ {% if reduction == "none" %}
49
+ output_flat[target_index] = {{ zero_literal }};
50
+ {% endif %}
51
+ }
52
+ {% endif %}
53
+ {% else %}
54
+ {{ acc_type }} loss_value = {{ acc_zero_literal }};
55
+ {% if use_ignore_index %}
56
+ if (ignored) {
57
+ {% if reduction == "none" %}
58
+ output_flat[target_index] = {{ zero_literal }};
59
+ {% endif %}
60
+ } else {
61
+ {% endif %}
62
+ {{ acc_type }} log_prob_value = ({{ acc_type }})input_flat[base + class_index * d] - max_value - logsum;
63
+ loss_value = -log_prob_value;
64
+ {% if use_ignore_index %}
65
+ }
66
+ {% endif %}
67
+ {% endif %}
68
+ {% if use_ignore_index %}
69
+ if (!ignored) {
70
+ {% endif %}
71
+ {% if weight %}
72
+ {{ acc_type }} sample_weight = {{ weight }}[class_index];
73
+ loss_value *= sample_weight;
74
+ {% endif %}
75
+ {% if reduction == "none" %}
76
+ output_flat[target_index] = loss_value;
77
+ {% else %}
78
+ loss_sum += loss_value;
79
+ {% if reduction == "mean" %}
80
+ {% if weight %}
81
+ weight_sum += sample_weight;
82
+ {% elif use_ignore_index %}
83
+ weight_sum += {{ acc_one_literal }};
84
+ {% endif %}
85
+ {% endif %}
86
+ {% endif %}
87
+ {% if use_ignore_index %}
88
+ }
89
+ {% endif %}
90
+ }
91
+ }
92
+ {% if reduction == "mean" %}
93
+ {% if weight or use_ignore_index %}
94
+ if (weight_sum == {{ acc_zero_literal }}) {
95
+ output_flat[0] = {{ zero_literal }};
96
+ } else {
97
+ output_flat[0] = loss_sum / weight_sum;
98
+ }
99
+ {% else %}
100
+ output_flat[0] = loss_sum / ({{ n }} * {{ d }});
101
+ {% endif %}
102
+ {% elif reduction == "sum" %}
103
+ output_flat[0] = loss_sum;
104
+ {% endif %}
105
+ }
@@ -0,0 +1,26 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ const {{ c_type }} *input_flat = (const {{ c_type }} *){{ input0 }};
3
+ {{ c_type }} *output_flat = ({{ c_type }} *){{ output }};
4
+ const idx_t outer = {{ outer }};
5
+ const idx_t axis_size = {{ axis_size }};
6
+ const idx_t inner = {{ inner }};
7
+ for (idx_t outer_idx = 0; outer_idx < outer; ++outer_idx) {
8
+ for (idx_t inner_idx = 0; inner_idx < inner; ++inner_idx) {
9
+ idx_t base = (outer_idx * axis_size * inner) + inner_idx;
10
+ {{ c_type }} max_value = input_flat[base];
11
+ for (idx_t axis_idx = 1; axis_idx < axis_size; ++axis_idx) {
12
+ {{ c_type }} value = input_flat[base + axis_idx * inner];
13
+ max_value = {{ max_fn }}(max_value, value);
14
+ }
15
+ {{ c_type }} sum = 0;
16
+ for (idx_t axis_idx = 0; axis_idx < axis_size; ++axis_idx) {
17
+ {{ c_type }} value = {{ exp_fn }}(input_flat[base + axis_idx * inner] - max_value);
18
+ output_flat[base + axis_idx * inner] = value;
19
+ sum += value;
20
+ }
21
+ for (idx_t axis_idx = 0; axis_idx < axis_size; ++axis_idx) {
22
+ output_flat[base + axis_idx * inner] /= sum;
23
+ }
24
+ }
25
+ }
26
+ }
@@ -0,0 +1,22 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ const {{ c_type }} *input_data = (const {{ c_type }} *){{ input0 }};
3
+ {{ c_type }} *output_data = ({{ c_type }} *){{ output }};
4
+ idx_t output_index = 0;
5
+ for (idx_t n = 0; n < {{ batch }}; ++n) {
6
+ for (idx_t c_out = 0; c_out < {{ out_channels }}; ++c_out) {
7
+ idx_t c_in = c_out % {{ in_channels }};
8
+ idx_t temp = c_out / {{ in_channels }};
9
+ idx_t offset_h = temp / {{ blocksize }};
10
+ idx_t offset_w = temp % {{ blocksize }};
11
+ for (idx_t h_out = 0; h_out < {{ out_h }}; ++h_out) {
12
+ idx_t h_in = h_out * {{ blocksize }} + offset_h;
13
+ for (idx_t w_out = 0; w_out < {{ out_w }}; ++w_out) {
14
+ idx_t w_in = w_out * {{ blocksize }} + offset_w;
15
+ idx_t input_index = ((n * {{ in_channels }} + c_in) * {{ in_h }} + h_in) * {{ in_w }} + w_in;
16
+ output_data[output_index] = input_data[input_index];
17
+ output_index++;
18
+ }
19
+ }
20
+ }
21
+ }
22
+ }