emx-onnx-cgen 0.3.8__py3-none-any.whl → 0.4.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (137) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +1025 -162
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +2081 -458
  6. emx_onnx_cgen/compiler.py +157 -75
  7. emx_onnx_cgen/determinism.py +39 -0
  8. emx_onnx_cgen/ir/context.py +25 -15
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +32 -7
  11. emx_onnx_cgen/ir/ops/__init__.py +20 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +138 -22
  13. emx_onnx_cgen/ir/ops/misc.py +95 -0
  14. emx_onnx_cgen/ir/ops/nn.py +361 -38
  15. emx_onnx_cgen/ir/ops/reduce.py +1 -16
  16. emx_onnx_cgen/lowering/__init__.py +9 -0
  17. emx_onnx_cgen/lowering/arg_reduce.py +0 -4
  18. emx_onnx_cgen/lowering/average_pool.py +157 -27
  19. emx_onnx_cgen/lowering/bernoulli.py +73 -0
  20. emx_onnx_cgen/lowering/common.py +48 -0
  21. emx_onnx_cgen/lowering/concat.py +41 -7
  22. emx_onnx_cgen/lowering/conv.py +19 -8
  23. emx_onnx_cgen/lowering/conv_integer.py +103 -0
  24. emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
  25. emx_onnx_cgen/lowering/elementwise.py +140 -43
  26. emx_onnx_cgen/lowering/gather.py +11 -2
  27. emx_onnx_cgen/lowering/gemm.py +7 -124
  28. emx_onnx_cgen/lowering/global_max_pool.py +0 -5
  29. emx_onnx_cgen/lowering/gru.py +323 -0
  30. emx_onnx_cgen/lowering/hamming_window.py +104 -0
  31. emx_onnx_cgen/lowering/hardmax.py +1 -37
  32. emx_onnx_cgen/lowering/identity.py +7 -6
  33. emx_onnx_cgen/lowering/logsoftmax.py +1 -35
  34. emx_onnx_cgen/lowering/lp_pool.py +15 -4
  35. emx_onnx_cgen/lowering/matmul.py +3 -105
  36. emx_onnx_cgen/lowering/optional_has_element.py +28 -0
  37. emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
  38. emx_onnx_cgen/lowering/reduce.py +0 -5
  39. emx_onnx_cgen/lowering/reshape.py +7 -16
  40. emx_onnx_cgen/lowering/shape.py +14 -8
  41. emx_onnx_cgen/lowering/slice.py +14 -4
  42. emx_onnx_cgen/lowering/softmax.py +1 -35
  43. emx_onnx_cgen/lowering/split.py +37 -3
  44. emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
  45. emx_onnx_cgen/lowering/tile.py +38 -1
  46. emx_onnx_cgen/lowering/topk.py +1 -5
  47. emx_onnx_cgen/lowering/transpose.py +9 -3
  48. emx_onnx_cgen/lowering/unsqueeze.py +11 -16
  49. emx_onnx_cgen/lowering/upsample.py +151 -0
  50. emx_onnx_cgen/lowering/variadic.py +1 -1
  51. emx_onnx_cgen/lowering/where.py +0 -5
  52. emx_onnx_cgen/onnx_import.py +578 -14
  53. emx_onnx_cgen/ops.py +3 -0
  54. emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
  55. emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
  56. emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
  57. emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
  58. emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
  59. emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
  60. emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
  61. emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
  62. emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
  63. emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
  64. emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
  65. emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
  66. emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
  67. emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
  68. emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
  69. emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
  70. emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
  71. emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
  72. emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
  73. emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
  74. emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
  75. emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
  76. emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
  77. emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
  78. emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
  79. emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
  80. emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
  81. emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
  82. emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
  83. emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
  84. emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
  85. emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
  86. emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
  87. emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
  88. emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
  89. emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
  90. emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
  91. emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
  92. emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
  93. emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
  94. emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
  95. emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
  96. emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
  97. emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
  98. emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
  99. emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
  100. emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
  101. emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
  102. emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
  103. emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
  104. emx_onnx_cgen/templates/range_op.c.j2 +8 -0
  105. emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
  106. emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
  107. emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
  108. emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
  109. emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
  110. emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
  111. emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
  112. emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
  113. emx_onnx_cgen/templates/size_op.c.j2 +4 -0
  114. emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
  115. emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
  116. emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
  117. emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
  118. emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
  119. emx_onnx_cgen/templates/split_op.c.j2 +18 -0
  120. emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
  121. emx_onnx_cgen/templates/testbench.c.j2 +161 -0
  122. emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
  123. emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
  124. emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
  125. emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
  126. emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
  127. emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
  128. emx_onnx_cgen/templates/where_op.c.j2 +9 -0
  129. emx_onnx_cgen/verification.py +45 -5
  130. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/METADATA +33 -15
  131. emx_onnx_cgen-0.4.2.dev0.dist-info/RECORD +190 -0
  132. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/WHEEL +1 -1
  133. emx_onnx_cgen/runtime/__init__.py +0 -1
  134. emx_onnx_cgen/runtime/evaluator.py +0 -2955
  135. emx_onnx_cgen-0.3.8.dist-info/RECORD +0 -107
  136. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/entry_points.txt +0 -0
  137. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,18 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for dim in output_shape %}
