emx-onnx-cgen 0.3.8__py3-none-any.whl → 0.4.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (137) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +1025 -162
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +2081 -458
  6. emx_onnx_cgen/compiler.py +157 -75
  7. emx_onnx_cgen/determinism.py +39 -0
  8. emx_onnx_cgen/ir/context.py +25 -15
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +32 -7
  11. emx_onnx_cgen/ir/ops/__init__.py +20 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +138 -22
  13. emx_onnx_cgen/ir/ops/misc.py +95 -0
  14. emx_onnx_cgen/ir/ops/nn.py +361 -38
  15. emx_onnx_cgen/ir/ops/reduce.py +1 -16
  16. emx_onnx_cgen/lowering/__init__.py +9 -0
  17. emx_onnx_cgen/lowering/arg_reduce.py +0 -4
  18. emx_onnx_cgen/lowering/average_pool.py +157 -27
  19. emx_onnx_cgen/lowering/bernoulli.py +73 -0
  20. emx_onnx_cgen/lowering/common.py +48 -0
  21. emx_onnx_cgen/lowering/concat.py +41 -7
  22. emx_onnx_cgen/lowering/conv.py +19 -8
  23. emx_onnx_cgen/lowering/conv_integer.py +103 -0
  24. emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
  25. emx_onnx_cgen/lowering/elementwise.py +140 -43
  26. emx_onnx_cgen/lowering/gather.py +11 -2
  27. emx_onnx_cgen/lowering/gemm.py +7 -124
  28. emx_onnx_cgen/lowering/global_max_pool.py +0 -5
  29. emx_onnx_cgen/lowering/gru.py +323 -0
  30. emx_onnx_cgen/lowering/hamming_window.py +104 -0
  31. emx_onnx_cgen/lowering/hardmax.py +1 -37
  32. emx_onnx_cgen/lowering/identity.py +7 -6
  33. emx_onnx_cgen/lowering/logsoftmax.py +1 -35
  34. emx_onnx_cgen/lowering/lp_pool.py +15 -4
  35. emx_onnx_cgen/lowering/matmul.py +3 -105
  36. emx_onnx_cgen/lowering/optional_has_element.py +28 -0
  37. emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
  38. emx_onnx_cgen/lowering/reduce.py +0 -5
  39. emx_onnx_cgen/lowering/reshape.py +7 -16
  40. emx_onnx_cgen/lowering/shape.py +14 -8
  41. emx_onnx_cgen/lowering/slice.py +14 -4
  42. emx_onnx_cgen/lowering/softmax.py +1 -35
  43. emx_onnx_cgen/lowering/split.py +37 -3
  44. emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
  45. emx_onnx_cgen/lowering/tile.py +38 -1
  46. emx_onnx_cgen/lowering/topk.py +1 -5
  47. emx_onnx_cgen/lowering/transpose.py +9 -3
  48. emx_onnx_cgen/lowering/unsqueeze.py +11 -16
  49. emx_onnx_cgen/lowering/upsample.py +151 -0
  50. emx_onnx_cgen/lowering/variadic.py +1 -1
  51. emx_onnx_cgen/lowering/where.py +0 -5
  52. emx_onnx_cgen/onnx_import.py +578 -14
  53. emx_onnx_cgen/ops.py +3 -0
  54. emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
  55. emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
  56. emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
  57. emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
  58. emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
  59. emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
  60. emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
  61. emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
  62. emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
  63. emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
  64. emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
  65. emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
  66. emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
  67. emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
  68. emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
  69. emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
  70. emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
  71. emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
  72. emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
  73. emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
  74. emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
  75. emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
  76. emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
  77. emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
  78. emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
  79. emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
  80. emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
  81. emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
  82. emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
  83. emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
  84. emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
  85. emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
  86. emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
  87. emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
  88. emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
  89. emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
  90. emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
  91. emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
  92. emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
  93. emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
  94. emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
  95. emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
  96. emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
  97. emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
  98. emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
  99. emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
  100. emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
  101. emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
  102. emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
  103. emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
  104. emx_onnx_cgen/templates/range_op.c.j2 +8 -0
  105. emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
  106. emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
  107. emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
  108. emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
  109. emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
  110. emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
  111. emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
  112. emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
  113. emx_onnx_cgen/templates/size_op.c.j2 +4 -0
  114. emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
  115. emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
  116. emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
  117. emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
  118. emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
  119. emx_onnx_cgen/templates/split_op.c.j2 +18 -0
  120. emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
  121. emx_onnx_cgen/templates/testbench.c.j2 +161 -0
  122. emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
  123. emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
  124. emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
  125. emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
  126. emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
  127. emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
  128. emx_onnx_cgen/templates/where_op.c.j2 +9 -0
  129. emx_onnx_cgen/verification.py +45 -5
  130. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/METADATA +33 -15
  131. emx_onnx_cgen-0.4.2.dev0.dist-info/RECORD +190 -0
  132. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/WHEEL +1 -1
  133. emx_onnx_cgen/runtime/__init__.py +0 -1
  134. emx_onnx_cgen/runtime/evaluator.py +0 -2955
  135. emx_onnx_cgen-0.3.8.dist-info/RECORD +0 -107
  136. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/entry_points.txt +0 -0
  137. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,10 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for dim in shape %}
