emx-onnx-cgen 0.3.7__py3-none-any.whl → 0.4.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/cli.py +1025 -162
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +2081 -458
- emx_onnx_cgen/compiler.py +157 -75
- emx_onnx_cgen/determinism.py +39 -0
- emx_onnx_cgen/ir/context.py +25 -15
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +32 -7
- emx_onnx_cgen/ir/ops/__init__.py +20 -0
- emx_onnx_cgen/ir/ops/elementwise.py +138 -22
- emx_onnx_cgen/ir/ops/misc.py +95 -0
- emx_onnx_cgen/ir/ops/nn.py +361 -38
- emx_onnx_cgen/ir/ops/reduce.py +1 -16
- emx_onnx_cgen/lowering/__init__.py +9 -0
- emx_onnx_cgen/lowering/arg_reduce.py +0 -4
- emx_onnx_cgen/lowering/average_pool.py +157 -27
- emx_onnx_cgen/lowering/bernoulli.py +73 -0
- emx_onnx_cgen/lowering/common.py +48 -0
- emx_onnx_cgen/lowering/concat.py +41 -7
- emx_onnx_cgen/lowering/conv.py +19 -8
- emx_onnx_cgen/lowering/conv_integer.py +103 -0
- emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
- emx_onnx_cgen/lowering/elementwise.py +140 -43
- emx_onnx_cgen/lowering/gather.py +11 -2
- emx_onnx_cgen/lowering/gemm.py +7 -124
- emx_onnx_cgen/lowering/global_max_pool.py +0 -5
- emx_onnx_cgen/lowering/gru.py +323 -0
- emx_onnx_cgen/lowering/hamming_window.py +104 -0
- emx_onnx_cgen/lowering/hardmax.py +1 -37
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/logsoftmax.py +1 -35
- emx_onnx_cgen/lowering/lp_pool.py +15 -4
- emx_onnx_cgen/lowering/matmul.py +3 -105
- emx_onnx_cgen/lowering/optional_has_element.py +28 -0
- emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
- emx_onnx_cgen/lowering/reduce.py +0 -5
- emx_onnx_cgen/lowering/reshape.py +7 -16
- emx_onnx_cgen/lowering/shape.py +14 -8
- emx_onnx_cgen/lowering/slice.py +14 -4
- emx_onnx_cgen/lowering/softmax.py +1 -35
- emx_onnx_cgen/lowering/split.py +37 -3
- emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
- emx_onnx_cgen/lowering/tile.py +38 -1
- emx_onnx_cgen/lowering/topk.py +1 -5
- emx_onnx_cgen/lowering/transpose.py +9 -3
- emx_onnx_cgen/lowering/unsqueeze.py +11 -16
- emx_onnx_cgen/lowering/upsample.py +151 -0
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +0 -5
- emx_onnx_cgen/onnx_import.py +578 -14
- emx_onnx_cgen/ops.py +3 -0
- emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
- emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
- emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
- emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
- emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
- emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
- emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
- emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
- emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
- emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
- emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
- emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
- emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
- emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
- emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
- emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
- emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
- emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
- emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
- emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
- emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
- emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
- emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
- emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
- emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
- emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
- emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
- emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
- emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
- emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
- emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
- emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
- emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
- emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
- emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
- emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
- emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
- emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
- emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
- emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
- emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
- emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
- emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
- emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
- emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
- emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
- emx_onnx_cgen/templates/range_op.c.j2 +8 -0
- emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
- emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
- emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
- emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
- emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
- emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
- emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
- emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
- emx_onnx_cgen/templates/size_op.c.j2 +4 -0
- emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
- emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
- emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
- emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
- emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
- emx_onnx_cgen/templates/split_op.c.j2 +18 -0
- emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
- emx_onnx_cgen/templates/testbench.c.j2 +161 -0
- emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
- emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
- emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
- emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
- emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
- emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
- emx_onnx_cgen/templates/where_op.c.j2 +9 -0
- emx_onnx_cgen/verification.py +45 -5
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/METADATA +33 -15
- emx_onnx_cgen-0.4.1.dev0.dist-info/RECORD +190 -0
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/WHEEL +1 -1
- emx_onnx_cgen/runtime/__init__.py +0 -1
- emx_onnx_cgen/runtime/evaluator.py +0 -2955
- emx_onnx_cgen-0.3.7.dist-info/RECORD +0 -107
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in shape[:2] %}
|
|
3
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
{{ c_type }} sum = {{ zero_literal }};
|
|
6
|
+
{% for dim in shape[2:] %}
|
|
7
|
+
for (idx_t {{ loop_vars[loop.index0 + 2] }} = 0; {{ loop_vars[loop.index0 + 2] }} < {{ dim }}; ++{{ loop_vars[loop.index0 + 2] }}) {
|
|
8
|
+
{% endfor %}
|
|
9
|
+
sum += {{ input0 }}[{{ loop_vars[0] }}][{{ loop_vars[1] }}]{% for var in loop_vars[2:] %}[{{ var }}]{% endfor %};
|
|
10
|
+
{% for _ in shape[2:] %}
|
|
11
|
+
}
|
|
12
|
+
{% endfor %}
|
|
13
|
+
{{ c_type }} mean = sum / {{ spatial_size }};
|
|
14
|
+
{{ c_type }} var = {{ zero_literal }};
|
|
15
|
+
{% for dim in shape[2:] %}
|
|
16
|
+
for (idx_t {{ loop_vars[loop.index0 + 2] }} = 0; {{ loop_vars[loop.index0 + 2] }} < {{ dim }}; ++{{ loop_vars[loop.index0 + 2] }}) {
|
|
17
|
+
{% endfor %}
|
|
18
|
+
{{ c_type }} diff = {{ input0 }}[{{ loop_vars[0] }}][{{ loop_vars[1] }}]{% for var in loop_vars[2:] %}[{{ var }}]{% endfor %} - mean;
|
|
19
|
+
var += diff * diff;
|
|
20
|
+
{% for _ in shape[2:] %}
|
|
21
|
+
}
|
|
22
|
+
{% endfor %}
|
|
23
|
+
{{ c_type }} denom = {{ sqrt_fn }}(var / {{ spatial_size }} + {{ epsilon_literal }});
|
|
24
|
+
{% for dim in shape[2:] %}
|
|
25
|
+
for (idx_t {{ loop_vars[loop.index0 + 2] }} = 0; {{ loop_vars[loop.index0 + 2] }} < {{ dim }}; ++{{ loop_vars[loop.index0 + 2] }}) {
|
|
26
|
+
{% endfor %}
|
|
27
|
+
{{ output }}[{{ loop_vars[0] }}][{{ loop_vars[1] }}]{% for var in loop_vars[2:] %}[{{ var }}]{% endfor %} =
|
|
28
|
+
({{ input0 }}[{{ loop_vars[0] }}][{{ loop_vars[1] }}]{% for var in loop_vars[2:] %}[{{ var }}]{% endfor %} - mean) / denom * {{ scale }}[{{ loop_vars[1] }}] + {{ bias }}[{{ loop_vars[1] }}];
|
|
29
|
+
{% for _ in shape[2:] %}
|
|
30
|
+
}
|
|
31
|
+
{% endfor %}
|
|
32
|
+
{% for _ in shape[:2] %}
|
|
33
|
+
}
|
|
34
|
+
{% endfor %}
|
|
35
|
+
}
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in prefix_shape %}
|
|
3
|
+
for (idx_t {{ prefix_loop_vars[loop.index0] }} = 0; {{ prefix_loop_vars[loop.index0] }} < {{ dim }}; ++{{ prefix_loop_vars[loop.index0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
{{ acc_type }} sum = {{ acc_zero_literal }};
|
|
6
|
+
{% if use_kahan %}
|
|
7
|
+
{{ acc_type }} sum_comp = {{ acc_zero_literal }};
|
|
8
|
+
{% endif %}
|
|
9
|
+
{% for dim in norm_shape %}
|
|
10
|
+
for (idx_t {{ norm_loop_vars[loop.index0] }} = 0; {{ norm_loop_vars[loop.index0] }} < {{ dim }}; ++{{ norm_loop_vars[loop.index0] }}) {
|
|
11
|
+
{% endfor %}
|
|
12
|
+
{% if use_kahan %}
|
|
13
|
+
{{ acc_type }} kahan_value = ({{ acc_type }}){{ input0 }}{% for var in prefix_loop_vars %}[{{ var }}]{% endfor %}{% for var in norm_loop_vars %}[{{ var }}]{% endfor %};
|
|
14
|
+
{{ acc_type }} kahan_y = kahan_value - sum_comp;
|
|
15
|
+
{{ acc_type }} kahan_t = sum + kahan_y;
|
|
16
|
+
sum_comp = (kahan_t - sum) - kahan_y;
|
|
17
|
+
sum = kahan_t;
|
|
18
|
+
{% else %}
|
|
19
|
+
sum += ({{ acc_type }}){{ input0 }}{% for var in prefix_loop_vars %}[{{ var }}]{% endfor %}{% for var in norm_loop_vars %}[{{ var }}]{% endfor %};
|
|
20
|
+
{% endif %}
|
|
21
|
+
{% for _ in norm_shape %}
|
|
22
|
+
}
|
|
23
|
+
{% endfor %}
|
|
24
|
+
{{ acc_type }} mean = sum / {{ inner }};
|
|
25
|
+
{{ acc_type }} var = {{ acc_zero_literal }};
|
|
26
|
+
{% if use_kahan %}
|
|
27
|
+
{{ acc_type }} var_comp = {{ acc_zero_literal }};
|
|
28
|
+
{% endif %}
|
|
29
|
+
{% for dim in norm_shape %}
|
|
30
|
+
for (idx_t {{ norm_loop_vars[loop.index0] }} = 0; {{ norm_loop_vars[loop.index0] }} < {{ dim }}; ++{{ norm_loop_vars[loop.index0] }}) {
|
|
31
|
+
{% endfor %}
|
|
32
|
+
{{ acc_type }} diff = ({{ acc_type }}){{ input0 }}{% for var in prefix_loop_vars %}[{{ var }}]{% endfor %}{% for var in norm_loop_vars %}[{{ var }}]{% endfor %} - mean;
|
|
33
|
+
{% if use_kahan %}
|
|
34
|
+
{{ acc_type }} kahan_value = diff * diff;
|
|
35
|
+
{{ acc_type }} kahan_y = kahan_value - var_comp;
|
|
36
|
+
{{ acc_type }} kahan_t = var + kahan_y;
|
|
37
|
+
var_comp = (kahan_t - var) - kahan_y;
|
|
38
|
+
var = kahan_t;
|
|
39
|
+
{% else %}
|
|
40
|
+
var += diff * diff;
|
|
41
|
+
{% endif %}
|
|
42
|
+
{% for _ in norm_shape %}
|
|
43
|
+
}
|
|
44
|
+
{% endfor %}
|
|
45
|
+
var = var / {{ inner }};
|
|
46
|
+
{{ acc_type }} inv_std = {{ acc_one_literal }} / {{ acc_sqrt_fn }}(var + {{ acc_epsilon_literal }});
|
|
47
|
+
{% if mean_output %}
|
|
48
|
+
{{ mean_output }}{% for var in mean_index_vars %}[{{ var }}]{% endfor %} = mean;
|
|
49
|
+
{% endif %}
|
|
50
|
+
{% if invstd_output %}
|
|
51
|
+
{{ invstd_output }}{% for var in mean_index_vars %}[{{ var }}]{% endfor %} = inv_std;
|
|
52
|
+
{% endif %}
|
|
53
|
+
{% for dim in norm_shape %}
|
|
54
|
+
for (idx_t {{ norm_loop_vars[loop.index0] }} = 0; {{ norm_loop_vars[loop.index0] }} < {{ dim }}; ++{{ norm_loop_vars[loop.index0] }}) {
|
|
55
|
+
{% endfor %}
|
|
56
|
+
{{ acc_type }} value = (({{ acc_type }}){{ input0 }}{% for var in prefix_loop_vars %}[{{ var }}]{% endfor %}{% for var in norm_loop_vars %}[{{ var }}]{% endfor %} - mean) * inv_std;
|
|
57
|
+
value = value * {{ scale }}{% for var in scale_index_vars %}[{{ var }}]{% endfor %}{% if bias %} + {{ bias }}{% for var in bias_index_vars %}[{{ var }}]{% endfor %}{% endif %};
|
|
58
|
+
{{ output }}{% for var in prefix_loop_vars %}[{{ var }}]{% endfor %}{% for var in norm_loop_vars %}[{{ var }}]{% endfor %} = value;
|
|
59
|
+
{% for _ in norm_shape %}
|
|
60
|
+
}
|
|
61
|
+
{% endfor %}
|
|
62
|
+
{% for _ in prefix_shape %}
|
|
63
|
+
}
|
|
64
|
+
{% endfor %}
|
|
65
|
+
}
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
const {{ c_type }} *input_flat = (const {{ c_type }} *){{ input0 }};
|
|
3
|
+
{{ c_type }} *output_flat = ({{ c_type }} *){{ output }};
|
|
4
|
+
const idx_t outer = {{ outer }};
|
|
5
|
+
const idx_t axis_size = {{ axis_size }};
|
|
6
|
+
const idx_t inner = {{ inner }};
|
|
7
|
+
for (idx_t outer_idx = 0; outer_idx < outer; ++outer_idx) {
|
|
8
|
+
for (idx_t inner_idx = 0; inner_idx < inner; ++inner_idx) {
|
|
9
|
+
idx_t base = (outer_idx * axis_size * inner) + inner_idx;
|
|
10
|
+
{{ c_type }} max_value = input_flat[base];
|
|
11
|
+
for (idx_t axis_idx = 1; axis_idx < axis_size; ++axis_idx) {
|
|
12
|
+
{{ c_type }} value = input_flat[base + axis_idx * inner];
|
|
13
|
+
max_value = {{ max_fn }}(max_value, value);
|
|
14
|
+
}
|
|
15
|
+
{{ c_type }} sum = 0;
|
|
16
|
+
for (idx_t axis_idx = 0; axis_idx < axis_size; ++axis_idx) {
|
|
17
|
+
{{ c_type }} value = {{ exp_fn }}(input_flat[base + axis_idx * inner] - max_value);
|
|
18
|
+
sum += value;
|
|
19
|
+
}
|
|
20
|
+
{{ c_type }} logsum = {{ log_fn }}(sum);
|
|
21
|
+
for (idx_t axis_idx = 0; axis_idx < axis_size; ++axis_idx) {
|
|
22
|
+
{{ c_type }} value = input_flat[base + axis_idx * inner] - max_value;
|
|
23
|
+
output_flat[base + axis_idx * inner] = value - logsum;
|
|
24
|
+
}
|
|
25
|
+
}
|
|
26
|
+
}
|
|
27
|
+
}
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
const {{ c_type }} *input_flat = (const {{ c_type }} *){{ input0 }};
|
|
3
|
+
{{ c_type }} *output_flat = ({{ c_type }} *){{ output }};
|
|
4
|
+
const idx_t outer = {{ outer }};
|
|
5
|
+
const idx_t axis_size = {{ axis_size }};
|
|
6
|
+
const idx_t inner = {{ inner }};
|
|
7
|
+
for (idx_t outer_idx = 0; outer_idx < outer; ++outer_idx) {
|
|
8
|
+
for (idx_t inner_idx = 0; inner_idx < inner; ++inner_idx) {
|
|
9
|
+
idx_t base = (outer_idx * axis_size * inner) + inner_idx;
|
|
10
|
+
{{ c_type }} acc = {{ zero_literal }};
|
|
11
|
+
for (idx_t axis_idx = 0; axis_idx < axis_size; ++axis_idx) {
|
|
12
|
+
{{ c_type }} value = input_flat[base + axis_idx * inner];
|
|
13
|
+
{% if p == 1 %}
|
|
14
|
+
acc += {{ abs_fn }}(value);
|
|
15
|
+
{% else %}
|
|
16
|
+
acc += value * value;
|
|
17
|
+
{% endif %}
|
|
18
|
+
}
|
|
19
|
+
{% if p == 2 %}
|
|
20
|
+
acc = {{ sqrt_fn }}(acc);
|
|
21
|
+
{% endif %}
|
|
22
|
+
for (idx_t axis_idx = 0; axis_idx < axis_size; ++axis_idx) {
|
|
23
|
+
output_flat[base + axis_idx * inner] = input_flat[base + axis_idx * inner] / acc;
|
|
24
|
+
}
|
|
25
|
+
}
|
|
26
|
+
}
|
|
27
|
+
}
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
for (idx_t n = 0; n < {{ batch }}; ++n) {
|
|
3
|
+
for (idx_t c = 0; c < {{ channels }}; ++c) {
|
|
4
|
+
for (idx_t oh = 0; oh < {{ out_h }}; ++oh) {
|
|
5
|
+
for (idx_t ow = 0; ow < {{ out_w }}; ++ow) {
|
|
6
|
+
{{ c_type }} acc = {{ zero_literal }};
|
|
7
|
+
const idx_t h_start = oh * {{ stride_h }} - {{ pad_top }};
|
|
8
|
+
const idx_t w_start = ow * {{ stride_w }} - {{ pad_left }};
|
|
9
|
+
for (idx_t kh = 0; kh < {{ kernel_h }}; ++kh) {
|
|
10
|
+
for (idx_t kw = 0; kw < {{ kernel_w }}; ++kw) {
|
|
11
|
+
const idx_t in_h = h_start + kh * {{ dilation_h }};
|
|
12
|
+
const idx_t in_w = w_start + kw * {{ dilation_w }};
|
|
13
|
+
if (in_h >= 0 && in_h < {{ in_h }} && in_w >= 0 && in_w < {{ in_w }}) {
|
|
14
|
+
{{ c_type }} value = {{ input0 }}[n][c][in_h][in_w];
|
|
15
|
+
acc += {{ pow_fn }}({{ abs_fn }}(value), ({{ c_type }}){{ p }});
|
|
16
|
+
}
|
|
17
|
+
}
|
|
18
|
+
}
|
|
19
|
+
{{ output }}[n][c][oh][ow] = {{ pow_fn }}(acc, ({{ c_type }})1.0 / ({{ c_type }}){{ p }});
|
|
20
|
+
}
|
|
21
|
+
}
|
|
22
|
+
}
|
|
23
|
+
}
|
|
24
|
+
}
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in shape %}
|
|
3
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
idx_t channel_start = {{ loop_vars[1] }} > {{ half }} ? {{ loop_vars[1] }} - {{ half }} : 0;
|
|
6
|
+
idx_t channel_end = {{ loop_vars[1] }} + {{ half }};
|
|
7
|
+
if (channel_end >= {{ channels }}) {
|
|
8
|
+
channel_end = {{ channels }} - 1;
|
|
9
|
+
}
|
|
10
|
+
{{ c_type }} sum = {{ zero_literal }};
|
|
11
|
+
for (idx_t c = channel_start; c <= channel_end; ++c) {
|
|
12
|
+
{{ c_type }} val = {{ input0 }}[{{ loop_vars[0] }}][c]{% for var in loop_vars[2:] %}[{{ var }}]{% endfor %};
|
|
13
|
+
sum += val * val;
|
|
14
|
+
}
|
|
15
|
+
{{ c_type }} scale = {{ bias_literal }} + {{ alpha_div_size_literal }} * sum;
|
|
16
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = {{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %} / {{ pow_fn }}(scale, {{ beta_literal }});
|
|
17
|
+
{% for _ in shape %}
|
|
18
|
+
}
|
|
19
|
+
{% endfor %}
|
|
20
|
+
}
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dir in range(num_directions) %}
|
|
3
|
+
{
|
|
4
|
+
const int dir = {{ dir }};
|
|
5
|
+
const int reverse = {% if direction == "reverse" %}1{% elif direction == "bidirectional" and dir == 1 %}1{% else %}0{% endif %};
|
|
6
|
+
{% set act_f = activation_functions[dir * 3] %}
|
|
7
|
+
{% set act_g = activation_functions[dir * 3 + 1] %}
|
|
8
|
+
{% set act_h = activation_functions[dir * 3 + 2] %}
|
|
9
|
+
{{ c_type }} H_prev[{{ batch_size }}][{{ hidden_size }}];
|
|
10
|
+
{{ c_type }} C_prev[{{ batch_size }}][{{ hidden_size }}];
|
|
11
|
+
for (int b = 0; b < {{ batch_size }}; ++b) {
|
|
12
|
+
for (int h = 0; h < {{ hidden_size }}; ++h) {
|
|
13
|
+
{% if input_initial_h %}
|
|
14
|
+
{% if layout == 0 %}
|
|
15
|
+
H_prev[b][h] = {{ input_initial_h }}[dir][b][h];
|
|
16
|
+
{% else %}
|
|
17
|
+
H_prev[b][h] = {{ input_initial_h }}[b][dir][h];
|
|
18
|
+
{% endif %}
|
|
19
|
+
{% else %}
|
|
20
|
+
H_prev[b][h] = {{ zero_literal }};
|
|
21
|
+
{% endif %}
|
|
22
|
+
{% if input_initial_c %}
|
|
23
|
+
{% if layout == 0 %}
|
|
24
|
+
C_prev[b][h] = {{ input_initial_c }}[dir][b][h];
|
|
25
|
+
{% else %}
|
|
26
|
+
C_prev[b][h] = {{ input_initial_c }}[b][dir][h];
|
|
27
|
+
{% endif %}
|
|
28
|
+
{% else %}
|
|
29
|
+
C_prev[b][h] = {{ zero_literal }};
|
|
30
|
+
{% endif %}
|
|
31
|
+
}
|
|
32
|
+
}
|
|
33
|
+
for (int b = 0; b < {{ batch_size }}; ++b) {
|
|
34
|
+
int seq_limit = {{ seq_length }};
|
|
35
|
+
{% if input_sequence_lens %}
|
|
36
|
+
seq_limit = (int){{ input_sequence_lens }}[b];
|
|
37
|
+
if (seq_limit < 0) {
|
|
38
|
+
seq_limit = 0;
|
|
39
|
+
}
|
|
40
|
+
if (seq_limit > {{ seq_length }}) {
|
|
41
|
+
seq_limit = {{ seq_length }};
|
|
42
|
+
}
|
|
43
|
+
{% endif %}
|
|
44
|
+
for (int step = 0; step < seq_limit; ++step) {
|
|
45
|
+
int t = reverse ? (seq_limit - 1 - step) : step;
|
|
46
|
+
{{ c_type }} H_next[{{ hidden_size }}];
|
|
47
|
+
{{ c_type }} C_next[{{ hidden_size }}];
|
|
48
|
+
for (int h = 0; h < {{ hidden_size }}; ++h) {
|
|
49
|
+
{{ c_type }} gate_i = {{ zero_literal }};
|
|
50
|
+
{{ c_type }} gate_o = {{ zero_literal }};
|
|
51
|
+
{{ c_type }} gate_f = {{ zero_literal }};
|
|
52
|
+
{{ c_type }} gate_c = {{ zero_literal }};
|
|
53
|
+
for (int i = 0; i < {{ input_size }}; ++i) {
|
|
54
|
+
{% if layout == 0 %}
|
|
55
|
+
{{ c_type }} x_val = {{ input_x }}[t][b][i];
|
|
56
|
+
{% else %}
|
|
57
|
+
{{ c_type }} x_val = {{ input_x }}[b][t][i];
|
|
58
|
+
{% endif %}
|
|
59
|
+
gate_i += x_val * {{ input_w }}[dir][h][i];
|
|
60
|
+
gate_o += x_val * {{ input_w }}[dir][{{ hidden_size }} + h][i];
|
|
61
|
+
gate_f += x_val * {{ input_w }}[dir][{{ hidden_size }} * 2 + h][i];
|
|
62
|
+
gate_c += x_val * {{ input_w }}[dir][{{ hidden_size }} * 3 + h][i];
|
|
63
|
+
}
|
|
64
|
+
for (int i = 0; i < {{ hidden_size }}; ++i) {
|
|
65
|
+
{{ c_type }} h_val = H_prev[b][i];
|
|
66
|
+
gate_i += h_val * {{ input_r }}[dir][h][i];
|
|
67
|
+
gate_o += h_val * {{ input_r }}[dir][{{ hidden_size }} + h][i];
|
|
68
|
+
gate_f += h_val * {{ input_r }}[dir][{{ hidden_size }} * 2 + h][i];
|
|
69
|
+
gate_c += h_val * {{ input_r }}[dir][{{ hidden_size }} * 3 + h][i];
|
|
70
|
+
}
|
|
71
|
+
{% if input_b %}
|
|
72
|
+
gate_i += {{ input_b }}[dir][h] + {{ input_b }}[dir][{{ hidden_size }} * 4 + h];
|
|
73
|
+
gate_o += {{ input_b }}[dir][{{ hidden_size }} + h] + {{ input_b }}[dir][{{ hidden_size }} * 5 + h];
|
|
74
|
+
gate_f += {{ input_b }}[dir][{{ hidden_size }} * 2 + h] + {{ input_b }}[dir][{{ hidden_size }} * 6 + h];
|
|
75
|
+
gate_c += {{ input_b }}[dir][{{ hidden_size }} * 3 + h] + {{ input_b }}[dir][{{ hidden_size }} * 7 + h];
|
|
76
|
+
{% endif %}
|
|
77
|
+
{% if use_clip %}
|
|
78
|
+
if (gate_i > {{ clip_literal }}) {
|
|
79
|
+
gate_i = {{ clip_literal }};
|
|
80
|
+
} else if (gate_i < -{{ clip_literal }}) {
|
|
81
|
+
gate_i = -{{ clip_literal }};
|
|
82
|
+
}
|
|
83
|
+
if (gate_o > {{ clip_literal }}) {
|
|
84
|
+
gate_o = {{ clip_literal }};
|
|
85
|
+
} else if (gate_o < -{{ clip_literal }}) {
|
|
86
|
+
gate_o = -{{ clip_literal }};
|
|
87
|
+
}
|
|
88
|
+
if (gate_f > {{ clip_literal }}) {
|
|
89
|
+
gate_f = {{ clip_literal }};
|
|
90
|
+
} else if (gate_f < -{{ clip_literal }}) {
|
|
91
|
+
gate_f = -{{ clip_literal }};
|
|
92
|
+
}
|
|
93
|
+
if (gate_c > {{ clip_literal }}) {
|
|
94
|
+
gate_c = {{ clip_literal }};
|
|
95
|
+
} else if (gate_c < -{{ clip_literal }}) {
|
|
96
|
+
gate_c = -{{ clip_literal }};
|
|
97
|
+
}
|
|
98
|
+
{% endif %}
|
|
99
|
+
{% if input_p %}
|
|
100
|
+
{{ c_type }} i_gate = {{ act_f }}(
|
|
101
|
+
gate_i + {{ input_p }}[dir][h] * C_prev[b][h]);
|
|
102
|
+
{% else %}
|
|
103
|
+
{{ c_type }} i_gate = {{ act_f }}(gate_i);
|
|
104
|
+
{% endif %}
|
|
105
|
+
{% if input_forget %}
|
|
106
|
+
{{ c_type }} f_gate = ({{ c_type }}){{ one_literal }} - i_gate;
|
|
107
|
+
{% else %}
|
|
108
|
+
{% if input_p %}
|
|
109
|
+
{{ c_type }} f_gate = {{ act_f }}(
|
|
110
|
+
gate_f + {{ input_p }}[dir][{{ hidden_size }} * 2 + h] * C_prev[b][h]);
|
|
111
|
+
{% else %}
|
|
112
|
+
{{ c_type }} f_gate = {{ act_f }}(gate_f);
|
|
113
|
+
{% endif %}
|
|
114
|
+
{% endif %}
|
|
115
|
+
{{ c_type }} c_gate = {{ act_g }}(gate_c);
|
|
116
|
+
{{ c_type }} c_new = f_gate * C_prev[b][h] + i_gate * c_gate;
|
|
117
|
+
{% if input_p %}
|
|
118
|
+
{{ c_type }} o_gate = {{ act_f }}(
|
|
119
|
+
gate_o + {{ input_p }}[dir][{{ hidden_size }} + h] * c_new);
|
|
120
|
+
{% else %}
|
|
121
|
+
{{ c_type }} o_gate = {{ act_f }}(gate_o);
|
|
122
|
+
{% endif %}
|
|
123
|
+
{{ c_type }} h_new = o_gate * {{ act_h }}(c_new);
|
|
124
|
+
C_next[h] = c_new;
|
|
125
|
+
H_next[h] = h_new;
|
|
126
|
+
{% if output_y %}
|
|
127
|
+
{% if layout == 0 %}
|
|
128
|
+
{{ output_y }}[step][dir][b][h] = h_new;
|
|
129
|
+
{% else %}
|
|
130
|
+
{{ output_y }}[b][step][dir][h] = h_new;
|
|
131
|
+
{% endif %}
|
|
132
|
+
{% endif %}
|
|
133
|
+
}
|
|
134
|
+
for (int h = 0; h < {{ hidden_size }}; ++h) {
|
|
135
|
+
C_prev[b][h] = C_next[h];
|
|
136
|
+
H_prev[b][h] = H_next[h];
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
{% if output_y %}
|
|
140
|
+
for (int step = seq_limit; step < {{ seq_length }}; ++step) {
|
|
141
|
+
for (int h = 0; h < {{ hidden_size }}; ++h) {
|
|
142
|
+
{% if layout == 0 %}
|
|
143
|
+
{{ output_y }}[step][dir][b][h] = {{ zero_literal }};
|
|
144
|
+
{% else %}
|
|
145
|
+
{{ output_y }}[b][step][dir][h] = {{ zero_literal }};
|
|
146
|
+
{% endif %}
|
|
147
|
+
}
|
|
148
|
+
}
|
|
149
|
+
{% endif %}
|
|
150
|
+
}
|
|
151
|
+
{% if output_y_h %}
|
|
152
|
+
for (int b = 0; b < {{ batch_size }}; ++b) {
|
|
153
|
+
for (int h = 0; h < {{ hidden_size }}; ++h) {
|
|
154
|
+
{% if layout == 0 %}
|
|
155
|
+
{{ output_y_h }}[dir][b][h] = H_prev[b][h];
|
|
156
|
+
{% else %}
|
|
157
|
+
{{ output_y_h }}[b][dir][h] = H_prev[b][h];
|
|
158
|
+
{% endif %}
|
|
159
|
+
}
|
|
160
|
+
}
|
|
161
|
+
{% endif %}
|
|
162
|
+
{% if output_y_c %}
|
|
163
|
+
for (int b = 0; b < {{ batch_size }}; ++b) {
|
|
164
|
+
for (int h = 0; h < {{ hidden_size }}; ++h) {
|
|
165
|
+
{% if layout == 0 %}
|
|
166
|
+
{{ output_y_c }}[dir][b][h] = C_prev[b][h];
|
|
167
|
+
{% else %}
|
|
168
|
+
{{ output_y_c }}[b][dir][h] = C_prev[b][h];
|
|
169
|
+
{% endif %}
|
|
170
|
+
}
|
|
171
|
+
}
|
|
172
|
+
{% endif %}
|
|
173
|
+
}
|
|
174
|
+
{% endfor %}
|
|
175
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for idx in range(output_loop_vars | length) %}
|
|
3
|
+
{% for indent in range(loop.index0) %} {% endfor %}for (idx_t {{ output_loop_vars[idx] }} = 0; {{ output_loop_vars[idx] }} < {{ output_loop_bounds[idx] }}; ++{{ output_loop_vars[idx] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}{{ acc_type }} acc = {{ zero_literal }};
|
|
6
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}for (idx_t k = 0; k < {{ k }}; ++k) {
|
|
7
|
+
{% for indent in range(output_loop_vars | length + 1) %} {% endfor %}acc += {{ input0_index_expr }} * {{ input1_index_expr }};
|
|
8
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}}
|
|
9
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}{{ output_index_expr }} = acc;
|
|
10
|
+
{% for idx in range(output_loop_vars | length) | reverse %}
|
|
11
|
+
{% for indent in range(loop.index0) %} {% endfor %}}
|
|
12
|
+
{% endfor %}
|
|
13
|
+
}
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
for (idx_t n = 0; n < {{ batch }}; ++n) {
|
|
3
|
+
for (idx_t c = 0; c < {{ channels }}; ++c) {
|
|
4
|
+
{% if spatial_rank == 1 %}
|
|
5
|
+
for (idx_t ox = 0; ox < {{ out_spatial[0] }}; ++ox) {
|
|
6
|
+
{{ c_type }} max_value = {{ min_literal }};
|
|
7
|
+
{% if indices %}
|
|
8
|
+
{{ indices_c_type }} max_index = 0;
|
|
9
|
+
{% endif %}
|
|
10
|
+
for (idx_t kx = 0; kx < {{ kernel_shape[0] }}; ++kx) {
|
|
11
|
+
const idx_t ix = ox * {{ strides[0] }} + kx * {{ dilations[0] }} - {{ pads[0] }};
|
|
12
|
+
if (ix >= 0 && ix < {{ in_spatial[0] }}) {
|
|
13
|
+
{{ c_type }} val = {{ input0 }}[n][c][ix];
|
|
14
|
+
{% if indices %}
|
|
15
|
+
const {{ c_type }} prev_max = max_value;
|
|
16
|
+
max_value = {{ max_fn }}(max_value, val);
|
|
17
|
+
max_index = (val > prev_max)
|
|
18
|
+
? ({{ indices_c_type }})(((({{ indices_c_type }})n * {{ channels }} + ({{ indices_c_type }})c) * {{ in_spatial[0] }}) + ({{ indices_c_type }})ix)
|
|
19
|
+
: max_index;
|
|
20
|
+
{% else %}
|
|
21
|
+
max_value = {{ max_fn }}(max_value, val);
|
|
22
|
+
{% endif %}
|
|
23
|
+
}
|
|
24
|
+
}
|
|
25
|
+
{{ output }}[n][c][ox] = max_value;
|
|
26
|
+
{% if indices %}
|
|
27
|
+
{{ indices }}[n][c][ox] = max_index;
|
|
28
|
+
{% endif %}
|
|
29
|
+
}
|
|
30
|
+
{% elif spatial_rank == 2 %}
|
|
31
|
+
for (idx_t oh = 0; oh < {{ out_spatial[0] }}; ++oh) {
|
|
32
|
+
for (idx_t ow = 0; ow < {{ out_spatial[1] }}; ++ow) {
|
|
33
|
+
{{ c_type }} max_value = {{ min_literal }};
|
|
34
|
+
{% if indices %}
|
|
35
|
+
{{ indices_c_type }} max_index = 0;
|
|
36
|
+
{% endif %}
|
|
37
|
+
for (idx_t kh = 0; kh < {{ kernel_shape[0] }}; ++kh) {
|
|
38
|
+
const idx_t ih = oh * {{ strides[0] }} + kh * {{ dilations[0] }} - {{ pads[0] }};
|
|
39
|
+
if (ih >= 0 && ih < {{ in_spatial[0] }}) {
|
|
40
|
+
for (idx_t kw = 0; kw < {{ kernel_shape[1] }}; ++kw) {
|
|
41
|
+
const idx_t iw = ow * {{ strides[1] }} + kw * {{ dilations[1] }} - {{ pads[1] }};
|
|
42
|
+
if (iw >= 0 && iw < {{ in_spatial[1] }}) {
|
|
43
|
+
{{ c_type }} val = {{ input0 }}[n][c][ih][iw];
|
|
44
|
+
{% if indices %}
|
|
45
|
+
const {{ c_type }} prev_max = max_value;
|
|
46
|
+
max_value = {{ max_fn }}(max_value, val);
|
|
47
|
+
max_index = (val > prev_max)
|
|
48
|
+
? (
|
|
49
|
+
{% if storage_order == 0 %}
|
|
50
|
+
({{ indices_c_type }})((((({{ indices_c_type }})n * {{ channels }} + ({{ indices_c_type }})c) * {{ in_spatial[0] }} + ({{ indices_c_type }})ih) * {{ in_spatial[1] }}) + ({{ indices_c_type }})iw)
|
|
51
|
+
{% else %}
|
|
52
|
+
({{ indices_c_type }})(((({{ indices_c_type }})n * {{ channels }} + ({{ indices_c_type }})c) * {{ in_spatial[0] }} * {{ in_spatial[1] }}) + ({{ indices_c_type }})ih + ({{ indices_c_type }})iw * {{ in_spatial[0] }})
|
|
53
|
+
{% endif %}
|
|
54
|
+
)
|
|
55
|
+
: max_index;
|
|
56
|
+
{% else %}
|
|
57
|
+
max_value = {{ max_fn }}(max_value, val);
|
|
58
|
+
{% endif %}
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
{{ output }}[n][c][oh][ow] = max_value;
|
|
64
|
+
{% if indices %}
|
|
65
|
+
{{ indices }}[n][c][oh][ow] = max_index;
|
|
66
|
+
{% endif %}
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
{% elif spatial_rank == 3 %}
|
|
70
|
+
for (idx_t od = 0; od < {{ out_spatial[0] }}; ++od) {
|
|
71
|
+
for (idx_t oh = 0; oh < {{ out_spatial[1] }}; ++oh) {
|
|
72
|
+
for (idx_t ow = 0; ow < {{ out_spatial[2] }}; ++ow) {
|
|
73
|
+
{{ c_type }} max_value = {{ min_literal }};
|
|
74
|
+
{% if indices %}
|
|
75
|
+
{{ indices_c_type }} max_index = 0;
|
|
76
|
+
{% endif %}
|
|
77
|
+
for (idx_t kd = 0; kd < {{ kernel_shape[0] }}; ++kd) {
|
|
78
|
+
const idx_t id = od * {{ strides[0] }} + kd * {{ dilations[0] }} - {{ pads[0] }};
|
|
79
|
+
if (id >= 0 && id < {{ in_spatial[0] }}) {
|
|
80
|
+
for (idx_t kh = 0; kh < {{ kernel_shape[1] }}; ++kh) {
|
|
81
|
+
const idx_t ih = oh * {{ strides[1] }} + kh * {{ dilations[1] }} - {{ pads[1] }};
|
|
82
|
+
if (ih >= 0 && ih < {{ in_spatial[1] }}) {
|
|
83
|
+
for (idx_t kw = 0; kw < {{ kernel_shape[2] }}; ++kw) {
|
|
84
|
+
const idx_t iw = ow * {{ strides[2] }} + kw * {{ dilations[2] }} - {{ pads[2] }};
|
|
85
|
+
if (iw >= 0 && iw < {{ in_spatial[2] }}) {
|
|
86
|
+
{{ c_type }} val = {{ input0 }}[n][c][id][ih][iw];
|
|
87
|
+
{% if indices %}
|
|
88
|
+
const {{ c_type }} prev_max = max_value;
|
|
89
|
+
max_value = {{ max_fn }}(max_value, val);
|
|
90
|
+
max_index = (val > prev_max)
|
|
91
|
+
? (
|
|
92
|
+
{% if storage_order == 0 %}
|
|
93
|
+
({{ indices_c_type }})(((((({{ indices_c_type }})n * {{ channels }} + ({{ indices_c_type }})c) * {{ in_spatial[0] }} + ({{ indices_c_type }})id) * {{ in_spatial[1] }} + ({{ indices_c_type }})ih) * {{ in_spatial[2] }}) + ({{ indices_c_type }})iw)
|
|
94
|
+
{% else %}
|
|
95
|
+
({{ indices_c_type }})(((({{ indices_c_type }})n * {{ channels }} + ({{ indices_c_type }})c) * {{ in_spatial[0] }} * {{ in_spatial[1] }} * {{ in_spatial[2] }}) + ({{ indices_c_type }})id + ({{ indices_c_type }})ih * {{ in_spatial[0] }} + ({{ indices_c_type }})iw * {{ in_spatial[0] }} * {{ in_spatial[1] }})
|
|
96
|
+
{% endif %}
|
|
97
|
+
)
|
|
98
|
+
: max_index;
|
|
99
|
+
{% else %}
|
|
100
|
+
max_value = {{ max_fn }}(max_value, val);
|
|
101
|
+
{% endif %}
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
}
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
}
|
|
108
|
+
{{ output }}[n][c][od][oh][ow] = max_value;
|
|
109
|
+
{% if indices %}
|
|
110
|
+
{{ indices }}[n][c][od][oh][ow] = max_index;
|
|
111
|
+
{% endif %}
|
|
112
|
+
}
|
|
113
|
+
}
|
|
114
|
+
}
|
|
115
|
+
{% endif %}
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
}
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for axis in non_axes %}
|
|
3
|
+
for (idx_t {{ loop_vars[axis] }} = 0; {{ loop_vars[axis] }} < {{ shape[axis] }}; ++{{ loop_vars[axis] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
{{ c_type }} sum = {{ zero_literal }};
|
|
6
|
+
{% for axis in axes %}
|
|
7
|
+
for (idx_t {{ loop_vars[axis] }} = 0; {{ loop_vars[axis] }} < {{ shape[axis] }}; ++{{ loop_vars[axis] }}) {
|
|
8
|
+
{% endfor %}
|
|
9
|
+
sum += {{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %};
|
|
10
|
+
{% for _ in axes %}
|
|
11
|
+
}
|
|
12
|
+
{% endfor %}
|
|
13
|
+
{{ c_type }} mean = sum / {{ reduce_count }};
|
|
14
|
+
{{ c_type }} var = {{ zero_literal }};
|
|
15
|
+
{% for axis in axes %}
|
|
16
|
+
for (idx_t {{ loop_vars[axis] }} = 0; {{ loop_vars[axis] }} < {{ shape[axis] }}; ++{{ loop_vars[axis] }}) {
|
|
17
|
+
{% endfor %}
|
|
18
|
+
{{ c_type }} diff = {{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %} - mean;
|
|
19
|
+
var += diff * diff;
|
|
20
|
+
{% for _ in axes %}
|
|
21
|
+
}
|
|
22
|
+
{% endfor %}
|
|
23
|
+
{{ c_type }} denom = {{ sqrt_fn }}(var / {{ reduce_count }} + {{ epsilon_literal }});
|
|
24
|
+
{% for axis in axes %}
|
|
25
|
+
for (idx_t {{ loop_vars[axis] }} = 0; {{ loop_vars[axis] }} < {{ shape[axis] }}; ++{{ loop_vars[axis] }}) {
|
|
26
|
+
{% endfor %}
|
|
27
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = ({{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %} - mean) / denom;
|
|
28
|
+
{% for _ in axes %}
|
|
29
|
+
}
|
|
30
|
+
{% endfor %}
|
|
31
|
+
{% for _ in non_axes %}
|
|
32
|
+
}
|
|
33
|
+
{% endfor %}
|
|
34
|
+
}
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in shape %}
|
|
3
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
{{ output_expr }} = {{ input_exprs[0] }};
|
|
6
|
+
{% for expr in input_exprs[1:] %}
|
|
7
|
+
{{ output_expr }} = {% if operator_kind == "func" %}{{ operator }}({{ output_expr }}, {{ expr }}){% elif operator_kind == "expr" %}{{ operator_expr }}{% else %}{{ output_expr }} {{ operator }} {{ expr }}{% endif %};
|
|
8
|
+
{% endfor %}
|
|
9
|
+
{% if is_mean %}
|
|
10
|
+
{{ output_expr }} = {{ output_expr }} / {{ mean_scale }};
|
|
11
|
+
{% endif %}
|
|
12
|
+
{% for _ in shape %}
|
|
13
|
+
}
|
|
14
|
+
{% endfor %}
|
|
15
|
+
}
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
const {{ c_type }} *input_flat = (const {{ c_type }} *){{ input0 }};
|
|
3
|
+
const {{ target_c_type }} *target_flat = (const {{ target_c_type }} *){{ target }};
|
|
4
|
+
{{ c_type }} *output_flat = ({{ c_type }} *){{ output }};
|
|
5
|
+
const idx_t n = {{ n }};
|
|
6
|
+
const idx_t c = {{ c }};
|
|
7
|
+
const idx_t d = {{ d }};
|
|
8
|
+
{% if reduction != "none" %}
|
|
9
|
+
{{ acc_type }} loss_sum = {{ acc_zero_literal }};
|
|
10
|
+
{% if reduction == "mean" %}
|
|
11
|
+
{{ acc_type }} weight_sum = {{ acc_zero_literal }};
|
|
12
|
+
{% endif %}
|
|
13
|
+
{% endif %}
|
|
14
|
+
for (idx_t n_idx = 0; n_idx < n; ++n_idx) {
|
|
15
|
+
for (idx_t d_idx = 0; d_idx < d; ++d_idx) {
|
|
16
|
+
idx_t target_index = n_idx * d + d_idx;
|
|
17
|
+
{{ target_c_type }} target_value = target_flat[target_index];
|
|
18
|
+
if ((int64_t)target_value == {{ ignore_index }}) {
|
|
19
|
+
{% if reduction == "none" %}
|
|
20
|
+
output_flat[target_index] = {{ zero_literal }};
|
|
21
|
+
{% endif %}
|
|
22
|
+
} else {
|
|
23
|
+
idx_t class_index = (idx_t)target_value;
|
|
24
|
+
idx_t input_index = (n_idx * c + class_index) * d + d_idx;
|
|
25
|
+
{{ acc_type }} value = -({{ acc_type }})input_flat[input_index];
|
|
26
|
+
{% if weight %}
|
|
27
|
+
{{ acc_type }} sample_weight = {{ weight }}[class_index];
|
|
28
|
+
value *= sample_weight;
|
|
29
|
+
{% endif %}
|
|
30
|
+
{% if reduction == "none" %}
|
|
31
|
+
output_flat[target_index] = value;
|
|
32
|
+
{% else %}
|
|
33
|
+
loss_sum += value;
|
|
34
|
+
{% if reduction == "mean" %}
|
|
35
|
+
{% if weight %}
|
|
36
|
+
weight_sum += sample_weight;
|
|
37
|
+
{% else %}
|
|
38
|
+
weight_sum += {{ acc_one_literal }};
|
|
39
|
+
{% endif %}
|
|
40
|
+
{% endif %}
|
|
41
|
+
{% endif %}
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
{% if reduction == "mean" %}
|
|
46
|
+
if (weight_sum == {{ acc_zero_literal }}) {
|
|
47
|
+
output_flat[0] = {{ zero_literal }};
|
|
48
|
+
} else {
|
|
49
|
+
output_flat[0] = loss_sum / weight_sum;
|
|
50
|
+
}
|
|
51
|
+
{% elif reduction == "sum" %}
|
|
52
|
+
output_flat[0] = loss_sum;
|
|
53
|
+
{% endif %}
|
|
54
|
+
}
|