emx-onnx-cgen 0.3.7__py3-none-any.whl → 0.4.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/cli.py +1025 -162
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +2081 -458
- emx_onnx_cgen/compiler.py +157 -75
- emx_onnx_cgen/determinism.py +39 -0
- emx_onnx_cgen/ir/context.py +25 -15
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +32 -7
- emx_onnx_cgen/ir/ops/__init__.py +20 -0
- emx_onnx_cgen/ir/ops/elementwise.py +138 -22
- emx_onnx_cgen/ir/ops/misc.py +95 -0
- emx_onnx_cgen/ir/ops/nn.py +361 -38
- emx_onnx_cgen/ir/ops/reduce.py +1 -16
- emx_onnx_cgen/lowering/__init__.py +9 -0
- emx_onnx_cgen/lowering/arg_reduce.py +0 -4
- emx_onnx_cgen/lowering/average_pool.py +157 -27
- emx_onnx_cgen/lowering/bernoulli.py +73 -0
- emx_onnx_cgen/lowering/common.py +48 -0
- emx_onnx_cgen/lowering/concat.py +41 -7
- emx_onnx_cgen/lowering/conv.py +19 -8
- emx_onnx_cgen/lowering/conv_integer.py +103 -0
- emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
- emx_onnx_cgen/lowering/elementwise.py +140 -43
- emx_onnx_cgen/lowering/gather.py +11 -2
- emx_onnx_cgen/lowering/gemm.py +7 -124
- emx_onnx_cgen/lowering/global_max_pool.py +0 -5
- emx_onnx_cgen/lowering/gru.py +323 -0
- emx_onnx_cgen/lowering/hamming_window.py +104 -0
- emx_onnx_cgen/lowering/hardmax.py +1 -37
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/logsoftmax.py +1 -35
- emx_onnx_cgen/lowering/lp_pool.py +15 -4
- emx_onnx_cgen/lowering/matmul.py +3 -105
- emx_onnx_cgen/lowering/optional_has_element.py +28 -0
- emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
- emx_onnx_cgen/lowering/reduce.py +0 -5
- emx_onnx_cgen/lowering/reshape.py +7 -16
- emx_onnx_cgen/lowering/shape.py +14 -8
- emx_onnx_cgen/lowering/slice.py +14 -4
- emx_onnx_cgen/lowering/softmax.py +1 -35
- emx_onnx_cgen/lowering/split.py +37 -3
- emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
- emx_onnx_cgen/lowering/tile.py +38 -1
- emx_onnx_cgen/lowering/topk.py +1 -5
- emx_onnx_cgen/lowering/transpose.py +9 -3
- emx_onnx_cgen/lowering/unsqueeze.py +11 -16
- emx_onnx_cgen/lowering/upsample.py +151 -0
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +0 -5
- emx_onnx_cgen/onnx_import.py +578 -14
- emx_onnx_cgen/ops.py +3 -0
- emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
- emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
- emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
- emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
- emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
- emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
- emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
- emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
- emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
- emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
- emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
- emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
- emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
- emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
- emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
- emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
- emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
- emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
- emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
- emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
- emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
- emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
- emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
- emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
- emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
- emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
- emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
- emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
- emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
- emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
- emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
- emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
- emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
- emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
- emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
- emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
- emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
- emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
- emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
- emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
- emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
- emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
- emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
- emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
- emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
- emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
- emx_onnx_cgen/templates/range_op.c.j2 +8 -0
- emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
- emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
- emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
- emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
- emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
- emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
- emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
- emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
- emx_onnx_cgen/templates/size_op.c.j2 +4 -0
- emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
- emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
- emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
- emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
- emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
- emx_onnx_cgen/templates/split_op.c.j2 +18 -0
- emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
- emx_onnx_cgen/templates/testbench.c.j2 +161 -0
- emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
- emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
- emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
- emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
- emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
- emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
- emx_onnx_cgen/templates/where_op.c.j2 +9 -0
- emx_onnx_cgen/verification.py +45 -5
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/METADATA +33 -15
- emx_onnx_cgen-0.4.1.dev0.dist-info/RECORD +190 -0
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/WHEEL +1 -1
- emx_onnx_cgen/runtime/__init__.py +0 -1
- emx_onnx_cgen/runtime/evaluator.py +0 -2955
- emx_onnx_cgen-0.3.7.dist-info/RECORD +0 -107
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in output_shape %}
|
|
3
|
+
for (idx_t {{ output_loop_vars[loop.index0] }} = 0; {{ output_loop_vars[loop.index0] }} < {{ dim }}; ++{{ output_loop_vars[loop.index0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
{{ input_c_type }} best = {{ input0 }}{{ init_index_expr }};
|
|
6
|
+
{{ output_c_type }} best_index = 0;
|
|
7
|
+
for (idx_t {{ reduce_var }} = 1; {{ reduce_var }} < {{ reduce_dim }}; ++{{ reduce_var }}) {
|
|
8
|
+
{{ input_c_type }} candidate = {{ input0 }}{{ input_index_expr }};
|
|
9
|
+
if (candidate {{ compare_op }} best) {
|
|
10
|
+
best = candidate;
|
|
11
|
+
best_index = ({{ output_c_type }}){{ reduce_var }};
|
|
12
|
+
}
|
|
13
|
+
}
|
|
14
|
+
{{ output }}{{ output_index_expr }} = best_index;
|
|
15
|
+
{% for _ in output_shape %}
|
|
16
|
+
}
|
|
17
|
+
{% endfor %}
|
|
18
|
+
}
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
const {{ c_type }} scale = {{ scale_literal }};
|
|
3
|
+
const {{ c_type }} softcap = {{ softcap_literal }};
|
|
4
|
+
{% if output_present_key %}
|
|
5
|
+
for (int b = 0; b < {{ batch }}; ++b) {
|
|
6
|
+
for (int h = 0; h < {{ kv_heads }}; ++h) {
|
|
7
|
+
for (int ki = 0; ki < {{ total_seq }}; ++ki) {
|
|
8
|
+
if (ki < {{ past_seq }}) {
|
|
9
|
+
for (int d = 0; d < {{ qk_head_size }}; ++d) {
|
|
10
|
+
{{ output_present_key }}[b][h][ki][d] = {{ input_past_key }}[b][h][ki][d];
|
|
11
|
+
}
|
|
12
|
+
} else {
|
|
13
|
+
int k_index = ki - {{ past_seq }};
|
|
14
|
+
for (int d = 0; d < {{ qk_head_size }}; ++d) {
|
|
15
|
+
{% if k_rank == 4 %}
|
|
16
|
+
{{ output_present_key }}[b][h][ki][d] = {{ input_k }}[b][h][k_index][d];
|
|
17
|
+
{% else %}
|
|
18
|
+
{{ output_present_key }}[b][h][ki][d] = {{ input_k }}[b][k_index][h * {{ qk_head_size }} + d];
|
|
19
|
+
{% endif %}
|
|
20
|
+
}
|
|
21
|
+
}
|
|
22
|
+
}
|
|
23
|
+
}
|
|
24
|
+
}
|
|
25
|
+
{% endif %}
|
|
26
|
+
{% if output_present_value %}
|
|
27
|
+
for (int b = 0; b < {{ batch }}; ++b) {
|
|
28
|
+
for (int h = 0; h < {{ kv_heads }}; ++h) {
|
|
29
|
+
for (int ki = 0; ki < {{ total_seq }}; ++ki) {
|
|
30
|
+
if (ki < {{ past_seq }}) {
|
|
31
|
+
for (int d = 0; d < {{ v_head_size }}; ++d) {
|
|
32
|
+
{{ output_present_value }}[b][h][ki][d] = {{ input_past_value }}[b][h][ki][d];
|
|
33
|
+
}
|
|
34
|
+
} else {
|
|
35
|
+
int v_index = ki - {{ past_seq }};
|
|
36
|
+
for (int d = 0; d < {{ v_head_size }}; ++d) {
|
|
37
|
+
{% if v_rank == 4 %}
|
|
38
|
+
{{ output_present_value }}[b][h][ki][d] = {{ input_v }}[b][h][v_index][d];
|
|
39
|
+
{% else %}
|
|
40
|
+
{{ output_present_value }}[b][h][ki][d] = {{ input_v }}[b][v_index][h * {{ v_head_size }} + d];
|
|
41
|
+
{% endif %}
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
{% endif %}
|
|
48
|
+
for (int b = 0; b < {{ batch }}; ++b) {
|
|
49
|
+
for (int h = 0; h < {{ q_heads }}; ++h) {
|
|
50
|
+
int kv_head = h / {{ head_group_size }};
|
|
51
|
+
for (int qi = 0; qi < {{ q_seq }}; ++qi) {
|
|
52
|
+
{{ c_type }} scores[{{ total_seq }}];
|
|
53
|
+
{{ c_type }} max_score = {{ min_literal }};
|
|
54
|
+
for (int ki = 0; ki < {{ total_seq }}; ++ki) {
|
|
55
|
+
{{ c_type }} score = {{ zero_literal }};
|
|
56
|
+
int k_index = ki - {{ past_seq }};
|
|
57
|
+
for (int d = 0; d < {{ qk_head_size }}; ++d) {
|
|
58
|
+
{{ c_type }} q_val;
|
|
59
|
+
{{ c_type }} k_val;
|
|
60
|
+
{% if q_rank == 4 %}
|
|
61
|
+
q_val = {{ input_q }}[b][h][qi][d];
|
|
62
|
+
{% else %}
|
|
63
|
+
q_val = {{ input_q }}[b][qi][h * {{ qk_head_size }} + d];
|
|
64
|
+
{% endif %}
|
|
65
|
+
{% if input_past_key %}
|
|
66
|
+
if (ki < {{ past_seq }}) {
|
|
67
|
+
k_val = {{ input_past_key }}[b][kv_head][ki][d];
|
|
68
|
+
} else {
|
|
69
|
+
{% if k_rank == 4 %}
|
|
70
|
+
k_val = {{ input_k }}[b][kv_head][k_index][d];
|
|
71
|
+
{% else %}
|
|
72
|
+
k_val = {{ input_k }}[b][k_index][kv_head * {{ qk_head_size }} + d];
|
|
73
|
+
{% endif %}
|
|
74
|
+
}
|
|
75
|
+
{% else %}
|
|
76
|
+
{% if k_rank == 4 %}
|
|
77
|
+
k_val = {{ input_k }}[b][kv_head][k_index][d];
|
|
78
|
+
{% else %}
|
|
79
|
+
k_val = {{ input_k }}[b][k_index][kv_head * {{ qk_head_size }} + d];
|
|
80
|
+
{% endif %}
|
|
81
|
+
{% endif %}
|
|
82
|
+
score += q_val * k_val;
|
|
83
|
+
}
|
|
84
|
+
score *= scale;
|
|
85
|
+
{{ c_type }} bias = {{ zero_literal }};
|
|
86
|
+
{% if input_attn_mask %}
|
|
87
|
+
if (ki >= {{ mask_kv_seq }}) {
|
|
88
|
+
bias = {{ min_literal }};
|
|
89
|
+
} else {
|
|
90
|
+
int mask_q = {{ '0' if mask_broadcast_q_seq else 'qi' }};
|
|
91
|
+
{% if mask_rank == 2 %}
|
|
92
|
+
{% if mask_is_bool %}
|
|
93
|
+
bias = {{ input_attn_mask }}[mask_q][ki] ? {{ zero_literal }} : {{ min_literal }};
|
|
94
|
+
{% else %}
|
|
95
|
+
bias = {{ input_attn_mask }}[mask_q][ki];
|
|
96
|
+
{% endif %}
|
|
97
|
+
{% elif mask_rank == 3 %}
|
|
98
|
+
int mask_b = {{ '0' if mask_broadcast_batch else 'b' }};
|
|
99
|
+
{% if mask_is_bool %}
|
|
100
|
+
bias = {{ input_attn_mask }}[mask_b][mask_q][ki] ? {{ zero_literal }} : {{ min_literal }};
|
|
101
|
+
{% else %}
|
|
102
|
+
bias = {{ input_attn_mask }}[mask_b][mask_q][ki];
|
|
103
|
+
{% endif %}
|
|
104
|
+
{% else %}
|
|
105
|
+
int mask_b = {{ '0' if mask_broadcast_batch else 'b' }};
|
|
106
|
+
int mask_h = {{ '0' if mask_broadcast_heads else 'h' }};
|
|
107
|
+
{% if mask_is_bool %}
|
|
108
|
+
bias = {{ input_attn_mask }}[mask_b][mask_h][mask_q][ki] ? {{ zero_literal }} : {{ min_literal }};
|
|
109
|
+
{% else %}
|
|
110
|
+
bias = {{ input_attn_mask }}[mask_b][mask_h][mask_q][ki];
|
|
111
|
+
{% endif %}
|
|
112
|
+
{% endif %}
|
|
113
|
+
}
|
|
114
|
+
{% endif %}
|
|
115
|
+
{% if input_nonpad_kv_seqlen %}
|
|
116
|
+
if (ki >= {{ input_nonpad_kv_seqlen }}[b]) {
|
|
117
|
+
bias = {{ min_literal }};
|
|
118
|
+
}
|
|
119
|
+
{% endif %}
|
|
120
|
+
if ({{ is_causal }} && ki > qi + {{ past_seq }}) {
|
|
121
|
+
bias = {{ min_literal }};
|
|
122
|
+
}
|
|
123
|
+
{{ c_type }} score_bias = score + bias;
|
|
124
|
+
{{ c_type }} score_softcap = score_bias;
|
|
125
|
+
if (softcap != {{ zero_literal }}) {
|
|
126
|
+
score_softcap = softcap * {{ tanh_fn }}(score_bias / softcap);
|
|
127
|
+
}
|
|
128
|
+
{% if output_qk_matmul %}
|
|
129
|
+
if ({{ qk_matmul_output_mode }} == 0) {
|
|
130
|
+
{{ output_qk_matmul }}[b][h][qi][ki] = score;
|
|
131
|
+
} else if ({{ qk_matmul_output_mode }} == 1) {
|
|
132
|
+
{{ output_qk_matmul }}[b][h][qi][ki] = score_bias;
|
|
133
|
+
} else if ({{ qk_matmul_output_mode }} == 2) {
|
|
134
|
+
{{ output_qk_matmul }}[b][h][qi][ki] = score_softcap;
|
|
135
|
+
}
|
|
136
|
+
{% endif %}
|
|
137
|
+
scores[ki] = score_softcap;
|
|
138
|
+
max_score = {{ max_fn }}(max_score, score_softcap);
|
|
139
|
+
}
|
|
140
|
+
{{ c_type }} weights[{{ total_seq }}];
|
|
141
|
+
{{ c_type }} sum = {{ zero_literal }};
|
|
142
|
+
for (int ki = 0; ki < {{ total_seq }}; ++ki) {
|
|
143
|
+
{{ c_type }} weight = {{ zero_literal }};
|
|
144
|
+
if (max_score != {{ min_literal }}) {
|
|
145
|
+
weight = {{ exp_fn }}(scores[ki] - max_score);
|
|
146
|
+
}
|
|
147
|
+
weights[ki] = weight;
|
|
148
|
+
sum += weight;
|
|
149
|
+
}
|
|
150
|
+
{{ c_type }} inv_sum = sum == {{ zero_literal }} ? {{ zero_literal }} : ({{ c_type }}){{ one_literal }} / sum;
|
|
151
|
+
{% if output_qk_matmul %}
|
|
152
|
+
if ({{ qk_matmul_output_mode }} == 3) {
|
|
153
|
+
for (int ki = 0; ki < {{ total_seq }}; ++ki) {
|
|
154
|
+
{{ output_qk_matmul }}[b][h][qi][ki] = weights[ki] * inv_sum;
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
{% endif %}
|
|
158
|
+
for (int vd = 0; vd < {{ v_head_size }}; ++vd) {
|
|
159
|
+
{{ c_type }} acc = {{ zero_literal }};
|
|
160
|
+
for (int ki = 0; ki < {{ total_seq }}; ++ki) {
|
|
161
|
+
{{ c_type }} weight = weights[ki] * inv_sum;
|
|
162
|
+
{% if input_past_value %}
|
|
163
|
+
if (ki < {{ past_seq }}) {
|
|
164
|
+
acc += weight * {{ input_past_value }}[b][kv_head][ki][vd];
|
|
165
|
+
} else {
|
|
166
|
+
{% if v_rank == 4 %}
|
|
167
|
+
acc += weight * {{ input_v }}[b][kv_head][ki - {{ past_seq }}][vd];
|
|
168
|
+
{% else %}
|
|
169
|
+
acc += weight * {{ input_v }}[b][ki - {{ past_seq }}][kv_head * {{ v_head_size }} + vd];
|
|
170
|
+
{% endif %}
|
|
171
|
+
}
|
|
172
|
+
{% else %}
|
|
173
|
+
{% if v_rank == 4 %}
|
|
174
|
+
acc += weight * {{ input_v }}[b][kv_head][ki][vd];
|
|
175
|
+
{% else %}
|
|
176
|
+
acc += weight * {{ input_v }}[b][ki][kv_head * {{ v_head_size }} + vd];
|
|
177
|
+
{% endif %}
|
|
178
|
+
{% endif %}
|
|
179
|
+
}
|
|
180
|
+
{% if output_rank == 4 %}
|
|
181
|
+
{{ output }}[b][h][qi][vd] = acc;
|
|
182
|
+
{% else %}
|
|
183
|
+
{{ output }}[b][qi][h * {{ v_head_size }} + vd] = acc;
|
|
184
|
+
{% endif %}
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
}
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% if spatial_rank == 3 %}
|
|
3
|
+
for (idx_t n = 0; n < {{ batch }}; ++n) {
|
|
4
|
+
for (idx_t c = 0; c < {{ channels }}; ++c) {
|
|
5
|
+
for (idx_t od = 0; od < {{ out_d }}; ++od) {
|
|
6
|
+
for (idx_t oh = 0; oh < {{ out_h }}; ++oh) {
|
|
7
|
+
for (idx_t ow = 0; ow < {{ out_w }}; ++ow) {
|
|
8
|
+
{{ c_type }} values[{{ kernel_d * kernel_h * kernel_w }}];
|
|
9
|
+
idx_t count = 0;
|
|
10
|
+
for (idx_t kd = 0; kd < {{ kernel_d }}; ++kd) {
|
|
11
|
+
const idx_t id = od * {{ stride_d }} + kd * {{ dilation_d }} - {{ pad_front }};
|
|
12
|
+
for (idx_t kh = 0; kh < {{ kernel_h }}; ++kh) {
|
|
13
|
+
const idx_t ih = oh * {{ stride_h }} + kh * {{ dilation_h }} - {{ pad_top }};
|
|
14
|
+
for (idx_t kw = 0; kw < {{ kernel_w }}; ++kw) {
|
|
15
|
+
const idx_t iw = ow * {{ stride_w }} + kw * {{ dilation_w }} - {{ pad_left }};
|
|
16
|
+
if (id >= 0 && id < {{ in_d }}
|
|
17
|
+
&& ih >= 0 && ih < {{ in_h }}
|
|
18
|
+
&& iw >= 0 && iw < {{ in_w }}) {
|
|
19
|
+
values[count++] = {{ input0 }}[n][c][id][ih][iw];
|
|
20
|
+
} else if ({{ count_include_pad }}) {
|
|
21
|
+
values[count++] = {{ zero_literal }};
|
|
22
|
+
}
|
|
23
|
+
}
|
|
24
|
+
}
|
|
25
|
+
}
|
|
26
|
+
const idx_t denom = {{ count_include_pad }}
|
|
27
|
+
? {{ kernel_d * kernel_h * kernel_w }}
|
|
28
|
+
: count;
|
|
29
|
+
{{ c_type }} acc = {{ zero_literal }};
|
|
30
|
+
if (count > 0) {
|
|
31
|
+
idx_t reduce_count = count;
|
|
32
|
+
while (reduce_count > 1) {
|
|
33
|
+
idx_t next = 0;
|
|
34
|
+
for (idx_t i = 0; i + 1 < reduce_count; i += 2) {
|
|
35
|
+
values[next++] = values[i] + values[i + 1];
|
|
36
|
+
}
|
|
37
|
+
if (reduce_count % 2) {
|
|
38
|
+
values[next++] = values[reduce_count - 1];
|
|
39
|
+
}
|
|
40
|
+
reduce_count = next;
|
|
41
|
+
}
|
|
42
|
+
acc = values[0];
|
|
43
|
+
}
|
|
44
|
+
{{ output }}[n][c][od][oh][ow] = denom ? acc / ({{ c_type }})denom : {{ zero_literal }};
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
{% elif spatial_rank == 1 %}
|
|
51
|
+
for (idx_t n = 0; n < {{ batch }}; ++n) {
|
|
52
|
+
for (idx_t c = 0; c < {{ channels }}; ++c) {
|
|
53
|
+
for (idx_t ow = 0; ow < {{ out_w }}; ++ow) {
|
|
54
|
+
{{ c_type }} values[{{ kernel_w }}];
|
|
55
|
+
idx_t count = 0;
|
|
56
|
+
for (idx_t kw = 0; kw < {{ kernel_w }}; ++kw) {
|
|
57
|
+
const idx_t iw = ow * {{ stride_w }} + kw * {{ dilation_w }} - {{ pad_left }};
|
|
58
|
+
if (iw >= 0 && iw < {{ in_w }}) {
|
|
59
|
+
values[count++] = {{ input0 }}[n][c][iw];
|
|
60
|
+
} else if ({{ count_include_pad }}) {
|
|
61
|
+
values[count++] = {{ zero_literal }};
|
|
62
|
+
}
|
|
63
|
+
}
|
|
64
|
+
const idx_t denom = {{ count_include_pad }} ? {{ kernel_w }} : count;
|
|
65
|
+
{{ c_type }} acc = {{ zero_literal }};
|
|
66
|
+
if (count > 0) {
|
|
67
|
+
idx_t reduce_count = count;
|
|
68
|
+
while (reduce_count > 1) {
|
|
69
|
+
idx_t next = 0;
|
|
70
|
+
for (idx_t i = 0; i + 1 < reduce_count; i += 2) {
|
|
71
|
+
values[next++] = values[i] + values[i + 1];
|
|
72
|
+
}
|
|
73
|
+
if (reduce_count % 2) {
|
|
74
|
+
values[next++] = values[reduce_count - 1];
|
|
75
|
+
}
|
|
76
|
+
reduce_count = next;
|
|
77
|
+
}
|
|
78
|
+
acc = values[0];
|
|
79
|
+
}
|
|
80
|
+
{{ output }}[n][c][ow] = denom ? acc / ({{ c_type }})denom : {{ zero_literal }};
|
|
81
|
+
}
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
{% else %}
|
|
85
|
+
for (idx_t n = 0; n < {{ batch }}; ++n) {
|
|
86
|
+
for (idx_t c = 0; c < {{ channels }}; ++c) {
|
|
87
|
+
for (idx_t oh = 0; oh < {{ out_h }}; ++oh) {
|
|
88
|
+
for (idx_t ow = 0; ow < {{ out_w }}; ++ow) {
|
|
89
|
+
{{ c_type }} values[{{ kernel_h * kernel_w }}];
|
|
90
|
+
idx_t count = 0;
|
|
91
|
+
for (idx_t kh = 0; kh < {{ kernel_h }}; ++kh) {
|
|
92
|
+
const idx_t ih = oh * {{ stride_h }} + kh * {{ dilation_h }} - {{ pad_top }};
|
|
93
|
+
for (idx_t kw = 0; kw < {{ kernel_w }}; ++kw) {
|
|
94
|
+
const idx_t iw = ow * {{ stride_w }} + kw * {{ dilation_w }} - {{ pad_left }};
|
|
95
|
+
if (ih >= 0 && ih < {{ in_h }} && iw >= 0 && iw < {{ in_w }}) {
|
|
96
|
+
values[count++] = {{ input0 }}[n][c][ih][iw];
|
|
97
|
+
} else if ({{ count_include_pad }}) {
|
|
98
|
+
values[count++] = {{ zero_literal }};
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
const idx_t denom = {{ count_include_pad }}
|
|
103
|
+
? {{ kernel_h * kernel_w }}
|
|
104
|
+
: count;
|
|
105
|
+
{{ c_type }} acc = {{ zero_literal }};
|
|
106
|
+
if (count > 0) {
|
|
107
|
+
idx_t reduce_count = count;
|
|
108
|
+
while (reduce_count > 1) {
|
|
109
|
+
idx_t next = 0;
|
|
110
|
+
for (idx_t i = 0; i + 1 < reduce_count; i += 2) {
|
|
111
|
+
values[next++] = values[i] + values[i + 1];
|
|
112
|
+
}
|
|
113
|
+
if (reduce_count % 2) {
|
|
114
|
+
values[next++] = values[reduce_count - 1];
|
|
115
|
+
}
|
|
116
|
+
reduce_count = next;
|
|
117
|
+
}
|
|
118
|
+
acc = values[0];
|
|
119
|
+
}
|
|
120
|
+
{{ output }}[n][c][oh][ow] = denom ? acc / ({{ c_type }})denom : {{ zero_literal }};
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
{% endif %}
|
|
126
|
+
}
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in shape %}
|
|
3
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
idx_t c = {{ loop_vars[1] }};
|
|
6
|
+
{{ c_type }} denom = {{ sqrt_fn }}({{ variance }}[c] + {{ epsilon_literal }});
|
|
7
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = ({{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %} - {{ mean }}[c]) / denom * {{ scale }}[c] + {{ bias }}[c];
|
|
8
|
+
{% for _ in shape %}
|
|
9
|
+
}
|
|
10
|
+
{% endfor %}
|
|
11
|
+
}
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
static inline uint64_t {{ op_name }}_rng_next_u64(uint64_t *state) {
|
|
2
|
+
uint64_t x = *state;
|
|
3
|
+
x ^= x >> 12;
|
|
4
|
+
x ^= x << 25;
|
|
5
|
+
x ^= x >> 27;
|
|
6
|
+
*state = x;
|
|
7
|
+
return x * 0x2545f4914f6cdd1dull;
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
static inline double {{ op_name }}_rng_next_double(uint64_t *state) {
|
|
11
|
+
return (double){{ op_name }}_rng_next_u64(state) * (1.0 / 18446744073709551616.0);
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
15
|
+
uint64_t rng_state = {{ seed }}ull;
|
|
16
|
+
if (!rng_state) {
|
|
17
|
+
rng_state = 0x243f6a8885a308d3ull;
|
|
18
|
+
}
|
|
19
|
+
{% for dim in shape %}
|
|
20
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
21
|
+
{% endfor %}
|
|
22
|
+
const double prob = (double){{ input0 }}{{ input_index_expr }};
|
|
23
|
+
int is_one = 0;
|
|
24
|
+
if (prob >= 1.0) {
|
|
25
|
+
is_one = 1;
|
|
26
|
+
} else if (prob > 0.0) {
|
|
27
|
+
const double sample = {{ op_name }}_rng_next_double(&rng_state);
|
|
28
|
+
is_one = sample < prob;
|
|
29
|
+
}
|
|
30
|
+
{{ output }}{{ output_index_expr }} = is_one ? {{ one_literal }} : {{ zero_literal }};
|
|
31
|
+
{% for _ in shape %}
|
|
32
|
+
}
|
|
33
|
+
{% endfor %}
|
|
34
|
+
}
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in shape %}
|
|
3
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = {% if operator_kind == "func" %}{{ operator }}({{ left_expr }}, {{ right_expr }}){% elif operator_kind == "expr" %}{{ operator_expr }}{% else %}{{ left_expr }} {{ operator }} {{ right_expr }}{% endif %};
|
|
6
|
+
{% for _ in shape %}
|
|
7
|
+
}
|
|
8
|
+
{% endfor %}
|
|
9
|
+
}
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in shape %}
|
|
3
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = ({{ output_c_type }}){{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %};
|
|
6
|
+
{% for _ in shape %}
|
|
7
|
+
}
|
|
8
|
+
{% endfor %}
|
|
9
|
+
}
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in shape %}
|
|
3
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
{{ output_c_type }} value = {{ input_expr }};
|
|
6
|
+
const {{ output_c_type }} min_value = {{ min_expr }};
|
|
7
|
+
const {{ output_c_type }} max_value = {{ max_expr }};
|
|
8
|
+
value = {{ max_fn }}(value, min_value);
|
|
9
|
+
value = {{ min_fn }}(value, max_value);
|
|
10
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = value;
|
|
11
|
+
{% for _ in shape %}
|
|
12
|
+
}
|
|
13
|
+
{% endfor %}
|
|
14
|
+
}
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
const void *inputs[] = { {% for input in inputs %}{{ input }}{% if not loop.last %}, {% endif %}{% endfor %} };
|
|
3
|
+
const idx_t axis_sizes[] = { {% for axis in axis_sizes %}{{ axis }}{% if not loop.last %}, {% endif %}{% endfor %} };
|
|
4
|
+
idx_t concat_axis = 0;
|
|
5
|
+
for (idx_t idx = 0; idx < {{ input_count }}; ++idx) {
|
|
6
|
+
concat_axis += axis_sizes[idx];
|
|
7
|
+
}
|
|
8
|
+
if (concat_axis == 0) {
|
|
9
|
+
return;
|
|
10
|
+
}
|
|
11
|
+
for (idx_t outer_idx = 0; outer_idx < {{ outer }}; ++outer_idx) {
|
|
12
|
+
idx_t output_offset = outer_idx * concat_axis * {{ inner }};
|
|
13
|
+
idx_t axis_offset = 0;
|
|
14
|
+
for (idx_t input_idx = 0; input_idx < {{ input_count }}; ++input_idx) {
|
|
15
|
+
idx_t axis = axis_sizes[input_idx];
|
|
16
|
+
idx_t copy_elems = axis * {{ inner }};
|
|
17
|
+
const unsigned char *input_bytes =
|
|
18
|
+
(const unsigned char *)inputs[input_idx];
|
|
19
|
+
idx_t input_offset = outer_idx * copy_elems;
|
|
20
|
+
memcpy(
|
|
21
|
+
((unsigned char *){{ output }}) + (output_offset + axis_offset) * sizeof({{ c_type }}),
|
|
22
|
+
input_bytes + input_offset * sizeof({{ c_type }}),
|
|
23
|
+
copy_elems * sizeof({{ c_type }})
|
|
24
|
+
);
|
|
25
|
+
axis_offset += copy_elems;
|
|
26
|
+
}
|
|
27
|
+
}
|
|
28
|
+
}
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
(void){{ input0 }};
|
|
3
|
+
{% for dim in shape %}
|
|
4
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
5
|
+
{% endfor %}
|
|
6
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = {{ value_literal }};
|
|
7
|
+
{% for _ in shape %}
|
|
8
|
+
}
|
|
9
|
+
{% endfor %}
|
|
10
|
+
}
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
for (idx_t n = 0; n < {{ batch }}; ++n) {
|
|
3
|
+
for (idx_t g = 0; g < {{ group }}; ++g) {
|
|
4
|
+
for (idx_t oc = 0; oc < {{ group_out_channels }}; ++oc) {
|
|
5
|
+
const idx_t oc_global = g * {{ group_out_channels }} + oc;
|
|
6
|
+
{% for dim in range(spatial_rank) %}
|
|
7
|
+
for (idx_t {{ out_indices[dim] }} = 0; {{ out_indices[dim] }} < {{ out_spatial[dim] }}; ++{{ out_indices[dim] }}) {
|
|
8
|
+
{% endfor %}
|
|
9
|
+
{{ acc_type }} acc = {{ acc_zero_literal }};
|
|
10
|
+
for (idx_t ic = 0; ic < {{ group_in_channels }}; ++ic) {
|
|
11
|
+
const idx_t ic_global = g * {{ group_in_channels }} + ic;
|
|
12
|
+
{% for dim in range(spatial_rank) %}
|
|
13
|
+
for (idx_t {{ kernel_indices[dim] }} = 0; {{ kernel_indices[dim] }} < {{ kernel_shape[dim] }}; ++{{ kernel_indices[dim] }}) {
|
|
14
|
+
{% endfor %}
|
|
15
|
+
{% for dim in range(spatial_rank) %}
|
|
16
|
+
const idx_t {{ in_indices[dim] }} = {{ out_indices[dim] }} * {{ strides[dim] }} + {{ kernel_indices[dim] }} * {{ dilations[dim] }} - {{ pads_begin[dim] }};
|
|
17
|
+
{% endfor %}
|
|
18
|
+
if ({% for dim in range(spatial_rank) %}{{ in_indices[dim] }} >= 0 && {{ in_indices[dim] }} < {{ in_spatial[dim] }}{% if not loop.last %} && {% endif %}{% endfor %}) {
|
|
19
|
+
{{ acc_type }} input_value = ({{ acc_type }}){{ input0 }}[n][ic_global]{% for dim in range(spatial_rank) %}[{{ in_indices[dim] }}]{% endfor %} - ({{ acc_type }})({{ x_zero_expr }});
|
|
20
|
+
{{ acc_type }} weight_value = ({{ acc_type }}){{ weights }}[oc_global][ic]{% for dim in range(spatial_rank) %}[{{ kernel_indices[dim] }}]{% endfor %} - ({{ acc_type }})({{ w_zero_expr }});
|
|
21
|
+
acc += input_value * weight_value;
|
|
22
|
+
}
|
|
23
|
+
{% for dim in range(spatial_rank) %}
|
|
24
|
+
}
|
|
25
|
+
{% endfor %}
|
|
26
|
+
}
|
|
27
|
+
{{ output }}[n][oc_global]{% for dim in range(spatial_rank) %}[{{ out_indices[dim] }}]{% endfor %} = ({{ c_type }})acc;
|
|
28
|
+
{% for dim in range(spatial_rank) %}
|
|
29
|
+
}
|
|
30
|
+
{% endfor %}
|
|
31
|
+
}
|
|
32
|
+
}
|
|
33
|
+
}
|
|
34
|
+
}
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
for (idx_t n = 0; n < {{ batch }}; ++n) {
|
|
3
|
+
for (idx_t g = 0; g < {{ group }}; ++g) {
|
|
4
|
+
for (idx_t oc = 0; oc < {{ group_out_channels }}; ++oc) {
|
|
5
|
+
const idx_t oc_global = g * {{ group_out_channels }} + oc;
|
|
6
|
+
{% for dim in range(spatial_rank) %}
|
|
7
|
+
for (idx_t {{ out_indices[dim] }} = 0; {{ out_indices[dim] }} < {{ out_spatial[dim] }}; ++{{ out_indices[dim] }}) {
|
|
8
|
+
{% endfor %}
|
|
9
|
+
{{ acc_type }} acc = {% if bias %}({{ acc_type }}){{ bias }}[oc_global]{% else %}{{ acc_zero_literal }}{% endif %};
|
|
10
|
+
for (idx_t ic = 0; ic < {{ group_in_channels }}; ++ic) {
|
|
11
|
+
const idx_t ic_global = g * {{ group_in_channels }} + ic;
|
|
12
|
+
{% for dim in range(spatial_rank) %}
|
|
13
|
+
for (idx_t {{ kernel_indices[dim] }} = 0; {{ kernel_indices[dim] }} < {{ kernel_shape[dim] }}; ++{{ kernel_indices[dim] }}) {
|
|
14
|
+
{% endfor %}
|
|
15
|
+
{% for dim in range(spatial_rank) %}
|
|
16
|
+
const idx_t {{ in_indices[dim] }} = {{ out_indices[dim] }} * {{ strides[dim] }} + {{ kernel_indices[dim] }} * {{ dilations[dim] }} - {{ pads_begin[dim] }};
|
|
17
|
+
{% endfor %}
|
|
18
|
+
if ({% for dim in range(spatial_rank) %}{{ in_indices[dim] }} >= 0 && {{ in_indices[dim] }} < {{ in_spatial[dim] }}{% if not loop.last %} && {% endif %}{% endfor %}) {
|
|
19
|
+
acc += ({{ acc_type }}){{ input0 }}[n][ic_global]{% for dim in range(spatial_rank) %}[{{ in_indices[dim] }}]{% endfor %} * ({{ acc_type }}){{ weights }}[oc_global][ic]{% for dim in range(spatial_rank) %}[{{ kernel_indices[dim] }}]{% endfor %};
|
|
20
|
+
}
|
|
21
|
+
{% for dim in range(spatial_rank) %}
|
|
22
|
+
}
|
|
23
|
+
{% endfor %}
|
|
24
|
+
}
|
|
25
|
+
{{ output }}[n][oc_global]{% for dim in range(spatial_rank) %}[{{ out_indices[dim] }}]{% endfor %} = ({{ c_type }})acc;
|
|
26
|
+
{% for dim in range(spatial_rank) %}
|
|
27
|
+
}
|
|
28
|
+
{% endfor %}
|
|
29
|
+
}
|
|
30
|
+
}
|
|
31
|
+
}
|
|
32
|
+
}
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
for (idx_t n = 0; n < {{ batch }}; ++n) {
|
|
3
|
+
for (idx_t oc = 0; oc < {{ out_channels }}; ++oc) {
|
|
4
|
+
{% for dim in range(spatial_rank) %}
|
|
5
|
+
for (idx_t {{ out_indices[dim] }} = 0; {{ out_indices[dim] }} < {{ out_spatial[dim] }}; ++{{ out_indices[dim] }}) {
|
|
6
|
+
{% endfor %}
|
|
7
|
+
{{ output }}[n][oc]{% for dim in range(spatial_rank) %}[{{ out_indices[dim] }}]{% endfor %} = {% if bias %}{{ bias }}[oc]{% else %}{{ zero_literal }}{% endif %};
|
|
8
|
+
{% for dim in range(spatial_rank) %}
|
|
9
|
+
}
|
|
10
|
+
{% endfor %}
|
|
11
|
+
}
|
|
12
|
+
}
|
|
13
|
+
for (idx_t n = 0; n < {{ batch }}; ++n) {
|
|
14
|
+
for (idx_t g = 0; g < {{ group }}; ++g) {
|
|
15
|
+
for (idx_t oc = 0; oc < {{ group_out_channels }}; ++oc) {
|
|
16
|
+
const idx_t oc_global = g * {{ group_out_channels }} + oc;
|
|
17
|
+
{% for dim in range(spatial_rank) %}
|
|
18
|
+
for (idx_t {{ kernel_indices[dim] }} = 0; {{ kernel_indices[dim] }} < {{ kernel_shape[dim] }}; ++{{ kernel_indices[dim] }}) {
|
|
19
|
+
{% endfor %}
|
|
20
|
+
{% for dim in range(spatial_rank) %}
|
|
21
|
+
for (idx_t {{ in_indices[dim] }} = 0; {{ in_indices[dim] }} < {{ in_spatial[dim] }}; ++{{ in_indices[dim] }}) {
|
|
22
|
+
{% endfor %}
|
|
23
|
+
{{ c_type }} acc = {{ zero_literal }};
|
|
24
|
+
for (idx_t ic = 0; ic < {{ group_in_channels }}; ++ic) {
|
|
25
|
+
const idx_t ic_global = g * {{ group_in_channels }} + ic;
|
|
26
|
+
acc += {{ input0 }}[n][ic_global]{% for dim in range(spatial_rank) %}[{{ in_indices[dim] }}]{% endfor %} * {{ weights }}[ic_global][oc]{% for dim in range(spatial_rank) %}[{{ kernel_indices[dim] }}]{% endfor %};
|
|
27
|
+
}
|
|
28
|
+
{% for dim in range(spatial_rank) %}
|
|
29
|
+
const idx_t {{ out_indices[dim] }} = {{ in_indices[dim] }} * {{ strides[dim] }} + {{ kernel_indices[dim] }} * {{ dilations[dim] }} - {{ pads_begin[dim] }};
|
|
30
|
+
{% endfor %}
|
|
31
|
+
if ({% for dim in range(spatial_rank) %}{{ out_indices[dim] }} >= 0 && {{ out_indices[dim] }} < {{ out_spatial[dim] }}{% if not loop.last %} && {% endif %}{% endfor %}) {
|
|
32
|
+
{{ output }}[n][oc_global]{% for dim in range(spatial_rank) %}[{{ out_indices[dim] }}]{% endfor %} += acc;
|
|
33
|
+
}
|
|
34
|
+
{% for dim in range(spatial_rank) %}
|
|
35
|
+
}
|
|
36
|
+
{% endfor %}
|
|
37
|
+
{% for dim in range(spatial_rank) %}
|
|
38
|
+
}
|
|
39
|
+
{% endfor %}
|
|
40
|
+
}
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
}
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}const {{ c_type }} {{ input0 }}{{ input_suffix }}{% if axis_input %}, const {{ axis_c_type }} {{ axis_input }}{{ axis_suffix }}{% endif %}, {{ c_type }} {{ output }}{{ output_suffix }}) {
|
|
2
|
+
const {{ c_type }} *input_data = (const {{ c_type }} *){{ input0 }};
|
|
3
|
+
{{ c_type }} *output_data = ({{ c_type }} *){{ output }};
|
|
4
|
+
const idx_t dims[{{ rank }}] = { {% for dim in input_shape %}{{ dim }}{% if not loop.last %}, {% endif %}{% endfor %} };
|
|
5
|
+
int axis = {% if axis_literal is not none %}{{ axis_literal }}{% else %}(int){{ axis_input }}[0]{% endif %};
|
|
6
|
+
if (axis < 0) {
|
|
7
|
+
axis += {{ rank }};
|
|
8
|
+
}
|
|
9
|
+
if (axis < 0 || axis >= {{ rank }}) {
|
|
10
|
+
return;
|
|
11
|
+
}
|
|
12
|
+
idx_t outer = 1;
|
|
13
|
+
for (int i = 0; i < axis; ++i) {
|
|
14
|
+
outer *= dims[i];
|
|
15
|
+
}
|
|
16
|
+
idx_t inner = 1;
|
|
17
|
+
for (int i = axis + 1; i < {{ rank }}; ++i) {
|
|
18
|
+
inner *= dims[i];
|
|
19
|
+
}
|
|
20
|
+
idx_t axis_dim = dims[axis];
|
|
21
|
+
for (idx_t outer_index = 0; outer_index < outer; ++outer_index) {
|
|
22
|
+
for (idx_t inner_index = 0; inner_index < inner; ++inner_index) {
|
|
23
|
+
{{ c_type }} acc = ({{ c_type }})0;
|
|
24
|
+
idx_t base = (outer_index * axis_dim * inner) + inner_index;
|
|
25
|
+
{% if reverse %}
|
|
26
|
+
for (idx_t axis_offset = 0; axis_offset < axis_dim; ++axis_offset) {
|
|
27
|
+
idx_t axis_index = axis_dim - 1 - axis_offset;
|
|
28
|
+
idx_t offset = base + axis_index * inner;
|
|
29
|
+
{% if exclusive %}
|
|
30
|
+
output_data[offset] = acc;
|
|
31
|
+
acc += input_data[offset];
|
|
32
|
+
{% else %}
|
|
33
|
+
acc += input_data[offset];
|
|
34
|
+
output_data[offset] = acc;
|
|
35
|
+
{% endif %}
|
|
36
|
+
}
|
|
37
|
+
{% else %}
|
|
38
|
+
for (idx_t axis_index = 0; axis_index < axis_dim; ++axis_index) {
|
|
39
|
+
idx_t offset = base + axis_index * inner;
|
|
40
|
+
{% if exclusive %}
|
|
41
|
+
output_data[offset] = acc;
|
|
42
|
+
acc += input_data[offset];
|
|
43
|
+
{% else %}
|
|
44
|
+
acc += input_data[offset];
|
|
45
|
+
output_data[offset] = acc;
|
|
46
|
+
{% endif %}
|
|
47
|
+
}
|
|
48
|
+
{% endif %}
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
}
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
const {{ c_type }} *input_data = (const {{ c_type }} *){{ input0 }};
|
|
3
|
+
{{ c_type }} *output_data = ({{ c_type }} *){{ output }};
|
|
4
|
+
idx_t output_index = 0;
|
|
5
|
+
for (idx_t n = 0; n < {{ batch }}; ++n) {
|
|
6
|
+
for (idx_t c_out = 0; c_out < {{ out_channels }}; ++c_out) {
|
|
7
|
+
for (idx_t h_out = 0; h_out < {{ out_h }}; ++h_out) {
|
|
8
|
+
idx_t h_in = h_out / {{ blocksize }};
|
|
9
|
+
idx_t offset_h = h_out % {{ blocksize }};
|
|
10
|
+
for (idx_t w_out = 0; w_out < {{ out_w }}; ++w_out) {
|
|
11
|
+
idx_t w_in = w_out / {{ blocksize }};
|
|
12
|
+
idx_t offset_w = w_out % {{ blocksize }};
|
|
13
|
+
idx_t c_in;
|
|
14
|
+
{% if mode == "DCR" %}
|
|
15
|
+
c_in = (offset_h * {{ blocksize }} + offset_w) * {{ out_channels }} + c_out;
|
|
16
|
+
{% else %}
|
|
17
|
+
c_in = (c_out * {{ blocksize }} + offset_h) * {{ blocksize }} + offset_w;
|
|
18
|
+
{% endif %}
|
|
19
|
+
idx_t input_index = ((n * {{ in_channels }} + c_in) * {{ in_h }} + h_in) * {{ in_w }} + w_in;
|
|
20
|
+
output_data[output_index] = input_data[input_index];
|
|
21
|
+
output_index++;
|
|
22
|
+
}
|
|
23
|
+
}
|
|
24
|
+
}
|
|
25
|
+
}
|
|
26
|
+
}
|