3
+ for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
4
+ {% endfor %}
5
+ {{ compute_type }} dequantized = (({{ compute_type }}){{ input_expr }} - ({{ compute_type }}){{ zero_expr }}) * ({{ compute_type }}){{ scale_expr }};
6
+ {{ output_expr }} = ({{ output_c_type }})dequantized;
7
+ {% for _ in shape %}
8
+ }
9
+ {% endfor %}
10
+ }
@@ -0,0 +1,55 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% if kind == "reduce_all" %}
3
+ {{ acc_type }} acc = {{ zero_literal }};
4
+ {% for var in input_loop_vars %}
5
+ {% for indent in range(loop.index0) %} {% endfor %}for (idx_t {{ var }} = 0; {{ var }} < {{ input_loop_bounds[loop.index0] }}; ++{{ var }}) {
6
+ {% endfor %}
7
+ {% for indent in range(input_loop_vars | length) %} {% endfor %}acc += {{ input_expr }};
8
+ {% for _ in input_loop_vars | reverse %}
9
+ {% for indent in range(loop.index0) %} {% endfor %}}
10
+ {% endfor %}
11
+ {{ output_expr }} = acc;
12
+ {% elif kind == "sum_j" %}
13
+ for (idx_t {{ output_loop_vars[0] }} = 0; {{ output_loop_vars[0] }} < {{ output_loop_bounds[0] }}; ++{{ output_loop_vars[0] }}) {
14
+ {{ acc_type }} acc = {{ zero_literal }};
15
+ for (idx_t {{ reduce_loop_var }} = 0; {{ reduce_loop_var }} < {{ reduce_loop_bound }}; ++{{ reduce_loop_var }}) {
16
+ acc += {{ input_expr }};
17
+ }
18
+ {{ output_expr }} = acc;
19
+ }
20
+ {% elif kind == "transpose" %}
21
+ {% for var in output_loop_vars %}
22
+ {% for indent in range(loop.index0) %} {% endfor %}for (idx_t {{ var }} = 0; {{ var }} < {{ output_loop_bounds[loop.index0] }}; ++{{ var }}) {
23
+ {% endfor %}
24
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}{{ output_expr }} = {{ input_expr }};
25
+ {% for _ in output_loop_vars | reverse %}
26
+ {% for indent in range(loop.index0) %} {% endfor %}}
27
+ {% endfor %}
28
+ {% elif kind == "dot" %}
29
+ {{ acc_type }} acc = {{ zero_literal }};
30
+ for (idx_t {{ reduce_loop_var }} = 0; {{ reduce_loop_var }} < {{ reduce_loop_bound }}; ++{{ reduce_loop_var }}) {
31
+ acc += {{ input0_expr }} * {{ input1_expr }};
32
+ }
33
+ {{ output_expr }} = acc;
34
+ {% elif kind == "batch_matmul" %}
35
+ {% for var in output_loop_vars %}
36
+ {% for indent in range(loop.index0) %} {% endfor %}for (idx_t {{ var }} = 0; {{ var }} < {{ output_loop_bounds[loop.index0] }}; ++{{ var }}) {
37
+ {% endfor %}
38
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}{{ acc_type }} acc = {{ zero_literal }};
39
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}for (idx_t {{ reduce_loop_var }} = 0; {{ reduce_loop_var }} < {{ reduce_loop_bound }}; ++{{ reduce_loop_var }}) {
40
+ {% for indent in range(output_loop_vars | length + 1) %} {% endfor %}acc += {{ input0_expr }} * {{ input1_expr }};
41
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}}
42
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}{{ output_expr }} = acc;
43
+ {% for _ in output_loop_vars | reverse %}
44
+ {% for indent in range(loop.index0) %} {% endfor %}}
45
+ {% endfor %}
46
+ {% elif kind == "batch_diagonal" %}
47
+ {% for var in output_loop_vars %}
48
+ {% for indent in range(loop.index0) %} {% endfor %}for (idx_t {{ var }} = 0; {{ var }} < {{ output_loop_bounds[loop.index0] }}; ++{{ var }}) {
49
+ {% endfor %}
50
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}{{ output_expr }} = {{ input_expr }};
51
+ {% for _ in output_loop_vars | reverse %}
52
+ {% for indent in range(loop.index0) %} {% endfor %}}
53
+ {% endfor %}
54
+ {% endif %}
55
+ }
@@ -0,0 +1,14 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ const {{ c_type }} *input_data = (const {{ c_type }} *){{ input0 }};
3
+ {{ c_type }} *output_data = ({{ c_type }} *){{ output }};
4
+ idx_t output_index = 0;
5
+ {% for dim in output_shape %}
6
+ for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
7
+ {% endfor %}
8
+ idx_t input_index = {{ input_index_expr }};
9
+ output_data[output_index] = input_data[input_index];
10
+ output_index++;
11
+ {% for _ in output_shape %}
12
+ }
13
+ {% endfor %}
14
+ }
@@ -0,0 +1,27 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ (void){{ input0 }};
3
+ {{ c_type }} *output_data = ({{ c_type }} *){{ output }};
4
+ idx_t total = (idx_t){{ batch_size }} * {{ rows }} * {{ cols }};
5
+ for (idx_t index = 0; index < total; ++index) {
6
+ output_data[index] = {{ zero_literal }};
7
+ }
8
+ idx_t k = {{ k }};
9
+ idx_t rows = {{ rows }};
10
+ idx_t cols = {{ cols }};
11
+ idx_t row_start = k >= 0 ? 0 : -k;
12
+ idx_t col_start = k >= 0 ? k : 0;
13
+ if (row_start >= rows || col_start >= cols) {
14
+ return;
15
+ }
16
+ idx_t max_rows = rows - row_start;
17
+ idx_t max_cols = cols - col_start;
18
+ idx_t diag_len = max_rows < max_cols ? max_rows : max_cols;
19
+ for (idx_t batch = 0; batch < {{ batch_size }}; ++batch) {
20
+ idx_t base = batch * rows * cols;
21
+ for (idx_t diag = 0; diag < diag_len; ++diag) {
22
+ idx_t row = row_start + diag;
23
+ idx_t col = col_start + diag;
24
+ output_data[base + row * cols + col] = {{ one_literal }};
25
+ }
26
+ }
27
+ }
@@ -0,0 +1,13 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for dim in output_shape %}
3
+ for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
4
+ {% endfor %}
5
+ idx_t gather_index = {{ indices }}{% for var in loop_vars %}[{{ var }}]{% endfor %};
6
+ if (gather_index < 0) {
7
+ gather_index += {{ axis_dim }};
8
+ }
9
+ {{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = {{ data }}{% for idx in data_indices %}[{{ idx }}]{% endfor %};
10
+ {% for _ in output_shape %}
11
+ }
12
+ {% endfor %}
13
+ }
@@ -0,0 +1,29 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% if indices_prefix_shape %}
3
+ {% for dim in indices_prefix_shape %}
4
+ for (idx_t {{ indices_prefix_loop_vars[loop.index0] }} = 0; {{ indices_prefix_loop_vars[loop.index0] }} < {{ dim }}; ++{{ indices_prefix_loop_vars[loop.index0] }}) {
5
+ {% endfor %}
6
+ {% endif %}
7
+ {% for idx in range(index_depth) %}
8
+ idx_t index{{ idx }} = {{ indices }}{% for var in indices_prefix_loop_vars %}[{{ var }}]{% endfor %}[{{ idx }}];
9
+ if (index{{ idx }} < 0) {
10
+ index{{ idx }} += {{ data_shape[batch_dims + idx] }};
11
+ }
12
+ {% endfor %}
13
+ {% if tail_shape %}
14
+ {% for dim in tail_shape %}
15
+ for (idx_t {{ tail_loop_vars[loop.index0] }} = 0; {{ tail_loop_vars[loop.index0] }} < {{ dim }}; ++{{ tail_loop_vars[loop.index0] }}) {
16
+ {% endfor %}
17
+ {% endif %}
18
+ {{ output_index_expr }} = {{ data_index_expr }};
19
+ {% if tail_shape %}
20
+ {% for _ in tail_shape %}
21
+ }
22
+ {% endfor %}
23
+ {% endif %}
24
+ {% if indices_prefix_shape %}
25
+ {% for _ in indices_prefix_shape %}
26
+ }
27
+ {% endfor %}
28
+ {% endif %}
29
+ }
@@ -0,0 +1,13 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for dim in output_shape %}
3
+ for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
4
+ {% endfor %}
5
+ idx_t gather_index = {{ indices }}{% for var in indices_indices %}[{{ var }}]{% endfor %};
6
+ if (gather_index < 0) {
7
+ gather_index += {{ axis_dim }};
8
+ }
9
+ {{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = {{ data }}{% for idx in data_indices %}[{{ idx }}]{% endfor %};
10
+ {% for _ in output_shape %}
11
+ }
12
+ {% endfor %}
13
+ }
@@ -0,0 +1,35 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ for (idx_t i = 0; i < {{ m }}; ++i) {
3
+ for (idx_t j = 0; j < {{ n }}; ++j) {
4
+ {{ acc_type }} acc = {{ zero_literal }};
5
+ for (idx_t k = 0; k < {{ k }}; ++k) {
6
+ {% if trans_a %}
7
+ const {{ c_type }} a_val = {{ input_a }}[k][i];
8
+ {% else %}
9
+ const {{ c_type }} a_val = {{ input_a }}[i][k];
10
+ {% endif %}
11
+ {% if trans_b %}
12
+ const {{ c_type }} b_val = {{ input_b }}[j][k];
13
+ {% else %}
14
+ const {{ c_type }} b_val = {{ input_b }}[k][j];
15
+ {% endif %}
16
+ acc += a_val * b_val;
17
+ }
18
+ {% if input_c %}
19
+ {% if c_rank == 2 %}
20
+ idx_t c_i = {% if c_dim0 == 1 %}0{% else %}i{% endif %};
21
+ idx_t c_j = {% if c_dim1 == 1 %}0{% else %}j{% endif %};
22
+ const {{ c_type }} bias = {{ input_c }}[c_i][c_j];
23
+ {% elif c_rank == 1 %}
24
+ idx_t c_j = {% if c_dim1 == 1 %}0{% else %}j{% endif %};
25
+ const {{ c_type }} bias = {{ input_c }}[c_j];
26
+ {% else %}
27
+ const {{ c_type }} bias = {{ input_c }}[0];
28
+ {% endif %}
29
+ {{ output }}[i][j] = acc * {{ alpha_literal }} + bias * {{ beta_literal }};
30
+ {% else %}
31
+ {{ output }}[i][j] = acc * {{ alpha_literal }};
32
+ {% endif %}
33
+ }
34
+ }
35
+ }
@@ -0,0 +1,184 @@
1
+ static inline double {{ op_name }}_reflect(double value, double x_min, double x_max) {
2
+ const double range = x_max - x_min;
3
+ if (range == 0.0) {
4
+ return x_min;
5
+ }
6
+ if (value < x_min) {
7
+ const double dx = x_min - value;
8
+ const int n = (int)(dx / range);
9
+ const double r = dx - (double)n * range;
10
+ return (n % 2 == 0) ? (x_min + r) : (x_max - r);
11
+ }
12
+ if (value > x_max) {
13
+ const double dx = value - x_max;
14
+ const int n = (int)(dx / range);
15
+ const double r = dx - (double)n * range;
16
+ return (n % 2 == 0) ? (x_max - r) : (x_min + r);
17
+ }
18
+ return value;
19
+ }
20
+ {% if mode == "cubic" %}
21
+
22
+ static inline void {{ op_name }}_cubic_coeffs(double x, double coeffs[4]) {
23
+ const double alpha = -0.75;
24
+ const double abs_x = fabs(x);
25
+ const double inv_x = 1.0 - abs_x;
26
+ const double span = 2.0 - abs_x;
27
+ coeffs[0] = ((alpha * (abs_x + 1.0) - 5.0 * alpha) * (abs_x + 1.0) + 8.0 * alpha) * (abs_x + 1.0) - 4.0 * alpha;
28
+ coeffs[1] = ((alpha + 2.0) * abs_x - (alpha + 3.0)) * abs_x * abs_x + 1.0;
29
+ coeffs[2] = ((alpha + 2.0) * inv_x - (alpha + 3.0)) * inv_x * inv_x + 1.0;
30
+ coeffs[3] = ((alpha * span - 5.0 * alpha) * span + 8.0 * alpha) * span - 4.0 * alpha;
31
+ }
32
+ {% endif %}
33
+
34
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
35
+ const int input_spatial[{{ spatial_rank }}] = { {% for dim in input_spatial %}{{ dim }}{% if not loop.last %}, {% endif %}{% endfor %} };
36
+ {% if padding_mode != "zeros" %}
37
+ const double border_min[{{ spatial_rank }}] = { {% for dim in input_spatial %}{% if align_corners %}0.0{% else %}-0.5{% endif %}{% if not loop.last %}, {% endif %}{% endfor %} };
38
+ const double border_max[{{ spatial_rank }}] = { {% for dim in input_spatial %}{% if align_corners %}{{ dim - 1 }}.0{% else %}{{ dim - 0.5 }}{% endif %}{% if not loop.last %}, {% endif %}{% endfor %} };
39
+ {% endif %}
40
+ {% set n_var = output_loop_vars[0] %}
41
+ {% set c_var = output_loop_vars[1] %}
42
+ {% set spatial_vars = output_loop_vars[2:] %}
43
+ for (idx_t {{ n_var }} = 0; {{ n_var }} < {{ output_shape[0] }}; ++{{ n_var }}) {
44
+ for (idx_t {{ c_var }} = 0; {{ c_var }} < {{ output_shape[1] }}; ++{{ c_var }}) {
45
+ {% for index in range(spatial_rank) %}
46
+ for (idx_t {{ spatial_vars[index] }} = 0; {{ spatial_vars[index] }} < {{ output_spatial[index] }}; ++{{ spatial_vars[index] }}) {
47
+ {% endfor %}
48
+ double coords[{{ spatial_rank }}];
49
+ {% for dim in range(spatial_rank) %}
50
+ const double grid_{{ dim }} = (double){{ grid }}[{{ n_var }}]{% for var in spatial_vars %}[{{ var }}]{% endfor %}[{{ spatial_rank - 1 - dim }}];
51
+ {% if align_corners %}
52
+ coords[{{ dim }}] = (grid_{{ dim }} + 1.0) * (double)(input_spatial[{{ dim }}] - 1) / 2.0;
53
+ {% else %}
54
+ coords[{{ dim }}] = ((grid_{{ dim }} + 1.0) * (double)input_spatial[{{ dim }}] - 1.0) / 2.0;
55
+ {% endif %}
56
+ {% endfor %}
57
+ {% if padding_mode != "zeros" %}
58
+ {% for dim in range(spatial_rank) %}
59
+ if (coords[{{ dim }}] < border_min[{{ dim }}] || coords[{{ dim }}] > border_max[{{ dim }}]) {
60
+ {% if padding_mode == "border" %}
61
+ if (coords[{{ dim }}] < 0.0) {
62
+ coords[{{ dim }}] = 0.0;
63
+ } else if (coords[{{ dim }}] > (double)(input_spatial[{{ dim }}] - 1)) {
64
+ coords[{{ dim }}] = (double)(input_spatial[{{ dim }}] - 1);
65
+ }
66
+ {% else %}
67
+ coords[{{ dim }}] = {{ op_name }}_reflect(coords[{{ dim }}], border_min[{{ dim }}], border_max[{{ dim }}]);
68
+ {% endif %}
69
+ }
70
+ {% endfor %}
71
+ {% endif %}
72
+ {% if mode == "nearest" %}
73
+ int in_bounds = 1;
74
+ {% for dim in range(spatial_rank) %}
75
+ int idx{{ dim }} = (int)nearbyint(coords[{{ dim }}]);
76
+ {% if padding_mode == "zeros" %}
77
+ if (idx{{ dim }} < 0 || idx{{ dim }} >= input_spatial[{{ dim }}]) {
78
+ in_bounds = 0;
79
+ }
80
+ {% elif padding_mode == "border" %}
81
+ if (idx{{ dim }} < 0) {
82
+ idx{{ dim }} = 0;
83
+ } else if (idx{{ dim }} >= input_spatial[{{ dim }}]) {
84
+ idx{{ dim }} = input_spatial[{{ dim }}] - 1;
85
+ }
86
+ {% else %}
87
+ idx{{ dim }} = (int){{ op_name }}_reflect((double)idx{{ dim }}, border_min[{{ dim }}], border_max[{{ dim }}]);
88
+ {% endif %}
89
+ {% endfor %}
90
+ if (!in_bounds) {
91
+ {{ output }}[{{ n_var }}][{{ c_var }}]{% for var in spatial_vars %}[{{ var }}]{% endfor %} = ({{ c_type }})0;
92
+ } else {
93
+ {{ output }}[{{ n_var }}][{{ c_var }}]{% for var in spatial_vars %}[{{ var }}]{% endfor %} = {{ input0 }}[{{ n_var }}][{{ c_var }}]{% for dim in range(spatial_rank) %}[idx{{ dim }}]{% endfor %};
94
+ }
95
+ {% elif mode == "linear" %}
96
+ int base[{{ spatial_rank }}];
97
+ int upper[{{ spatial_rank }}];
98
+ double w0[{{ spatial_rank }}];
99
+ double w1[{{ spatial_rank }}];
100
+ {% for dim in range(spatial_rank) %}
101
+ base[{{ dim }}] = (int)floor(coords[{{ dim }}]);
102
+ upper[{{ dim }}] = base[{{ dim }}] + 1;
103
+ w1[{{ dim }}] = coords[{{ dim }}] - (double)base[{{ dim }}];
104
+ w0[{{ dim }}] = 1.0 - w1[{{ dim }}];
105
+ {% endfor %}
106
+ double acc = 0.0;
107
+ {% for offsets in linear_offsets %}
108
+ {
109
+ double weight = 1.0;
110
+ {% for dim in range(spatial_rank) %}
111
+ int idx{{ dim }} = {% if offsets[dim] == 0 %}base[{{ dim }}]{% else %}upper[{{ dim }}]{% endif %};
112
+ weight *= {% if offsets[dim] == 0 %}w0[{{ dim }}]{% else %}w1[{{ dim }}]{% endif %};
113
+ {% endfor %}
114
+ int in_bounds = 1;
115
+ {% for dim in range(spatial_rank) %}
116
+ {% if padding_mode == "zeros" %}
117
+ if (idx{{ dim }} < 0 || idx{{ dim }} >= input_spatial[{{ dim }}]) {
118
+ in_bounds = 0;
119
+ }
120
+ {% elif padding_mode == "border" %}
121
+ if (idx{{ dim }} < 0) {
122
+ idx{{ dim }} = 0;
123
+ } else if (idx{{ dim }} >= input_spatial[{{ dim }}]) {
124
+ idx{{ dim }} = input_spatial[{{ dim }}] - 1;
125
+ }
126
+ {% else %}
127
+ idx{{ dim }} = (int){{ op_name }}_reflect((double)idx{{ dim }}, border_min[{{ dim }}], border_max[{{ dim }}]);
128
+ {% endif %}
129
+ {% endfor %}
130
+ if (in_bounds) {
131
+ acc += weight * (double){{ input0 }}[{{ n_var }}][{{ c_var }}]{% for dim in range(spatial_rank) %}[idx{{ dim }}]{% endfor %};
132
+ }
133
+ }
134
+ {% endfor %}
135
+ {{ output }}[{{ n_var }}][{{ c_var }}]{% for var in spatial_vars %}[{{ var }}]{% endfor %} = ({{ c_type }})acc;
136
+ {% else %}
137
+ int base[{{ spatial_rank }}];
138
+ int idxs[{{ spatial_rank }}][4];
139
+ double coeffs[{{ spatial_rank }}][4];
140
+ {% for dim in range(spatial_rank) %}
141
+ base[{{ dim }}] = (int)floor(coords[{{ dim }}]);
142
+ idxs[{{ dim }}][0] = base[{{ dim }}] - 1;
143
+ idxs[{{ dim }}][1] = base[{{ dim }}];
144
+ idxs[{{ dim }}][2] = base[{{ dim }}] + 1;
145
+ idxs[{{ dim }}][3] = base[{{ dim }}] + 2;
146
+ {{ op_name }}_cubic_coeffs(coords[{{ dim }}] - (double)base[{{ dim }}], coeffs[{{ dim }}]);
147
+ {% endfor %}
148
+ double acc = 0.0;
149
+ {% for offsets in cubic_offsets %}
150
+ {
151
+ double weight = 1.0;
152
+ {% for dim in range(spatial_rank) %}
153
+ int idx{{ dim }} = idxs[{{ dim }}][{{ offsets[dim] }}];
154
+ weight *= coeffs[{{ dim }}][{{ offsets[dim] }}];
155
+ {% endfor %}
156
+ int in_bounds = 1;
157
+ {% for dim in range(spatial_rank) %}
158
+ {% if padding_mode == "zeros" %}
159
+ if (idx{{ dim }} < 0 || idx{{ dim }} >= input_spatial[{{ dim }}]) {
160
+ in_bounds = 0;
161
+ }
162
+ {% elif padding_mode == "border" %}
163
+ if (idx{{ dim }} < 0) {
164
+ idx{{ dim }} = 0;
165
+ } else if (idx{{ dim }} >= input_spatial[{{ dim }}]) {
166
+ idx{{ dim }} = input_spatial[{{ dim }}] - 1;
167
+ }
168
+ {% else %}
169
+ idx{{ dim }} = (int){{ op_name }}_reflect((double)idx{{ dim }}, border_min[{{ dim }}], border_max[{{ dim }}]);
170
+ {% endif %}
171
+ {% endfor %}
172
+ if (in_bounds) {
173
+ acc += weight * (double){{ input0 }}[{{ n_var }}][{{ c_var }}]{% for dim in range(spatial_rank) %}[idx{{ dim }}]{% endfor %};
174
+ }
175
+ }
176
+ {% endfor %}
177
+ {{ output }}[{{ n_var }}][{{ c_var }}]{% for var in spatial_vars %}[{{ var }}]{% endfor %} = ({{ c_type }})acc;
178
+ {% endif %}
179
+ {% for index in range(spatial_rank) %}
180
+ }
181
+ {% endfor %}
182
+ }
183
+ }
184
+ }
@@ -0,0 +1,46 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for dim in shape[:1] %}
3
+ for (idx_t {{ loop_vars[0] }} = 0; {{ loop_vars[0] }} < {{ dim }}; ++{{ loop_vars[0] }}) {
4
+ {% endfor %}
5
+ for (idx_t g = 0; g < {{ num_groups }}; ++g) {
6
+ {{ c_type }} sum = {{ zero_literal }};
7
+ for (idx_t c_in_group = 0; c_in_group < {{ group_size }}; ++c_in_group) {
8
+ idx_t c = g * {{ group_size }} + c_in_group;
9
+ {% for dim in shape[2:] %}
10
+ for (idx_t {{ loop_vars[loop.index0 + 2] }} = 0; {{ loop_vars[loop.index0 + 2] }} < {{ dim }}; ++{{ loop_vars[loop.index0 + 2] }}) {
11
+ {% endfor %}
12
+ sum += {{ input0 }}[{{ loop_vars[0] }}][c]{% for var in loop_vars[2:] %}[{{ var }}]{% endfor %};
13
+ {% for _ in shape[2:] %}
14
+ }
15
+ {% endfor %}
16
+ }
17
+ {{ c_type }} mean = sum / ({{ group_size }} * {{ spatial_size }});
18
+ {{ c_type }} var = {{ zero_literal }};
19
+ for (idx_t c_in_group = 0; c_in_group < {{ group_size }}; ++c_in_group) {
20
+ idx_t c = g * {{ group_size }} + c_in_group;
21
+ {% for dim in shape[2:] %}
22
+ for (idx_t {{ loop_vars[loop.index0 + 2] }} = 0; {{ loop_vars[loop.index0 + 2] }} < {{ dim }}; ++{{ loop_vars[loop.index0 + 2] }}) {
23
+ {% endfor %}
24
+ {{ c_type }} diff = {{ input0 }}[{{ loop_vars[0] }}][c]{% for var in loop_vars[2:] %}[{{ var }}]{% endfor %} - mean;
25
+ var += diff * diff;
26
+ {% for _ in shape[2:] %}
27
+ }
28
+ {% endfor %}
29
+ }
30
+ {{ c_type }} denom = {{ sqrt_fn }}(var / ({{ group_size }} * {{ spatial_size }}) + {{ epsilon_literal }});
31
+ for (idx_t c_in_group = 0; c_in_group < {{ group_size }}; ++c_in_group) {
32
+ idx_t c = g * {{ group_size }} + c_in_group;
33
+ {% for dim in shape[2:] %}
34
+ for (idx_t {{ loop_vars[loop.index0 + 2] }} = 0; {{ loop_vars[loop.index0 + 2] }} < {{ dim }}; ++{{ loop_vars[loop.index0 + 2] }}) {
35
+ {% endfor %}
36
+ {{ output }}[{{ loop_vars[0] }}][c]{% for var in loop_vars[2:] %}[{{ var }}]{% endfor %} =
37
+ ({{ input0 }}[{{ loop_vars[0] }}][c]{% for var in loop_vars[2:] %}[{{ var }}]{% endfor %} - mean) / denom * {{ scale }}[c] + {{ bias }}[c];
38
+ {% for _ in shape[2:] %}
39
+ }
40
+ {% endfor %}
41
+ }
42
+ }
43
+ {% for _ in shape[:1] %}
44
+ }
45
+ {% endfor %}
46
+ }
@@ -0,0 +1,152 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for dir in range(num_directions) %}
3
+ {
4
+ const int dir = {{ dir }};
5
+ const int reverse = {% if direction == "reverse" %}1{% elif direction == "bidirectional" and dir == 1 %}1{% else %}0{% endif %};
6
+ {% set act_f = activation_functions[dir * 2] %}
7
+ {% set act_g = activation_functions[dir * 2 + 1] %}
8
+ {{ c_type }} H_prev[{{ batch_size }}][{{ hidden_size }}];
9
+ for (int b = 0; b < {{ batch_size }}; ++b) {
10
+ for (int h = 0; h < {{ hidden_size }}; ++h) {
11
+ {% if input_initial_h %}
12
+ {% if layout == 0 %}
13
+ H_prev[b][h] = {{ input_initial_h }}[dir][b][h];
14
+ {% else %}
15
+ H_prev[b][h] = {{ input_initial_h }}[b][dir][h];
16
+ {% endif %}
17
+ {% else %}
18
+ H_prev[b][h] = {{ zero_literal }};
19
+ {% endif %}
20
+ }
21
+ }
22
+ for (int b = 0; b < {{ batch_size }}; ++b) {
23
+ int seq_limit = {{ seq_length }};
24
+ {% if input_sequence_lens %}
25
+ seq_limit = (int){{ input_sequence_lens }}[b];
26
+ if (seq_limit < 0) {
27
+ seq_limit = 0;
28
+ }
29
+ if (seq_limit > {{ seq_length }}) {
30
+ seq_limit = {{ seq_length }};
31
+ }
32
+ {% endif %}
33
+ for (int step = 0; step < seq_limit; ++step) {
34
+ int t = reverse ? (seq_limit - 1 - step) : step;
35
+ {{ c_type }} Z_gate[{{ hidden_size }}];
36
+ {{ c_type }} R_gate[{{ hidden_size }}];
37
+ {{ c_type }} H_next[{{ hidden_size }}];
38
+ for (int h = 0; h < {{ hidden_size }}; ++h) {
39
+ {{ c_type }} gate_z = {{ zero_literal }};
40
+ {{ c_type }} gate_r = {{ zero_literal }};
41
+ for (int i = 0; i < {{ input_size }}; ++i) {
42
+ {% if layout == 0 %}
43
+ {{ c_type }} x_val = {{ input_x }}[t][b][i];
44
+ {% else %}
45
+ {{ c_type }} x_val = {{ input_x }}[b][t][i];
46
+ {% endif %}
47
+ gate_z += x_val * {{ input_w }}[dir][h][i];
48
+ gate_r += x_val * {{ input_w }}[dir][{{ hidden_size }} + h][i];
49
+ }
50
+ for (int i = 0; i < {{ hidden_size }}; ++i) {
51
+ {{ c_type }} h_val = H_prev[b][i];
52
+ gate_z += h_val * {{ input_r }}[dir][h][i];
53
+ gate_r += h_val * {{ input_r }}[dir][{{ hidden_size }} + h][i];
54
+ }
55
+ {% if input_b %}
56
+ gate_z += {{ input_b }}[dir][h] + {{ input_b }}[dir][{{ hidden_size }} * 3 + h];
57
+ gate_r += {{ input_b }}[dir][{{ hidden_size }} + h] + {{ input_b }}[dir][{{ hidden_size }} * 4 + h];
58
+ {% endif %}
59
+ {% if use_clip %}
60
+ if (gate_z > {{ clip_literal }}) {
61
+ gate_z = {{ clip_literal }};
62
+ } else if (gate_z < -{{ clip_literal }}) {
63
+ gate_z = -{{ clip_literal }};
64
+ }
65
+ if (gate_r > {{ clip_literal }}) {
66
+ gate_r = {{ clip_literal }};
67
+ } else if (gate_r < -{{ clip_literal }}) {
68
+ gate_r = -{{ clip_literal }};
69
+ }
70
+ {% endif %}
71
+ Z_gate[h] = {{ act_f }}(gate_z);
72
+ R_gate[h] = {{ act_f }}(gate_r);
73
+ }
74
+ for (int h = 0; h < {{ hidden_size }}; ++h) {
75
+ {{ c_type }} gate_h = {{ zero_literal }};
76
+ for (int i = 0; i < {{ input_size }}; ++i) {
77
+ {% if layout == 0 %}
78
+ {{ c_type }} x_val = {{ input_x }}[t][b][i];
79
+ {% else %}
80
+ {{ c_type }} x_val = {{ input_x }}[b][t][i];
81
+ {% endif %}
82
+ gate_h += x_val * {{ input_w }}[dir][{{ hidden_size }} * 2 + h][i];
83
+ }
84
+ {% if linear_before_reset %}
85
+ {{ c_type }} recur = {{ zero_literal }};
86
+ for (int i = 0; i < {{ hidden_size }}; ++i) {
87
+ recur += H_prev[b][i] * {{ input_r }}[dir][{{ hidden_size }} * 2 + h][i];
88
+ }
89
+ {% if input_b %}
90
+ recur += {{ input_b }}[dir][{{ hidden_size }} * 5 + h];
91
+ {% endif %}
92
+ gate_h += R_gate[h] * recur;
93
+ {% if input_b %}
94
+ gate_h += {{ input_b }}[dir][{{ hidden_size }} * 2 + h];
95
+ {% endif %}
96
+ {% else %}
97
+ for (int i = 0; i < {{ hidden_size }}; ++i) {
98
+ gate_h += (R_gate[i] * H_prev[b][i]) * {{ input_r }}[dir][{{ hidden_size }} * 2 + h][i];
99
+ }
100
+ {% if input_b %}
101
+ gate_h += {{ input_b }}[dir][{{ hidden_size }} * 2 + h] + {{ input_b }}[dir][{{ hidden_size }} * 5 + h];
102
+ {% endif %}
103
+ {% endif %}
104
+ {% if use_clip %}
105
+ if (gate_h > {{ clip_literal }}) {
106
+ gate_h = {{ clip_literal }};
107
+ } else if (gate_h < -{{ clip_literal }}) {
108
+ gate_h = -{{ clip_literal }};
109
+ }
110
+ {% endif %}
111
+ {{ c_type }} h_candidate = {{ act_g }}(gate_h);
112
+ {{ c_type }} h_prev = H_prev[b][h];
113
+ {{ c_type }} h_new = ({{ one_literal }} - Z_gate[h]) * h_candidate + Z_gate[h] * h_prev;
114
+ H_next[h] = h_new;
115
+ {% if output_y %}
116
+ {% if layout == 0 %}
117
+ {{ output_y }}[step][dir][b][h] = h_new;
118
+ {% else %}
119
+ {{ output_y }}[b][step][dir][h] = h_new;
120
+ {% endif %}
121
+ {% endif %}
122
+ }
123
+ for (int h = 0; h < {{ hidden_size }}; ++h) {
124
+ H_prev[b][h] = H_next[h];
125
+ }
126
+ }
127
+ {% if output_y %}
128
+ for (int step = seq_limit; step < {{ seq_length }}; ++step) {
129
+ for (int h = 0; h < {{ hidden_size }}; ++h) {
130
+ {% if layout == 0 %}
131
+ {{ output_y }}[step][dir][b][h] = {{ zero_literal }};
132
+ {% else %}
133
+ {{ output_y }}[b][step][dir][h] = {{ zero_literal }};
134
+ {% endif %}
135
+ }
136
+ }
137
+ {% endif %}
138
+ }
139
+ {% if output_y_h %}
140
+ for (int b = 0; b < {{ batch_size }}; ++b) {
141
+ for (int h = 0; h < {{ hidden_size }}; ++h) {
142
+ {% if layout == 0 %}
143
+ {{ output_y_h }}[dir][b][h] = H_prev[b][h];
144
+ {% else %}
145
+ {{ output_y_h }}[b][dir][h] = H_prev[b][h];
146
+ {% endif %}
147
+ }
148
+ }
149
+ {% endif %}
150
+ }
151
+ {% endfor %}
152
+ }
@@ -0,0 +1,12 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ const double size_value = (double){{ size }}[0];
3
+ const double denom = {{ periodic_literal }} ? size_value : (size_value - 1.0);
4
+ const double alpha = 25.0 / 46.0;
5
+ const double beta = 1.0 - alpha;
6
+ const double pi = 3.14159265358979323846;
7
+ for (idx_t idx = 0; idx < {{ length }}; ++idx) {
8
+ const double phase = (2.0 * pi * (double)idx) / denom;
9
+ const double value = alpha - beta * cos(phase);
10
+ {{ output }}[idx] = ({{ c_type }})value;
11
+ }
12
+ }
@@ -0,0 +1,24 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ const {{ c_type }} *input_flat = (const {{ c_type }} *){{ input0 }};
3
+ {{ c_type }} *output_flat = ({{ c_type }} *){{ output }};
4
+ const idx_t outer = {{ outer }};
5
+ const idx_t axis_size = {{ axis_size }};
6
+ const idx_t inner = {{ inner }};
7
+ for (idx_t outer_idx = 0; outer_idx < outer; ++outer_idx) {
8
+ for (idx_t inner_idx = 0; inner_idx < inner; ++inner_idx) {
9
+ idx_t base = (outer_idx * axis_size * inner) + inner_idx;
10
+ {{ c_type }} max_value = input_flat[base];
11
+ idx_t max_index = 0;
12
+ for (idx_t axis_idx = 1; axis_idx < axis_size; ++axis_idx) {
13
+ {{ c_type }} value = input_flat[base + axis_idx * inner];
14
+ const {{ c_type }} prev_max = max_value;
15
+ max_value = {{ max_fn }}(max_value, value);
16
+ max_index = (value > prev_max) ? axis_idx : max_index;
17
+ }
18
+ for (idx_t axis_idx = 0; axis_idx < axis_size; ++axis_idx) {
19
+ output_flat[base + axis_idx * inner] =
20
+ axis_idx == max_index ? {{ one_literal }} : {{ zero_literal }};
21
+ }
22
+ }
23
+ }
24
+ }
@@ -0,0 +1,9 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for dim in shape %}
3
+ for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
4
+ {% endfor %}
5
+ {{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = {{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %};
6
+ {% for _ in shape %}
7
+ }
8
+ {% endfor %}
9
+ }