emx-onnx-cgen 0.3.7__py3-none-any.whl → 0.4.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +1025 -162
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +2081 -458
  6. emx_onnx_cgen/compiler.py +157 -75
  7. emx_onnx_cgen/determinism.py +39 -0
  8. emx_onnx_cgen/ir/context.py +25 -15
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +32 -7
  11. emx_onnx_cgen/ir/ops/__init__.py +20 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +138 -22
  13. emx_onnx_cgen/ir/ops/misc.py +95 -0
  14. emx_onnx_cgen/ir/ops/nn.py +361 -38
  15. emx_onnx_cgen/ir/ops/reduce.py +1 -16
  16. emx_onnx_cgen/lowering/__init__.py +9 -0
  17. emx_onnx_cgen/lowering/arg_reduce.py +0 -4
  18. emx_onnx_cgen/lowering/average_pool.py +157 -27
  19. emx_onnx_cgen/lowering/bernoulli.py +73 -0
  20. emx_onnx_cgen/lowering/common.py +48 -0
  21. emx_onnx_cgen/lowering/concat.py +41 -7
  22. emx_onnx_cgen/lowering/conv.py +19 -8
  23. emx_onnx_cgen/lowering/conv_integer.py +103 -0
  24. emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
  25. emx_onnx_cgen/lowering/elementwise.py +140 -43
  26. emx_onnx_cgen/lowering/gather.py +11 -2
  27. emx_onnx_cgen/lowering/gemm.py +7 -124
  28. emx_onnx_cgen/lowering/global_max_pool.py +0 -5
  29. emx_onnx_cgen/lowering/gru.py +323 -0
  30. emx_onnx_cgen/lowering/hamming_window.py +104 -0
  31. emx_onnx_cgen/lowering/hardmax.py +1 -37
  32. emx_onnx_cgen/lowering/identity.py +7 -6
  33. emx_onnx_cgen/lowering/logsoftmax.py +1 -35
  34. emx_onnx_cgen/lowering/lp_pool.py +15 -4
  35. emx_onnx_cgen/lowering/matmul.py +3 -105
  36. emx_onnx_cgen/lowering/optional_has_element.py +28 -0
  37. emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
  38. emx_onnx_cgen/lowering/reduce.py +0 -5
  39. emx_onnx_cgen/lowering/reshape.py +7 -16
  40. emx_onnx_cgen/lowering/shape.py +14 -8
  41. emx_onnx_cgen/lowering/slice.py +14 -4
  42. emx_onnx_cgen/lowering/softmax.py +1 -35
  43. emx_onnx_cgen/lowering/split.py +37 -3
  44. emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
  45. emx_onnx_cgen/lowering/tile.py +38 -1
  46. emx_onnx_cgen/lowering/topk.py +1 -5
  47. emx_onnx_cgen/lowering/transpose.py +9 -3
  48. emx_onnx_cgen/lowering/unsqueeze.py +11 -16
  49. emx_onnx_cgen/lowering/upsample.py +151 -0
  50. emx_onnx_cgen/lowering/variadic.py +1 -1
  51. emx_onnx_cgen/lowering/where.py +0 -5
  52. emx_onnx_cgen/onnx_import.py +578 -14
  53. emx_onnx_cgen/ops.py +3 -0
  54. emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
  55. emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
  56. emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
  57. emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
  58. emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
  59. emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
  60. emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
  61. emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
  62. emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
  63. emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
  64. emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
  65. emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
  66. emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
  67. emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
  68. emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
  69. emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
  70. emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
  71. emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
  72. emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
  73. emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
  74. emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
  75. emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
  76. emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
  77. emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
  78. emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
  79. emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
  80. emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
  81. emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
  82. emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
  83. emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
  84. emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
  85. emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
  86. emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
  87. emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
  88. emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
  89. emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
  90. emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
  91. emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
  92. emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
  93. emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
  94. emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
  95. emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
  96. emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
  97. emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
  98. emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
  99. emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
  100. emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
  101. emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
  102. emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
  103. emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
  104. emx_onnx_cgen/templates/range_op.c.j2 +8 -0
  105. emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
  106. emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
  107. emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
  108. emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
  109. emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
  110. emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
  111. emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
  112. emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
  113. emx_onnx_cgen/templates/size_op.c.j2 +4 -0
  114. emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
  115. emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
  116. emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
  117. emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
  118. emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
  119. emx_onnx_cgen/templates/split_op.c.j2 +18 -0
  120. emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
  121. emx_onnx_cgen/templates/testbench.c.j2 +161 -0
  122. emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
  123. emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
  124. emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
  125. emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
  126. emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
  127. emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
  128. emx_onnx_cgen/templates/where_op.c.j2 +9 -0
  129. emx_onnx_cgen/verification.py +45 -5
  130. {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/METADATA +33 -15
  131. emx_onnx_cgen-0.4.1.dev0.dist-info/RECORD +190 -0
  132. {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/WHEEL +1 -1
  133. emx_onnx_cgen/runtime/__init__.py +0 -1
  134. emx_onnx_cgen/runtime/evaluator.py +0 -2955
  135. emx_onnx_cgen-0.3.7.dist-info/RECORD +0 -107
  136. {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/entry_points.txt +0 -0
  137. {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,35 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for dim in shape[:2] %}
3
+ for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
4
+ {% endfor %}
5
+ {{ c_type }} sum = {{ zero_literal }};
6
+ {% for dim in shape[2:] %}
7
+ for (idx_t {{ loop_vars[loop.index0 + 2] }} = 0; {{ loop_vars[loop.index0 + 2] }} < {{ dim }}; ++{{ loop_vars[loop.index0 + 2] }}) {
8
+ {% endfor %}
9
+ sum += {{ input0 }}[{{ loop_vars[0] }}][{{ loop_vars[1] }}]{% for var in loop_vars[2:] %}[{{ var }}]{% endfor %};
10
+ {% for _ in shape[2:] %}
11
+ }
12
+ {% endfor %}
13
+ {{ c_type }} mean = sum / {{ spatial_size }};
14
+ {{ c_type }} var = {{ zero_literal }};
15
+ {% for dim in shape[2:] %}
16
+ for (idx_t {{ loop_vars[loop.index0 + 2] }} = 0; {{ loop_vars[loop.index0 + 2] }} < {{ dim }}; ++{{ loop_vars[loop.index0 + 2] }}) {
17
+ {% endfor %}
18
+ {{ c_type }} diff = {{ input0 }}[{{ loop_vars[0] }}][{{ loop_vars[1] }}]{% for var in loop_vars[2:] %}[{{ var }}]{% endfor %} - mean;
19
+ var += diff * diff;
20
+ {% for _ in shape[2:] %}
21
+ }
22
+ {% endfor %}
23
+ {{ c_type }} denom = {{ sqrt_fn }}(var / {{ spatial_size }} + {{ epsilon_literal }});
24
+ {% for dim in shape[2:] %}
25
+ for (idx_t {{ loop_vars[loop.index0 + 2] }} = 0; {{ loop_vars[loop.index0 + 2] }} < {{ dim }}; ++{{ loop_vars[loop.index0 + 2] }}) {
26
+ {% endfor %}
27
+ {{ output }}[{{ loop_vars[0] }}][{{ loop_vars[1] }}]{% for var in loop_vars[2:] %}[{{ var }}]{% endfor %} =
28
+ ({{ input0 }}[{{ loop_vars[0] }}][{{ loop_vars[1] }}]{% for var in loop_vars[2:] %}[{{ var }}]{% endfor %} - mean) / denom * {{ scale }}[{{ loop_vars[1] }}] + {{ bias }}[{{ loop_vars[1] }}];
29
+ {% for _ in shape[2:] %}
30
+ }
31
+ {% endfor %}
32
+ {% for _ in shape[:2] %}
33
+ }
34
+ {% endfor %}
35
+ }
@@ -0,0 +1,65 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for dim in prefix_shape %}
3
+ for (idx_t {{ prefix_loop_vars[loop.index0] }} = 0; {{ prefix_loop_vars[loop.index0] }} < {{ dim }}; ++{{ prefix_loop_vars[loop.index0] }}) {
4
+ {% endfor %}
5
+ {{ acc_type }} sum = {{ acc_zero_literal }};
6
+ {% if use_kahan %}
7
+ {{ acc_type }} sum_comp = {{ acc_zero_literal }};
8
+ {% endif %}
9
+ {% for dim in norm_shape %}
10
+ for (idx_t {{ norm_loop_vars[loop.index0] }} = 0; {{ norm_loop_vars[loop.index0] }} < {{ dim }}; ++{{ norm_loop_vars[loop.index0] }}) {
11
+ {% endfor %}
12
+ {% if use_kahan %}
13
+ {{ acc_type }} kahan_value = ({{ acc_type }}){{ input0 }}{% for var in prefix_loop_vars %}[{{ var }}]{% endfor %}{% for var in norm_loop_vars %}[{{ var }}]{% endfor %};
14
+ {{ acc_type }} kahan_y = kahan_value - sum_comp;
15
+ {{ acc_type }} kahan_t = sum + kahan_y;
16
+ sum_comp = (kahan_t - sum) - kahan_y;
17
+ sum = kahan_t;
18
+ {% else %}
19
+ sum += ({{ acc_type }}){{ input0 }}{% for var in prefix_loop_vars %}[{{ var }}]{% endfor %}{% for var in norm_loop_vars %}[{{ var }}]{% endfor %};
20
+ {% endif %}
21
+ {% for _ in norm_shape %}
22
+ }
23
+ {% endfor %}
24
+ {{ acc_type }} mean = sum / {{ inner }};
25
+ {{ acc_type }} var = {{ acc_zero_literal }};
26
+ {% if use_kahan %}
27
+ {{ acc_type }} var_comp = {{ acc_zero_literal }};
28
+ {% endif %}
29
+ {% for dim in norm_shape %}
30
+ for (idx_t {{ norm_loop_vars[loop.index0] }} = 0; {{ norm_loop_vars[loop.index0] }} < {{ dim }}; ++{{ norm_loop_vars[loop.index0] }}) {
31
+ {% endfor %}
32
+ {{ acc_type }} diff = ({{ acc_type }}){{ input0 }}{% for var in prefix_loop_vars %}[{{ var }}]{% endfor %}{% for var in norm_loop_vars %}[{{ var }}]{% endfor %} - mean;
33
+ {% if use_kahan %}
34
+ {{ acc_type }} kahan_value = diff * diff;
35
+ {{ acc_type }} kahan_y = kahan_value - var_comp;
36
+ {{ acc_type }} kahan_t = var + kahan_y;
37
+ var_comp = (kahan_t - var) - kahan_y;
38
+ var = kahan_t;
39
+ {% else %}
40
+ var += diff * diff;
41
+ {% endif %}
42
+ {% for _ in norm_shape %}
43
+ }
44
+ {% endfor %}
45
+ var = var / {{ inner }};
46
+ {{ acc_type }} inv_std = {{ acc_one_literal }} / {{ acc_sqrt_fn }}(var + {{ acc_epsilon_literal }});
47
+ {% if mean_output %}
48
+ {{ mean_output }}{% for var in mean_index_vars %}[{{ var }}]{% endfor %} = mean;
49
+ {% endif %}
50
+ {% if invstd_output %}
51
+ {{ invstd_output }}{% for var in mean_index_vars %}[{{ var }}]{% endfor %} = inv_std;
52
+ {% endif %}
53
+ {% for dim in norm_shape %}
54
+ for (idx_t {{ norm_loop_vars[loop.index0] }} = 0; {{ norm_loop_vars[loop.index0] }} < {{ dim }}; ++{{ norm_loop_vars[loop.index0] }}) {
55
+ {% endfor %}
56
+ {{ acc_type }} value = (({{ acc_type }}){{ input0 }}{% for var in prefix_loop_vars %}[{{ var }}]{% endfor %}{% for var in norm_loop_vars %}[{{ var }}]{% endfor %} - mean) * inv_std;
57
+ value = value * {{ scale }}{% for var in scale_index_vars %}[{{ var }}]{% endfor %}{% if bias %} + {{ bias }}{% for var in bias_index_vars %}[{{ var }}]{% endfor %}{% endif %};
58
+ {{ output }}{% for var in prefix_loop_vars %}[{{ var }}]{% endfor %}{% for var in norm_loop_vars %}[{{ var }}]{% endfor %} = value;
59
+ {% for _ in norm_shape %}
60
+ }
61
+ {% endfor %}
62
+ {% for _ in prefix_shape %}
63
+ }
64
+ {% endfor %}
65
+ }
@@ -0,0 +1,27 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ const {{ c_type }} *input_flat = (const {{ c_type }} *){{ input0 }};
3
+ {{ c_type }} *output_flat = ({{ c_type }} *){{ output }};
4
+ const idx_t outer = {{ outer }};
5
+ const idx_t axis_size = {{ axis_size }};
6
+ const idx_t inner = {{ inner }};
7
+ for (idx_t outer_idx = 0; outer_idx < outer; ++outer_idx) {
8
+ for (idx_t inner_idx = 0; inner_idx < inner; ++inner_idx) {
9
+ idx_t base = (outer_idx * axis_size * inner) + inner_idx;
10
+ {{ c_type }} max_value = input_flat[base];
11
+ for (idx_t axis_idx = 1; axis_idx < axis_size; ++axis_idx) {
12
+ {{ c_type }} value = input_flat[base + axis_idx * inner];
13
+ max_value = {{ max_fn }}(max_value, value);
14
+ }
15
+ {{ c_type }} sum = 0;
16
+ for (idx_t axis_idx = 0; axis_idx < axis_size; ++axis_idx) {
17
+ {{ c_type }} value = {{ exp_fn }}(input_flat[base + axis_idx * inner] - max_value);
18
+ sum += value;
19
+ }
20
+ {{ c_type }} logsum = {{ log_fn }}(sum);
21
+ for (idx_t axis_idx = 0; axis_idx < axis_size; ++axis_idx) {
22
+ {{ c_type }} value = input_flat[base + axis_idx * inner] - max_value;
23
+ output_flat[base + axis_idx * inner] = value - logsum;
24
+ }
25
+ }
26
+ }
27
+ }
@@ -0,0 +1,27 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ const {{ c_type }} *input_flat = (const {{ c_type }} *){{ input0 }};
3
+ {{ c_type }} *output_flat = ({{ c_type }} *){{ output }};
4
+ const idx_t outer = {{ outer }};
5
+ const idx_t axis_size = {{ axis_size }};
6
+ const idx_t inner = {{ inner }};
7
+ for (idx_t outer_idx = 0; outer_idx < outer; ++outer_idx) {
8
+ for (idx_t inner_idx = 0; inner_idx < inner; ++inner_idx) {
9
+ idx_t base = (outer_idx * axis_size * inner) + inner_idx;
10
+ {{ c_type }} acc = {{ zero_literal }};
11
+ for (idx_t axis_idx = 0; axis_idx < axis_size; ++axis_idx) {
12
+ {{ c_type }} value = input_flat[base + axis_idx * inner];
13
+ {% if p == 1 %}
14
+ acc += {{ abs_fn }}(value);
15
+ {% else %}
16
+ acc += value * value;
17
+ {% endif %}
18
+ }
19
+ {% if p == 2 %}
20
+ acc = {{ sqrt_fn }}(acc);
21
+ {% endif %}
22
+ for (idx_t axis_idx = 0; axis_idx < axis_size; ++axis_idx) {
23
+ output_flat[base + axis_idx * inner] = input_flat[base + axis_idx * inner] / acc;
24
+ }
25
+ }
26
+ }
27
+ }
@@ -0,0 +1,24 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ for (idx_t n = 0; n < {{ batch }}; ++n) {
3
+ for (idx_t c = 0; c < {{ channels }}; ++c) {
4
+ for (idx_t oh = 0; oh < {{ out_h }}; ++oh) {
5
+ for (idx_t ow = 0; ow < {{ out_w }}; ++ow) {
6
+ {{ c_type }} acc = {{ zero_literal }};
7
+ const idx_t h_start = oh * {{ stride_h }} - {{ pad_top }};
8
+ const idx_t w_start = ow * {{ stride_w }} - {{ pad_left }};
9
+ for (idx_t kh = 0; kh < {{ kernel_h }}; ++kh) {
10
+ for (idx_t kw = 0; kw < {{ kernel_w }}; ++kw) {
11
+ const idx_t in_h = h_start + kh * {{ dilation_h }};
12
+ const idx_t in_w = w_start + kw * {{ dilation_w }};
13
+ if (in_h >= 0 && in_h < {{ in_h }} && in_w >= 0 && in_w < {{ in_w }}) {
14
+ {{ c_type }} value = {{ input0 }}[n][c][in_h][in_w];
15
+ acc += {{ pow_fn }}({{ abs_fn }}(value), ({{ c_type }}){{ p }});
16
+ }
17
+ }
18
+ }
19
+ {{ output }}[n][c][oh][ow] = {{ pow_fn }}(acc, ({{ c_type }})1.0 / ({{ c_type }}){{ p }});
20
+ }
21
+ }
22
+ }
23
+ }
24
+ }
@@ -0,0 +1,20 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for dim in shape %}
3
+ for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
4
+ {% endfor %}
5
+ idx_t channel_start = {{ loop_vars[1] }} > {{ half }} ? {{ loop_vars[1] }} - {{ half }} : 0;
6
+ idx_t channel_end = {{ loop_vars[1] }} + {{ half }};
7
+ if (channel_end >= {{ channels }}) {
8
+ channel_end = {{ channels }} - 1;
9
+ }
10
+ {{ c_type }} sum = {{ zero_literal }};
11
+ for (idx_t c = channel_start; c <= channel_end; ++c) {
12
+ {{ c_type }} val = {{ input0 }}[{{ loop_vars[0] }}][c]{% for var in loop_vars[2:] %}[{{ var }}]{% endfor %};
13
+ sum += val * val;
14
+ }
15
+ {{ c_type }} scale = {{ bias_literal }} + {{ alpha_div_size_literal }} * sum;
16
+ {{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = {{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %} / {{ pow_fn }}(scale, {{ beta_literal }});
17
+ {% for _ in shape %}
18
+ }
19
+ {% endfor %}
20
+ }
@@ -0,0 +1,175 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for dir in range(num_directions) %}
3
+ {
4
+ const int dir = {{ dir }};
5
+ const int reverse = {% if direction == "reverse" %}1{% elif direction == "bidirectional" and dir == 1 %}1{% else %}0{% endif %};
6
+ {% set act_f = activation_functions[dir * 3] %}
7
+ {% set act_g = activation_functions[dir * 3 + 1] %}
8
+ {% set act_h = activation_functions[dir * 3 + 2] %}
9
+ {{ c_type }} H_prev[{{ batch_size }}][{{ hidden_size }}];
10
+ {{ c_type }} C_prev[{{ batch_size }}][{{ hidden_size }}];
11
+ for (int b = 0; b < {{ batch_size }}; ++b) {
12
+ for (int h = 0; h < {{ hidden_size }}; ++h) {
13
+ {% if input_initial_h %}
14
+ {% if layout == 0 %}
15
+ H_prev[b][h] = {{ input_initial_h }}[dir][b][h];
16
+ {% else %}
17
+ H_prev[b][h] = {{ input_initial_h }}[b][dir][h];
18
+ {% endif %}
19
+ {% else %}
20
+ H_prev[b][h] = {{ zero_literal }};
21
+ {% endif %}
22
+ {% if input_initial_c %}
23
+ {% if layout == 0 %}
24
+ C_prev[b][h] = {{ input_initial_c }}[dir][b][h];
25
+ {% else %}
26
+ C_prev[b][h] = {{ input_initial_c }}[b][dir][h];
27
+ {% endif %}
28
+ {% else %}
29
+ C_prev[b][h] = {{ zero_literal }};
30
+ {% endif %}
31
+ }
32
+ }
33
+ for (int b = 0; b < {{ batch_size }}; ++b) {
34
+ int seq_limit = {{ seq_length }};
35
+ {% if input_sequence_lens %}
36
+ seq_limit = (int){{ input_sequence_lens }}[b];
37
+ if (seq_limit < 0) {
38
+ seq_limit = 0;
39
+ }
40
+ if (seq_limit > {{ seq_length }}) {
41
+ seq_limit = {{ seq_length }};
42
+ }
43
+ {% endif %}
44
+ for (int step = 0; step < seq_limit; ++step) {
45
+ int t = reverse ? (seq_limit - 1 - step) : step;
46
+ {{ c_type }} H_next[{{ hidden_size }}];
47
+ {{ c_type }} C_next[{{ hidden_size }}];
48
+ for (int h = 0; h < {{ hidden_size }}; ++h) {
49
+ {{ c_type }} gate_i = {{ zero_literal }};
50
+ {{ c_type }} gate_o = {{ zero_literal }};
51
+ {{ c_type }} gate_f = {{ zero_literal }};
52
+ {{ c_type }} gate_c = {{ zero_literal }};
53
+ for (int i = 0; i < {{ input_size }}; ++i) {
54
+ {% if layout == 0 %}
55
+ {{ c_type }} x_val = {{ input_x }}[t][b][i];
56
+ {% else %}
57
+ {{ c_type }} x_val = {{ input_x }}[b][t][i];
58
+ {% endif %}
59
+ gate_i += x_val * {{ input_w }}[dir][h][i];
60
+ gate_o += x_val * {{ input_w }}[dir][{{ hidden_size }} + h][i];
61
+ gate_f += x_val * {{ input_w }}[dir][{{ hidden_size }} * 2 + h][i];
62
+ gate_c += x_val * {{ input_w }}[dir][{{ hidden_size }} * 3 + h][i];
63
+ }
64
+ for (int i = 0; i < {{ hidden_size }}; ++i) {
65
+ {{ c_type }} h_val = H_prev[b][i];
66
+ gate_i += h_val * {{ input_r }}[dir][h][i];
67
+ gate_o += h_val * {{ input_r }}[dir][{{ hidden_size }} + h][i];
68
+ gate_f += h_val * {{ input_r }}[dir][{{ hidden_size }} * 2 + h][i];
69
+ gate_c += h_val * {{ input_r }}[dir][{{ hidden_size }} * 3 + h][i];
70
+ }
71
+ {% if input_b %}
72
+ gate_i += {{ input_b }}[dir][h] + {{ input_b }}[dir][{{ hidden_size }} * 4 + h];
73
+ gate_o += {{ input_b }}[dir][{{ hidden_size }} + h] + {{ input_b }}[dir][{{ hidden_size }} * 5 + h];
74
+ gate_f += {{ input_b }}[dir][{{ hidden_size }} * 2 + h] + {{ input_b }}[dir][{{ hidden_size }} * 6 + h];
75
+ gate_c += {{ input_b }}[dir][{{ hidden_size }} * 3 + h] + {{ input_b }}[dir][{{ hidden_size }} * 7 + h];
76
+ {% endif %}
77
+ {% if use_clip %}
78
+ if (gate_i > {{ clip_literal }}) {
79
+ gate_i = {{ clip_literal }};
80
+ } else if (gate_i < -{{ clip_literal }}) {
81
+ gate_i = -{{ clip_literal }};
82
+ }
83
+ if (gate_o > {{ clip_literal }}) {
84
+ gate_o = {{ clip_literal }};
85
+ } else if (gate_o < -{{ clip_literal }}) {
86
+ gate_o = -{{ clip_literal }};
87
+ }
88
+ if (gate_f > {{ clip_literal }}) {
89
+ gate_f = {{ clip_literal }};
90
+ } else if (gate_f < -{{ clip_literal }}) {
91
+ gate_f = -{{ clip_literal }};
92
+ }
93
+ if (gate_c > {{ clip_literal }}) {
94
+ gate_c = {{ clip_literal }};
95
+ } else if (gate_c < -{{ clip_literal }}) {
96
+ gate_c = -{{ clip_literal }};
97
+ }
98
+ {% endif %}
99
+ {% if input_p %}
100
+ {{ c_type }} i_gate = {{ act_f }}(
101
+ gate_i + {{ input_p }}[dir][h] * C_prev[b][h]);
102
+ {% else %}
103
+ {{ c_type }} i_gate = {{ act_f }}(gate_i);
104
+ {% endif %}
105
+ {% if input_forget %}
106
+ {{ c_type }} f_gate = ({{ c_type }}){{ one_literal }} - i_gate;
107
+ {% else %}
108
+ {% if input_p %}
109
+ {{ c_type }} f_gate = {{ act_f }}(
110
+ gate_f + {{ input_p }}[dir][{{ hidden_size }} * 2 + h] * C_prev[b][h]);
111
+ {% else %}
112
+ {{ c_type }} f_gate = {{ act_f }}(gate_f);
113
+ {% endif %}
114
+ {% endif %}
115
+ {{ c_type }} c_gate = {{ act_g }}(gate_c);
116
+ {{ c_type }} c_new = f_gate * C_prev[b][h] + i_gate * c_gate;
117
+ {% if input_p %}
118
+ {{ c_type }} o_gate = {{ act_f }}(
119
+ gate_o + {{ input_p }}[dir][{{ hidden_size }} + h] * c_new);
120
+ {% else %}
121
+ {{ c_type }} o_gate = {{ act_f }}(gate_o);
122
+ {% endif %}
123
+ {{ c_type }} h_new = o_gate * {{ act_h }}(c_new);
124
+ C_next[h] = c_new;
125
+ H_next[h] = h_new;
126
+ {% if output_y %}
127
+ {% if layout == 0 %}
128
+ {{ output_y }}[step][dir][b][h] = h_new;
129
+ {% else %}
130
+ {{ output_y }}[b][step][dir][h] = h_new;
131
+ {% endif %}
132
+ {% endif %}
133
+ }
134
+ for (int h = 0; h < {{ hidden_size }}; ++h) {
135
+ C_prev[b][h] = C_next[h];
136
+ H_prev[b][h] = H_next[h];
137
+ }
138
+ }
139
+ {% if output_y %}
140
+ for (int step = seq_limit; step < {{ seq_length }}; ++step) {
141
+ for (int h = 0; h < {{ hidden_size }}; ++h) {
142
+ {% if layout == 0 %}
143
+ {{ output_y }}[step][dir][b][h] = {{ zero_literal }};
144
+ {% else %}
145
+ {{ output_y }}[b][step][dir][h] = {{ zero_literal }};
146
+ {% endif %}
147
+ }
148
+ }
149
+ {% endif %}
150
+ }
151
+ {% if output_y_h %}
152
+ for (int b = 0; b < {{ batch_size }}; ++b) {
153
+ for (int h = 0; h < {{ hidden_size }}; ++h) {
154
+ {% if layout == 0 %}
155
+ {{ output_y_h }}[dir][b][h] = H_prev[b][h];
156
+ {% else %}
157
+ {{ output_y_h }}[b][dir][h] = H_prev[b][h];
158
+ {% endif %}
159
+ }
160
+ }
161
+ {% endif %}
162
+ {% if output_y_c %}
163
+ for (int b = 0; b < {{ batch_size }}; ++b) {
164
+ for (int h = 0; h < {{ hidden_size }}; ++h) {
165
+ {% if layout == 0 %}
166
+ {{ output_y_c }}[dir][b][h] = C_prev[b][h];
167
+ {% else %}
168
+ {{ output_y_c }}[b][dir][h] = C_prev[b][h];
169
+ {% endif %}
170
+ }
171
+ }
172
+ {% endif %}
173
+ }
174
+ {% endfor %}
175
+ }
@@ -0,0 +1,13 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for idx in range(output_loop_vars | length) %}
3
+ {% for indent in range(loop.index0) %} {% endfor %}for (idx_t {{ output_loop_vars[idx] }} = 0; {{ output_loop_vars[idx] }} < {{ output_loop_bounds[idx] }}; ++{{ output_loop_vars[idx] }}) {
4
+ {% endfor %}
5
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}{{ acc_type }} acc = {{ zero_literal }};
6
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}for (idx_t k = 0; k < {{ k }}; ++k) {
7
+ {% for indent in range(output_loop_vars | length + 1) %} {% endfor %}acc += {{ input0_index_expr }} * {{ input1_index_expr }};
8
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}}
9
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}{{ output_index_expr }} = acc;
10
+ {% for idx in range(output_loop_vars | length) | reverse %}
11
+ {% for indent in range(loop.index0) %} {% endfor %}}
12
+ {% endfor %}
13
+ }
@@ -0,0 +1,118 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ for (idx_t n = 0; n < {{ batch }}; ++n) {
3
+ for (idx_t c = 0; c < {{ channels }}; ++c) {
4
+ {% if spatial_rank == 1 %}
5
+ for (idx_t ox = 0; ox < {{ out_spatial[0] }}; ++ox) {
6
+ {{ c_type }} max_value = {{ min_literal }};
7
+ {% if indices %}
8
+ {{ indices_c_type }} max_index = 0;
9
+ {% endif %}
10
+ for (idx_t kx = 0; kx < {{ kernel_shape[0] }}; ++kx) {
11
+ const idx_t ix = ox * {{ strides[0] }} + kx * {{ dilations[0] }} - {{ pads[0] }};
12
+ if (ix >= 0 && ix < {{ in_spatial[0] }}) {
13
+ {{ c_type }} val = {{ input0 }}[n][c][ix];
14
+ {% if indices %}
15
+ const {{ c_type }} prev_max = max_value;
16
+ max_value = {{ max_fn }}(max_value, val);
17
+ max_index = (val > prev_max)
18
+ ? ({{ indices_c_type }})(((({{ indices_c_type }})n * {{ channels }} + ({{ indices_c_type }})c) * {{ in_spatial[0] }}) + ({{ indices_c_type }})ix)
19
+ : max_index;
20
+ {% else %}
21
+ max_value = {{ max_fn }}(max_value, val);
22
+ {% endif %}
23
+ }
24
+ }
25
+ {{ output }}[n][c][ox] = max_value;
26
+ {% if indices %}
27
+ {{ indices }}[n][c][ox] = max_index;
28
+ {% endif %}
29
+ }
30
+ {% elif spatial_rank == 2 %}
31
+ for (idx_t oh = 0; oh < {{ out_spatial[0] }}; ++oh) {
32
+ for (idx_t ow = 0; ow < {{ out_spatial[1] }}; ++ow) {
33
+ {{ c_type }} max_value = {{ min_literal }};
34
+ {% if indices %}
35
+ {{ indices_c_type }} max_index = 0;
36
+ {% endif %}
37
+ for (idx_t kh = 0; kh < {{ kernel_shape[0] }}; ++kh) {
38
+ const idx_t ih = oh * {{ strides[0] }} + kh * {{ dilations[0] }} - {{ pads[0] }};
39
+ if (ih >= 0 && ih < {{ in_spatial[0] }}) {
40
+ for (idx_t kw = 0; kw < {{ kernel_shape[1] }}; ++kw) {
41
+ const idx_t iw = ow * {{ strides[1] }} + kw * {{ dilations[1] }} - {{ pads[1] }};
42
+ if (iw >= 0 && iw < {{ in_spatial[1] }}) {
43
+ {{ c_type }} val = {{ input0 }}[n][c][ih][iw];
44
+ {% if indices %}
45
+ const {{ c_type }} prev_max = max_value;
46
+ max_value = {{ max_fn }}(max_value, val);
47
+ max_index = (val > prev_max)
48
+ ? (
49
+ {% if storage_order == 0 %}
50
+ ({{ indices_c_type }})((((({{ indices_c_type }})n * {{ channels }} + ({{ indices_c_type }})c) * {{ in_spatial[0] }} + ({{ indices_c_type }})ih) * {{ in_spatial[1] }}) + ({{ indices_c_type }})iw)
51
+ {% else %}
52
+ ({{ indices_c_type }})(((({{ indices_c_type }})n * {{ channels }} + ({{ indices_c_type }})c) * {{ in_spatial[0] }} * {{ in_spatial[1] }}) + ({{ indices_c_type }})ih + ({{ indices_c_type }})iw * {{ in_spatial[0] }})
53
+ {% endif %}
54
+ )
55
+ : max_index;
56
+ {% else %}
57
+ max_value = {{ max_fn }}(max_value, val);
58
+ {% endif %}
59
+ }
60
+ }
61
+ }
62
+ }
63
+ {{ output }}[n][c][oh][ow] = max_value;
64
+ {% if indices %}
65
+ {{ indices }}[n][c][oh][ow] = max_index;
66
+ {% endif %}
67
+ }
68
+ }
69
+ {% elif spatial_rank == 3 %}
70
+ for (idx_t od = 0; od < {{ out_spatial[0] }}; ++od) {
71
+ for (idx_t oh = 0; oh < {{ out_spatial[1] }}; ++oh) {
72
+ for (idx_t ow = 0; ow < {{ out_spatial[2] }}; ++ow) {
73
+ {{ c_type }} max_value = {{ min_literal }};
74
+ {% if indices %}
75
+ {{ indices_c_type }} max_index = 0;
76
+ {% endif %}
77
+ for (idx_t kd = 0; kd < {{ kernel_shape[0] }}; ++kd) {
78
+ const idx_t id = od * {{ strides[0] }} + kd * {{ dilations[0] }} - {{ pads[0] }};
79
+ if (id >= 0 && id < {{ in_spatial[0] }}) {
80
+ for (idx_t kh = 0; kh < {{ kernel_shape[1] }}; ++kh) {
81
+ const idx_t ih = oh * {{ strides[1] }} + kh * {{ dilations[1] }} - {{ pads[1] }};
82
+ if (ih >= 0 && ih < {{ in_spatial[1] }}) {
83
+ for (idx_t kw = 0; kw < {{ kernel_shape[2] }}; ++kw) {
84
+ const idx_t iw = ow * {{ strides[2] }} + kw * {{ dilations[2] }} - {{ pads[2] }};
85
+ if (iw >= 0 && iw < {{ in_spatial[2] }}) {
86
+ {{ c_type }} val = {{ input0 }}[n][c][id][ih][iw];
87
+ {% if indices %}
88
+ const {{ c_type }} prev_max = max_value;
89
+ max_value = {{ max_fn }}(max_value, val);
90
+ max_index = (val > prev_max)
91
+ ? (
92
+ {% if storage_order == 0 %}
93
+ ({{ indices_c_type }})(((((({{ indices_c_type }})n * {{ channels }} + ({{ indices_c_type }})c) * {{ in_spatial[0] }} + ({{ indices_c_type }})id) * {{ in_spatial[1] }} + ({{ indices_c_type }})ih) * {{ in_spatial[2] }}) + ({{ indices_c_type }})iw)
94
+ {% else %}
95
+ ({{ indices_c_type }})(((({{ indices_c_type }})n * {{ channels }} + ({{ indices_c_type }})c) * {{ in_spatial[0] }} * {{ in_spatial[1] }} * {{ in_spatial[2] }}) + ({{ indices_c_type }})id + ({{ indices_c_type }})ih * {{ in_spatial[0] }} + ({{ indices_c_type }})iw * {{ in_spatial[0] }} * {{ in_spatial[1] }})
96
+ {% endif %}
97
+ )
98
+ : max_index;
99
+ {% else %}
100
+ max_value = {{ max_fn }}(max_value, val);
101
+ {% endif %}
102
+ }
103
+ }
104
+ }
105
+ }
106
+ }
107
+ }
108
+ {{ output }}[n][c][od][oh][ow] = max_value;
109
+ {% if indices %}
110
+ {{ indices }}[n][c][od][oh][ow] = max_index;
111
+ {% endif %}
112
+ }
113
+ }
114
+ }
115
+ {% endif %}
116
+ }
117
+ }
118
+ }
@@ -0,0 +1,34 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for axis in non_axes %}
3
+ for (idx_t {{ loop_vars[axis] }} = 0; {{ loop_vars[axis] }} < {{ shape[axis] }}; ++{{ loop_vars[axis] }}) {
4
+ {% endfor %}
5
+ {{ c_type }} sum = {{ zero_literal }};
6
+ {% for axis in axes %}
7
+ for (idx_t {{ loop_vars[axis] }} = 0; {{ loop_vars[axis] }} < {{ shape[axis] }}; ++{{ loop_vars[axis] }}) {
8
+ {% endfor %}
9
+ sum += {{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %};
10
+ {% for _ in axes %}
11
+ }
12
+ {% endfor %}
13
+ {{ c_type }} mean = sum / {{ reduce_count }};
14
+ {{ c_type }} var = {{ zero_literal }};
15
+ {% for axis in axes %}
16
+ for (idx_t {{ loop_vars[axis] }} = 0; {{ loop_vars[axis] }} < {{ shape[axis] }}; ++{{ loop_vars[axis] }}) {
17
+ {% endfor %}
18
+ {{ c_type }} diff = {{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %} - mean;
19
+ var += diff * diff;
20
+ {% for _ in axes %}
21
+ }
22
+ {% endfor %}
23
+ {{ c_type }} denom = {{ sqrt_fn }}(var / {{ reduce_count }} + {{ epsilon_literal }});
24
+ {% for axis in axes %}
25
+ for (idx_t {{ loop_vars[axis] }} = 0; {{ loop_vars[axis] }} < {{ shape[axis] }}; ++{{ loop_vars[axis] }}) {
26
+ {% endfor %}
27
+ {{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = ({{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %} - mean) / denom;
28
+ {% for _ in axes %}
29
+ }
30
+ {% endfor %}
31
+ {% for _ in non_axes %}
32
+ }
33
+ {% endfor %}
34
+ }
@@ -0,0 +1,15 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for dim in shape %}
3
+ for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
4
+ {% endfor %}
5
+ {{ output_expr }} = {{ input_exprs[0] }};
6
+ {% for expr in input_exprs[1:] %}
7
+ {{ output_expr }} = {% if operator_kind == "func" %}{{ operator }}({{ output_expr }}, {{ expr }}){% elif operator_kind == "expr" %}{{ operator_expr }}{% else %}{{ output_expr }} {{ operator }} {{ expr }}{% endif %};
8
+ {% endfor %}
9
+ {% if is_mean %}
10
+ {{ output_expr }} = {{ output_expr }} / {{ mean_scale }};
11
+ {% endif %}
12
+ {% for _ in shape %}
13
+ }
14
+ {% endfor %}
15
+ }
@@ -0,0 +1,54 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ const {{ c_type }} *input_flat = (const {{ c_type }} *){{ input0 }};
3
+ const {{ target_c_type }} *target_flat = (const {{ target_c_type }} *){{ target }};
4
+ {{ c_type }} *output_flat = ({{ c_type }} *){{ output }};
5
+ const idx_t n = {{ n }};
6
+ const idx_t c = {{ c }};
7
+ const idx_t d = {{ d }};
8
+ {% if reduction != "none" %}
9
+ {{ acc_type }} loss_sum = {{ acc_zero_literal }};
10
+ {% if reduction == "mean" %}
11
+ {{ acc_type }} weight_sum = {{ acc_zero_literal }};
12
+ {% endif %}
13
+ {% endif %}
14
+ for (idx_t n_idx = 0; n_idx < n; ++n_idx) {
15
+ for (idx_t d_idx = 0; d_idx < d; ++d_idx) {
16
+ idx_t target_index = n_idx * d + d_idx;
17
+ {{ target_c_type }} target_value = target_flat[target_index];
18
+ if ((int64_t)target_value == {{ ignore_index }}) {
19
+ {% if reduction == "none" %}
20
+ output_flat[target_index] = {{ zero_literal }};
21
+ {% endif %}
22
+ } else {
23
+ idx_t class_index = (idx_t)target_value;
24
+ idx_t input_index = (n_idx * c + class_index) * d + d_idx;
25
+ {{ acc_type }} value = -({{ acc_type }})input_flat[input_index];
26
+ {% if weight %}
27
+ {{ acc_type }} sample_weight = {{ weight }}[class_index];
28
+ value *= sample_weight;
29
+ {% endif %}
30
+ {% if reduction == "none" %}
31
+ output_flat[target_index] = value;
32
+ {% else %}
33
+ loss_sum += value;
34
+ {% if reduction == "mean" %}
35
+ {% if weight %}
36
+ weight_sum += sample_weight;
37
+ {% else %}
38
+ weight_sum += {{ acc_one_literal }};
39
+ {% endif %}
40
+ {% endif %}
41
+ {% endif %}
42
+ }
43
+ }
44
+ }
45
+ {% if reduction == "mean" %}
46
+ if (weight_sum == {{ acc_zero_literal }}) {
47
+ output_flat[0] = {{ zero_literal }};
48
+ } else {
49
+ output_flat[0] = loss_sum / weight_sum;
50
+ }
51
+ {% elif reduction == "sum" %}
52
+ output_flat[0] = loss_sum;
53
+ {% endif %}
54
+ }