emx-onnx-cgen 0.3.8__py3-none-any.whl → 0.4.2.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of emx-onnx-cgen might be problematic. Click here for more details.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/cli.py +1025 -162
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +2081 -458
- emx_onnx_cgen/compiler.py +157 -75
- emx_onnx_cgen/determinism.py +39 -0
- emx_onnx_cgen/ir/context.py +25 -15
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +32 -7
- emx_onnx_cgen/ir/ops/__init__.py +20 -0
- emx_onnx_cgen/ir/ops/elementwise.py +138 -22
- emx_onnx_cgen/ir/ops/misc.py +95 -0
- emx_onnx_cgen/ir/ops/nn.py +361 -38
- emx_onnx_cgen/ir/ops/reduce.py +1 -16
- emx_onnx_cgen/lowering/__init__.py +9 -0
- emx_onnx_cgen/lowering/arg_reduce.py +0 -4
- emx_onnx_cgen/lowering/average_pool.py +157 -27
- emx_onnx_cgen/lowering/bernoulli.py +73 -0
- emx_onnx_cgen/lowering/common.py +48 -0
- emx_onnx_cgen/lowering/concat.py +41 -7
- emx_onnx_cgen/lowering/conv.py +19 -8
- emx_onnx_cgen/lowering/conv_integer.py +103 -0
- emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
- emx_onnx_cgen/lowering/elementwise.py +140 -43
- emx_onnx_cgen/lowering/gather.py +11 -2
- emx_onnx_cgen/lowering/gemm.py +7 -124
- emx_onnx_cgen/lowering/global_max_pool.py +0 -5
- emx_onnx_cgen/lowering/gru.py +323 -0
- emx_onnx_cgen/lowering/hamming_window.py +104 -0
- emx_onnx_cgen/lowering/hardmax.py +1 -37
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/logsoftmax.py +1 -35
- emx_onnx_cgen/lowering/lp_pool.py +15 -4
- emx_onnx_cgen/lowering/matmul.py +3 -105
- emx_onnx_cgen/lowering/optional_has_element.py +28 -0
- emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
- emx_onnx_cgen/lowering/reduce.py +0 -5
- emx_onnx_cgen/lowering/reshape.py +7 -16
- emx_onnx_cgen/lowering/shape.py +14 -8
- emx_onnx_cgen/lowering/slice.py +14 -4
- emx_onnx_cgen/lowering/softmax.py +1 -35
- emx_onnx_cgen/lowering/split.py +37 -3
- emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
- emx_onnx_cgen/lowering/tile.py +38 -1
- emx_onnx_cgen/lowering/topk.py +1 -5
- emx_onnx_cgen/lowering/transpose.py +9 -3
- emx_onnx_cgen/lowering/unsqueeze.py +11 -16
- emx_onnx_cgen/lowering/upsample.py +151 -0
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +0 -5
- emx_onnx_cgen/onnx_import.py +578 -14
- emx_onnx_cgen/ops.py +3 -0
- emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
- emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
- emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
- emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
- emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
- emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
- emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
- emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
- emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
- emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
- emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
- emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
- emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
- emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
- emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
- emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
- emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
- emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
- emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
- emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
- emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
- emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
- emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
- emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
- emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
- emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
- emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
- emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
- emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
- emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
- emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
- emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
- emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
- emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
- emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
- emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
- emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
- emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
- emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
- emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
- emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
- emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
- emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
- emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
- emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
- emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
- emx_onnx_cgen/templates/range_op.c.j2 +8 -0
- emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
- emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
- emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
- emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
- emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
- emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
- emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
- emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
- emx_onnx_cgen/templates/size_op.c.j2 +4 -0
- emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
- emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
- emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
- emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
- emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
- emx_onnx_cgen/templates/split_op.c.j2 +18 -0
- emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
- emx_onnx_cgen/templates/testbench.c.j2 +161 -0
- emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
- emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
- emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
- emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
- emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
- emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
- emx_onnx_cgen/templates/where_op.c.j2 +9 -0
- emx_onnx_cgen/verification.py +45 -5
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/METADATA +33 -15
- emx_onnx_cgen-0.4.2.dev0.dist-info/RECORD +190 -0
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/WHEEL +1 -1
- emx_onnx_cgen/runtime/__init__.py +0 -1
- emx_onnx_cgen/runtime/evaluator.py +0 -2955
- emx_onnx_cgen-0.3.8.dist-info/RECORD +0 -107
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
const int64_t input_shape[{{ rank }}] = { {% for dim in input_shape %}{{ dim }}{% if not loop.last %}, {% endif %}{% endfor %} };
|
|
3
|
+
const int64_t output_shape[{{ rank }}] EMX_UNUSED = { {% for dim in output_shape %}{{ dim }}{% if not loop.last %}, {% endif %}{% endfor %} };
|
|
4
|
+
double scales[{{ rank }}] EMX_UNUSED;
|
|
5
|
+
{% if roi_input %}
|
|
6
|
+
double roi[{{ rank * 2 }}] EMX_UNUSED;
|
|
7
|
+
{% if roi_axes %}
|
|
8
|
+
for (idx_t r = 0; r < {{ rank }}; ++r) {
|
|
9
|
+
roi[r] = 0.0;
|
|
10
|
+
roi[r + {{ rank }}] = 1.0;
|
|
11
|
+
}
|
|
12
|
+
{% for axis in axes %}
|
|
13
|
+
roi[{{ axis }}] = (double){{ roi_input }}[{{ roi_axis_map[loop.index0] }}];
|
|
14
|
+
roi[{{ axis + rank }}] = (double){{ roi_input }}[{{ roi_axis_map[loop.index0] + (axes | length) }}];
|
|
15
|
+
{% endfor %}
|
|
16
|
+
{% else %}
|
|
17
|
+
for (idx_t r = 0; r < {{ rank * 2 }}; ++r) {
|
|
18
|
+
roi[r] = (double){{ roi_input }}[r];
|
|
19
|
+
}
|
|
20
|
+
{% endif %}
|
|
21
|
+
{% else %}
|
|
22
|
+
const double roi[{{ rank * 2 }}] EMX_UNUSED = {
|
|
23
|
+
{% for _ in range(rank) %} 0.0{% if not loop.last %},{% endif %}
|
|
24
|
+
{% endfor %}{% if rank > 0 %},{% endif %}
|
|
25
|
+
{% for _ in range(rank) %} 1.0{% if not loop.last %},{% endif %}
|
|
26
|
+
{% endfor %}
|
|
27
|
+
};
|
|
28
|
+
{% endif %}
|
|
29
|
+
{% if scales_input %}
|
|
30
|
+
{% if scales_axes %}
|
|
31
|
+
for (idx_t r = 0; r < {{ rank }}; ++r) {
|
|
32
|
+
scales[r] = 1.0;
|
|
33
|
+
}
|
|
34
|
+
{% for axis in axes %}
|
|
35
|
+
scales[{{ axis }}] = (double){{ scales_input }}[{{ scales_axis_map[loop.index0] }}];
|
|
36
|
+
{% endfor %}
|
|
37
|
+
{% else %}
|
|
38
|
+
{% for axis in range(rank) %}
|
|
39
|
+
scales[{{ axis }}] = (double){{ scales_input }}[{{ axis }}];
|
|
40
|
+
{% endfor %}
|
|
41
|
+
{% endif %}
|
|
42
|
+
{% elif sizes_input %}
|
|
43
|
+
{% if keep_aspect_ratio_policy != "stretch" %}
|
|
44
|
+
{% if keep_aspect_ratio_policy == "not_larger" %}
|
|
45
|
+
double scale_value = INFINITY;
|
|
46
|
+
{% else %}
|
|
47
|
+
double scale_value = 0.0;
|
|
48
|
+
{% endif %}
|
|
49
|
+
{% for axis in axes %}
|
|
50
|
+
{
|
|
51
|
+
const double size_axis = (double){{ sizes_input }}[{{ sizes_axis_map[loop.index0] }}];
|
|
52
|
+
const double axis_scale = size_axis / (double)input_shape[{{ axis }}];
|
|
53
|
+
{% if keep_aspect_ratio_policy == "not_larger" %}
|
|
54
|
+
if (axis_scale < scale_value) {
|
|
55
|
+
scale_value = axis_scale;
|
|
56
|
+
}
|
|
57
|
+
{% else %}
|
|
58
|
+
if (axis_scale > scale_value) {
|
|
59
|
+
scale_value = axis_scale;
|
|
60
|
+
}
|
|
61
|
+
{% endif %}
|
|
62
|
+
}
|
|
63
|
+
{% endfor %}
|
|
64
|
+
for (idx_t r = 0; r < {{ rank }}; ++r) {
|
|
65
|
+
scales[r] = 1.0;
|
|
66
|
+
}
|
|
67
|
+
{% for axis in axes %}
|
|
68
|
+
scales[{{ axis }}] = scale_value;
|
|
69
|
+
{% endfor %}
|
|
70
|
+
{% else %}
|
|
71
|
+
for (idx_t r = 0; r < {{ rank }}; ++r) {
|
|
72
|
+
scales[r] = 1.0;
|
|
73
|
+
}
|
|
74
|
+
{% for axis in axes %}
|
|
75
|
+
scales[{{ axis }}] = (double){{ sizes_input }}[{{ sizes_axis_map[loop.index0] }}] / (double)input_shape[{{ axis }}];
|
|
76
|
+
{% endfor %}
|
|
77
|
+
{% endif %}
|
|
78
|
+
{% else %}
|
|
79
|
+
{% for axis in range(rank) %}
|
|
80
|
+
scales[{{ axis }}] = {{ scales[axis] }};
|
|
81
|
+
{% endfor %}
|
|
82
|
+
{% endif %}
|
|
83
|
+
{% for dim in output_shape %}
|
|
84
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
85
|
+
{% endfor %}
|
|
86
|
+
int use_extrapolation = 0;
|
|
87
|
+
{% for dim in range(rank) %}
|
|
88
|
+
double x_orig{{ dim }};
|
|
89
|
+
{% if coordinate_transformation_mode == "align_corners" %}
|
|
90
|
+
if (output_shape[{{ dim }}] == 1) {
|
|
91
|
+
x_orig{{ dim }} = 0.0;
|
|
92
|
+
} else {
|
|
93
|
+
x_orig{{ dim }} = (double){{ loop_vars[dim] }} * (input_shape[{{ dim }}] - 1) / (double)(output_shape[{{ dim }}] - 1);
|
|
94
|
+
}
|
|
95
|
+
{% elif coordinate_transformation_mode == "asymmetric" %}
|
|
96
|
+
x_orig{{ dim }} = (double){{ loop_vars[dim] }} / scales[{{ dim }}];
|
|
97
|
+
{% elif coordinate_transformation_mode == "tf_crop_and_resize" %}
|
|
98
|
+
{
|
|
99
|
+
const double roi_start = roi[{{ dim }}];
|
|
100
|
+
const double roi_end = roi[{{ dim + rank }}];
|
|
101
|
+
if (output_shape[{{ dim }}] == 1) {
|
|
102
|
+
x_orig{{ dim }} = (roi_end - roi_start) * (input_shape[{{ dim }}] - 1) / 2.0;
|
|
103
|
+
} else {
|
|
104
|
+
x_orig{{ dim }} = (double){{ loop_vars[dim] }} * (roi_end - roi_start) * (input_shape[{{ dim }}] - 1)
|
|
105
|
+
/ (double)(output_shape[{{ dim }}] - 1);
|
|
106
|
+
}
|
|
107
|
+
x_orig{{ dim }} += roi_start * (input_shape[{{ dim }}] - 1);
|
|
108
|
+
if (x_orig{{ dim }} < 0.0 || x_orig{{ dim }} > (double)(input_shape[{{ dim }}] - 1)) {
|
|
109
|
+
use_extrapolation = 1;
|
|
110
|
+
}
|
|
111
|
+
}
|
|
112
|
+
{% elif coordinate_transformation_mode == "pytorch_half_pixel" %}
|
|
113
|
+
if (output_shape[{{ dim }}] == 1) {
|
|
114
|
+
x_orig{{ dim }} = -0.5;
|
|
115
|
+
} else {
|
|
116
|
+
x_orig{{ dim }} = ((double){{ loop_vars[dim] }} + 0.5) / scales[{{ dim }}] - 0.5;
|
|
117
|
+
}
|
|
118
|
+
{% elif coordinate_transformation_mode == "half_pixel_symmetric" %}
|
|
119
|
+
{
|
|
120
|
+
const double output_width = scales[{{ dim }}] * (double)input_shape[{{ dim }}];
|
|
121
|
+
const double adjustment = (double)output_shape[{{ dim }}] / output_width;
|
|
122
|
+
const double center = (double)input_shape[{{ dim }}] / 2.0;
|
|
123
|
+
const double offset = center * (1.0 - adjustment);
|
|
124
|
+
x_orig{{ dim }} = offset + ((double){{ loop_vars[dim] }} + 0.5) / scales[{{ dim }}] - 0.5;
|
|
125
|
+
}
|
|
126
|
+
{% else %}
|
|
127
|
+
x_orig{{ dim }} = ((double){{ loop_vars[dim] }} + 0.5) / scales[{{ dim }}] - 0.5;
|
|
128
|
+
{% endif %}
|
|
129
|
+
{% endfor %}
|
|
130
|
+
if (use_extrapolation) {
|
|
131
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = ({{ c_type }}){{ extrapolation_value }};
|
|
132
|
+
} else {
|
|
133
|
+
{% if mode == "nearest" %}
|
|
134
|
+
{% for dim in range(rank) %}
|
|
135
|
+
const double x_val{{ dim }} = x_orig{{ dim }};
|
|
136
|
+
const double x_floor{{ dim }} EMX_UNUSED = floor(x_val{{ dim }});
|
|
137
|
+
const double x_ceil{{ dim }} EMX_UNUSED = ceil(x_val{{ dim }});
|
|
138
|
+
int idx{{ dim }};
|
|
139
|
+
{% if nearest_mode == "round_prefer_floor" %}
|
|
140
|
+
idx{{ dim }} = (x_val{{ dim }} - x_floor{{ dim }} <= 0.5) ? (int)x_floor{{ dim }} : (int)x_ceil{{ dim }};
|
|
141
|
+
{% elif nearest_mode == "round_prefer_ceil" %}
|
|
142
|
+
idx{{ dim }} = (x_val{{ dim }} - x_floor{{ dim }} < 0.5) ? (int)x_floor{{ dim }} : (int)x_ceil{{ dim }};
|
|
143
|
+
{% elif nearest_mode == "floor" %}
|
|
144
|
+
idx{{ dim }} = (int)x_floor{{ dim }};
|
|
145
|
+
{% else %}
|
|
146
|
+
idx{{ dim }} = (int)x_ceil{{ dim }};
|
|
147
|
+
{% endif %}
|
|
148
|
+
if (idx{{ dim }} < 0) {
|
|
149
|
+
idx{{ dim }} = 0;
|
|
150
|
+
} else if (idx{{ dim }} >= input_shape[{{ dim }}]) {
|
|
151
|
+
idx{{ dim }} = (int)input_shape[{{ dim }}] - 1;
|
|
152
|
+
}
|
|
153
|
+
{% endfor %}
|
|
154
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = {{ input0 }}{% for dim in range(rank) %}[idx{{ dim }}]{% endfor %};
|
|
155
|
+
{% else %}
|
|
156
|
+
double acc = 0.0;
|
|
157
|
+
{% for dim in range(rank) %}
|
|
158
|
+
double x_floor{{ dim }} = floor(x_orig{{ dim }});
|
|
159
|
+
double ratio{{ dim }} = x_orig{{ dim }} - x_floor{{ dim }};
|
|
160
|
+
const int is_integer{{ dim }} = ratio{{ dim }} == 0.0;
|
|
161
|
+
if (is_integer{{ dim }}) {
|
|
162
|
+
ratio{{ dim }} = 1.0;
|
|
163
|
+
}
|
|
164
|
+
int count{{ dim }};
|
|
165
|
+
{% if antialias %}
|
|
166
|
+
double scale_clamped{{ dim }} = scales[{{ dim }}] < 1.0 ? scales[{{ dim }}] : 1.0;
|
|
167
|
+
{% if mode == "linear" %}
|
|
168
|
+
int coeff_start{{ dim }} = (int)floor(-1.0 / scale_clamped{{ dim }}) + 1;
|
|
169
|
+
count{{ dim }} = 2 - 2 * coeff_start{{ dim }};
|
|
170
|
+
{% else %}
|
|
171
|
+
int coeff_start{{ dim }} = (int)floor(-2.0 / scale_clamped{{ dim }}) + 1;
|
|
172
|
+
int coeff_end{{ dim }} = 2 - coeff_start{{ dim }};
|
|
173
|
+
count{{ dim }} = coeff_end{{ dim }} - coeff_start{{ dim }};
|
|
174
|
+
{% endif %}
|
|
175
|
+
{% else %}
|
|
176
|
+
{% if mode == "linear" %}
|
|
177
|
+
count{{ dim }} = 2;
|
|
178
|
+
{% else %}
|
|
179
|
+
count{{ dim }} = 4;
|
|
180
|
+
{% endif %}
|
|
181
|
+
{% endif %}
|
|
182
|
+
int start_index{{ dim }} = (int)x_floor{{ dim }} - (count{{ dim }} / 2);
|
|
183
|
+
if (!is_integer{{ dim }}) {
|
|
184
|
+
start_index{{ dim }} += 1;
|
|
185
|
+
}
|
|
186
|
+
int idx{{ dim }}_values[count{{ dim }}];
|
|
187
|
+
double coeff{{ dim }}[count{{ dim }}];
|
|
188
|
+
{% if antialias %}
|
|
189
|
+
double coeff_sum{{ dim }} = 0.0;
|
|
190
|
+
for (int c = 0; c < count{{ dim }}; ++c) {
|
|
191
|
+
idx{{ dim }}_values[c] = start_index{{ dim }} + c;
|
|
192
|
+
{% if mode == "linear" %}
|
|
193
|
+
double arg = (coeff_start{{ dim }} + c - ratio{{ dim }}) * scale_clamped{{ dim }};
|
|
194
|
+
double coeff = 1.0 - fabs(arg);
|
|
195
|
+
if (coeff < 0.0) {
|
|
196
|
+
coeff = 0.0;
|
|
197
|
+
}
|
|
198
|
+
{% else %}
|
|
199
|
+
double x = scale_clamped{{ dim }} * (coeff_start{{ dim }} + c - ratio{{ dim }});
|
|
200
|
+
double ax = fabs(x);
|
|
201
|
+
double coeff;
|
|
202
|
+
if (ax <= 1.0) {
|
|
203
|
+
coeff = ({{ cubic_coeff_a }} + 2.0) * ax * ax * ax - ({{ cubic_coeff_a }} + 3.0) * ax * ax + 1.0;
|
|
204
|
+
} else if (ax < 2.0) {
|
|
205
|
+
coeff = {{ cubic_coeff_a }} * ax * ax * ax - 5.0 * {{ cubic_coeff_a }} * ax * ax + 8.0 * {{ cubic_coeff_a }} * ax - 4.0 * {{ cubic_coeff_a }};
|
|
206
|
+
} else {
|
|
207
|
+
coeff = 0.0;
|
|
208
|
+
}
|
|
209
|
+
{% endif %}
|
|
210
|
+
coeff{{ dim }}[c] = coeff;
|
|
211
|
+
coeff_sum{{ dim }} += coeff;
|
|
212
|
+
}
|
|
213
|
+
if (coeff_sum{{ dim }} > 0.0) {
|
|
214
|
+
for (int c = 0; c < count{{ dim }}; ++c) {
|
|
215
|
+
coeff{{ dim }}[c] /= coeff_sum{{ dim }};
|
|
216
|
+
}
|
|
217
|
+
}
|
|
218
|
+
{% else %}
|
|
219
|
+
for (int c = 0; c < count{{ dim }}; ++c) {
|
|
220
|
+
idx{{ dim }}_values[c] = start_index{{ dim }} + c;
|
|
221
|
+
}
|
|
222
|
+
{% if mode == "linear" %}
|
|
223
|
+
coeff{{ dim }}[0] = 1.0 - ratio{{ dim }};
|
|
224
|
+
coeff{{ dim }}[1] = ratio{{ dim }};
|
|
225
|
+
{% else %}
|
|
226
|
+
{
|
|
227
|
+
const double t = 1.0 - ratio{{ dim }};
|
|
228
|
+
coeff{{ dim }}[0] = (({{ cubic_coeff_a }} * (ratio{{ dim }} + 1.0) - 5.0 * {{ cubic_coeff_a }}) * (ratio{{ dim }} + 1.0) + 8.0 * {{ cubic_coeff_a }}) * (ratio{{ dim }} + 1.0) - 4.0 * {{ cubic_coeff_a }};
|
|
229
|
+
coeff{{ dim }}[1] = (({{ cubic_coeff_a }} + 2.0) * ratio{{ dim }} - ({{ cubic_coeff_a }} + 3.0)) * ratio{{ dim }} * ratio{{ dim }} + 1.0;
|
|
230
|
+
coeff{{ dim }}[2] = (({{ cubic_coeff_a }} + 2.0) * t - ({{ cubic_coeff_a }} + 3.0)) * t * t + 1.0;
|
|
231
|
+
coeff{{ dim }}[3] = (({{ cubic_coeff_a }} * (t + 1.0) - 5.0 * {{ cubic_coeff_a }}) * (t + 1.0) + 8.0 * {{ cubic_coeff_a }}) * (t + 1.0) - 4.0 * {{ cubic_coeff_a }};
|
|
232
|
+
}
|
|
233
|
+
{% endif %}
|
|
234
|
+
{% endif %}
|
|
235
|
+
{% if exclude_outside %}
|
|
236
|
+
{
|
|
237
|
+
double coeff_sum{{ dim }} = 0.0;
|
|
238
|
+
for (int c = 0; c < count{{ dim }}; ++c) {
|
|
239
|
+
if (idx{{ dim }}_values[c] < 0 || idx{{ dim }}_values[c] >= input_shape[{{ dim }}]) {
|
|
240
|
+
coeff{{ dim }}[c] = 0.0;
|
|
241
|
+
}
|
|
242
|
+
coeff_sum{{ dim }} += coeff{{ dim }}[c];
|
|
243
|
+
}
|
|
244
|
+
if (coeff_sum{{ dim }} > 0.0) {
|
|
245
|
+
for (int c = 0; c < count{{ dim }}; ++c) {
|
|
246
|
+
coeff{{ dim }}[c] /= coeff_sum{{ dim }};
|
|
247
|
+
}
|
|
248
|
+
}
|
|
249
|
+
}
|
|
250
|
+
{% endif %}
|
|
251
|
+
{% endfor %}
|
|
252
|
+
{% for dim in range(rank) %}
|
|
253
|
+
for (int n{{ dim }} = 0; n{{ dim }} < count{{ dim }}; ++n{{ dim }}) {
|
|
254
|
+
{% endfor %}
|
|
255
|
+
double weight = 1.0;
|
|
256
|
+
{% for dim in range(rank) %}
|
|
257
|
+
weight *= coeff{{ dim }}[n{{ dim }}];
|
|
258
|
+
{% endfor %}
|
|
259
|
+
{% for dim in range(rank) %}
|
|
260
|
+
int idx{{ dim }} = idx{{ dim }}_values[n{{ dim }}];
|
|
261
|
+
if (idx{{ dim }} < 0) {
|
|
262
|
+
idx{{ dim }} = 0;
|
|
263
|
+
} else if (idx{{ dim }} >= input_shape[{{ dim }}]) {
|
|
264
|
+
idx{{ dim }} = (int)input_shape[{{ dim }}] - 1;
|
|
265
|
+
}
|
|
266
|
+
{% endfor %}
|
|
267
|
+
acc += weight * (double){{ input0 }}{% for dim in range(rank) %}[idx{{ dim }}]{% endfor %};
|
|
268
|
+
{% for dim in range(rank) %}
|
|
269
|
+
}
|
|
270
|
+
{% endfor %}
|
|
271
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = ({{ c_type }})acc;
|
|
272
|
+
{% endif %}
|
|
273
|
+
}
|
|
274
|
+
{% for _ in output_shape %}
|
|
275
|
+
}
|
|
276
|
+
{% endfor %}
|
|
277
|
+
}
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in prefix_shape %}
|
|
3
|
+
for (idx_t {{ prefix_loop_vars[loop.index0] }} = 0; {{ prefix_loop_vars[loop.index0] }} < {{ dim }}; ++{{ prefix_loop_vars[loop.index0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
{{ c_type }} sum = {{ zero_literal }};
|
|
6
|
+
{% for dim in norm_shape %}
|
|
7
|
+
for (idx_t {{ norm_loop_vars[loop.index0] }} = 0; {{ norm_loop_vars[loop.index0] }} < {{ dim }}; ++{{ norm_loop_vars[loop.index0] }}) {
|
|
8
|
+
{% endfor %}
|
|
9
|
+
{{ c_type }} value = {{ input0 }}{% for var in prefix_loop_vars %}[{{ var }}]{% endfor %}{% for var in norm_loop_vars %}[{{ var }}]{% endfor %};
|
|
10
|
+
sum += value * value;
|
|
11
|
+
{% for _ in norm_shape %}
|
|
12
|
+
}
|
|
13
|
+
{% endfor %}
|
|
14
|
+
{{ c_type }} mean_square = sum / {{ inner }};
|
|
15
|
+
{{ c_type }} denom = {{ sqrt_fn }}(mean_square + {{ epsilon_literal }});
|
|
16
|
+
{% for dim in norm_shape %}
|
|
17
|
+
for (idx_t {{ norm_loop_vars[loop.index0] }} = 0; {{ norm_loop_vars[loop.index0] }} < {{ dim }}; ++{{ norm_loop_vars[loop.index0] }}) {
|
|
18
|
+
{% endfor %}
|
|
19
|
+
{{ c_type }} value = {{ input0 }}{% for var in prefix_loop_vars %}[{{ var }}]{% endfor %}{% for var in norm_loop_vars %}[{{ var }}]{% endfor %} / denom;
|
|
20
|
+
value = value * {{ scale }}{% for var in scale_index_vars %}[{{ var }}]{% endfor %};
|
|
21
|
+
{{ output }}{% for var in prefix_loop_vars %}[{{ var }}]{% endfor %}{% for var in norm_loop_vars %}[{{ var }}]{% endfor %} = value;
|
|
22
|
+
{% for _ in norm_shape %}
|
|
23
|
+
}
|
|
24
|
+
{% endfor %}
|
|
25
|
+
{% for _ in prefix_shape %}
|
|
26
|
+
}
|
|
27
|
+
{% endfor %}
|
|
28
|
+
}
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% if input_rank == 3 %}
|
|
3
|
+
for (idx_t b = 0; b < {{ batch }}; ++b) {
|
|
4
|
+
for (idx_t s = 0; s < {{ seq_len }}; ++s) {
|
|
5
|
+
for (idx_t h = 0; h < {{ num_heads }}; ++h) {
|
|
6
|
+
idx_t head_offset = h * {{ head_size }};
|
|
7
|
+
for (idx_t d = 0; d < {{ rotary_dim_half }}; ++d) {
|
|
8
|
+
{% if interleaved %}
|
|
9
|
+
idx_t idx1 = 2 * d;
|
|
10
|
+
idx_t idx2 = 2 * d + 1;
|
|
11
|
+
{% else %}
|
|
12
|
+
idx_t idx1 = d;
|
|
13
|
+
idx_t idx2 = d + {{ rotary_dim_half }};
|
|
14
|
+
{% endif %}
|
|
15
|
+
{% if has_position_ids %}
|
|
16
|
+
idx_t pos = (idx_t){{ position_ids }}[b][s];
|
|
17
|
+
{{ c_type }} cos_val = {{ cos_cache }}[pos][d];
|
|
18
|
+
{{ c_type }} sin_val = {{ sin_cache }}[pos][d];
|
|
19
|
+
{% else %}
|
|
20
|
+
{{ c_type }} cos_val = {{ cos_cache }}[b][s][d];
|
|
21
|
+
{{ c_type }} sin_val = {{ sin_cache }}[b][s][d];
|
|
22
|
+
{% endif %}
|
|
23
|
+
{{ c_type }} x1 = {{ input0 }}[b][s][head_offset + idx1];
|
|
24
|
+
{{ c_type }} x2 = {{ input0 }}[b][s][head_offset + idx2];
|
|
25
|
+
{{ output }}[b][s][head_offset + idx1] = (cos_val * x1) - (sin_val * x2);
|
|
26
|
+
{{ output }}[b][s][head_offset + idx2] = (sin_val * x1) + (cos_val * x2);
|
|
27
|
+
}
|
|
28
|
+
for (idx_t d = {{ rotary_dim }}; d < {{ head_size }}; ++d) {
|
|
29
|
+
{{ output }}[b][s][head_offset + d] = {{ input0 }}[b][s][head_offset + d];
|
|
30
|
+
}
|
|
31
|
+
}
|
|
32
|
+
}
|
|
33
|
+
}
|
|
34
|
+
{% else %}
|
|
35
|
+
for (idx_t b = 0; b < {{ batch }}; ++b) {
|
|
36
|
+
for (idx_t h = 0; h < {{ num_heads }}; ++h) {
|
|
37
|
+
for (idx_t s = 0; s < {{ seq_len }}; ++s) {
|
|
38
|
+
for (idx_t d = 0; d < {{ rotary_dim_half }}; ++d) {
|
|
39
|
+
{% if interleaved %}
|
|
40
|
+
idx_t idx1 = 2 * d;
|
|
41
|
+
idx_t idx2 = 2 * d + 1;
|
|
42
|
+
{% else %}
|
|
43
|
+
idx_t idx1 = d;
|
|
44
|
+
idx_t idx2 = d + {{ rotary_dim_half }};
|
|
45
|
+
{% endif %}
|
|
46
|
+
{% if has_position_ids %}
|
|
47
|
+
idx_t pos = (idx_t){{ position_ids }}[b][s];
|
|
48
|
+
{{ c_type }} cos_val = {{ cos_cache }}[pos][d];
|
|
49
|
+
{{ c_type }} sin_val = {{ sin_cache }}[pos][d];
|
|
50
|
+
{% else %}
|
|
51
|
+
{{ c_type }} cos_val = {{ cos_cache }}[b][s][d];
|
|
52
|
+
{{ c_type }} sin_val = {{ sin_cache }}[b][s][d];
|
|
53
|
+
{% endif %}
|
|
54
|
+
{{ c_type }} x1 = {{ input0 }}[b][h][s][idx1];
|
|
55
|
+
{{ c_type }} x2 = {{ input0 }}[b][h][s][idx2];
|
|
56
|
+
{{ output }}[b][h][s][idx1] = (cos_val * x1) - (sin_val * x2);
|
|
57
|
+
{{ output }}[b][h][s][idx2] = (sin_val * x1) + (cos_val * x2);
|
|
58
|
+
}
|
|
59
|
+
for (idx_t d = {{ rotary_dim }}; d < {{ head_size }}; ++d) {
|
|
60
|
+
{{ output }}[b][h][s][d] = {{ input0 }}[b][h][s][d];
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
}
|
|
64
|
+
}
|
|
65
|
+
{% endif %}
|
|
66
|
+
}
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in output_shape %}
|
|
3
|
+
for (idx_t {{ output_loop_vars[loop.index0] }} = 0; {{ output_loop_vars[loop.index0] }} < {{ dim }}; ++{{ output_loop_vars[loop.index0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
{{ output }}{% for var in output_loop_vars %}[{{ var }}]{% endfor %} = {{ data }}{% for var in output_loop_vars %}[{{ var }}]{% endfor %};
|
|
6
|
+
{% for _ in output_shape %}
|
|
7
|
+
}
|
|
8
|
+
{% endfor %}
|
|
9
|
+
|
|
10
|
+
{% if indices_prefix_shape %}
|
|
11
|
+
{% for dim in indices_prefix_shape %}
|
|
12
|
+
for (idx_t {{ indices_prefix_loop_vars[loop.index0] }} = 0; {{ indices_prefix_loop_vars[loop.index0] }} < {{ dim }}; ++{{ indices_prefix_loop_vars[loop.index0] }}) {
|
|
13
|
+
{% endfor %}
|
|
14
|
+
{% endif %}
|
|
15
|
+
{% for idx in range(index_depth) %}
|
|
16
|
+
idx_t index{{ idx }} = {{ indices }}{% for var in indices_prefix_loop_vars %}[{{ var }}]{% endfor %}[{{ idx }}];
|
|
17
|
+
if (index{{ idx }} < 0) {
|
|
18
|
+
index{{ idx }} += {{ data_shape[idx] }};
|
|
19
|
+
}
|
|
20
|
+
{% endfor %}
|
|
21
|
+
{% if tail_shape %}
|
|
22
|
+
{% for dim in tail_shape %}
|
|
23
|
+
for (idx_t {{ tail_loop_vars[loop.index0] }} = 0; {{ tail_loop_vars[loop.index0] }} < {{ dim }}; ++{{ tail_loop_vars[loop.index0] }}) {
|
|
24
|
+
{% endfor %}
|
|
25
|
+
{% endif %}
|
|
26
|
+
{{ c_type }} update_val = {{ updates_index_expr }};
|
|
27
|
+
{% if reduction == "none" %}
|
|
28
|
+
{{ output_index_expr }} = update_val;
|
|
29
|
+
{% elif reduction == "add" %}
|
|
30
|
+
{{ output_index_expr }} += update_val;
|
|
31
|
+
{% elif reduction == "mul" %}
|
|
32
|
+
{{ output_index_expr }} *= update_val;
|
|
33
|
+
{% elif reduction == "max" %}
|
|
34
|
+
if (update_val > {{ output_index_expr }}) {
|
|
35
|
+
{{ output_index_expr }} = update_val;
|
|
36
|
+
}
|
|
37
|
+
{% elif reduction == "min" %}
|
|
38
|
+
if (update_val < {{ output_index_expr }}) {
|
|
39
|
+
{{ output_index_expr }} = update_val;
|
|
40
|
+
}
|
|
41
|
+
{% endif %}
|
|
42
|
+
{% if tail_shape %}
|
|
43
|
+
{% for _ in tail_shape %}
|
|
44
|
+
}
|
|
45
|
+
{% endfor %}
|
|
46
|
+
{% endif %}
|
|
47
|
+
{% if indices_prefix_shape %}
|
|
48
|
+
{% for _ in indices_prefix_shape %}
|
|
49
|
+
}
|
|
50
|
+
{% endfor %}
|
|
51
|
+
{% endif %}
|
|
52
|
+
}
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in output_shape %}
|
|
3
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = {{ input0 }}{% for idx in input_indices %}[{{ idx }}]{% endfor %};
|
|
6
|
+
{% for _ in output_shape %}
|
|
7
|
+
}
|
|
8
|
+
{% endfor %}
|
|
9
|
+
}
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
const idx_t input_rank = {{ input_rank }};
|
|
3
|
+
const idx_t input_dims[] = { {% for dim in input_shape %}{{ dim }}{% if not loop.last %}, {% endif %}{% endfor %} };
|
|
4
|
+
idx_t start_indices[{{ input_rank }}];
|
|
5
|
+
idx_t step_values[{{ input_rank }}];
|
|
6
|
+
for (idx_t idx = 0; idx < input_rank; ++idx) {
|
|
7
|
+
start_indices[idx] = 0;
|
|
8
|
+
step_values[idx] = 1;
|
|
9
|
+
}
|
|
10
|
+
for (idx_t i = 0; i < {{ starts_len }}; ++i) {
|
|
11
|
+
{% if axes_input %}
|
|
12
|
+
idx_t axis_value = {{ axes_input }}[i];
|
|
13
|
+
{% else %}
|
|
14
|
+
idx_t axis_value = i;
|
|
15
|
+
{% endif %}
|
|
16
|
+
if (axis_value < 0) {
|
|
17
|
+
axis_value += input_rank;
|
|
18
|
+
}
|
|
19
|
+
idx_t axis = axis_value;
|
|
20
|
+
idx_t dim = input_dims[axis];
|
|
21
|
+
idx_t start_value = {{ starts_input }}[i];
|
|
22
|
+
idx_t step_value = 1;
|
|
23
|
+
{% if steps_input %}
|
|
24
|
+
step_value = {{ steps_input }}[i];
|
|
25
|
+
{% endif %}
|
|
26
|
+
idx_t end_value = {{ ends_input }}[i];
|
|
27
|
+
if (start_value < 0) {
|
|
28
|
+
start_value += dim;
|
|
29
|
+
}
|
|
30
|
+
if (end_value < 0) {
|
|
31
|
+
end_value += dim;
|
|
32
|
+
}
|
|
33
|
+
if (step_value > 0) {
|
|
34
|
+
if (start_value < 0) {
|
|
35
|
+
start_value = 0;
|
|
36
|
+
}
|
|
37
|
+
if (start_value > dim) {
|
|
38
|
+
start_value = dim;
|
|
39
|
+
}
|
|
40
|
+
if (end_value < 0) {
|
|
41
|
+
end_value = 0;
|
|
42
|
+
}
|
|
43
|
+
if (end_value > dim) {
|
|
44
|
+
end_value = dim;
|
|
45
|
+
}
|
|
46
|
+
} else {
|
|
47
|
+
if (start_value < 0) {
|
|
48
|
+
start_value = -1;
|
|
49
|
+
}
|
|
50
|
+
if (start_value >= dim) {
|
|
51
|
+
start_value = dim - 1;
|
|
52
|
+
}
|
|
53
|
+
if (end_value < -1) {
|
|
54
|
+
end_value = -1;
|
|
55
|
+
}
|
|
56
|
+
if (end_value >= dim) {
|
|
57
|
+
end_value = dim - 1;
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
start_indices[axis] = start_value;
|
|
61
|
+
step_values[axis] = step_value;
|
|
62
|
+
}
|
|
63
|
+
{% for dim in output_shape %}
|
|
64
|
+
for (idx_t {{ output_loop_vars[loop.index0] }} = 0; {{ output_loop_vars[loop.index0] }} < {{ dim }}; ++{{ output_loop_vars[loop.index0] }}) {
|
|
65
|
+
{% endfor %}
|
|
66
|
+
{{ output }}{% for var in output_loop_vars %}[{{ var }}]{% endfor %} = {{ input0 }}{% for var in output_loop_vars %}[start_indices[{{ loop.index0 }}] + ({{ var }} * step_values[{{ loop.index0 }}])]{% endfor %};
|
|
67
|
+
{% for _ in output_shape %}
|
|
68
|
+
}
|
|
69
|
+
{% endfor %}
|
|
70
|
+
}
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
const {{ c_type }} *input_flat = (const {{ c_type }} *){{ input0 }};
|
|
3
|
+
const {{ target_c_type }} *target_flat = (const {{ target_c_type }} *){{ target }};
|
|
4
|
+
{{ c_type }} *output_flat = ({{ c_type }} *){{ output }};
|
|
5
|
+
{% if log_prob %}
|
|
6
|
+
{{ c_type }} *log_prob_flat = ({{ c_type }} *){{ log_prob }};
|
|
7
|
+
{% endif %}
|
|
8
|
+
const idx_t n = {{ n }};
|
|
9
|
+
const idx_t c = {{ c }};
|
|
10
|
+
const idx_t d = {{ d }};
|
|
11
|
+
{% if reduction != "none" %}
|
|
12
|
+
{{ acc_type }} loss_sum = {{ acc_zero_literal }};
|
|
13
|
+
{% if reduction == "mean" and (weight or use_ignore_index) %}
|
|
14
|
+
{{ acc_type }} weight_sum = {{ acc_zero_literal }};
|
|
15
|
+
{% endif %}
|
|
16
|
+
{% endif %}
|
|
17
|
+
for (idx_t n_idx = 0; n_idx < n; ++n_idx) {
|
|
18
|
+
for (idx_t d_idx = 0; d_idx < d; ++d_idx) {
|
|
19
|
+
idx_t target_index = n_idx * d + d_idx;
|
|
20
|
+
{{ target_c_type }} target_value = target_flat[target_index];
|
|
21
|
+
{% if use_ignore_index %}
|
|
22
|
+
bool ignored = ((int64_t)target_value == {{ ignore_index }});
|
|
23
|
+
{% endif %}
|
|
24
|
+
idx_t class_index = (idx_t)target_value;
|
|
25
|
+
idx_t base = (n_idx * c * d) + d_idx;
|
|
26
|
+
{{ acc_type }} max_value = ({{ acc_type }})input_flat[base];
|
|
27
|
+
for (idx_t c_idx = 1; c_idx < c; ++c_idx) {
|
|
28
|
+
{{ acc_type }} value = ({{ acc_type }})input_flat[base + c_idx * d];
|
|
29
|
+
max_value = {{ max_fn }}(max_value, value);
|
|
30
|
+
}
|
|
31
|
+
{{ acc_type }} sum = {{ acc_zero_literal }};
|
|
32
|
+
for (idx_t c_idx = 0; c_idx < c; ++c_idx) {
|
|
33
|
+
{{ acc_type }} value = ({{ acc_type }})input_flat[base + c_idx * d] - max_value;
|
|
34
|
+
sum += {{ acc_exp_fn }}(value);
|
|
35
|
+
}
|
|
36
|
+
{{ acc_type }} logsum = {{ acc_log_fn }}(sum);
|
|
37
|
+
{% if log_prob %}
|
|
38
|
+
{{ acc_type }} loss_value = {{ acc_zero_literal }};
|
|
39
|
+
for (idx_t c_idx = 0; c_idx < c; ++c_idx) {
|
|
40
|
+
{{ acc_type }} log_prob_value = ({{ acc_type }})input_flat[base + c_idx * d] - max_value - logsum;
|
|
41
|
+
log_prob_flat[base + c_idx * d] = ({{ c_type }})log_prob_value;
|
|
42
|
+
if (c_idx == class_index) {
|
|
43
|
+
loss_value = -log_prob_value;
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
{% if use_ignore_index %}
|
|
47
|
+
if (ignored) {
|
|
48
|
+
{% if reduction == "none" %}
|
|
49
|
+
output_flat[target_index] = {{ zero_literal }};
|
|
50
|
+
{% endif %}
|
|
51
|
+
}
|
|
52
|
+
{% endif %}
|
|
53
|
+
{% else %}
|
|
54
|
+
{{ acc_type }} loss_value = {{ acc_zero_literal }};
|
|
55
|
+
{% if use_ignore_index %}
|
|
56
|
+
if (ignored) {
|
|
57
|
+
{% if reduction == "none" %}
|
|
58
|
+
output_flat[target_index] = {{ zero_literal }};
|
|
59
|
+
{% endif %}
|
|
60
|
+
} else {
|
|
61
|
+
{% endif %}
|
|
62
|
+
{{ acc_type }} log_prob_value = ({{ acc_type }})input_flat[base + class_index * d] - max_value - logsum;
|
|
63
|
+
loss_value = -log_prob_value;
|
|
64
|
+
{% if use_ignore_index %}
|
|
65
|
+
}
|
|
66
|
+
{% endif %}
|
|
67
|
+
{% endif %}
|
|
68
|
+
{% if use_ignore_index %}
|
|
69
|
+
if (!ignored) {
|
|
70
|
+
{% endif %}
|
|
71
|
+
{% if weight %}
|
|
72
|
+
{{ acc_type }} sample_weight = {{ weight }}[class_index];
|
|
73
|
+
loss_value *= sample_weight;
|
|
74
|
+
{% endif %}
|
|
75
|
+
{% if reduction == "none" %}
|
|
76
|
+
output_flat[target_index] = loss_value;
|
|
77
|
+
{% else %}
|
|
78
|
+
loss_sum += loss_value;
|
|
79
|
+
{% if reduction == "mean" %}
|
|
80
|
+
{% if weight %}
|
|
81
|
+
weight_sum += sample_weight;
|
|
82
|
+
{% elif use_ignore_index %}
|
|
83
|
+
weight_sum += {{ acc_one_literal }};
|
|
84
|
+
{% endif %}
|
|
85
|
+
{% endif %}
|
|
86
|
+
{% endif %}
|
|
87
|
+
{% if use_ignore_index %}
|
|
88
|
+
}
|
|
89
|
+
{% endif %}
|
|
90
|
+
}
|
|
91
|
+
}
|
|
92
|
+
{% if reduction == "mean" %}
|
|
93
|
+
{% if weight or use_ignore_index %}
|
|
94
|
+
if (weight_sum == {{ acc_zero_literal }}) {
|
|
95
|
+
output_flat[0] = {{ zero_literal }};
|
|
96
|
+
} else {
|
|
97
|
+
output_flat[0] = loss_sum / weight_sum;
|
|
98
|
+
}
|
|
99
|
+
{% else %}
|
|
100
|
+
output_flat[0] = loss_sum / ({{ n }} * {{ d }});
|
|
101
|
+
{% endif %}
|
|
102
|
+
{% elif reduction == "sum" %}
|
|
103
|
+
output_flat[0] = loss_sum;
|
|
104
|
+
{% endif %}
|
|
105
|
+
}
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
const {{ c_type }} *input_flat = (const {{ c_type }} *){{ input0 }};
|
|
3
|
+
{{ c_type }} *output_flat = ({{ c_type }} *){{ output }};
|
|
4
|
+
const idx_t outer = {{ outer }};
|
|
5
|
+
const idx_t axis_size = {{ axis_size }};
|
|
6
|
+
const idx_t inner = {{ inner }};
|
|
7
|
+
for (idx_t outer_idx = 0; outer_idx < outer; ++outer_idx) {
|
|
8
|
+
for (idx_t inner_idx = 0; inner_idx < inner; ++inner_idx) {
|
|
9
|
+
idx_t base = (outer_idx * axis_size * inner) + inner_idx;
|
|
10
|
+
{{ c_type }} max_value = input_flat[base];
|
|
11
|
+
for (idx_t axis_idx = 1; axis_idx < axis_size; ++axis_idx) {
|
|
12
|
+
{{ c_type }} value = input_flat[base + axis_idx * inner];
|
|
13
|
+
max_value = {{ max_fn }}(max_value, value);
|
|
14
|
+
}
|
|
15
|
+
{{ c_type }} sum = 0;
|
|
16
|
+
for (idx_t axis_idx = 0; axis_idx < axis_size; ++axis_idx) {
|
|
17
|
+
{{ c_type }} value = {{ exp_fn }}(input_flat[base + axis_idx * inner] - max_value);
|
|
18
|
+
output_flat[base + axis_idx * inner] = value;
|
|
19
|
+
sum += value;
|
|
20
|
+
}
|
|
21
|
+
for (idx_t axis_idx = 0; axis_idx < axis_size; ++axis_idx) {
|
|
22
|
+
output_flat[base + axis_idx * inner] /= sum;
|
|
23
|
+
}
|
|
24
|
+
}
|
|
25
|
+
}
|
|
26
|
+
}
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
const {{ c_type }} *input_data = (const {{ c_type }} *){{ input0 }};
|
|
3
|
+
{{ c_type }} *output_data = ({{ c_type }} *){{ output }};
|
|
4
|
+
idx_t output_index = 0;
|
|
5
|
+
for (idx_t n = 0; n < {{ batch }}; ++n) {
|
|
6
|
+
for (idx_t c_out = 0; c_out < {{ out_channels }}; ++c_out) {
|
|
7
|
+
idx_t c_in = c_out % {{ in_channels }};
|
|
8
|
+
idx_t temp = c_out / {{ in_channels }};
|
|
9
|
+
idx_t offset_h = temp / {{ blocksize }};
|
|
10
|
+
idx_t offset_w = temp % {{ blocksize }};
|
|
11
|
+
for (idx_t h_out = 0; h_out < {{ out_h }}; ++h_out) {
|
|
12
|
+
idx_t h_in = h_out * {{ blocksize }} + offset_h;
|
|
13
|
+
for (idx_t w_out = 0; w_out < {{ out_w }}; ++w_out) {
|
|
14
|
+
idx_t w_in = w_out * {{ blocksize }} + offset_w;
|
|
15
|
+
idx_t input_index = ((n * {{ in_channels }} + c_in) * {{ in_h }} + h_in) * {{ in_w }} + w_in;
|
|
16
|
+
output_data[output_index] = input_data[input_index];
|
|
17
|
+
output_index++;
|
|
18
|
+
}
|
|
19
|
+
}
|
|
20
|
+
}
|
|
21
|
+
}
|
|
22
|
+
}
|