emx-onnx-cgen 0.3.7__py3-none-any.whl → 0.4.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/cli.py +1025 -162
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +2081 -458
- emx_onnx_cgen/compiler.py +157 -75
- emx_onnx_cgen/determinism.py +39 -0
- emx_onnx_cgen/ir/context.py +25 -15
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +32 -7
- emx_onnx_cgen/ir/ops/__init__.py +20 -0
- emx_onnx_cgen/ir/ops/elementwise.py +138 -22
- emx_onnx_cgen/ir/ops/misc.py +95 -0
- emx_onnx_cgen/ir/ops/nn.py +361 -38
- emx_onnx_cgen/ir/ops/reduce.py +1 -16
- emx_onnx_cgen/lowering/__init__.py +9 -0
- emx_onnx_cgen/lowering/arg_reduce.py +0 -4
- emx_onnx_cgen/lowering/average_pool.py +157 -27
- emx_onnx_cgen/lowering/bernoulli.py +73 -0
- emx_onnx_cgen/lowering/common.py +48 -0
- emx_onnx_cgen/lowering/concat.py +41 -7
- emx_onnx_cgen/lowering/conv.py +19 -8
- emx_onnx_cgen/lowering/conv_integer.py +103 -0
- emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
- emx_onnx_cgen/lowering/elementwise.py +140 -43
- emx_onnx_cgen/lowering/gather.py +11 -2
- emx_onnx_cgen/lowering/gemm.py +7 -124
- emx_onnx_cgen/lowering/global_max_pool.py +0 -5
- emx_onnx_cgen/lowering/gru.py +323 -0
- emx_onnx_cgen/lowering/hamming_window.py +104 -0
- emx_onnx_cgen/lowering/hardmax.py +1 -37
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/logsoftmax.py +1 -35
- emx_onnx_cgen/lowering/lp_pool.py +15 -4
- emx_onnx_cgen/lowering/matmul.py +3 -105
- emx_onnx_cgen/lowering/optional_has_element.py +28 -0
- emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
- emx_onnx_cgen/lowering/reduce.py +0 -5
- emx_onnx_cgen/lowering/reshape.py +7 -16
- emx_onnx_cgen/lowering/shape.py +14 -8
- emx_onnx_cgen/lowering/slice.py +14 -4
- emx_onnx_cgen/lowering/softmax.py +1 -35
- emx_onnx_cgen/lowering/split.py +37 -3
- emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
- emx_onnx_cgen/lowering/tile.py +38 -1
- emx_onnx_cgen/lowering/topk.py +1 -5
- emx_onnx_cgen/lowering/transpose.py +9 -3
- emx_onnx_cgen/lowering/unsqueeze.py +11 -16
- emx_onnx_cgen/lowering/upsample.py +151 -0
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +0 -5
- emx_onnx_cgen/onnx_import.py +578 -14
- emx_onnx_cgen/ops.py +3 -0
- emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
- emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
- emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
- emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
- emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
- emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
- emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
- emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
- emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
- emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
- emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
- emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
- emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
- emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
- emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
- emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
- emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
- emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
- emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
- emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
- emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
- emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
- emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
- emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
- emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
- emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
- emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
- emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
- emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
- emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
- emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
- emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
- emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
- emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
- emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
- emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
- emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
- emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
- emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
- emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
- emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
- emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
- emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
- emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
- emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
- emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
- emx_onnx_cgen/templates/range_op.c.j2 +8 -0
- emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
- emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
- emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
- emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
- emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
- emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
- emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
- emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
- emx_onnx_cgen/templates/size_op.c.j2 +4 -0
- emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
- emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
- emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
- emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
- emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
- emx_onnx_cgen/templates/split_op.c.j2 +18 -0
- emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
- emx_onnx_cgen/templates/testbench.c.j2 +161 -0
- emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
- emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
- emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
- emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
- emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
- emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
- emx_onnx_cgen/templates/where_op.c.j2 +9 -0
- emx_onnx_cgen/verification.py +45 -5
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/METADATA +33 -15
- emx_onnx_cgen-0.4.1.dev0.dist-info/RECORD +190 -0
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/WHEEL +1 -1
- emx_onnx_cgen/runtime/__init__.py +0 -1
- emx_onnx_cgen/runtime/evaluator.py +0 -2955
- emx_onnx_cgen-0.3.7.dist-info/RECORD +0 -107
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in shape %}
|
|
3
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
{{ compute_type }} dequantized = (({{ compute_type }}){{ input_expr }} - ({{ compute_type }}){{ zero_expr }}) * ({{ compute_type }}){{ scale_expr }};
|
|
6
|
+
{{ output_expr }} = ({{ output_c_type }})dequantized;
|
|
7
|
+
{% for _ in shape %}
|
|
8
|
+
}
|
|
9
|
+
{% endfor %}
|
|
10
|
+
}
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% if kind == "reduce_all" %}
|
|
3
|
+
{{ acc_type }} acc = {{ zero_literal }};
|
|
4
|
+
{% for var in input_loop_vars %}
|
|
5
|
+
{% for indent in range(loop.index0) %} {% endfor %}for (idx_t {{ var }} = 0; {{ var }} < {{ input_loop_bounds[loop.index0] }}; ++{{ var }}) {
|
|
6
|
+
{% endfor %}
|
|
7
|
+
{% for indent in range(input_loop_vars | length) %} {% endfor %}acc += {{ input_expr }};
|
|
8
|
+
{% for _ in input_loop_vars | reverse %}
|
|
9
|
+
{% for indent in range(loop.index0) %} {% endfor %}}
|
|
10
|
+
{% endfor %}
|
|
11
|
+
{{ output_expr }} = acc;
|
|
12
|
+
{% elif kind == "sum_j" %}
|
|
13
|
+
for (idx_t {{ output_loop_vars[0] }} = 0; {{ output_loop_vars[0] }} < {{ output_loop_bounds[0] }}; ++{{ output_loop_vars[0] }}) {
|
|
14
|
+
{{ acc_type }} acc = {{ zero_literal }};
|
|
15
|
+
for (idx_t {{ reduce_loop_var }} = 0; {{ reduce_loop_var }} < {{ reduce_loop_bound }}; ++{{ reduce_loop_var }}) {
|
|
16
|
+
acc += {{ input_expr }};
|
|
17
|
+
}
|
|
18
|
+
{{ output_expr }} = acc;
|
|
19
|
+
}
|
|
20
|
+
{% elif kind == "transpose" %}
|
|
21
|
+
{% for var in output_loop_vars %}
|
|
22
|
+
{% for indent in range(loop.index0) %} {% endfor %}for (idx_t {{ var }} = 0; {{ var }} < {{ output_loop_bounds[loop.index0] }}; ++{{ var }}) {
|
|
23
|
+
{% endfor %}
|
|
24
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}{{ output_expr }} = {{ input_expr }};
|
|
25
|
+
{% for _ in output_loop_vars | reverse %}
|
|
26
|
+
{% for indent in range(loop.index0) %} {% endfor %}}
|
|
27
|
+
{% endfor %}
|
|
28
|
+
{% elif kind == "dot" %}
|
|
29
|
+
{{ acc_type }} acc = {{ zero_literal }};
|
|
30
|
+
for (idx_t {{ reduce_loop_var }} = 0; {{ reduce_loop_var }} < {{ reduce_loop_bound }}; ++{{ reduce_loop_var }}) {
|
|
31
|
+
acc += {{ input0_expr }} * {{ input1_expr }};
|
|
32
|
+
}
|
|
33
|
+
{{ output_expr }} = acc;
|
|
34
|
+
{% elif kind == "batch_matmul" %}
|
|
35
|
+
{% for var in output_loop_vars %}
|
|
36
|
+
{% for indent in range(loop.index0) %} {% endfor %}for (idx_t {{ var }} = 0; {{ var }} < {{ output_loop_bounds[loop.index0] }}; ++{{ var }}) {
|
|
37
|
+
{% endfor %}
|
|
38
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}{{ acc_type }} acc = {{ zero_literal }};
|
|
39
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}for (idx_t {{ reduce_loop_var }} = 0; {{ reduce_loop_var }} < {{ reduce_loop_bound }}; ++{{ reduce_loop_var }}) {
|
|
40
|
+
{% for indent in range(output_loop_vars | length + 1) %} {% endfor %}acc += {{ input0_expr }} * {{ input1_expr }};
|
|
41
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}}
|
|
42
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}{{ output_expr }} = acc;
|
|
43
|
+
{% for _ in output_loop_vars | reverse %}
|
|
44
|
+
{% for indent in range(loop.index0) %} {% endfor %}}
|
|
45
|
+
{% endfor %}
|
|
46
|
+
{% elif kind == "batch_diagonal" %}
|
|
47
|
+
{% for var in output_loop_vars %}
|
|
48
|
+
{% for indent in range(loop.index0) %} {% endfor %}for (idx_t {{ var }} = 0; {{ var }} < {{ output_loop_bounds[loop.index0] }}; ++{{ var }}) {
|
|
49
|
+
{% endfor %}
|
|
50
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}{{ output_expr }} = {{ input_expr }};
|
|
51
|
+
{% for _ in output_loop_vars | reverse %}
|
|
52
|
+
{% for indent in range(loop.index0) %} {% endfor %}}
|
|
53
|
+
{% endfor %}
|
|
54
|
+
{% endif %}
|
|
55
|
+
}
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
const {{ c_type }} *input_data = (const {{ c_type }} *){{ input0 }};
|
|
3
|
+
{{ c_type }} *output_data = ({{ c_type }} *){{ output }};
|
|
4
|
+
idx_t output_index = 0;
|
|
5
|
+
{% for dim in output_shape %}
|
|
6
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
7
|
+
{% endfor %}
|
|
8
|
+
idx_t input_index = {{ input_index_expr }};
|
|
9
|
+
output_data[output_index] = input_data[input_index];
|
|
10
|
+
output_index++;
|
|
11
|
+
{% for _ in output_shape %}
|
|
12
|
+
}
|
|
13
|
+
{% endfor %}
|
|
14
|
+
}
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
(void){{ input0 }};
|
|
3
|
+
{{ c_type }} *output_data = ({{ c_type }} *){{ output }};
|
|
4
|
+
idx_t total = (idx_t){{ batch_size }} * {{ rows }} * {{ cols }};
|
|
5
|
+
for (idx_t index = 0; index < total; ++index) {
|
|
6
|
+
output_data[index] = {{ zero_literal }};
|
|
7
|
+
}
|
|
8
|
+
idx_t k = {{ k }};
|
|
9
|
+
idx_t rows = {{ rows }};
|
|
10
|
+
idx_t cols = {{ cols }};
|
|
11
|
+
idx_t row_start = k >= 0 ? 0 : -k;
|
|
12
|
+
idx_t col_start = k >= 0 ? k : 0;
|
|
13
|
+
if (row_start >= rows || col_start >= cols) {
|
|
14
|
+
return;
|
|
15
|
+
}
|
|
16
|
+
idx_t max_rows = rows - row_start;
|
|
17
|
+
idx_t max_cols = cols - col_start;
|
|
18
|
+
idx_t diag_len = max_rows < max_cols ? max_rows : max_cols;
|
|
19
|
+
for (idx_t batch = 0; batch < {{ batch_size }}; ++batch) {
|
|
20
|
+
idx_t base = batch * rows * cols;
|
|
21
|
+
for (idx_t diag = 0; diag < diag_len; ++diag) {
|
|
22
|
+
idx_t row = row_start + diag;
|
|
23
|
+
idx_t col = col_start + diag;
|
|
24
|
+
output_data[base + row * cols + col] = {{ one_literal }};
|
|
25
|
+
}
|
|
26
|
+
}
|
|
27
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in output_shape %}
|
|
3
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
idx_t gather_index = {{ indices }}{% for var in loop_vars %}[{{ var }}]{% endfor %};
|
|
6
|
+
if (gather_index < 0) {
|
|
7
|
+
gather_index += {{ axis_dim }};
|
|
8
|
+
}
|
|
9
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = {{ data }}{% for idx in data_indices %}[{{ idx }}]{% endfor %};
|
|
10
|
+
{% for _ in output_shape %}
|
|
11
|
+
}
|
|
12
|
+
{% endfor %}
|
|
13
|
+
}
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% if indices_prefix_shape %}
|
|
3
|
+
{% for dim in indices_prefix_shape %}
|
|
4
|
+
for (idx_t {{ indices_prefix_loop_vars[loop.index0] }} = 0; {{ indices_prefix_loop_vars[loop.index0] }} < {{ dim }}; ++{{ indices_prefix_loop_vars[loop.index0] }}) {
|
|
5
|
+
{% endfor %}
|
|
6
|
+
{% endif %}
|
|
7
|
+
{% for idx in range(index_depth) %}
|
|
8
|
+
idx_t index{{ idx }} = {{ indices }}{% for var in indices_prefix_loop_vars %}[{{ var }}]{% endfor %}[{{ idx }}];
|
|
9
|
+
if (index{{ idx }} < 0) {
|
|
10
|
+
index{{ idx }} += {{ data_shape[batch_dims + idx] }};
|
|
11
|
+
}
|
|
12
|
+
{% endfor %}
|
|
13
|
+
{% if tail_shape %}
|
|
14
|
+
{% for dim in tail_shape %}
|
|
15
|
+
for (idx_t {{ tail_loop_vars[loop.index0] }} = 0; {{ tail_loop_vars[loop.index0] }} < {{ dim }}; ++{{ tail_loop_vars[loop.index0] }}) {
|
|
16
|
+
{% endfor %}
|
|
17
|
+
{% endif %}
|
|
18
|
+
{{ output_index_expr }} = {{ data_index_expr }};
|
|
19
|
+
{% if tail_shape %}
|
|
20
|
+
{% for _ in tail_shape %}
|
|
21
|
+
}
|
|
22
|
+
{% endfor %}
|
|
23
|
+
{% endif %}
|
|
24
|
+
{% if indices_prefix_shape %}
|
|
25
|
+
{% for _ in indices_prefix_shape %}
|
|
26
|
+
}
|
|
27
|
+
{% endfor %}
|
|
28
|
+
{% endif %}
|
|
29
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in output_shape %}
|
|
3
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
idx_t gather_index = {{ indices }}{% for var in indices_indices %}[{{ var }}]{% endfor %};
|
|
6
|
+
if (gather_index < 0) {
|
|
7
|
+
gather_index += {{ axis_dim }};
|
|
8
|
+
}
|
|
9
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = {{ data }}{% for idx in data_indices %}[{{ idx }}]{% endfor %};
|
|
10
|
+
{% for _ in output_shape %}
|
|
11
|
+
}
|
|
12
|
+
{% endfor %}
|
|
13
|
+
}
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
for (idx_t i = 0; i < {{ m }}; ++i) {
|
|
3
|
+
for (idx_t j = 0; j < {{ n }}; ++j) {
|
|
4
|
+
{{ acc_type }} acc = {{ zero_literal }};
|
|
5
|
+
for (idx_t k = 0; k < {{ k }}; ++k) {
|
|
6
|
+
{% if trans_a %}
|
|
7
|
+
const {{ c_type }} a_val = {{ input_a }}[k][i];
|
|
8
|
+
{% else %}
|
|
9
|
+
const {{ c_type }} a_val = {{ input_a }}[i][k];
|
|
10
|
+
{% endif %}
|
|
11
|
+
{% if trans_b %}
|
|
12
|
+
const {{ c_type }} b_val = {{ input_b }}[j][k];
|
|
13
|
+
{% else %}
|
|
14
|
+
const {{ c_type }} b_val = {{ input_b }}[k][j];
|
|
15
|
+
{% endif %}
|
|
16
|
+
acc += a_val * b_val;
|
|
17
|
+
}
|
|
18
|
+
{% if input_c %}
|
|
19
|
+
{% if c_rank == 2 %}
|
|
20
|
+
idx_t c_i = {% if c_dim0 == 1 %}0{% else %}i{% endif %};
|
|
21
|
+
idx_t c_j = {% if c_dim1 == 1 %}0{% else %}j{% endif %};
|
|
22
|
+
const {{ c_type }} bias = {{ input_c }}[c_i][c_j];
|
|
23
|
+
{% elif c_rank == 1 %}
|
|
24
|
+
idx_t c_j = {% if c_dim1 == 1 %}0{% else %}j{% endif %};
|
|
25
|
+
const {{ c_type }} bias = {{ input_c }}[c_j];
|
|
26
|
+
{% else %}
|
|
27
|
+
const {{ c_type }} bias = {{ input_c }}[0];
|
|
28
|
+
{% endif %}
|
|
29
|
+
{{ output }}[i][j] = acc * {{ alpha_literal }} + bias * {{ beta_literal }};
|
|
30
|
+
{% else %}
|
|
31
|
+
{{ output }}[i][j] = acc * {{ alpha_literal }};
|
|
32
|
+
{% endif %}
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
}
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
static inline double {{ op_name }}_reflect(double value, double x_min, double x_max) {
|
|
2
|
+
const double range = x_max - x_min;
|
|
3
|
+
if (range == 0.0) {
|
|
4
|
+
return x_min;
|
|
5
|
+
}
|
|
6
|
+
if (value < x_min) {
|
|
7
|
+
const double dx = x_min - value;
|
|
8
|
+
const int n = (int)(dx / range);
|
|
9
|
+
const double r = dx - (double)n * range;
|
|
10
|
+
return (n % 2 == 0) ? (x_min + r) : (x_max - r);
|
|
11
|
+
}
|
|
12
|
+
if (value > x_max) {
|
|
13
|
+
const double dx = value - x_max;
|
|
14
|
+
const int n = (int)(dx / range);
|
|
15
|
+
const double r = dx - (double)n * range;
|
|
16
|
+
return (n % 2 == 0) ? (x_max - r) : (x_min + r);
|
|
17
|
+
}
|
|
18
|
+
return value;
|
|
19
|
+
}
|
|
20
|
+
{% if mode == "cubic" %}
|
|
21
|
+
|
|
22
|
+
static inline void {{ op_name }}_cubic_coeffs(double x, double coeffs[4]) {
|
|
23
|
+
const double alpha = -0.75;
|
|
24
|
+
const double abs_x = fabs(x);
|
|
25
|
+
const double inv_x = 1.0 - abs_x;
|
|
26
|
+
const double span = 2.0 - abs_x;
|
|
27
|
+
coeffs[0] = ((alpha * (abs_x + 1.0) - 5.0 * alpha) * (abs_x + 1.0) + 8.0 * alpha) * (abs_x + 1.0) - 4.0 * alpha;
|
|
28
|
+
coeffs[1] = ((alpha + 2.0) * abs_x - (alpha + 3.0)) * abs_x * abs_x + 1.0;
|
|
29
|
+
coeffs[2] = ((alpha + 2.0) * inv_x - (alpha + 3.0)) * inv_x * inv_x + 1.0;
|
|
30
|
+
coeffs[3] = ((alpha * span - 5.0 * alpha) * span + 8.0 * alpha) * span - 4.0 * alpha;
|
|
31
|
+
}
|
|
32
|
+
{% endif %}
|
|
33
|
+
|
|
34
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
35
|
+
const int input_spatial[{{ spatial_rank }}] = { {% for dim in input_spatial %}{{ dim }}{% if not loop.last %}, {% endif %}{% endfor %} };
|
|
36
|
+
{% if padding_mode != "zeros" %}
|
|
37
|
+
const double border_min[{{ spatial_rank }}] = { {% for dim in input_spatial %}{% if align_corners %}0.0{% else %}-0.5{% endif %}{% if not loop.last %}, {% endif %}{% endfor %} };
|
|
38
|
+
const double border_max[{{ spatial_rank }}] = { {% for dim in input_spatial %}{% if align_corners %}{{ dim - 1 }}.0{% else %}{{ dim - 0.5 }}{% endif %}{% if not loop.last %}, {% endif %}{% endfor %} };
|
|
39
|
+
{% endif %}
|
|
40
|
+
{% set n_var = output_loop_vars[0] %}
|
|
41
|
+
{% set c_var = output_loop_vars[1] %}
|
|
42
|
+
{% set spatial_vars = output_loop_vars[2:] %}
|
|
43
|
+
for (idx_t {{ n_var }} = 0; {{ n_var }} < {{ output_shape[0] }}; ++{{ n_var }}) {
|
|
44
|
+
for (idx_t {{ c_var }} = 0; {{ c_var }} < {{ output_shape[1] }}; ++{{ c_var }}) {
|
|
45
|
+
{% for index in range(spatial_rank) %}
|
|
46
|
+
for (idx_t {{ spatial_vars[index] }} = 0; {{ spatial_vars[index] }} < {{ output_spatial[index] }}; ++{{ spatial_vars[index] }}) {
|
|
47
|
+
{% endfor %}
|
|
48
|
+
double coords[{{ spatial_rank }}];
|
|
49
|
+
{% for dim in range(spatial_rank) %}
|
|
50
|
+
const double grid_{{ dim }} = (double){{ grid }}[{{ n_var }}]{% for var in spatial_vars %}[{{ var }}]{% endfor %}[{{ spatial_rank - 1 - dim }}];
|
|
51
|
+
{% if align_corners %}
|
|
52
|
+
coords[{{ dim }}] = (grid_{{ dim }} + 1.0) * (double)(input_spatial[{{ dim }}] - 1) / 2.0;
|
|
53
|
+
{% else %}
|
|
54
|
+
coords[{{ dim }}] = ((grid_{{ dim }} + 1.0) * (double)input_spatial[{{ dim }}] - 1.0) / 2.0;
|
|
55
|
+
{% endif %}
|
|
56
|
+
{% endfor %}
|
|
57
|
+
{% if padding_mode != "zeros" %}
|
|
58
|
+
{% for dim in range(spatial_rank) %}
|
|
59
|
+
if (coords[{{ dim }}] < border_min[{{ dim }}] || coords[{{ dim }}] > border_max[{{ dim }}]) {
|
|
60
|
+
{% if padding_mode == "border" %}
|
|
61
|
+
if (coords[{{ dim }}] < 0.0) {
|
|
62
|
+
coords[{{ dim }}] = 0.0;
|
|
63
|
+
} else if (coords[{{ dim }}] > (double)(input_spatial[{{ dim }}] - 1)) {
|
|
64
|
+
coords[{{ dim }}] = (double)(input_spatial[{{ dim }}] - 1);
|
|
65
|
+
}
|
|
66
|
+
{% else %}
|
|
67
|
+
coords[{{ dim }}] = {{ op_name }}_reflect(coords[{{ dim }}], border_min[{{ dim }}], border_max[{{ dim }}]);
|
|
68
|
+
{% endif %}
|
|
69
|
+
}
|
|
70
|
+
{% endfor %}
|
|
71
|
+
{% endif %}
|
|
72
|
+
{% if mode == "nearest" %}
|
|
73
|
+
int in_bounds = 1;
|
|
74
|
+
{% for dim in range(spatial_rank) %}
|
|
75
|
+
int idx{{ dim }} = (int)nearbyint(coords[{{ dim }}]);
|
|
76
|
+
{% if padding_mode == "zeros" %}
|
|
77
|
+
if (idx{{ dim }} < 0 || idx{{ dim }} >= input_spatial[{{ dim }}]) {
|
|
78
|
+
in_bounds = 0;
|
|
79
|
+
}
|
|
80
|
+
{% elif padding_mode == "border" %}
|
|
81
|
+
if (idx{{ dim }} < 0) {
|
|
82
|
+
idx{{ dim }} = 0;
|
|
83
|
+
} else if (idx{{ dim }} >= input_spatial[{{ dim }}]) {
|
|
84
|
+
idx{{ dim }} = input_spatial[{{ dim }}] - 1;
|
|
85
|
+
}
|
|
86
|
+
{% else %}
|
|
87
|
+
idx{{ dim }} = (int){{ op_name }}_reflect((double)idx{{ dim }}, border_min[{{ dim }}], border_max[{{ dim }}]);
|
|
88
|
+
{% endif %}
|
|
89
|
+
{% endfor %}
|
|
90
|
+
if (!in_bounds) {
|
|
91
|
+
{{ output }}[{{ n_var }}][{{ c_var }}]{% for var in spatial_vars %}[{{ var }}]{% endfor %} = ({{ c_type }})0;
|
|
92
|
+
} else {
|
|
93
|
+
{{ output }}[{{ n_var }}][{{ c_var }}]{% for var in spatial_vars %}[{{ var }}]{% endfor %} = {{ input0 }}[{{ n_var }}][{{ c_var }}]{% for dim in range(spatial_rank) %}[idx{{ dim }}]{% endfor %};
|
|
94
|
+
}
|
|
95
|
+
{% elif mode == "linear" %}
|
|
96
|
+
int base[{{ spatial_rank }}];
|
|
97
|
+
int upper[{{ spatial_rank }}];
|
|
98
|
+
double w0[{{ spatial_rank }}];
|
|
99
|
+
double w1[{{ spatial_rank }}];
|
|
100
|
+
{% for dim in range(spatial_rank) %}
|
|
101
|
+
base[{{ dim }}] = (int)floor(coords[{{ dim }}]);
|
|
102
|
+
upper[{{ dim }}] = base[{{ dim }}] + 1;
|
|
103
|
+
w1[{{ dim }}] = coords[{{ dim }}] - (double)base[{{ dim }}];
|
|
104
|
+
w0[{{ dim }}] = 1.0 - w1[{{ dim }}];
|
|
105
|
+
{% endfor %}
|
|
106
|
+
double acc = 0.0;
|
|
107
|
+
{% for offsets in linear_offsets %}
|
|
108
|
+
{
|
|
109
|
+
double weight = 1.0;
|
|
110
|
+
{% for dim in range(spatial_rank) %}
|
|
111
|
+
int idx{{ dim }} = {% if offsets[dim] == 0 %}base[{{ dim }}]{% else %}upper[{{ dim }}]{% endif %};
|
|
112
|
+
weight *= {% if offsets[dim] == 0 %}w0[{{ dim }}]{% else %}w1[{{ dim }}]{% endif %};
|
|
113
|
+
{% endfor %}
|
|
114
|
+
int in_bounds = 1;
|
|
115
|
+
{% for dim in range(spatial_rank) %}
|
|
116
|
+
{% if padding_mode == "zeros" %}
|
|
117
|
+
if (idx{{ dim }} < 0 || idx{{ dim }} >= input_spatial[{{ dim }}]) {
|
|
118
|
+
in_bounds = 0;
|
|
119
|
+
}
|
|
120
|
+
{% elif padding_mode == "border" %}
|
|
121
|
+
if (idx{{ dim }} < 0) {
|
|
122
|
+
idx{{ dim }} = 0;
|
|
123
|
+
} else if (idx{{ dim }} >= input_spatial[{{ dim }}]) {
|
|
124
|
+
idx{{ dim }} = input_spatial[{{ dim }}] - 1;
|
|
125
|
+
}
|
|
126
|
+
{% else %}
|
|
127
|
+
idx{{ dim }} = (int){{ op_name }}_reflect((double)idx{{ dim }}, border_min[{{ dim }}], border_max[{{ dim }}]);
|
|
128
|
+
{% endif %}
|
|
129
|
+
{% endfor %}
|
|
130
|
+
if (in_bounds) {
|
|
131
|
+
acc += weight * (double){{ input0 }}[{{ n_var }}][{{ c_var }}]{% for dim in range(spatial_rank) %}[idx{{ dim }}]{% endfor %};
|
|
132
|
+
}
|
|
133
|
+
}
|
|
134
|
+
{% endfor %}
|
|
135
|
+
{{ output }}[{{ n_var }}][{{ c_var }}]{% for var in spatial_vars %}[{{ var }}]{% endfor %} = ({{ c_type }})acc;
|
|
136
|
+
{% else %}
|
|
137
|
+
int base[{{ spatial_rank }}];
|
|
138
|
+
int idxs[{{ spatial_rank }}][4];
|
|
139
|
+
double coeffs[{{ spatial_rank }}][4];
|
|
140
|
+
{% for dim in range(spatial_rank) %}
|
|
141
|
+
base[{{ dim }}] = (int)floor(coords[{{ dim }}]);
|
|
142
|
+
idxs[{{ dim }}][0] = base[{{ dim }}] - 1;
|
|
143
|
+
idxs[{{ dim }}][1] = base[{{ dim }}];
|
|
144
|
+
idxs[{{ dim }}][2] = base[{{ dim }}] + 1;
|
|
145
|
+
idxs[{{ dim }}][3] = base[{{ dim }}] + 2;
|
|
146
|
+
{{ op_name }}_cubic_coeffs(coords[{{ dim }}] - (double)base[{{ dim }}], coeffs[{{ dim }}]);
|
|
147
|
+
{% endfor %}
|
|
148
|
+
double acc = 0.0;
|
|
149
|
+
{% for offsets in cubic_offsets %}
|
|
150
|
+
{
|
|
151
|
+
double weight = 1.0;
|
|
152
|
+
{% for dim in range(spatial_rank) %}
|
|
153
|
+
int idx{{ dim }} = idxs[{{ dim }}][{{ offsets[dim] }}];
|
|
154
|
+
weight *= coeffs[{{ dim }}][{{ offsets[dim] }}];
|
|
155
|
+
{% endfor %}
|
|
156
|
+
int in_bounds = 1;
|
|
157
|
+
{% for dim in range(spatial_rank) %}
|
|
158
|
+
{% if padding_mode == "zeros" %}
|
|
159
|
+
if (idx{{ dim }} < 0 || idx{{ dim }} >= input_spatial[{{ dim }}]) {
|
|
160
|
+
in_bounds = 0;
|
|
161
|
+
}
|
|
162
|
+
{% elif padding_mode == "border" %}
|
|
163
|
+
if (idx{{ dim }} < 0) {
|
|
164
|
+
idx{{ dim }} = 0;
|
|
165
|
+
} else if (idx{{ dim }} >= input_spatial[{{ dim }}]) {
|
|
166
|
+
idx{{ dim }} = input_spatial[{{ dim }}] - 1;
|
|
167
|
+
}
|
|
168
|
+
{% else %}
|
|
169
|
+
idx{{ dim }} = (int){{ op_name }}_reflect((double)idx{{ dim }}, border_min[{{ dim }}], border_max[{{ dim }}]);
|
|
170
|
+
{% endif %}
|
|
171
|
+
{% endfor %}
|
|
172
|
+
if (in_bounds) {
|
|
173
|
+
acc += weight * (double){{ input0 }}[{{ n_var }}][{{ c_var }}]{% for dim in range(spatial_rank) %}[idx{{ dim }}]{% endfor %};
|
|
174
|
+
}
|
|
175
|
+
}
|
|
176
|
+
{% endfor %}
|
|
177
|
+
{{ output }}[{{ n_var }}][{{ c_var }}]{% for var in spatial_vars %}[{{ var }}]{% endfor %} = ({{ c_type }})acc;
|
|
178
|
+
{% endif %}
|
|
179
|
+
{% for index in range(spatial_rank) %}
|
|
180
|
+
}
|
|
181
|
+
{% endfor %}
|
|
182
|
+
}
|
|
183
|
+
}
|
|
184
|
+
}
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in shape[:1] %}
|
|
3
|
+
for (idx_t {{ loop_vars[0] }} = 0; {{ loop_vars[0] }} < {{ dim }}; ++{{ loop_vars[0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
for (idx_t g = 0; g < {{ num_groups }}; ++g) {
|
|
6
|
+
{{ c_type }} sum = {{ zero_literal }};
|
|
7
|
+
for (idx_t c_in_group = 0; c_in_group < {{ group_size }}; ++c_in_group) {
|
|
8
|
+
idx_t c = g * {{ group_size }} + c_in_group;
|
|
9
|
+
{% for dim in shape[2:] %}
|
|
10
|
+
for (idx_t {{ loop_vars[loop.index0 + 2] }} = 0; {{ loop_vars[loop.index0 + 2] }} < {{ dim }}; ++{{ loop_vars[loop.index0 + 2] }}) {
|
|
11
|
+
{% endfor %}
|
|
12
|
+
sum += {{ input0 }}[{{ loop_vars[0] }}][c]{% for var in loop_vars[2:] %}[{{ var }}]{% endfor %};
|
|
13
|
+
{% for _ in shape[2:] %}
|
|
14
|
+
}
|
|
15
|
+
{% endfor %}
|
|
16
|
+
}
|
|
17
|
+
{{ c_type }} mean = sum / ({{ group_size }} * {{ spatial_size }});
|
|
18
|
+
{{ c_type }} var = {{ zero_literal }};
|
|
19
|
+
for (idx_t c_in_group = 0; c_in_group < {{ group_size }}; ++c_in_group) {
|
|
20
|
+
idx_t c = g * {{ group_size }} + c_in_group;
|
|
21
|
+
{% for dim in shape[2:] %}
|
|
22
|
+
for (idx_t {{ loop_vars[loop.index0 + 2] }} = 0; {{ loop_vars[loop.index0 + 2] }} < {{ dim }}; ++{{ loop_vars[loop.index0 + 2] }}) {
|
|
23
|
+
{% endfor %}
|
|
24
|
+
{{ c_type }} diff = {{ input0 }}[{{ loop_vars[0] }}][c]{% for var in loop_vars[2:] %}[{{ var }}]{% endfor %} - mean;
|
|
25
|
+
var += diff * diff;
|
|
26
|
+
{% for _ in shape[2:] %}
|
|
27
|
+
}
|
|
28
|
+
{% endfor %}
|
|
29
|
+
}
|
|
30
|
+
{{ c_type }} denom = {{ sqrt_fn }}(var / ({{ group_size }} * {{ spatial_size }}) + {{ epsilon_literal }});
|
|
31
|
+
for (idx_t c_in_group = 0; c_in_group < {{ group_size }}; ++c_in_group) {
|
|
32
|
+
idx_t c = g * {{ group_size }} + c_in_group;
|
|
33
|
+
{% for dim in shape[2:] %}
|
|
34
|
+
for (idx_t {{ loop_vars[loop.index0 + 2] }} = 0; {{ loop_vars[loop.index0 + 2] }} < {{ dim }}; ++{{ loop_vars[loop.index0 + 2] }}) {
|
|
35
|
+
{% endfor %}
|
|
36
|
+
{{ output }}[{{ loop_vars[0] }}][c]{% for var in loop_vars[2:] %}[{{ var }}]{% endfor %} =
|
|
37
|
+
({{ input0 }}[{{ loop_vars[0] }}][c]{% for var in loop_vars[2:] %}[{{ var }}]{% endfor %} - mean) / denom * {{ scale }}[c] + {{ bias }}[c];
|
|
38
|
+
{% for _ in shape[2:] %}
|
|
39
|
+
}
|
|
40
|
+
{% endfor %}
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
{% for _ in shape[:1] %}
|
|
44
|
+
}
|
|
45
|
+
{% endfor %}
|
|
46
|
+
}
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dir in range(num_directions) %}
|
|
3
|
+
{
|
|
4
|
+
const int dir = {{ dir }};
|
|
5
|
+
const int reverse = {% if direction == "reverse" %}1{% elif direction == "bidirectional" and dir == 1 %}1{% else %}0{% endif %};
|
|
6
|
+
{% set act_f = activation_functions[dir * 2] %}
|
|
7
|
+
{% set act_g = activation_functions[dir * 2 + 1] %}
|
|
8
|
+
{{ c_type }} H_prev[{{ batch_size }}][{{ hidden_size }}];
|
|
9
|
+
for (int b = 0; b < {{ batch_size }}; ++b) {
|
|
10
|
+
for (int h = 0; h < {{ hidden_size }}; ++h) {
|
|
11
|
+
{% if input_initial_h %}
|
|
12
|
+
{% if layout == 0 %}
|
|
13
|
+
H_prev[b][h] = {{ input_initial_h }}[dir][b][h];
|
|
14
|
+
{% else %}
|
|
15
|
+
H_prev[b][h] = {{ input_initial_h }}[b][dir][h];
|
|
16
|
+
{% endif %}
|
|
17
|
+
{% else %}
|
|
18
|
+
H_prev[b][h] = {{ zero_literal }};
|
|
19
|
+
{% endif %}
|
|
20
|
+
}
|
|
21
|
+
}
|
|
22
|
+
for (int b = 0; b < {{ batch_size }}; ++b) {
|
|
23
|
+
int seq_limit = {{ seq_length }};
|
|
24
|
+
{% if input_sequence_lens %}
|
|
25
|
+
seq_limit = (int){{ input_sequence_lens }}[b];
|
|
26
|
+
if (seq_limit < 0) {
|
|
27
|
+
seq_limit = 0;
|
|
28
|
+
}
|
|
29
|
+
if (seq_limit > {{ seq_length }}) {
|
|
30
|
+
seq_limit = {{ seq_length }};
|
|
31
|
+
}
|
|
32
|
+
{% endif %}
|
|
33
|
+
for (int step = 0; step < seq_limit; ++step) {
|
|
34
|
+
int t = reverse ? (seq_limit - 1 - step) : step;
|
|
35
|
+
{{ c_type }} Z_gate[{{ hidden_size }}];
|
|
36
|
+
{{ c_type }} R_gate[{{ hidden_size }}];
|
|
37
|
+
{{ c_type }} H_next[{{ hidden_size }}];
|
|
38
|
+
for (int h = 0; h < {{ hidden_size }}; ++h) {
|
|
39
|
+
{{ c_type }} gate_z = {{ zero_literal }};
|
|
40
|
+
{{ c_type }} gate_r = {{ zero_literal }};
|
|
41
|
+
for (int i = 0; i < {{ input_size }}; ++i) {
|
|
42
|
+
{% if layout == 0 %}
|
|
43
|
+
{{ c_type }} x_val = {{ input_x }}[t][b][i];
|
|
44
|
+
{% else %}
|
|
45
|
+
{{ c_type }} x_val = {{ input_x }}[b][t][i];
|
|
46
|
+
{% endif %}
|
|
47
|
+
gate_z += x_val * {{ input_w }}[dir][h][i];
|
|
48
|
+
gate_r += x_val * {{ input_w }}[dir][{{ hidden_size }} + h][i];
|
|
49
|
+
}
|
|
50
|
+
for (int i = 0; i < {{ hidden_size }}; ++i) {
|
|
51
|
+
{{ c_type }} h_val = H_prev[b][i];
|
|
52
|
+
gate_z += h_val * {{ input_r }}[dir][h][i];
|
|
53
|
+
gate_r += h_val * {{ input_r }}[dir][{{ hidden_size }} + h][i];
|
|
54
|
+
}
|
|
55
|
+
{% if input_b %}
|
|
56
|
+
gate_z += {{ input_b }}[dir][h] + {{ input_b }}[dir][{{ hidden_size }} * 3 + h];
|
|
57
|
+
gate_r += {{ input_b }}[dir][{{ hidden_size }} + h] + {{ input_b }}[dir][{{ hidden_size }} * 4 + h];
|
|
58
|
+
{% endif %}
|
|
59
|
+
{% if use_clip %}
|
|
60
|
+
if (gate_z > {{ clip_literal }}) {
|
|
61
|
+
gate_z = {{ clip_literal }};
|
|
62
|
+
} else if (gate_z < -{{ clip_literal }}) {
|
|
63
|
+
gate_z = -{{ clip_literal }};
|
|
64
|
+
}
|
|
65
|
+
if (gate_r > {{ clip_literal }}) {
|
|
66
|
+
gate_r = {{ clip_literal }};
|
|
67
|
+
} else if (gate_r < -{{ clip_literal }}) {
|
|
68
|
+
gate_r = -{{ clip_literal }};
|
|
69
|
+
}
|
|
70
|
+
{% endif %}
|
|
71
|
+
Z_gate[h] = {{ act_f }}(gate_z);
|
|
72
|
+
R_gate[h] = {{ act_f }}(gate_r);
|
|
73
|
+
}
|
|
74
|
+
for (int h = 0; h < {{ hidden_size }}; ++h) {
|
|
75
|
+
{{ c_type }} gate_h = {{ zero_literal }};
|
|
76
|
+
for (int i = 0; i < {{ input_size }}; ++i) {
|
|
77
|
+
{% if layout == 0 %}
|
|
78
|
+
{{ c_type }} x_val = {{ input_x }}[t][b][i];
|
|
79
|
+
{% else %}
|
|
80
|
+
{{ c_type }} x_val = {{ input_x }}[b][t][i];
|
|
81
|
+
{% endif %}
|
|
82
|
+
gate_h += x_val * {{ input_w }}[dir][{{ hidden_size }} * 2 + h][i];
|
|
83
|
+
}
|
|
84
|
+
{% if linear_before_reset %}
|
|
85
|
+
{{ c_type }} recur = {{ zero_literal }};
|
|
86
|
+
for (int i = 0; i < {{ hidden_size }}; ++i) {
|
|
87
|
+
recur += H_prev[b][i] * {{ input_r }}[dir][{{ hidden_size }} * 2 + h][i];
|
|
88
|
+
}
|
|
89
|
+
{% if input_b %}
|
|
90
|
+
recur += {{ input_b }}[dir][{{ hidden_size }} * 5 + h];
|
|
91
|
+
{% endif %}
|
|
92
|
+
gate_h += R_gate[h] * recur;
|
|
93
|
+
{% if input_b %}
|
|
94
|
+
gate_h += {{ input_b }}[dir][{{ hidden_size }} * 2 + h];
|
|
95
|
+
{% endif %}
|
|
96
|
+
{% else %}
|
|
97
|
+
for (int i = 0; i < {{ hidden_size }}; ++i) {
|
|
98
|
+
gate_h += (R_gate[i] * H_prev[b][i]) * {{ input_r }}[dir][{{ hidden_size }} * 2 + h][i];
|
|
99
|
+
}
|
|
100
|
+
{% if input_b %}
|
|
101
|
+
gate_h += {{ input_b }}[dir][{{ hidden_size }} * 2 + h] + {{ input_b }}[dir][{{ hidden_size }} * 5 + h];
|
|
102
|
+
{% endif %}
|
|
103
|
+
{% endif %}
|
|
104
|
+
{% if use_clip %}
|
|
105
|
+
if (gate_h > {{ clip_literal }}) {
|
|
106
|
+
gate_h = {{ clip_literal }};
|
|
107
|
+
} else if (gate_h < -{{ clip_literal }}) {
|
|
108
|
+
gate_h = -{{ clip_literal }};
|
|
109
|
+
}
|
|
110
|
+
{% endif %}
|
|
111
|
+
{{ c_type }} h_candidate = {{ act_g }}(gate_h);
|
|
112
|
+
{{ c_type }} h_prev = H_prev[b][h];
|
|
113
|
+
{{ c_type }} h_new = ({{ one_literal }} - Z_gate[h]) * h_candidate + Z_gate[h] * h_prev;
|
|
114
|
+
H_next[h] = h_new;
|
|
115
|
+
{% if output_y %}
|
|
116
|
+
{% if layout == 0 %}
|
|
117
|
+
{{ output_y }}[step][dir][b][h] = h_new;
|
|
118
|
+
{% else %}
|
|
119
|
+
{{ output_y }}[b][step][dir][h] = h_new;
|
|
120
|
+
{% endif %}
|
|
121
|
+
{% endif %}
|
|
122
|
+
}
|
|
123
|
+
for (int h = 0; h < {{ hidden_size }}; ++h) {
|
|
124
|
+
H_prev[b][h] = H_next[h];
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
{% if output_y %}
|
|
128
|
+
for (int step = seq_limit; step < {{ seq_length }}; ++step) {
|
|
129
|
+
for (int h = 0; h < {{ hidden_size }}; ++h) {
|
|
130
|
+
{% if layout == 0 %}
|
|
131
|
+
{{ output_y }}[step][dir][b][h] = {{ zero_literal }};
|
|
132
|
+
{% else %}
|
|
133
|
+
{{ output_y }}[b][step][dir][h] = {{ zero_literal }};
|
|
134
|
+
{% endif %}
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
{% endif %}
|
|
138
|
+
}
|
|
139
|
+
{% if output_y_h %}
|
|
140
|
+
for (int b = 0; b < {{ batch_size }}; ++b) {
|
|
141
|
+
for (int h = 0; h < {{ hidden_size }}; ++h) {
|
|
142
|
+
{% if layout == 0 %}
|
|
143
|
+
{{ output_y_h }}[dir][b][h] = H_prev[b][h];
|
|
144
|
+
{% else %}
|
|
145
|
+
{{ output_y_h }}[b][dir][h] = H_prev[b][h];
|
|
146
|
+
{% endif %}
|
|
147
|
+
}
|
|
148
|
+
}
|
|
149
|
+
{% endif %}
|
|
150
|
+
}
|
|
151
|
+
{% endfor %}
|
|
152
|
+
}
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
const double size_value = (double){{ size }}[0];
|
|
3
|
+
const double denom = {{ periodic_literal }} ? size_value : (size_value - 1.0);
|
|
4
|
+
const double alpha = 25.0 / 46.0;
|
|
5
|
+
const double beta = 1.0 - alpha;
|
|
6
|
+
const double pi = 3.14159265358979323846;
|
|
7
|
+
for (idx_t idx = 0; idx < {{ length }}; ++idx) {
|
|
8
|
+
const double phase = (2.0 * pi * (double)idx) / denom;
|
|
9
|
+
const double value = alpha - beta * cos(phase);
|
|
10
|
+
{{ output }}[idx] = ({{ c_type }})value;
|
|
11
|
+
}
|
|
12
|
+
}
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
const {{ c_type }} *input_flat = (const {{ c_type }} *){{ input0 }};
|
|
3
|
+
{{ c_type }} *output_flat = ({{ c_type }} *){{ output }};
|
|
4
|
+
const idx_t outer = {{ outer }};
|
|
5
|
+
const idx_t axis_size = {{ axis_size }};
|
|
6
|
+
const idx_t inner = {{ inner }};
|
|
7
|
+
for (idx_t outer_idx = 0; outer_idx < outer; ++outer_idx) {
|
|
8
|
+
for (idx_t inner_idx = 0; inner_idx < inner; ++inner_idx) {
|
|
9
|
+
idx_t base = (outer_idx * axis_size * inner) + inner_idx;
|
|
10
|
+
{{ c_type }} max_value = input_flat[base];
|
|
11
|
+
idx_t max_index = 0;
|
|
12
|
+
for (idx_t axis_idx = 1; axis_idx < axis_size; ++axis_idx) {
|
|
13
|
+
{{ c_type }} value = input_flat[base + axis_idx * inner];
|
|
14
|
+
const {{ c_type }} prev_max = max_value;
|
|
15
|
+
max_value = {{ max_fn }}(max_value, value);
|
|
16
|
+
max_index = (value > prev_max) ? axis_idx : max_index;
|
|
17
|
+
}
|
|
18
|
+
for (idx_t axis_idx = 0; axis_idx < axis_size; ++axis_idx) {
|
|
19
|
+
output_flat[base + axis_idx * inner] =
|
|
20
|
+
axis_idx == max_index ? {{ one_literal }} : {{ zero_literal }};
|
|
21
|
+
}
|
|
22
|
+
}
|
|
23
|
+
}
|
|
24
|
+
}
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in shape %}
|
|
3
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = {{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %};
|
|
6
|
+
{% for _ in shape %}
|
|
7
|
+
}
|
|
8
|
+
{% endfor %}
|
|
9
|
+
}
|