emx-onnx-cgen 0.3.7__py3-none-any.whl → 0.4.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/cli.py +1025 -162
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +2081 -458
- emx_onnx_cgen/compiler.py +157 -75
- emx_onnx_cgen/determinism.py +39 -0
- emx_onnx_cgen/ir/context.py +25 -15
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +32 -7
- emx_onnx_cgen/ir/ops/__init__.py +20 -0
- emx_onnx_cgen/ir/ops/elementwise.py +138 -22
- emx_onnx_cgen/ir/ops/misc.py +95 -0
- emx_onnx_cgen/ir/ops/nn.py +361 -38
- emx_onnx_cgen/ir/ops/reduce.py +1 -16
- emx_onnx_cgen/lowering/__init__.py +9 -0
- emx_onnx_cgen/lowering/arg_reduce.py +0 -4
- emx_onnx_cgen/lowering/average_pool.py +157 -27
- emx_onnx_cgen/lowering/bernoulli.py +73 -0
- emx_onnx_cgen/lowering/common.py +48 -0
- emx_onnx_cgen/lowering/concat.py +41 -7
- emx_onnx_cgen/lowering/conv.py +19 -8
- emx_onnx_cgen/lowering/conv_integer.py +103 -0
- emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
- emx_onnx_cgen/lowering/elementwise.py +140 -43
- emx_onnx_cgen/lowering/gather.py +11 -2
- emx_onnx_cgen/lowering/gemm.py +7 -124
- emx_onnx_cgen/lowering/global_max_pool.py +0 -5
- emx_onnx_cgen/lowering/gru.py +323 -0
- emx_onnx_cgen/lowering/hamming_window.py +104 -0
- emx_onnx_cgen/lowering/hardmax.py +1 -37
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/logsoftmax.py +1 -35
- emx_onnx_cgen/lowering/lp_pool.py +15 -4
- emx_onnx_cgen/lowering/matmul.py +3 -105
- emx_onnx_cgen/lowering/optional_has_element.py +28 -0
- emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
- emx_onnx_cgen/lowering/reduce.py +0 -5
- emx_onnx_cgen/lowering/reshape.py +7 -16
- emx_onnx_cgen/lowering/shape.py +14 -8
- emx_onnx_cgen/lowering/slice.py +14 -4
- emx_onnx_cgen/lowering/softmax.py +1 -35
- emx_onnx_cgen/lowering/split.py +37 -3
- emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
- emx_onnx_cgen/lowering/tile.py +38 -1
- emx_onnx_cgen/lowering/topk.py +1 -5
- emx_onnx_cgen/lowering/transpose.py +9 -3
- emx_onnx_cgen/lowering/unsqueeze.py +11 -16
- emx_onnx_cgen/lowering/upsample.py +151 -0
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +0 -5
- emx_onnx_cgen/onnx_import.py +578 -14
- emx_onnx_cgen/ops.py +3 -0
- emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
- emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
- emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
- emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
- emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
- emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
- emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
- emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
- emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
- emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
- emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
- emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
- emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
- emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
- emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
- emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
- emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
- emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
- emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
- emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
- emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
- emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
- emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
- emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
- emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
- emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
- emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
- emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
- emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
- emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
- emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
- emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
- emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
- emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
- emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
- emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
- emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
- emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
- emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
- emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
- emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
- emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
- emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
- emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
- emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
- emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
- emx_onnx_cgen/templates/range_op.c.j2 +8 -0
- emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
- emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
- emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
- emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
- emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
- emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
- emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
- emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
- emx_onnx_cgen/templates/size_op.c.j2 +4 -0
- emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
- emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
- emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
- emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
- emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
- emx_onnx_cgen/templates/split_op.c.j2 +18 -0
- emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
- emx_onnx_cgen/templates/testbench.c.j2 +161 -0
- emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
- emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
- emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
- emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
- emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
- emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
- emx_onnx_cgen/templates/where_op.c.j2 +9 -0
- emx_onnx_cgen/verification.py +45 -5
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/METADATA +33 -15
- emx_onnx_cgen-0.4.1.dev0.dist-info/RECORD +190 -0
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/WHEEL +1 -1
- emx_onnx_cgen/runtime/__init__.py +0 -1
- emx_onnx_cgen/runtime/evaluator.py +0 -2955
- emx_onnx_cgen-0.3.7.dist-info/RECORD +0 -107
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
static inline int {{ op_name }}_suppress_by_iou(
|
|
2
|
+
const {{ input_c_type }} boxes[][4],
|
|
3
|
+
idx_t box_index1,
|
|
4
|
+
idx_t box_index2,
|
|
5
|
+
int center_point_box,
|
|
6
|
+
{{ compute_type }} iou_threshold
|
|
7
|
+
) {
|
|
8
|
+
{{ compute_type }} box1_0 = ({{ compute_type }})boxes[box_index1][0];
|
|
9
|
+
{{ compute_type }} box1_1 = ({{ compute_type }})boxes[box_index1][1];
|
|
10
|
+
{{ compute_type }} box1_2 = ({{ compute_type }})boxes[box_index1][2];
|
|
11
|
+
{{ compute_type }} box1_3 = ({{ compute_type }})boxes[box_index1][3];
|
|
12
|
+
{{ compute_type }} box2_0 = ({{ compute_type }})boxes[box_index2][0];
|
|
13
|
+
{{ compute_type }} box2_1 = ({{ compute_type }})boxes[box_index2][1];
|
|
14
|
+
{{ compute_type }} box2_2 = ({{ compute_type }})boxes[box_index2][2];
|
|
15
|
+
{{ compute_type }} box2_3 = ({{ compute_type }})boxes[box_index2][3];
|
|
16
|
+
{{ compute_type }} x1_min;
|
|
17
|
+
{{ compute_type }} x1_max;
|
|
18
|
+
{{ compute_type }} x2_min;
|
|
19
|
+
{{ compute_type }} x2_max;
|
|
20
|
+
{{ compute_type }} y1_min;
|
|
21
|
+
{{ compute_type }} y1_max;
|
|
22
|
+
{{ compute_type }} y2_min;
|
|
23
|
+
{{ compute_type }} y2_max;
|
|
24
|
+
|
|
25
|
+
if (center_point_box == 0) {
|
|
26
|
+
x1_min = {{ min_fn }}(box1_1, box1_3);
|
|
27
|
+
x1_max = {{ max_fn }}(box1_1, box1_3);
|
|
28
|
+
x2_min = {{ min_fn }}(box2_1, box2_3);
|
|
29
|
+
x2_max = {{ max_fn }}(box2_1, box2_3);
|
|
30
|
+
|
|
31
|
+
y1_min = {{ min_fn }}(box1_0, box1_2);
|
|
32
|
+
y1_max = {{ max_fn }}(box1_0, box1_2);
|
|
33
|
+
y2_min = {{ min_fn }}(box2_0, box2_2);
|
|
34
|
+
y2_max = {{ max_fn }}(box2_0, box2_2);
|
|
35
|
+
} else {
|
|
36
|
+
{{ compute_type }} box1_width_half = box1_2 / ({{ compute_type }})2;
|
|
37
|
+
{{ compute_type }} box1_height_half = box1_3 / ({{ compute_type }})2;
|
|
38
|
+
{{ compute_type }} box2_width_half = box2_2 / ({{ compute_type }})2;
|
|
39
|
+
{{ compute_type }} box2_height_half = box2_3 / ({{ compute_type }})2;
|
|
40
|
+
|
|
41
|
+
x1_min = box1_0 - box1_width_half;
|
|
42
|
+
x1_max = box1_0 + box1_width_half;
|
|
43
|
+
x2_min = box2_0 - box2_width_half;
|
|
44
|
+
x2_max = box2_0 + box2_width_half;
|
|
45
|
+
|
|
46
|
+
y1_min = box1_1 - box1_height_half;
|
|
47
|
+
y1_max = box1_1 + box1_height_half;
|
|
48
|
+
y2_min = box2_1 - box2_height_half;
|
|
49
|
+
y2_max = box2_1 + box2_height_half;
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
{{ compute_type }} intersection_x_min = {{ max_fn }}(x1_min, x2_min);
|
|
53
|
+
{{ compute_type }} intersection_x_max = {{ min_fn }}(x1_max, x2_max);
|
|
54
|
+
if (intersection_x_max <= intersection_x_min) {
|
|
55
|
+
return 0;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
{{ compute_type }} intersection_y_min = {{ max_fn }}(y1_min, y2_min);
|
|
59
|
+
{{ compute_type }} intersection_y_max = {{ min_fn }}(y1_max, y2_max);
|
|
60
|
+
if (intersection_y_max <= intersection_y_min) {
|
|
61
|
+
return 0;
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
{{ compute_type }} intersection_area =
|
|
65
|
+
(intersection_x_max - intersection_x_min)
|
|
66
|
+
* (intersection_y_max - intersection_y_min);
|
|
67
|
+
if (intersection_area <= ({{ compute_type }})0) {
|
|
68
|
+
return 0;
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
{{ compute_type }} area1 = (x1_max - x1_min) * (y1_max - y1_min);
|
|
72
|
+
{{ compute_type }} area2 = (x2_max - x2_min) * (y2_max - y2_min);
|
|
73
|
+
{{ compute_type }} union_area = area1 + area2 - intersection_area;
|
|
74
|
+
if (area1 <= ({{ compute_type }})0
|
|
75
|
+
|| area2 <= ({{ compute_type }})0
|
|
76
|
+
|| union_area <= ({{ compute_type }})0) {
|
|
77
|
+
return 0;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
{{ compute_type }} intersection_over_union = intersection_area / union_area;
|
|
81
|
+
return intersection_over_union > iou_threshold;
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
85
|
+
idx_t output_index = 0;
|
|
86
|
+
const idx_t output_capacity = {{ output_capacity }};
|
|
87
|
+
int64_t max_output_boxes_per_class_value = 0;
|
|
88
|
+
{{ compute_type }} iou_threshold_value = {{ iou_threshold_default }};
|
|
89
|
+
{{ compute_type }} score_threshold_value = {{ score_threshold_default }};
|
|
90
|
+
{% if max_output_boxes_per_class %}
|
|
91
|
+
max_output_boxes_per_class_value =
|
|
92
|
+
(int64_t){{ max_output_boxes_per_class }}[0];
|
|
93
|
+
if (max_output_boxes_per_class_value < 0) {
|
|
94
|
+
max_output_boxes_per_class_value = 0;
|
|
95
|
+
}
|
|
96
|
+
{% endif %}
|
|
97
|
+
{% if iou_threshold %}
|
|
98
|
+
iou_threshold_value = ({{ compute_type }}){{ iou_threshold }}[0];
|
|
99
|
+
{% endif %}
|
|
100
|
+
{% if score_threshold %}
|
|
101
|
+
score_threshold_value = ({{ compute_type }}){{ score_threshold }}[0];
|
|
102
|
+
{% endif %}
|
|
103
|
+
if (output_capacity == 0 || max_output_boxes_per_class_value == 0) {
|
|
104
|
+
for (idx_t idx = 0; idx < output_capacity; ++idx) {
|
|
105
|
+
{{ output }}[idx][0] = 0;
|
|
106
|
+
{{ output }}[idx][1] = 0;
|
|
107
|
+
{{ output }}[idx][2] = 0;
|
|
108
|
+
}
|
|
109
|
+
return;
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
for (idx_t batch_index = 0; batch_index < {{ num_batches }}; ++batch_index) {
|
|
113
|
+
for (idx_t class_index = 0; class_index < {{ num_classes }}; ++class_index) {
|
|
114
|
+
{{ compute_type }} candidate_scores[{{ num_boxes }}];
|
|
115
|
+
idx_t candidate_indices[{{ num_boxes }}];
|
|
116
|
+
idx_t selected_indices[{{ num_boxes }}];
|
|
117
|
+
idx_t candidate_count = 0;
|
|
118
|
+
idx_t selected_count = 0;
|
|
119
|
+
for (idx_t box_index = 0; box_index < {{ num_boxes }}; ++box_index) {
|
|
120
|
+
{{ compute_type }} score = ({{ compute_type }})
|
|
121
|
+
{{ scores }}[batch_index][class_index][box_index];
|
|
122
|
+
{% if score_threshold_enabled %}
|
|
123
|
+
if (score <= score_threshold_value) {
|
|
124
|
+
continue;
|
|
125
|
+
}
|
|
126
|
+
{% endif %}
|
|
127
|
+
candidate_scores[candidate_count] = score;
|
|
128
|
+
candidate_indices[candidate_count] = box_index;
|
|
129
|
+
++candidate_count;
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
for (idx_t i = 1; i < candidate_count; ++i) {
|
|
133
|
+
{{ compute_type }} candidate_score = candidate_scores[i];
|
|
134
|
+
idx_t candidate_index = candidate_indices[i];
|
|
135
|
+
idx_t j = i;
|
|
136
|
+
while (j > 0
|
|
137
|
+
&& (candidate_scores[j - 1] > candidate_score
|
|
138
|
+
|| (candidate_scores[j - 1] == candidate_score
|
|
139
|
+
&& candidate_indices[j - 1] < candidate_index))) {
|
|
140
|
+
candidate_scores[j] = candidate_scores[j - 1];
|
|
141
|
+
candidate_indices[j] = candidate_indices[j - 1];
|
|
142
|
+
--j;
|
|
143
|
+
}
|
|
144
|
+
candidate_scores[j] = candidate_score;
|
|
145
|
+
candidate_indices[j] = candidate_index;
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
for (idx_t candidate_pos = candidate_count; candidate_pos > 0; --candidate_pos) {
|
|
149
|
+
if (selected_count >= (idx_t)max_output_boxes_per_class_value) {
|
|
150
|
+
break;
|
|
151
|
+
}
|
|
152
|
+
idx_t candidate_index = candidate_indices[candidate_pos - 1];
|
|
153
|
+
int selected = 1;
|
|
154
|
+
for (idx_t selected_idx = 0; selected_idx < selected_count; ++selected_idx) {
|
|
155
|
+
if ({{ op_name }}_suppress_by_iou(
|
|
156
|
+
{{ boxes }}[batch_index],
|
|
157
|
+
candidate_index,
|
|
158
|
+
selected_indices[selected_idx],
|
|
159
|
+
{{ center_point_box }},
|
|
160
|
+
iou_threshold_value)) {
|
|
161
|
+
selected = 0;
|
|
162
|
+
break;
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
if (selected) {
|
|
166
|
+
selected_indices[selected_count] = candidate_index;
|
|
167
|
+
if (output_index >= output_capacity) {
|
|
168
|
+
return;
|
|
169
|
+
}
|
|
170
|
+
{{ output }}[output_index][0] = ({{ output_c_type }})batch_index;
|
|
171
|
+
{{ output }}[output_index][1] = ({{ output_c_type }})class_index;
|
|
172
|
+
{{ output }}[output_index][2] = ({{ output_c_type }})candidate_index;
|
|
173
|
+
++output_index;
|
|
174
|
+
++selected_count;
|
|
175
|
+
}
|
|
176
|
+
}
|
|
177
|
+
}
|
|
178
|
+
}
|
|
179
|
+
}
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
idx_t out_index = 0;
|
|
3
|
+
{% for dim in input_shape %}
|
|
4
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
5
|
+
{% endfor %}
|
|
6
|
+
if ({{ input_expr }} != {{ zero_literal }}) {
|
|
7
|
+
{% for var in loop_vars %}
|
|
8
|
+
{{ output }}[{{ loop.index0 }}][out_index] = ({{ output_c_type }}){{ var }};
|
|
9
|
+
{% endfor %}
|
|
10
|
+
++out_index;
|
|
11
|
+
}
|
|
12
|
+
{% for _ in input_shape %}
|
|
13
|
+
}
|
|
14
|
+
{% endfor %}
|
|
15
|
+
}
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
(void){{ depth }};
|
|
3
|
+
const {{ c_type }} off_value = {{ values }}[0];
|
|
4
|
+
const {{ c_type }} on_value = {{ values }}[1];
|
|
5
|
+
const int64_t depth_value = (int64_t){{ depth_dim }};
|
|
6
|
+
{% for dim in output_shape %}
|
|
7
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
8
|
+
{% endfor %}
|
|
9
|
+
int64_t index_value = (int64_t){{ indices }}{% for idx in indices_indices %}[{{ idx }}]{% endfor %};
|
|
10
|
+
int64_t adjusted = index_value;
|
|
11
|
+
if (index_value < 0) {
|
|
12
|
+
adjusted = index_value + depth_value;
|
|
13
|
+
}
|
|
14
|
+
{{ output }}{% for idx in loop_vars %}[{{ idx }}]{% endfor %} = off_value;
|
|
15
|
+
if (
|
|
16
|
+
index_value >= -depth_value
|
|
17
|
+
&& index_value < depth_value
|
|
18
|
+
&& (int64_t){{ axis_index }} == adjusted
|
|
19
|
+
) {
|
|
20
|
+
{{ output }}{% for idx in loop_vars %}[{{ idx }}]{% endfor %} = on_value;
|
|
21
|
+
}
|
|
22
|
+
{% for _ in output_shape %}
|
|
23
|
+
}
|
|
24
|
+
{% endfor %}
|
|
25
|
+
}
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}const {{ c_type }} {{ input0 }}{{ input_suffix }}{% if pads_input %}, const {{ pads_c_type }} {{ pads_input }}{{ pads_suffix }}{% endif %}{% if axes_input %}, const {{ axes_c_type }} {{ axes_input }}{{ axes_suffix }}{% endif %}{% if value_input %}, const {{ c_type }} {{ value_input }}{{ value_suffix }}{% endif %}, {{ c_type }} {{ output }}{{ output_suffix }}) {
|
|
2
|
+
const {{ c_type }} *{{ input0_flat }} = (const {{ c_type }} *){{ input0 }};
|
|
3
|
+
{% if axes_input %}
|
|
4
|
+
{% if pads_values %}
|
|
5
|
+
const {{ pads_c_type }} pad_values[] = { {% for value in pads_values %}{{ value }}{{ ", " if not loop.last else "" }}{% endfor %} };
|
|
6
|
+
{% endif %}
|
|
7
|
+
idx_t pad_begin[{{ output_shape|length }}];
|
|
8
|
+
for (idx_t pad_index = 0; pad_index < {{ output_shape|length }}; ++pad_index) {
|
|
9
|
+
pad_begin[pad_index] = 0;
|
|
10
|
+
}
|
|
11
|
+
for (idx_t axis_index = 0; axis_index < {{ axes_length }}; ++axis_index) {
|
|
12
|
+
idx_t axis = (idx_t){{ axes_input }}[axis_index];
|
|
13
|
+
if (axis < 0) {
|
|
14
|
+
axis += {{ output_shape|length }};
|
|
15
|
+
}
|
|
16
|
+
if (axis >= 0 && axis < {{ output_shape|length }}) {
|
|
17
|
+
pad_begin[axis] = {% if pads_input %}{{ pads_input }}[axis_index]{% else %}pad_values[axis_index]{% endif %};
|
|
18
|
+
}
|
|
19
|
+
}
|
|
20
|
+
{% endif %}
|
|
21
|
+
{% for dim in output_shape %}
|
|
22
|
+
for (idx_t {{ out_loop_vars[loop.index0] }} = 0; {{ out_loop_vars[loop.index0] }} < {{ dim }}; ++{{ out_loop_vars[loop.index0] }}) {
|
|
23
|
+
{% endfor %}
|
|
24
|
+
{{ output }}{% for var in out_loop_vars %}[{{ var }}]{% endfor %} = {{ pad_value_expr }};
|
|
25
|
+
{% for _ in output_shape %}
|
|
26
|
+
}
|
|
27
|
+
{% endfor %}
|
|
28
|
+
{% for dim in output_shape %}
|
|
29
|
+
for (idx_t {{ out_loop_vars[loop.index0] }} = 0; {{ out_loop_vars[loop.index0] }} < {{ dim }}; ++{{ out_loop_vars[loop.index0] }}) {
|
|
30
|
+
{% endfor %}
|
|
31
|
+
idx_t {{ base_index }} = 0;
|
|
32
|
+
int pad_in_bounds = 1;
|
|
33
|
+
{% for index in range(output_shape|length) %}
|
|
34
|
+
idx_t {{ idx_vars[index] }} = {{ out_loop_vars[index] }} - (idx_t)({{ pad_begin_exprs[index] }});
|
|
35
|
+
if ({{ input_shape[index] }} == 0) {
|
|
36
|
+
pad_in_bounds = 0;
|
|
37
|
+
}
|
|
38
|
+
{% if mode == "constant" %}
|
|
39
|
+
if (pad_in_bounds && ({{ idx_vars[index] }} < 0 || {{ idx_vars[index] }} >= {{ input_shape[index] }})) {
|
|
40
|
+
pad_in_bounds = 0;
|
|
41
|
+
}
|
|
42
|
+
{% elif mode == "edge" %}
|
|
43
|
+
if (pad_in_bounds && {{ idx_vars[index] }} < 0) {
|
|
44
|
+
{{ idx_vars[index] }} = 0;
|
|
45
|
+
} else if (pad_in_bounds && {{ idx_vars[index] }} >= {{ input_shape[index] }}) {
|
|
46
|
+
{{ idx_vars[index] }} = {{ input_shape[index] }} - 1;
|
|
47
|
+
}
|
|
48
|
+
{% elif mode == "wrap" %}
|
|
49
|
+
if (pad_in_bounds) {
|
|
50
|
+
{{ idx_vars[index] }} %= {{ input_shape[index] }};
|
|
51
|
+
if ({{ idx_vars[index] }} < 0) {
|
|
52
|
+
{{ idx_vars[index] }} += {{ input_shape[index] }};
|
|
53
|
+
}
|
|
54
|
+
}
|
|
55
|
+
{% elif mode == "reflect" %}
|
|
56
|
+
if (pad_in_bounds && {{ input_shape[index] }} == 1) {
|
|
57
|
+
{{ idx_vars[index] }} = 0;
|
|
58
|
+
} else if (pad_in_bounds) {
|
|
59
|
+
idx_t {{ reflect_vars[index] }} = {{ input_shape[index] }} - 1;
|
|
60
|
+
{{ idx_vars[index] }} %= (2 * {{ reflect_vars[index] }});
|
|
61
|
+
if ({{ idx_vars[index] }} < 0) {
|
|
62
|
+
{{ idx_vars[index] }} += 2 * {{ reflect_vars[index] }};
|
|
63
|
+
}
|
|
64
|
+
if ({{ idx_vars[index] }} > {{ reflect_vars[index] }}) {
|
|
65
|
+
{{ idx_vars[index] }} = 2 * {{ reflect_vars[index] }} - {{ idx_vars[index] }};
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
{% endif %}
|
|
69
|
+
if (pad_in_bounds) {
|
|
70
|
+
{{ base_index }} += {{ idx_vars[index] }} * {{ input_strides[index] }};
|
|
71
|
+
}
|
|
72
|
+
{% endfor %}
|
|
73
|
+
if (!pad_in_bounds) {
|
|
74
|
+
} else {
|
|
75
|
+
{{ output }}{% for var in out_loop_vars %}[{{ var }}]{% endfor %} = {{ input0_flat }}[{{ base_index }}];
|
|
76
|
+
}
|
|
77
|
+
{% for _ in output_shape %}
|
|
78
|
+
}
|
|
79
|
+
{% endfor %}
|
|
80
|
+
}
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% if scale_is_float16 %}
|
|
3
|
+
const {{ scale_type }} scale_product = ({{ scale_type }})((( {{ scale_type }}){{ input0_scale_expr }}) * (({{ scale_type }}){{ input1_scale_expr }}));
|
|
4
|
+
const {{ scale_type }} scale = ({{ scale_type }})(scale_product / ({{ scale_type }}){{ output_scale_expr }});
|
|
5
|
+
{% else %}
|
|
6
|
+
const {{ scale_type }} scale = (({{ scale_type }}){{ input0_scale_expr }}) * (({{ scale_type }}){{ input1_scale_expr }}) / (({{ scale_type }}){{ output_scale_expr }});
|
|
7
|
+
{% endif %}
|
|
8
|
+
const int32_t input0_zero = (int32_t){{ input0_zero_expr }};
|
|
9
|
+
const int32_t input1_zero = (int32_t){{ input1_zero_expr }};
|
|
10
|
+
const {{ compute_type }} output_zero = ({{ compute_type }}){{ output_zero_expr }};
|
|
11
|
+
{% for idx in range(output_loop_vars | length) %}
|
|
12
|
+
{% for indent in range(loop.index0) %} {% endfor %}for (idx_t {{ output_loop_vars[idx] }} = 0; {{ output_loop_vars[idx] }} < {{ output_loop_bounds[idx] }}; ++{{ output_loop_vars[idx] }}) {
|
|
13
|
+
{% endfor %}
|
|
14
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}int32_t acc = 0;
|
|
15
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}for (idx_t k = 0; k < {{ k }}; ++k) {
|
|
16
|
+
{% for indent in range(output_loop_vars | length + 1) %} {% endfor %}acc += ((int32_t){{ input0_index_expr }} - input0_zero) * ((int32_t){{ input1_index_expr }} - input1_zero);
|
|
17
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}}
|
|
18
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}{{ compute_type }} scaled = (({{ compute_type }})acc) * scale + output_zero;
|
|
19
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}{{ compute_type }} rounded = {{ round_fn }}(scaled);
|
|
20
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}{{ compute_type }} wrapped = {{ mod_fn }}(rounded, ({{ compute_type }})256.0);
|
|
21
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}if (wrapped < ({{ compute_type }})0.0) {
|
|
22
|
+
{% for indent in range(output_loop_vars | length + 1) %} {% endfor %}wrapped += ({{ compute_type }})256.0;
|
|
23
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}}
|
|
24
|
+
{% if output_is_signed %}
|
|
25
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}if (wrapped >= ({{ compute_type }})128.0) {
|
|
26
|
+
{% for indent in range(output_loop_vars | length + 1) %} {% endfor %}wrapped -= ({{ compute_type }})256.0;
|
|
27
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}}
|
|
28
|
+
{% endif %}
|
|
29
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}{{ output_index_expr }} = ({{ output_c_type }})wrapped;
|
|
30
|
+
{% for idx in range(output_loop_vars | length) | reverse %}
|
|
31
|
+
{% for indent in range(loop.index0) %} {% endfor %}}
|
|
32
|
+
{% endfor %}
|
|
33
|
+
}
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
const {{ compute_type }} scale = (({{ compute_type }}){{ input0_scale_expr }}) * (({{ compute_type }}){{ input1_scale_expr }}) / (({{ compute_type }}){{ output_scale_expr }});
|
|
3
|
+
const int32_t input0_zero = (int32_t){{ input0_zero_expr }};
|
|
4
|
+
const int32_t input1_zero = (int32_t){{ input1_zero_expr }};
|
|
5
|
+
const {{ compute_type }} output_zero = ({{ compute_type }}){{ output_zero_expr }};
|
|
6
|
+
{% for idx in range(output_loop_vars | length) %}
|
|
7
|
+
{% for indent in range(loop.index0) %} {% endfor %}for (idx_t {{ output_loop_vars[idx] }} = 0; {{ output_loop_vars[idx] }} < {{ output_loop_bounds[idx] }}; ++{{ output_loop_vars[idx] }}) {
|
|
8
|
+
{% endfor %}
|
|
9
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}int32_t acc = ((int32_t){{ input0_index_expr }} - input0_zero) * ((int32_t){{ input1_index_expr }} - input1_zero);
|
|
10
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}{{ compute_type }} scaled = (({{ compute_type }})acc) * scale + output_zero;
|
|
11
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}{{ compute_type }} rounded = {{ round_fn }}(scaled);
|
|
12
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}rounded = {{ max_fn }}(rounded, ({{ compute_type }}){{ min_literal }});
|
|
13
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}rounded = {{ min_fn }}(rounded, ({{ compute_type }}){{ max_literal }});
|
|
14
|
+
{% for indent in range(output_loop_vars | length) %} {% endfor %}{{ output_index_expr }} = ({{ output_c_type }})rounded;
|
|
15
|
+
{% for idx in range(output_loop_vars | length) | reverse %}
|
|
16
|
+
{% for indent in range(loop.index0) %} {% endfor %}}
|
|
17
|
+
{% endfor %}
|
|
18
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in shape %}
|
|
3
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
{{ compute_type }} scaled = (({{ compute_type }}){{ input_expr }} / ({{ compute_type }}){{ scale_expr }});
|
|
6
|
+
{{ compute_type }} rounded = {{ round_fn }}(scaled) + ({{ compute_type }}){{ zero_expr }};
|
|
7
|
+
rounded = {{ max_fn }}(rounded, ({{ compute_type }}){{ min_literal }});
|
|
8
|
+
rounded = {{ min_fn }}(rounded, ({{ compute_type }}){{ max_literal }});
|
|
9
|
+
{{ output_expr }} = ({{ output_c_type }})rounded;
|
|
10
|
+
{% for _ in shape %}
|
|
11
|
+
}
|
|
12
|
+
{% endfor %}
|
|
13
|
+
}
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
(void){{ limit }};
|
|
3
|
+
const {{ c_type }} start_value = {{ start }}[0];
|
|
4
|
+
const {{ c_type }} delta_value = {{ delta }}[0];
|
|
5
|
+
for (idx_t idx = 0; idx < {{ length }}; ++idx) {
|
|
6
|
+
{{ output }}[idx] = start_value + (({{ c_type }})idx * delta_value);
|
|
7
|
+
}
|
|
8
|
+
}
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in output_shape %}
|
|
3
|
+
for (idx_t {{ output_loop_vars[loop.index0] }} = 0; {{ output_loop_vars[loop.index0] }} < {{ dim }}; ++{{ output_loop_vars[loop.index0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
{{ c_type }} acc = {{ init_literal }};
|
|
6
|
+
{% if use_kahan %}
|
|
7
|
+
{{ c_type }} acc_comp = {{ zero_literal }};
|
|
8
|
+
{% endif %}
|
|
9
|
+
{% for dim in reduce_dims %}
|
|
10
|
+
for (idx_t {{ reduce_loop_vars[loop.index0] }} = 0; {{ reduce_loop_vars[loop.index0] }} < {{ dim }}; ++{{ reduce_loop_vars[loop.index0] }}) {
|
|
11
|
+
{% endfor %}
|
|
12
|
+
{% if use_kahan %}
|
|
13
|
+
{{ c_type }} kahan_value = {{ kahan_value_expr }};
|
|
14
|
+
{{ c_type }} kahan_y = kahan_value - acc_comp;
|
|
15
|
+
{{ c_type }} kahan_t = acc + kahan_y;
|
|
16
|
+
acc_comp = (kahan_t - acc) - kahan_y;
|
|
17
|
+
acc = kahan_t;
|
|
18
|
+
{% else %}
|
|
19
|
+
{{ update_expr }}
|
|
20
|
+
{% endif %}
|
|
21
|
+
{% for _ in reduce_dims %}
|
|
22
|
+
}
|
|
23
|
+
{% endfor %}
|
|
24
|
+
{{ output }}{{ output_index_expr }} = {{ final_expr }};
|
|
25
|
+
{% for _ in output_shape %}
|
|
26
|
+
}
|
|
27
|
+
{% endfor %}
|
|
28
|
+
}
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
idx_t axis_count = {{ axes_count }};
|
|
3
|
+
bool reduce_mask[{{ input_shape | length }}];
|
|
4
|
+
for (idx_t i = 0; i < {{ input_shape | length }}; ++i) {
|
|
5
|
+
reduce_mask[i] = false;
|
|
6
|
+
}
|
|
7
|
+
if (axis_count == 0) {
|
|
8
|
+
{% if noop_with_empty_axes %}
|
|
9
|
+
{% for dim in input_shape %}
|
|
10
|
+
for (idx_t {{ input_loop_vars[loop.index0] }} = 0; {{ input_loop_vars[loop.index0] }} < {{ dim }}; ++{{ input_loop_vars[loop.index0] }}) {
|
|
11
|
+
{% endfor %}
|
|
12
|
+
{{ output }}{{ input_index_expr }} = {{ input0 }}{{ input_index_expr }};
|
|
13
|
+
{% for _ in input_shape %}
|
|
14
|
+
}
|
|
15
|
+
{% endfor %}
|
|
16
|
+
return;
|
|
17
|
+
{% else %}
|
|
18
|
+
for (idx_t i = 0; i < {{ input_shape | length }}; ++i) {
|
|
19
|
+
reduce_mask[i] = true;
|
|
20
|
+
}
|
|
21
|
+
{% endif %}
|
|
22
|
+
} else {
|
|
23
|
+
for (idx_t i = 0; i < axis_count; ++i) {
|
|
24
|
+
int axis = (int){{ axes_input }}[i];
|
|
25
|
+
if (axis < 0) {
|
|
26
|
+
axis += {{ input_shape | length }};
|
|
27
|
+
}
|
|
28
|
+
if (axis >= 0 && axis < {{ input_shape | length }}) {
|
|
29
|
+
reduce_mask[axis] = true;
|
|
30
|
+
}
|
|
31
|
+
}
|
|
32
|
+
}
|
|
33
|
+
idx_t reduce_count = 1;
|
|
34
|
+
{% for dim in input_shape %}
|
|
35
|
+
if (reduce_mask[{{ loop.index0 }}]) {
|
|
36
|
+
reduce_count *= {{ dim }};
|
|
37
|
+
}
|
|
38
|
+
{% endfor %}
|
|
39
|
+
{% for dim in output_shape %}
|
|
40
|
+
for (idx_t {{ output_loop_vars[loop.index0] }} = 0; {{ output_loop_vars[loop.index0] }} < {{ dim }}; ++{{ output_loop_vars[loop.index0] }}) {
|
|
41
|
+
{% endfor %}
|
|
42
|
+
{{ output }}{{ output_loop_index_expr }} = {{ init_literal }};
|
|
43
|
+
{% for _ in output_shape %}
|
|
44
|
+
}
|
|
45
|
+
{% endfor %}
|
|
46
|
+
{% for dim in input_shape %}
|
|
47
|
+
for (idx_t {{ input_loop_vars[loop.index0] }} = 0; {{ input_loop_vars[loop.index0] }} < {{ dim }}; ++{{ input_loop_vars[loop.index0] }}) {
|
|
48
|
+
{% endfor %}
|
|
49
|
+
idx_t out_indices[{{ output_shape | length }}];
|
|
50
|
+
{% if keepdims %}
|
|
51
|
+
{% for axis in range(input_shape | length) %}
|
|
52
|
+
out_indices[{{ axis }}] = reduce_mask[{{ axis }}] ? 0 : {{ input_loop_vars[axis] }};
|
|
53
|
+
{% endfor %}
|
|
54
|
+
{% else %}
|
|
55
|
+
idx_t out_pos = 0;
|
|
56
|
+
{% for axis in range(input_shape | length) %}
|
|
57
|
+
if (!reduce_mask[{{ axis }}]) {
|
|
58
|
+
out_indices[out_pos++] = {{ input_loop_vars[axis] }};
|
|
59
|
+
}
|
|
60
|
+
{% endfor %}
|
|
61
|
+
{% endif %}
|
|
62
|
+
{{ c_type }} *out_ptr = &{{ output }}{{ output_index_expr }};
|
|
63
|
+
{{ update_expr }}
|
|
64
|
+
{% for _ in input_shape %}
|
|
65
|
+
}
|
|
66
|
+
{% endfor %}
|
|
67
|
+
{% if post_expr %}
|
|
68
|
+
{% for dim in output_shape %}
|
|
69
|
+
for (idx_t {{ output_loop_vars[loop.index0] }} = 0; {{ output_loop_vars[loop.index0] }} < {{ dim }}; ++{{ output_loop_vars[loop.index0] }}) {
|
|
70
|
+
{% endfor %}
|
|
71
|
+
{{ c_type }} *out_ptr = &{{ output }}{{ output_loop_index_expr }};
|
|
72
|
+
{{ post_expr }}
|
|
73
|
+
{% for _ in output_shape %}
|
|
74
|
+
}
|
|
75
|
+
{% endfor %}
|
|
76
|
+
{% endif %}
|
|
77
|
+
}
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
const {{ c_type }} *input0_data = (const {{ c_type }} *){{ input0 }};
|
|
3
|
+
{% for dim in output_shape %}
|
|
4
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
5
|
+
{% endfor %}
|
|
6
|
+
{% if loop_vars %}
|
|
7
|
+
{% set ns = namespace(expr=loop_vars[0]) %}
|
|
8
|
+
{% for dim in output_shape[1:] %}
|
|
9
|
+
{% set ns.expr = "(" ~ ns.expr ~ " * " ~ dim ~ " + " ~ loop_vars[loop.index] ~ ")" %}
|
|
10
|
+
{% endfor %}
|
|
11
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = input0_data[{{ ns.expr }}];
|
|
12
|
+
{% else %}
|
|
13
|
+
{{ output }} = input0_data[0];
|
|
14
|
+
{% endif %}
|
|
15
|
+
{% for _ in output_shape %}
|
|
16
|
+
}
|
|
17
|
+
{% endfor %}
|
|
18
|
+
}
|