3
+ for (idx_t {{ output_loop_vars[loop.index0] }} = 0; {{ output_loop_vars[loop.index0] }} < {{ dim }}; ++{{ output_loop_vars[loop.index0] }}) {
4
+ {% endfor %}
5
+ {{ input_c_type }} best = {{ input0 }}{{ init_index_expr }};
6
+ {{ output_c_type }} best_index = 0;
7
+ for (idx_t {{ reduce_var }} = 1; {{ reduce_var }} < {{ reduce_dim }}; ++{{ reduce_var }}) {
8
+ {{ input_c_type }} candidate = {{ input0 }}{{ input_index_expr }};
9
+ if (candidate {{ compare_op }} best) {
10
+ best = candidate;
11
+ best_index = ({{ output_c_type }}){{ reduce_var }};
12
+ }
13
+ }
14
+ {{ output }}{{ output_index_expr }} = best_index;
15
+ {% for _ in output_shape %}
16
+ }
17
+ {% endfor %}
18
+ }
@@ -0,0 +1,189 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ const {{ c_type }} scale = {{ scale_literal }};
3
+ const {{ c_type }} softcap = {{ softcap_literal }};
4
+ {% if output_present_key %}
5
+ for (int b = 0; b < {{ batch }}; ++b) {
6
+ for (int h = 0; h < {{ kv_heads }}; ++h) {
7
+ for (int ki = 0; ki < {{ total_seq }}; ++ki) {
8
+ if (ki < {{ past_seq }}) {
9
+ for (int d = 0; d < {{ qk_head_size }}; ++d) {
10
+ {{ output_present_key }}[b][h][ki][d] = {{ input_past_key }}[b][h][ki][d];
11
+ }
12
+ } else {
13
+ int k_index = ki - {{ past_seq }};
14
+ for (int d = 0; d < {{ qk_head_size }}; ++d) {
15
+ {% if k_rank == 4 %}
16
+ {{ output_present_key }}[b][h][ki][d] = {{ input_k }}[b][h][k_index][d];
17
+ {% else %}
18
+ {{ output_present_key }}[b][h][ki][d] = {{ input_k }}[b][k_index][h * {{ qk_head_size }} + d];
19
+ {% endif %}
20
+ }
21
+ }
22
+ }
23
+ }
24
+ }
25
+ {% endif %}
26
+ {% if output_present_value %}
27
+ for (int b = 0; b < {{ batch }}; ++b) {
28
+ for (int h = 0; h < {{ kv_heads }}; ++h) {
29
+ for (int ki = 0; ki < {{ total_seq }}; ++ki) {
30
+ if (ki < {{ past_seq }}) {
31
+ for (int d = 0; d < {{ v_head_size }}; ++d) {
32
+ {{ output_present_value }}[b][h][ki][d] = {{ input_past_value }}[b][h][ki][d];
33
+ }
34
+ } else {
35
+ int v_index = ki - {{ past_seq }};
36
+ for (int d = 0; d < {{ v_head_size }}; ++d) {
37
+ {% if v_rank == 4 %}
38
+ {{ output_present_value }}[b][h][ki][d] = {{ input_v }}[b][h][v_index][d];
39
+ {% else %}
40
+ {{ output_present_value }}[b][h][ki][d] = {{ input_v }}[b][v_index][h * {{ v_head_size }} + d];
41
+ {% endif %}
42
+ }
43
+ }
44
+ }
45
+ }
46
+ }
47
+ {% endif %}
48
+ for (int b = 0; b < {{ batch }}; ++b) {
49
+ for (int h = 0; h < {{ q_heads }}; ++h) {
50
+ int kv_head = h / {{ head_group_size }};
51
+ for (int qi = 0; qi < {{ q_seq }}; ++qi) {
52
+ {{ c_type }} scores[{{ total_seq }}];
53
+ {{ c_type }} max_score = {{ min_literal }};
54
+ for (int ki = 0; ki < {{ total_seq }}; ++ki) {
55
+ {{ c_type }} score = {{ zero_literal }};
56
+ int k_index = ki - {{ past_seq }};
57
+ for (int d = 0; d < {{ qk_head_size }}; ++d) {
58
+ {{ c_type }} q_val;
59
+ {{ c_type }} k_val;
60
+ {% if q_rank == 4 %}
61
+ q_val = {{ input_q }}[b][h][qi][d];
62
+ {% else %}
63
+ q_val = {{ input_q }}[b][qi][h * {{ qk_head_size }} + d];
64
+ {% endif %}
65
+ {% if input_past_key %}
66
+ if (ki < {{ past_seq }}) {
67
+ k_val = {{ input_past_key }}[b][kv_head][ki][d];
68
+ } else {
69
+ {% if k_rank == 4 %}
70
+ k_val = {{ input_k }}[b][kv_head][k_index][d];
71
+ {% else %}
72
+ k_val = {{ input_k }}[b][k_index][kv_head * {{ qk_head_size }} + d];
73
+ {% endif %}
74
+ }
75
+ {% else %}
76
+ {% if k_rank == 4 %}
77
+ k_val = {{ input_k }}[b][kv_head][k_index][d];
78
+ {% else %}
79
+ k_val = {{ input_k }}[b][k_index][kv_head * {{ qk_head_size }} + d];
80
+ {% endif %}
81
+ {% endif %}
82
+ score += q_val * k_val;
83
+ }
84
+ score *= scale;
85
+ {{ c_type }} bias = {{ zero_literal }};
86
+ {% if input_attn_mask %}
87
+ if (ki >= {{ mask_kv_seq }}) {
88
+ bias = {{ min_literal }};
89
+ } else {
90
+ int mask_q = {{ '0' if mask_broadcast_q_seq else 'qi' }};
91
+ {% if mask_rank == 2 %}
92
+ {% if mask_is_bool %}
93
+ bias = {{ input_attn_mask }}[mask_q][ki] ? {{ zero_literal }} : {{ min_literal }};
94
+ {% else %}
95
+ bias = {{ input_attn_mask }}[mask_q][ki];
96
+ {% endif %}
97
+ {% elif mask_rank == 3 %}
98
+ int mask_b = {{ '0' if mask_broadcast_batch else 'b' }};
99
+ {% if mask_is_bool %}
100
+ bias = {{ input_attn_mask }}[mask_b][mask_q][ki] ? {{ zero_literal }} : {{ min_literal }};
101
+ {% else %}
102
+ bias = {{ input_attn_mask }}[mask_b][mask_q][ki];
103
+ {% endif %}
104
+ {% else %}
105
+ int mask_b = {{ '0' if mask_broadcast_batch else 'b' }};
106
+ int mask_h = {{ '0' if mask_broadcast_heads else 'h' }};
107
+ {% if mask_is_bool %}
108
+ bias = {{ input_attn_mask }}[mask_b][mask_h][mask_q][ki] ? {{ zero_literal }} : {{ min_literal }};
109
+ {% else %}
110
+ bias = {{ input_attn_mask }}[mask_b][mask_h][mask_q][ki];
111
+ {% endif %}
112
+ {% endif %}
113
+ }
114
+ {% endif %}
115
+ {% if input_nonpad_kv_seqlen %}
116
+ if (ki >= {{ input_nonpad_kv_seqlen }}[b]) {
117
+ bias = {{ min_literal }};
118
+ }
119
+ {% endif %}
120
+ if ({{ is_causal }} && ki > qi + {{ past_seq }}) {
121
+ bias = {{ min_literal }};
122
+ }
123
+ {{ c_type }} score_bias = score + bias;
124
+ {{ c_type }} score_softcap = score_bias;
125
+ if (softcap != {{ zero_literal }}) {
126
+ score_softcap = softcap * {{ tanh_fn }}(score_bias / softcap);
127
+ }
128
+ {% if output_qk_matmul %}
129
+ if ({{ qk_matmul_output_mode }} == 0) {
130
+ {{ output_qk_matmul }}[b][h][qi][ki] = score;
131
+ } else if ({{ qk_matmul_output_mode }} == 1) {
132
+ {{ output_qk_matmul }}[b][h][qi][ki] = score_bias;
133
+ } else if ({{ qk_matmul_output_mode }} == 2) {
134
+ {{ output_qk_matmul }}[b][h][qi][ki] = score_softcap;
135
+ }
136
+ {% endif %}
137
+ scores[ki] = score_softcap;
138
+ max_score = {{ max_fn }}(max_score, score_softcap);
139
+ }
140
+ {{ c_type }} weights[{{ total_seq }}];
141
+ {{ c_type }} sum = {{ zero_literal }};
142
+ for (int ki = 0; ki < {{ total_seq }}; ++ki) {
143
+ {{ c_type }} weight = {{ zero_literal }};
144
+ if (max_score != {{ min_literal }}) {
145
+ weight = {{ exp_fn }}(scores[ki] - max_score);
146
+ }
147
+ weights[ki] = weight;
148
+ sum += weight;
149
+ }
150
+ {{ c_type }} inv_sum = sum == {{ zero_literal }} ? {{ zero_literal }} : ({{ c_type }}){{ one_literal }} / sum;
151
+ {% if output_qk_matmul %}
152
+ if ({{ qk_matmul_output_mode }} == 3) {
153
+ for (int ki = 0; ki < {{ total_seq }}; ++ki) {
154
+ {{ output_qk_matmul }}[b][h][qi][ki] = weights[ki] * inv_sum;
155
+ }
156
+ }
157
+ {% endif %}
158
+ for (int vd = 0; vd < {{ v_head_size }}; ++vd) {
159
+ {{ c_type }} acc = {{ zero_literal }};
160
+ for (int ki = 0; ki < {{ total_seq }}; ++ki) {
161
+ {{ c_type }} weight = weights[ki] * inv_sum;
162
+ {% if input_past_value %}
163
+ if (ki < {{ past_seq }}) {
164
+ acc += weight * {{ input_past_value }}[b][kv_head][ki][vd];
165
+ } else {
166
+ {% if v_rank == 4 %}
167
+ acc += weight * {{ input_v }}[b][kv_head][ki - {{ past_seq }}][vd];
168
+ {% else %}
169
+ acc += weight * {{ input_v }}[b][ki - {{ past_seq }}][kv_head * {{ v_head_size }} + vd];
170
+ {% endif %}
171
+ }
172
+ {% else %}
173
+ {% if v_rank == 4 %}
174
+ acc += weight * {{ input_v }}[b][kv_head][ki][vd];
175
+ {% else %}
176
+ acc += weight * {{ input_v }}[b][ki][kv_head * {{ v_head_size }} + vd];
177
+ {% endif %}
178
+ {% endif %}
179
+ }
180
+ {% if output_rank == 4 %}
181
+ {{ output }}[b][h][qi][vd] = acc;
182
+ {% else %}
183
+ {{ output }}[b][qi][h * {{ v_head_size }} + vd] = acc;
184
+ {% endif %}
185
+ }
186
+ }
187
+ }
188
+ }
189
+ }
@@ -0,0 +1,126 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% if spatial_rank == 3 %}
3
+ for (idx_t n = 0; n < {{ batch }}; ++n) {
4
+ for (idx_t c = 0; c < {{ channels }}; ++c) {
5
+ for (idx_t od = 0; od < {{ out_d }}; ++od) {
6
+ for (idx_t oh = 0; oh < {{ out_h }}; ++oh) {
7
+ for (idx_t ow = 0; ow < {{ out_w }}; ++ow) {
8
+ {{ c_type }} values[{{ kernel_d * kernel_h * kernel_w }}];
9
+ idx_t count = 0;
10
+ for (idx_t kd = 0; kd < {{ kernel_d }}; ++kd) {
11
+ const idx_t id = od * {{ stride_d }} + kd * {{ dilation_d }} - {{ pad_front }};
12
+ for (idx_t kh = 0; kh < {{ kernel_h }}; ++kh) {
13
+ const idx_t ih = oh * {{ stride_h }} + kh * {{ dilation_h }} - {{ pad_top }};
14
+ for (idx_t kw = 0; kw < {{ kernel_w }}; ++kw) {
15
+ const idx_t iw = ow * {{ stride_w }} + kw * {{ dilation_w }} - {{ pad_left }};
16
+ if (id >= 0 && id < {{ in_d }}
17
+ && ih >= 0 && ih < {{ in_h }}
18
+ && iw >= 0 && iw < {{ in_w }}) {
19
+ values[count++] = {{ input0 }}[n][c][id][ih][iw];
20
+ } else if ({{ count_include_pad }}) {
21
+ values[count++] = {{ zero_literal }};
22
+ }
23
+ }
24
+ }
25
+ }
26
+ const idx_t denom = {{ count_include_pad }}
27
+ ? {{ kernel_d * kernel_h * kernel_w }}
28
+ : count;
29
+ {{ c_type }} acc = {{ zero_literal }};
30
+ if (count > 0) {
31
+ idx_t reduce_count = count;
32
+ while (reduce_count > 1) {
33
+ idx_t next = 0;
34
+ for (idx_t i = 0; i + 1 < reduce_count; i += 2) {
35
+ values[next++] = values[i] + values[i + 1];
36
+ }
37
+ if (reduce_count % 2) {
38
+ values[next++] = values[reduce_count - 1];
39
+ }
40
+ reduce_count = next;
41
+ }
42
+ acc = values[0];
43
+ }
44
+ {{ output }}[n][c][od][oh][ow] = denom ? acc / ({{ c_type }})denom : {{ zero_literal }};
45
+ }
46
+ }
47
+ }
48
+ }
49
+ }
50
+ {% elif spatial_rank == 1 %}
51
+ for (idx_t n = 0; n < {{ batch }}; ++n) {
52
+ for (idx_t c = 0; c < {{ channels }}; ++c) {
53
+ for (idx_t ow = 0; ow < {{ out_w }}; ++ow) {
54
+ {{ c_type }} values[{{ kernel_w }}];
55
+ idx_t count = 0;
56
+ for (idx_t kw = 0; kw < {{ kernel_w }}; ++kw) {
57
+ const idx_t iw = ow * {{ stride_w }} + kw * {{ dilation_w }} - {{ pad_left }};
58
+ if (iw >= 0 && iw < {{ in_w }}) {
59
+ values[count++] = {{ input0 }}[n][c][iw];
60
+ } else if ({{ count_include_pad }}) {
61
+ values[count++] = {{ zero_literal }};
62
+ }
63
+ }
64
+ const idx_t denom = {{ count_include_pad }} ? {{ kernel_w }} : count;
65
+ {{ c_type }} acc = {{ zero_literal }};
66
+ if (count > 0) {
67
+ idx_t reduce_count = count;
68
+ while (reduce_count > 1) {
69
+ idx_t next = 0;
70
+ for (idx_t i = 0; i + 1 < reduce_count; i += 2) {
71
+ values[next++] = values[i] + values[i + 1];
72
+ }
73
+ if (reduce_count % 2) {
74
+ values[next++] = values[reduce_count - 1];
75
+ }
76
+ reduce_count = next;
77
+ }
78
+ acc = values[0];
79
+ }
80
+ {{ output }}[n][c][ow] = denom ? acc / ({{ c_type }})denom : {{ zero_literal }};
81
+ }
82
+ }
83
+ }
84
+ {% else %}
85
+ for (idx_t n = 0; n < {{ batch }}; ++n) {
86
+ for (idx_t c = 0; c < {{ channels }}; ++c) {
87
+ for (idx_t oh = 0; oh < {{ out_h }}; ++oh) {
88
+ for (idx_t ow = 0; ow < {{ out_w }}; ++ow) {
89
+ {{ c_type }} values[{{ kernel_h * kernel_w }}];
90
+ idx_t count = 0;
91
+ for (idx_t kh = 0; kh < {{ kernel_h }}; ++kh) {
92
+ const idx_t ih = oh * {{ stride_h }} + kh * {{ dilation_h }} - {{ pad_top }};
93
+ for (idx_t kw = 0; kw < {{ kernel_w }}; ++kw) {
94
+ const idx_t iw = ow * {{ stride_w }} + kw * {{ dilation_w }} - {{ pad_left }};
95
+ if (ih >= 0 && ih < {{ in_h }} && iw >= 0 && iw < {{ in_w }}) {
96
+ values[count++] = {{ input0 }}[n][c][ih][iw];
97
+ } else if ({{ count_include_pad }}) {
98
+ values[count++] = {{ zero_literal }};
99
+ }
100
+ }
101
+ }
102
+ const idx_t denom = {{ count_include_pad }}
103
+ ? {{ kernel_h * kernel_w }}
104
+ : count;
105
+ {{ c_type }} acc = {{ zero_literal }};
106
+ if (count > 0) {
107
+ idx_t reduce_count = count;
108
+ while (reduce_count > 1) {
109
+ idx_t next = 0;
110
+ for (idx_t i = 0; i + 1 < reduce_count; i += 2) {
111
+ values[next++] = values[i] + values[i + 1];
112
+ }
113
+ if (reduce_count % 2) {
114
+ values[next++] = values[reduce_count - 1];
115
+ }
116
+ reduce_count = next;
117
+ }
118
+ acc = values[0];
119
+ }
120
+ {{ output }}[n][c][oh][ow] = denom ? acc / ({{ c_type }})denom : {{ zero_literal }};
121
+ }
122
+ }
123
+ }
124
+ }
125
+ {% endif %}
126
+ }
@@ -0,0 +1,11 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for dim in shape %}
3
+ for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
4
+ {% endfor %}
5
+ idx_t c = {{ loop_vars[1] }};
6
+ {{ c_type }} denom = {{ sqrt_fn }}({{ variance }}[c] + {{ epsilon_literal }});
7
+ {{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = ({{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %} - {{ mean }}[c]) / denom * {{ scale }}[c] + {{ bias }}[c];
8
+ {% for _ in shape %}
9
+ }
10
+ {% endfor %}
11
+ }
@@ -0,0 +1,34 @@
1
+ static inline uint64_t {{ op_name }}_rng_next_u64(uint64_t *state) {
2
+ uint64_t x = *state;
3
+ x ^= x >> 12;
4
+ x ^= x << 25;
5
+ x ^= x >> 27;
6
+ *state = x;
7
+ return x * 0x2545f4914f6cdd1dull;
8
+ }
9
+
10
+ static inline double {{ op_name }}_rng_next_double(uint64_t *state) {
11
+ return (double){{ op_name }}_rng_next_u64(state) * (1.0 / 18446744073709551616.0);
12
+ }
13
+
14
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
15
+ uint64_t rng_state = {{ seed }}ull;
16
+ if (!rng_state) {
17
+ rng_state = 0x243f6a8885a308d3ull;
18
+ }
19
+ {% for dim in shape %}
20
+ for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
21
+ {% endfor %}
22
+ const double prob = (double){{ input0 }}{{ input_index_expr }};
23
+ int is_one = 0;
24
+ if (prob >= 1.0) {
25
+ is_one = 1;
26
+ } else if (prob > 0.0) {
27
+ const double sample = {{ op_name }}_rng_next_double(&rng_state);
28
+ is_one = sample < prob;
29
+ }
30
+ {{ output }}{{ output_index_expr }} = is_one ? {{ one_literal }} : {{ zero_literal }};
31
+ {% for _ in shape %}
32
+ }
33
+ {% endfor %}
34
+ }
@@ -0,0 +1,9 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for dim in shape %}
3
+ for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
4
+ {% endfor %}
5
+ {{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = {% if operator_kind == "func" %}{{ operator }}({{ left_expr }}, {{ right_expr }}){% elif operator_kind == "expr" %}{{ operator_expr }}{% else %}{{ left_expr }} {{ operator }} {{ right_expr }}{% endif %};
6
+ {% for _ in shape %}
7
+ }
8
+ {% endfor %}
9
+ }
@@ -0,0 +1,9 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for dim in shape %}
3
+ for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
4
+ {% endfor %}
5
+ {{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = ({{ output_c_type }}){{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %};
6
+ {% for _ in shape %}
7
+ }
8
+ {% endfor %}
9
+ }
@@ -0,0 +1,14 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for dim in shape %}
3
+ for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
4
+ {% endfor %}
5
+ {{ output_c_type }} value = {{ input_expr }};
6
+ const {{ output_c_type }} min_value = {{ min_expr }};
7
+ const {{ output_c_type }} max_value = {{ max_expr }};
8
+ value = {{ max_fn }}(value, min_value);
9
+ value = {{ min_fn }}(value, max_value);
10
+ {{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = value;
11
+ {% for _ in shape %}
12
+ }
13
+ {% endfor %}
14
+ }
@@ -0,0 +1,28 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ const void *inputs[] = { {% for input in inputs %}{{ input }}{% if not loop.last %}, {% endif %}{% endfor %} };
3
+ const idx_t axis_sizes[] = { {% for axis in axis_sizes %}{{ axis }}{% if not loop.last %}, {% endif %}{% endfor %} };
4
+ idx_t concat_axis = 0;
5
+ for (idx_t idx = 0; idx < {{ input_count }}; ++idx) {
6
+ concat_axis += axis_sizes[idx];
7
+ }
8
+ if (concat_axis == 0) {
9
+ return;
10
+ }
11
+ for (idx_t outer_idx = 0; outer_idx < {{ outer }}; ++outer_idx) {
12
+ idx_t output_offset = outer_idx * concat_axis * {{ inner }};
13
+ idx_t axis_offset = 0;
14
+ for (idx_t input_idx = 0; input_idx < {{ input_count }}; ++input_idx) {
15
+ idx_t axis = axis_sizes[input_idx];
16
+ idx_t copy_elems = axis * {{ inner }};
17
+ const unsigned char *input_bytes =
18
+ (const unsigned char *)inputs[input_idx];
19
+ idx_t input_offset = outer_idx * copy_elems;
20
+ memcpy(
21
+ ((unsigned char *){{ output }}) + (output_offset + axis_offset) * sizeof({{ c_type }}),
22
+ input_bytes + input_offset * sizeof({{ c_type }}),
23
+ copy_elems * sizeof({{ c_type }})
24
+ );
25
+ axis_offset += copy_elems;
26
+ }
27
+ }
28
+ }
@@ -0,0 +1,10 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ (void){{ input0 }};
3
+ {% for dim in shape %}
4
+ for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
5
+ {% endfor %}
6
+ {{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = {{ value_literal }};
7
+ {% for _ in shape %}
8
+ }
9
+ {% endfor %}
10
+ }
@@ -0,0 +1,34 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ for (idx_t n = 0; n < {{ batch }}; ++n) {
3
+ for (idx_t g = 0; g < {{ group }}; ++g) {
4
+ for (idx_t oc = 0; oc < {{ group_out_channels }}; ++oc) {
5
+ const idx_t oc_global = g * {{ group_out_channels }} + oc;
6
+ {% for dim in range(spatial_rank) %}
7
+ for (idx_t {{ out_indices[dim] }} = 0; {{ out_indices[dim] }} < {{ out_spatial[dim] }}; ++{{ out_indices[dim] }}) {
8
+ {% endfor %}
9
+ {{ acc_type }} acc = {{ acc_zero_literal }};
10
+ for (idx_t ic = 0; ic < {{ group_in_channels }}; ++ic) {
11
+ const idx_t ic_global = g * {{ group_in_channels }} + ic;
12
+ {% for dim in range(spatial_rank) %}
13
+ for (idx_t {{ kernel_indices[dim] }} = 0; {{ kernel_indices[dim] }} < {{ kernel_shape[dim] }}; ++{{ kernel_indices[dim] }}) {
14
+ {% endfor %}
15
+ {% for dim in range(spatial_rank) %}
16
+ const idx_t {{ in_indices[dim] }} = {{ out_indices[dim] }} * {{ strides[dim] }} + {{ kernel_indices[dim] }} * {{ dilations[dim] }} - {{ pads_begin[dim] }};
17
+ {% endfor %}
18
+ if ({% for dim in range(spatial_rank) %}{{ in_indices[dim] }} >= 0 && {{ in_indices[dim] }} < {{ in_spatial[dim] }}{% if not loop.last %} && {% endif %}{% endfor %}) {
19
+ {{ acc_type }} input_value = ({{ acc_type }}){{ input0 }}[n][ic_global]{% for dim in range(spatial_rank) %}[{{ in_indices[dim] }}]{% endfor %} - ({{ acc_type }})({{ x_zero_expr }});
20
+ {{ acc_type }} weight_value = ({{ acc_type }}){{ weights }}[oc_global][ic]{% for dim in range(spatial_rank) %}[{{ kernel_indices[dim] }}]{% endfor %} - ({{ acc_type }})({{ w_zero_expr }});
21
+ acc += input_value * weight_value;
22
+ }
23
+ {% for dim in range(spatial_rank) %}
24
+ }
25
+ {% endfor %}
26
+ }
27
+ {{ output }}[n][oc_global]{% for dim in range(spatial_rank) %}[{{ out_indices[dim] }}]{% endfor %} = ({{ c_type }})acc;
28
+ {% for dim in range(spatial_rank) %}
29
+ }
30
+ {% endfor %}
31
+ }
32
+ }
33
+ }
34
+ }
@@ -0,0 +1,32 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ for (idx_t n = 0; n < {{ batch }}; ++n) {
3
+ for (idx_t g = 0; g < {{ group }}; ++g) {
4
+ for (idx_t oc = 0; oc < {{ group_out_channels }}; ++oc) {
5
+ const idx_t oc_global = g * {{ group_out_channels }} + oc;
6
+ {% for dim in range(spatial_rank) %}
7
+ for (idx_t {{ out_indices[dim] }} = 0; {{ out_indices[dim] }} < {{ out_spatial[dim] }}; ++{{ out_indices[dim] }}) {
8
+ {% endfor %}
9
+ {{ acc_type }} acc = {% if bias %}({{ acc_type }}){{ bias }}[oc_global]{% else %}{{ acc_zero_literal }}{% endif %};
10
+ for (idx_t ic = 0; ic < {{ group_in_channels }}; ++ic) {
11
+ const idx_t ic_global = g * {{ group_in_channels }} + ic;
12
+ {% for dim in range(spatial_rank) %}
13
+ for (idx_t {{ kernel_indices[dim] }} = 0; {{ kernel_indices[dim] }} < {{ kernel_shape[dim] }}; ++{{ kernel_indices[dim] }}) {
14
+ {% endfor %}
15
+ {% for dim in range(spatial_rank) %}
16
+ const idx_t {{ in_indices[dim] }} = {{ out_indices[dim] }} * {{ strides[dim] }} + {{ kernel_indices[dim] }} * {{ dilations[dim] }} - {{ pads_begin[dim] }};
17
+ {% endfor %}
18
+ if ({% for dim in range(spatial_rank) %}{{ in_indices[dim] }} >= 0 && {{ in_indices[dim] }} < {{ in_spatial[dim] }}{% if not loop.last %} && {% endif %}{% endfor %}) {
19
+ acc += ({{ acc_type }}){{ input0 }}[n][ic_global]{% for dim in range(spatial_rank) %}[{{ in_indices[dim] }}]{% endfor %} * ({{ acc_type }}){{ weights }}[oc_global][ic]{% for dim in range(spatial_rank) %}[{{ kernel_indices[dim] }}]{% endfor %};
20
+ }
21
+ {% for dim in range(spatial_rank) %}
22
+ }
23
+ {% endfor %}
24
+ }
25
+ {{ output }}[n][oc_global]{% for dim in range(spatial_rank) %}[{{ out_indices[dim] }}]{% endfor %} = ({{ c_type }})acc;
26
+ {% for dim in range(spatial_rank) %}
27
+ }
28
+ {% endfor %}
29
+ }
30
+ }
31
+ }
32
+ }
@@ -0,0 +1,43 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ for (idx_t n = 0; n < {{ batch }}; ++n) {
3
+ for (idx_t oc = 0; oc < {{ out_channels }}; ++oc) {
4
+ {% for dim in range(spatial_rank) %}
5
+ for (idx_t {{ out_indices[dim] }} = 0; {{ out_indices[dim] }} < {{ out_spatial[dim] }}; ++{{ out_indices[dim] }}) {
6
+ {% endfor %}
7
+ {{ output }}[n][oc]{% for dim in range(spatial_rank) %}[{{ out_indices[dim] }}]{% endfor %} = {% if bias %}{{ bias }}[oc]{% else %}{{ zero_literal }}{% endif %};
8
+ {% for dim in range(spatial_rank) %}
9
+ }
10
+ {% endfor %}
11
+ }
12
+ }
13
+ for (idx_t n = 0; n < {{ batch }}; ++n) {
14
+ for (idx_t g = 0; g < {{ group }}; ++g) {
15
+ for (idx_t oc = 0; oc < {{ group_out_channels }}; ++oc) {
16
+ const idx_t oc_global = g * {{ group_out_channels }} + oc;
17
+ {% for dim in range(spatial_rank) %}
18
+ for (idx_t {{ kernel_indices[dim] }} = 0; {{ kernel_indices[dim] }} < {{ kernel_shape[dim] }}; ++{{ kernel_indices[dim] }}) {
19
+ {% endfor %}
20
+ {% for dim in range(spatial_rank) %}
21
+ for (idx_t {{ in_indices[dim] }} = 0; {{ in_indices[dim] }} < {{ in_spatial[dim] }}; ++{{ in_indices[dim] }}) {
22
+ {% endfor %}
23
+ {{ c_type }} acc = {{ zero_literal }};
24
+ for (idx_t ic = 0; ic < {{ group_in_channels }}; ++ic) {
25
+ const idx_t ic_global = g * {{ group_in_channels }} + ic;
26
+ acc += {{ input0 }}[n][ic_global]{% for dim in range(spatial_rank) %}[{{ in_indices[dim] }}]{% endfor %} * {{ weights }}[ic_global][oc]{% for dim in range(spatial_rank) %}[{{ kernel_indices[dim] }}]{% endfor %};
27
+ }
28
+ {% for dim in range(spatial_rank) %}
29
+ const idx_t {{ out_indices[dim] }} = {{ in_indices[dim] }} * {{ strides[dim] }} + {{ kernel_indices[dim] }} * {{ dilations[dim] }} - {{ pads_begin[dim] }};
30
+ {% endfor %}
31
+ if ({% for dim in range(spatial_rank) %}{{ out_indices[dim] }} >= 0 && {{ out_indices[dim] }} < {{ out_spatial[dim] }}{% if not loop.last %} && {% endif %}{% endfor %}) {
32
+ {{ output }}[n][oc_global]{% for dim in range(spatial_rank) %}[{{ out_indices[dim] }}]{% endfor %} += acc;
33
+ }
34
+ {% for dim in range(spatial_rank) %}
35
+ }
36
+ {% endfor %}
37
+ {% for dim in range(spatial_rank) %}
38
+ }
39
+ {% endfor %}
40
+ }
41
+ }
42
+ }
43
+ }
@@ -0,0 +1,51 @@
1
+ static inline void {{ op_name }}({{ dim_args }}const {{ c_type }} {{ input0 }}{{ input_suffix }}{% if axis_input %}, const {{ axis_c_type }} {{ axis_input }}{{ axis_suffix }}{% endif %}, {{ c_type }} {{ output }}{{ output_suffix }}) {
2
+ const {{ c_type }} *input_data = (const {{ c_type }} *){{ input0 }};
3
+ {{ c_type }} *output_data = ({{ c_type }} *){{ output }};
4
+ const idx_t dims[{{ rank }}] = { {% for dim in input_shape %}{{ dim }}{% if not loop.last %}, {% endif %}{% endfor %} };
5
+ int axis = {% if axis_literal is not none %}{{ axis_literal }}{% else %}(int){{ axis_input }}[0]{% endif %};
6
+ if (axis < 0) {
7
+ axis += {{ rank }};
8
+ }
9
+ if (axis < 0 || axis >= {{ rank }}) {
10
+ return;
11
+ }
12
+ idx_t outer = 1;
13
+ for (int i = 0; i < axis; ++i) {
14
+ outer *= dims[i];
15
+ }
16
+ idx_t inner = 1;
17
+ for (int i = axis + 1; i < {{ rank }}; ++i) {
18
+ inner *= dims[i];
19
+ }
20
+ idx_t axis_dim = dims[axis];
21
+ for (idx_t outer_index = 0; outer_index < outer; ++outer_index) {
22
+ for (idx_t inner_index = 0; inner_index < inner; ++inner_index) {
23
+ {{ c_type }} acc = ({{ c_type }})0;
24
+ idx_t base = (outer_index * axis_dim * inner) + inner_index;
25
+ {% if reverse %}
26
+ for (idx_t axis_offset = 0; axis_offset < axis_dim; ++axis_offset) {
27
+ idx_t axis_index = axis_dim - 1 - axis_offset;
28
+ idx_t offset = base + axis_index * inner;
29
+ {% if exclusive %}
30
+ output_data[offset] = acc;
31
+ acc += input_data[offset];
32
+ {% else %}
33
+ acc += input_data[offset];
34
+ output_data[offset] = acc;
35
+ {% endif %}
36
+ }
37
+ {% else %}
38
+ for (idx_t axis_index = 0; axis_index < axis_dim; ++axis_index) {
39
+ idx_t offset = base + axis_index * inner;
40
+ {% if exclusive %}
41
+ output_data[offset] = acc;
42
+ acc += input_data[offset];
43
+ {% else %}
44
+ acc += input_data[offset];
45
+ output_data[offset] = acc;
46
+ {% endif %}
47
+ }
48
+ {% endif %}
49
+ }
50
+ }
51
+ }
@@ -0,0 +1,26 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ const {{ c_type }} *input_data = (const {{ c_type }} *){{ input0 }};
3
+ {{ c_type }} *output_data = ({{ c_type }} *){{ output }};
4
+ idx_t output_index = 0;
5
+ for (idx_t n = 0; n < {{ batch }}; ++n) {
6
+ for (idx_t c_out = 0; c_out < {{ out_channels }}; ++c_out) {
7
+ for (idx_t h_out = 0; h_out < {{ out_h }}; ++h_out) {
8
+ idx_t h_in = h_out / {{ blocksize }};
9
+ idx_t offset_h = h_out % {{ blocksize }};
10
+ for (idx_t w_out = 0; w_out < {{ out_w }}; ++w_out) {
11
+ idx_t w_in = w_out / {{ blocksize }};
12
+ idx_t offset_w = w_out % {{ blocksize }};
13
+ idx_t c_in;
14
+ {% if mode == "DCR" %}
15
+ c_in = (offset_h * {{ blocksize }} + offset_w) * {{ out_channels }} + c_out;
16
+ {% else %}
17
+ c_in = (c_out * {{ blocksize }} + offset_h) * {{ blocksize }} + offset_w;
18
+ {% endif %}
19
+ idx_t input_index = ((n * {{ in_channels }} + c_in) * {{ in_h }} + h_in) * {{ in_w }} + w_in;
20
+ output_data[output_index] = input_data[input_index];
21
+ output_index++;
22
+ }
23
+ }
24
+ }
25
+ }
26
+ }