emx-onnx-cgen 0.3.8__py3-none-any.whl → 0.4.2.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of emx-onnx-cgen might be problematic. Click here for more details.

Files changed (137) hide show
  1. emx_onnx_cgen/_build_info.py +1 -1
  2. emx_onnx_cgen/_version.py +2 -2
  3. emx_onnx_cgen/cli.py +1025 -162
  4. emx_onnx_cgen/codegen/__init__.py +2 -0
  5. emx_onnx_cgen/codegen/c_emitter.py +2081 -458
  6. emx_onnx_cgen/compiler.py +157 -75
  7. emx_onnx_cgen/determinism.py +39 -0
  8. emx_onnx_cgen/ir/context.py +25 -15
  9. emx_onnx_cgen/ir/model.py +1 -0
  10. emx_onnx_cgen/ir/op_base.py +32 -7
  11. emx_onnx_cgen/ir/ops/__init__.py +20 -0
  12. emx_onnx_cgen/ir/ops/elementwise.py +138 -22
  13. emx_onnx_cgen/ir/ops/misc.py +95 -0
  14. emx_onnx_cgen/ir/ops/nn.py +361 -38
  15. emx_onnx_cgen/ir/ops/reduce.py +1 -16
  16. emx_onnx_cgen/lowering/__init__.py +9 -0
  17. emx_onnx_cgen/lowering/arg_reduce.py +0 -4
  18. emx_onnx_cgen/lowering/average_pool.py +157 -27
  19. emx_onnx_cgen/lowering/bernoulli.py +73 -0
  20. emx_onnx_cgen/lowering/common.py +48 -0
  21. emx_onnx_cgen/lowering/concat.py +41 -7
  22. emx_onnx_cgen/lowering/conv.py +19 -8
  23. emx_onnx_cgen/lowering/conv_integer.py +103 -0
  24. emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
  25. emx_onnx_cgen/lowering/elementwise.py +140 -43
  26. emx_onnx_cgen/lowering/gather.py +11 -2
  27. emx_onnx_cgen/lowering/gemm.py +7 -124
  28. emx_onnx_cgen/lowering/global_max_pool.py +0 -5
  29. emx_onnx_cgen/lowering/gru.py +323 -0
  30. emx_onnx_cgen/lowering/hamming_window.py +104 -0
  31. emx_onnx_cgen/lowering/hardmax.py +1 -37
  32. emx_onnx_cgen/lowering/identity.py +7 -6
  33. emx_onnx_cgen/lowering/logsoftmax.py +1 -35
  34. emx_onnx_cgen/lowering/lp_pool.py +15 -4
  35. emx_onnx_cgen/lowering/matmul.py +3 -105
  36. emx_onnx_cgen/lowering/optional_has_element.py +28 -0
  37. emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
  38. emx_onnx_cgen/lowering/reduce.py +0 -5
  39. emx_onnx_cgen/lowering/reshape.py +7 -16
  40. emx_onnx_cgen/lowering/shape.py +14 -8
  41. emx_onnx_cgen/lowering/slice.py +14 -4
  42. emx_onnx_cgen/lowering/softmax.py +1 -35
  43. emx_onnx_cgen/lowering/split.py +37 -3
  44. emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
  45. emx_onnx_cgen/lowering/tile.py +38 -1
  46. emx_onnx_cgen/lowering/topk.py +1 -5
  47. emx_onnx_cgen/lowering/transpose.py +9 -3
  48. emx_onnx_cgen/lowering/unsqueeze.py +11 -16
  49. emx_onnx_cgen/lowering/upsample.py +151 -0
  50. emx_onnx_cgen/lowering/variadic.py +1 -1
  51. emx_onnx_cgen/lowering/where.py +0 -5
  52. emx_onnx_cgen/onnx_import.py +578 -14
  53. emx_onnx_cgen/ops.py +3 -0
  54. emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
  55. emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
  56. emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
  57. emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
  58. emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
  59. emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
  60. emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
  61. emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
  62. emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
  63. emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
  64. emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
  65. emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
  66. emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
  67. emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
  68. emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
  69. emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
  70. emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
  71. emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
  72. emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
  73. emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
  74. emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
  75. emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
  76. emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
  77. emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
  78. emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
  79. emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
  80. emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
  81. emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
  82. emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
  83. emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
  84. emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
  85. emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
  86. emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
  87. emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
  88. emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
  89. emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
  90. emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
  91. emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
  92. emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
  93. emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
  94. emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
  95. emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
  96. emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
  97. emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
  98. emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
  99. emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
  100. emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
  101. emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
  102. emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
  103. emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
  104. emx_onnx_cgen/templates/range_op.c.j2 +8 -0
  105. emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
  106. emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
  107. emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
  108. emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
  109. emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
  110. emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
  111. emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
  112. emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
  113. emx_onnx_cgen/templates/size_op.c.j2 +4 -0
  114. emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
  115. emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
  116. emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
  117. emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
  118. emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
  119. emx_onnx_cgen/templates/split_op.c.j2 +18 -0
  120. emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
  121. emx_onnx_cgen/templates/testbench.c.j2 +161 -0
  122. emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
  123. emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
  124. emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
  125. emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
  126. emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
  127. emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
  128. emx_onnx_cgen/templates/where_op.c.j2 +9 -0
  129. emx_onnx_cgen/verification.py +45 -5
  130. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/METADATA +33 -15
  131. emx_onnx_cgen-0.4.2.dev0.dist-info/RECORD +190 -0
  132. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/WHEEL +1 -1
  133. emx_onnx_cgen/runtime/__init__.py +0 -1
  134. emx_onnx_cgen/runtime/evaluator.py +0 -2955
  135. emx_onnx_cgen-0.3.8.dist-info/RECORD +0 -107
  136. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/entry_points.txt +0 -0
  137. {emx_onnx_cgen-0.3.8.dist-info → emx_onnx_cgen-0.4.2.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,179 @@
1
+ static inline int {{ op_name }}_suppress_by_iou(
2
+ const {{ input_c_type }} boxes[][4],
3
+ idx_t box_index1,
4
+ idx_t box_index2,
5
+ int center_point_box,
6
+ {{ compute_type }} iou_threshold
7
+ ) {
8
+ {{ compute_type }} box1_0 = ({{ compute_type }})boxes[box_index1][0];
9
+ {{ compute_type }} box1_1 = ({{ compute_type }})boxes[box_index1][1];
10
+ {{ compute_type }} box1_2 = ({{ compute_type }})boxes[box_index1][2];
11
+ {{ compute_type }} box1_3 = ({{ compute_type }})boxes[box_index1][3];
12
+ {{ compute_type }} box2_0 = ({{ compute_type }})boxes[box_index2][0];
13
+ {{ compute_type }} box2_1 = ({{ compute_type }})boxes[box_index2][1];
14
+ {{ compute_type }} box2_2 = ({{ compute_type }})boxes[box_index2][2];
15
+ {{ compute_type }} box2_3 = ({{ compute_type }})boxes[box_index2][3];
16
+ {{ compute_type }} x1_min;
17
+ {{ compute_type }} x1_max;
18
+ {{ compute_type }} x2_min;
19
+ {{ compute_type }} x2_max;
20
+ {{ compute_type }} y1_min;
21
+ {{ compute_type }} y1_max;
22
+ {{ compute_type }} y2_min;
23
+ {{ compute_type }} y2_max;
24
+
25
+ if (center_point_box == 0) {
26
+ x1_min = {{ min_fn }}(box1_1, box1_3);
27
+ x1_max = {{ max_fn }}(box1_1, box1_3);
28
+ x2_min = {{ min_fn }}(box2_1, box2_3);
29
+ x2_max = {{ max_fn }}(box2_1, box2_3);
30
+
31
+ y1_min = {{ min_fn }}(box1_0, box1_2);
32
+ y1_max = {{ max_fn }}(box1_0, box1_2);
33
+ y2_min = {{ min_fn }}(box2_0, box2_2);
34
+ y2_max = {{ max_fn }}(box2_0, box2_2);
35
+ } else {
36
+ {{ compute_type }} box1_width_half = box1_2 / ({{ compute_type }})2;
37
+ {{ compute_type }} box1_height_half = box1_3 / ({{ compute_type }})2;
38
+ {{ compute_type }} box2_width_half = box2_2 / ({{ compute_type }})2;
39
+ {{ compute_type }} box2_height_half = box2_3 / ({{ compute_type }})2;
40
+
41
+ x1_min = box1_0 - box1_width_half;
42
+ x1_max = box1_0 + box1_width_half;
43
+ x2_min = box2_0 - box2_width_half;
44
+ x2_max = box2_0 + box2_width_half;
45
+
46
+ y1_min = box1_1 - box1_height_half;
47
+ y1_max = box1_1 + box1_height_half;
48
+ y2_min = box2_1 - box2_height_half;
49
+ y2_max = box2_1 + box2_height_half;
50
+ }
51
+
52
+ {{ compute_type }} intersection_x_min = {{ max_fn }}(x1_min, x2_min);
53
+ {{ compute_type }} intersection_x_max = {{ min_fn }}(x1_max, x2_max);
54
+ if (intersection_x_max <= intersection_x_min) {
55
+ return 0;
56
+ }
57
+
58
+ {{ compute_type }} intersection_y_min = {{ max_fn }}(y1_min, y2_min);
59
+ {{ compute_type }} intersection_y_max = {{ min_fn }}(y1_max, y2_max);
60
+ if (intersection_y_max <= intersection_y_min) {
61
+ return 0;
62
+ }
63
+
64
+ {{ compute_type }} intersection_area =
65
+ (intersection_x_max - intersection_x_min)
66
+ * (intersection_y_max - intersection_y_min);
67
+ if (intersection_area <= ({{ compute_type }})0) {
68
+ return 0;
69
+ }
70
+
71
+ {{ compute_type }} area1 = (x1_max - x1_min) * (y1_max - y1_min);
72
+ {{ compute_type }} area2 = (x2_max - x2_min) * (y2_max - y2_min);
73
+ {{ compute_type }} union_area = area1 + area2 - intersection_area;
74
+ if (area1 <= ({{ compute_type }})0
75
+ || area2 <= ({{ compute_type }})0
76
+ || union_area <= ({{ compute_type }})0) {
77
+ return 0;
78
+ }
79
+
80
+ {{ compute_type }} intersection_over_union = intersection_area / union_area;
81
+ return intersection_over_union > iou_threshold;
82
+ }
83
+
84
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
85
+ idx_t output_index = 0;
86
+ const idx_t output_capacity = {{ output_capacity }};
87
+ int64_t max_output_boxes_per_class_value = 0;
88
+ {{ compute_type }} iou_threshold_value = {{ iou_threshold_default }};
89
+ {{ compute_type }} score_threshold_value = {{ score_threshold_default }};
90
+ {% if max_output_boxes_per_class %}
91
+ max_output_boxes_per_class_value =
92
+ (int64_t){{ max_output_boxes_per_class }}[0];
93
+ if (max_output_boxes_per_class_value < 0) {
94
+ max_output_boxes_per_class_value = 0;
95
+ }
96
+ {% endif %}
97
+ {% if iou_threshold %}
98
+ iou_threshold_value = ({{ compute_type }}){{ iou_threshold }}[0];
99
+ {% endif %}
100
+ {% if score_threshold %}
101
+ score_threshold_value = ({{ compute_type }}){{ score_threshold }}[0];
102
+ {% endif %}
103
+ if (output_capacity == 0 || max_output_boxes_per_class_value == 0) {
104
+ for (idx_t idx = 0; idx < output_capacity; ++idx) {
105
+ {{ output }}[idx][0] = 0;
106
+ {{ output }}[idx][1] = 0;
107
+ {{ output }}[idx][2] = 0;
108
+ }
109
+ return;
110
+ }
111
+
112
+ for (idx_t batch_index = 0; batch_index < {{ num_batches }}; ++batch_index) {
113
+ for (idx_t class_index = 0; class_index < {{ num_classes }}; ++class_index) {
114
+ {{ compute_type }} candidate_scores[{{ num_boxes }}];
115
+ idx_t candidate_indices[{{ num_boxes }}];
116
+ idx_t selected_indices[{{ num_boxes }}];
117
+ idx_t candidate_count = 0;
118
+ idx_t selected_count = 0;
119
+ for (idx_t box_index = 0; box_index < {{ num_boxes }}; ++box_index) {
120
+ {{ compute_type }} score = ({{ compute_type }})
121
+ {{ scores }}[batch_index][class_index][box_index];
122
+ {% if score_threshold_enabled %}
123
+ if (score <= score_threshold_value) {
124
+ continue;
125
+ }
126
+ {% endif %}
127
+ candidate_scores[candidate_count] = score;
128
+ candidate_indices[candidate_count] = box_index;
129
+ ++candidate_count;
130
+ }
131
+
132
+ for (idx_t i = 1; i < candidate_count; ++i) {
133
+ {{ compute_type }} candidate_score = candidate_scores[i];
134
+ idx_t candidate_index = candidate_indices[i];
135
+ idx_t j = i;
136
+ while (j > 0
137
+ && (candidate_scores[j - 1] > candidate_score
138
+ || (candidate_scores[j - 1] == candidate_score
139
+ && candidate_indices[j - 1] < candidate_index))) {
140
+ candidate_scores[j] = candidate_scores[j - 1];
141
+ candidate_indices[j] = candidate_indices[j - 1];
142
+ --j;
143
+ }
144
+ candidate_scores[j] = candidate_score;
145
+ candidate_indices[j] = candidate_index;
146
+ }
147
+
148
+ for (idx_t candidate_pos = candidate_count; candidate_pos > 0; --candidate_pos) {
149
+ if (selected_count >= (idx_t)max_output_boxes_per_class_value) {
150
+ break;
151
+ }
152
+ idx_t candidate_index = candidate_indices[candidate_pos - 1];
153
+ int selected = 1;
154
+ for (idx_t selected_idx = 0; selected_idx < selected_count; ++selected_idx) {
155
+ if ({{ op_name }}_suppress_by_iou(
156
+ {{ boxes }}[batch_index],
157
+ candidate_index,
158
+ selected_indices[selected_idx],
159
+ {{ center_point_box }},
160
+ iou_threshold_value)) {
161
+ selected = 0;
162
+ break;
163
+ }
164
+ }
165
+ if (selected) {
166
+ selected_indices[selected_count] = candidate_index;
167
+ if (output_index >= output_capacity) {
168
+ return;
169
+ }
170
+ {{ output }}[output_index][0] = ({{ output_c_type }})batch_index;
171
+ {{ output }}[output_index][1] = ({{ output_c_type }})class_index;
172
+ {{ output }}[output_index][2] = ({{ output_c_type }})candidate_index;
173
+ ++output_index;
174
+ ++selected_count;
175
+ }
176
+ }
177
+ }
178
+ }
179
+ }
@@ -0,0 +1,15 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ idx_t out_index = 0;
3
+ {% for dim in input_shape %}
4
+ for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
5
+ {% endfor %}
6
+ if ({{ input_expr }} != {{ zero_literal }}) {
7
+ {% for var in loop_vars %}
8
+ {{ output }}[{{ loop.index0 }}][out_index] = ({{ output_c_type }}){{ var }};
9
+ {% endfor %}
10
+ ++out_index;
11
+ }
12
+ {% for _ in input_shape %}
13
+ }
14
+ {% endfor %}
15
+ }
@@ -0,0 +1,25 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ (void){{ depth }};
3
+ const {{ c_type }} off_value = {{ values }}[0];
4
+ const {{ c_type }} on_value = {{ values }}[1];
5
+ const int64_t depth_value = (int64_t){{ depth_dim }};
6
+ {% for dim in output_shape %}
7
+ for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
8
+ {% endfor %}
9
+ int64_t index_value = (int64_t){{ indices }}{% for idx in indices_indices %}[{{ idx }}]{% endfor %};
10
+ int64_t adjusted = index_value;
11
+ if (index_value < 0) {
12
+ adjusted = index_value + depth_value;
13
+ }
14
+ {{ output }}{% for idx in loop_vars %}[{{ idx }}]{% endfor %} = off_value;
15
+ if (
16
+ index_value >= -depth_value
17
+ && index_value < depth_value
18
+ && (int64_t){{ axis_index }} == adjusted
19
+ ) {
20
+ {{ output }}{% for idx in loop_vars %}[{{ idx }}]{% endfor %} = on_value;
21
+ }
22
+ {% for _ in output_shape %}
23
+ }
24
+ {% endfor %}
25
+ }
@@ -0,0 +1,4 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ (void){{ input0 }};
3
+ {{ output }}[0] = {{ input_present }};
4
+ }
@@ -0,0 +1,80 @@
1
+ static inline void {{ op_name }}({{ dim_args }}const {{ c_type }} {{ input0 }}{{ input_suffix }}{% if pads_input %}, const {{ pads_c_type }} {{ pads_input }}{{ pads_suffix }}{% endif %}{% if axes_input %}, const {{ axes_c_type }} {{ axes_input }}{{ axes_suffix }}{% endif %}{% if value_input %}, const {{ c_type }} {{ value_input }}{{ value_suffix }}{% endif %}, {{ c_type }} {{ output }}{{ output_suffix }}) {
2
+ const {{ c_type }} *{{ input0_flat }} = (const {{ c_type }} *){{ input0 }};
3
+ {% if axes_input %}
4
+ {% if pads_values %}
5
+ const {{ pads_c_type }} pad_values[] = { {% for value in pads_values %}{{ value }}{{ ", " if not loop.last else "" }}{% endfor %} };
6
+ {% endif %}
7
+ idx_t pad_begin[{{ output_shape|length }}];
8
+ for (idx_t pad_index = 0; pad_index < {{ output_shape|length }}; ++pad_index) {
9
+ pad_begin[pad_index] = 0;
10
+ }
11
+ for (idx_t axis_index = 0; axis_index < {{ axes_length }}; ++axis_index) {
12
+ idx_t axis = (idx_t){{ axes_input }}[axis_index];
13
+ if (axis < 0) {
14
+ axis += {{ output_shape|length }};
15
+ }
16
+ if (axis >= 0 && axis < {{ output_shape|length }}) {
17
+ pad_begin[axis] = {% if pads_input %}{{ pads_input }}[axis_index]{% else %}pad_values[axis_index]{% endif %};
18
+ }
19
+ }
20
+ {% endif %}
21
+ {% for dim in output_shape %}
22
+ for (idx_t {{ out_loop_vars[loop.index0] }} = 0; {{ out_loop_vars[loop.index0] }} < {{ dim }}; ++{{ out_loop_vars[loop.index0] }}) {
23
+ {% endfor %}
24
+ {{ output }}{% for var in out_loop_vars %}[{{ var }}]{% endfor %} = {{ pad_value_expr }};
25
+ {% for _ in output_shape %}
26
+ }
27
+ {% endfor %}
28
+ {% for dim in output_shape %}
29
+ for (idx_t {{ out_loop_vars[loop.index0] }} = 0; {{ out_loop_vars[loop.index0] }} < {{ dim }}; ++{{ out_loop_vars[loop.index0] }}) {
30
+ {% endfor %}
31
+ idx_t {{ base_index }} = 0;
32
+ int pad_in_bounds = 1;
33
+ {% for index in range(output_shape|length) %}
34
+ idx_t {{ idx_vars[index] }} = {{ out_loop_vars[index] }} - (idx_t)({{ pad_begin_exprs[index] }});
35
+ if ({{ input_shape[index] }} == 0) {
36
+ pad_in_bounds = 0;
37
+ }
38
+ {% if mode == "constant" %}
39
+ if (pad_in_bounds && ({{ idx_vars[index] }} < 0 || {{ idx_vars[index] }} >= {{ input_shape[index] }})) {
40
+ pad_in_bounds = 0;
41
+ }
42
+ {% elif mode == "edge" %}
43
+ if (pad_in_bounds && {{ idx_vars[index] }} < 0) {
44
+ {{ idx_vars[index] }} = 0;
45
+ } else if (pad_in_bounds && {{ idx_vars[index] }} >= {{ input_shape[index] }}) {
46
+ {{ idx_vars[index] }} = {{ input_shape[index] }} - 1;
47
+ }
48
+ {% elif mode == "wrap" %}
49
+ if (pad_in_bounds) {
50
+ {{ idx_vars[index] }} %= {{ input_shape[index] }};
51
+ if ({{ idx_vars[index] }} < 0) {
52
+ {{ idx_vars[index] }} += {{ input_shape[index] }};
53
+ }
54
+ }
55
+ {% elif mode == "reflect" %}
56
+ if (pad_in_bounds && {{ input_shape[index] }} == 1) {
57
+ {{ idx_vars[index] }} = 0;
58
+ } else if (pad_in_bounds) {
59
+ idx_t {{ reflect_vars[index] }} = {{ input_shape[index] }} - 1;
60
+ {{ idx_vars[index] }} %= (2 * {{ reflect_vars[index] }});
61
+ if ({{ idx_vars[index] }} < 0) {
62
+ {{ idx_vars[index] }} += 2 * {{ reflect_vars[index] }};
63
+ }
64
+ if ({{ idx_vars[index] }} > {{ reflect_vars[index] }}) {
65
+ {{ idx_vars[index] }} = 2 * {{ reflect_vars[index] }} - {{ idx_vars[index] }};
66
+ }
67
+ }
68
+ {% endif %}
69
+ if (pad_in_bounds) {
70
+ {{ base_index }} += {{ idx_vars[index] }} * {{ input_strides[index] }};
71
+ }
72
+ {% endfor %}
73
+ if (!pad_in_bounds) {
74
+ } else {
75
+ {{ output }}{% for var in out_loop_vars %}[{{ var }}]{% endfor %} = {{ input0_flat }}[{{ base_index }}];
76
+ }
77
+ {% for _ in output_shape %}
78
+ }
79
+ {% endfor %}
80
+ }
@@ -0,0 +1,33 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% if scale_is_float16 %}
3
+ const {{ scale_type }} scale_product = ({{ scale_type }})((( {{ scale_type }}){{ input0_scale_expr }}) * (({{ scale_type }}){{ input1_scale_expr }}));
4
+ const {{ scale_type }} scale = ({{ scale_type }})(scale_product / ({{ scale_type }}){{ output_scale_expr }});
5
+ {% else %}
6
+ const {{ scale_type }} scale = (({{ scale_type }}){{ input0_scale_expr }}) * (({{ scale_type }}){{ input1_scale_expr }}) / (({{ scale_type }}){{ output_scale_expr }});
7
+ {% endif %}
8
+ const int32_t input0_zero = (int32_t){{ input0_zero_expr }};
9
+ const int32_t input1_zero = (int32_t){{ input1_zero_expr }};
10
+ const {{ compute_type }} output_zero = ({{ compute_type }}){{ output_zero_expr }};
11
+ {% for idx in range(output_loop_vars | length) %}
12
+ {% for indent in range(loop.index0) %} {% endfor %}for (idx_t {{ output_loop_vars[idx] }} = 0; {{ output_loop_vars[idx] }} < {{ output_loop_bounds[idx] }}; ++{{ output_loop_vars[idx] }}) {
13
+ {% endfor %}
14
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}int32_t acc = 0;
15
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}for (idx_t k = 0; k < {{ k }}; ++k) {
16
+ {% for indent in range(output_loop_vars | length + 1) %} {% endfor %}acc += ((int32_t){{ input0_index_expr }} - input0_zero) * ((int32_t){{ input1_index_expr }} - input1_zero);
17
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}}
18
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}{{ compute_type }} scaled = (({{ compute_type }})acc) * scale + output_zero;
19
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}{{ compute_type }} rounded = {{ round_fn }}(scaled);
20
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}{{ compute_type }} wrapped = {{ mod_fn }}(rounded, ({{ compute_type }})256.0);
21
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}if (wrapped < ({{ compute_type }})0.0) {
22
+ {% for indent in range(output_loop_vars | length + 1) %} {% endfor %}wrapped += ({{ compute_type }})256.0;
23
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}}
24
+ {% if output_is_signed %}
25
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}if (wrapped >= ({{ compute_type }})128.0) {
26
+ {% for indent in range(output_loop_vars | length + 1) %} {% endfor %}wrapped -= ({{ compute_type }})256.0;
27
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}}
28
+ {% endif %}
29
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}{{ output_index_expr }} = ({{ output_c_type }})wrapped;
30
+ {% for idx in range(output_loop_vars | length) | reverse %}
31
+ {% for indent in range(loop.index0) %} {% endfor %}}
32
+ {% endfor %}
33
+ }
@@ -0,0 +1,18 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ const {{ compute_type }} scale = (({{ compute_type }}){{ input0_scale_expr }}) * (({{ compute_type }}){{ input1_scale_expr }}) / (({{ compute_type }}){{ output_scale_expr }});
3
+ const int32_t input0_zero = (int32_t){{ input0_zero_expr }};
4
+ const int32_t input1_zero = (int32_t){{ input1_zero_expr }};
5
+ const {{ compute_type }} output_zero = ({{ compute_type }}){{ output_zero_expr }};
6
+ {% for idx in range(output_loop_vars | length) %}
7
+ {% for indent in range(loop.index0) %} {% endfor %}for (idx_t {{ output_loop_vars[idx] }} = 0; {{ output_loop_vars[idx] }} < {{ output_loop_bounds[idx] }}; ++{{ output_loop_vars[idx] }}) {
8
+ {% endfor %}
9
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}int32_t acc = ((int32_t){{ input0_index_expr }} - input0_zero) * ((int32_t){{ input1_index_expr }} - input1_zero);
10
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}{{ compute_type }} scaled = (({{ compute_type }})acc) * scale + output_zero;
11
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}{{ compute_type }} rounded = {{ round_fn }}(scaled);
12
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}rounded = {{ max_fn }}(rounded, ({{ compute_type }}){{ min_literal }});
13
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}rounded = {{ min_fn }}(rounded, ({{ compute_type }}){{ max_literal }});
14
+ {% for indent in range(output_loop_vars | length) %} {% endfor %}{{ output_index_expr }} = ({{ output_c_type }})rounded;
15
+ {% for idx in range(output_loop_vars | length) | reverse %}
16
+ {% for indent in range(loop.index0) %} {% endfor %}}
17
+ {% endfor %}
18
+ }
@@ -0,0 +1,13 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for dim in shape %}
3
+ for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
4
+ {% endfor %}
5
+ {{ compute_type }} scaled = (({{ compute_type }}){{ input_expr }} / ({{ compute_type }}){{ scale_expr }});
6
+ {{ compute_type }} rounded = {{ round_fn }}(scaled) + ({{ compute_type }}){{ zero_expr }};
7
+ rounded = {{ max_fn }}(rounded, ({{ compute_type }}){{ min_literal }});
8
+ rounded = {{ min_fn }}(rounded, ({{ compute_type }}){{ max_literal }});
9
+ {{ output_expr }} = ({{ output_c_type }})rounded;
10
+ {% for _ in shape %}
11
+ }
12
+ {% endfor %}
13
+ }
@@ -0,0 +1,8 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ (void){{ limit }};
3
+ const {{ c_type }} start_value = {{ start }}[0];
4
+ const {{ c_type }} delta_value = {{ delta }}[0];
5
+ for (idx_t idx = 0; idx < {{ length }}; ++idx) {
6
+ {{ output }}[idx] = start_value + (({{ c_type }})idx * delta_value);
7
+ }
8
+ }
@@ -0,0 +1,28 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ {% for dim in output_shape %}
3
+ for (idx_t {{ output_loop_vars[loop.index0] }} = 0; {{ output_loop_vars[loop.index0] }} < {{ dim }}; ++{{ output_loop_vars[loop.index0] }}) {
4
+ {% endfor %}
5
+ {{ c_type }} acc = {{ init_literal }};
6
+ {% if use_kahan %}
7
+ {{ c_type }} acc_comp = {{ zero_literal }};
8
+ {% endif %}
9
+ {% for dim in reduce_dims %}
10
+ for (idx_t {{ reduce_loop_vars[loop.index0] }} = 0; {{ reduce_loop_vars[loop.index0] }} < {{ dim }}; ++{{ reduce_loop_vars[loop.index0] }}) {
11
+ {% endfor %}
12
+ {% if use_kahan %}
13
+ {{ c_type }} kahan_value = {{ kahan_value_expr }};
14
+ {{ c_type }} kahan_y = kahan_value - acc_comp;
15
+ {{ c_type }} kahan_t = acc + kahan_y;
16
+ acc_comp = (kahan_t - acc) - kahan_y;
17
+ acc = kahan_t;
18
+ {% else %}
19
+ {{ update_expr }}
20
+ {% endif %}
21
+ {% for _ in reduce_dims %}
22
+ }
23
+ {% endfor %}
24
+ {{ output }}{{ output_index_expr }} = {{ final_expr }};
25
+ {% for _ in output_shape %}
26
+ }
27
+ {% endfor %}
28
+ }
@@ -0,0 +1,77 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ idx_t axis_count = {{ axes_count }};
3
+ bool reduce_mask[{{ input_shape | length }}];
4
+ for (idx_t i = 0; i < {{ input_shape | length }}; ++i) {
5
+ reduce_mask[i] = false;
6
+ }
7
+ if (axis_count == 0) {
8
+ {% if noop_with_empty_axes %}
9
+ {% for dim in input_shape %}
10
+ for (idx_t {{ input_loop_vars[loop.index0] }} = 0; {{ input_loop_vars[loop.index0] }} < {{ dim }}; ++{{ input_loop_vars[loop.index0] }}) {
11
+ {% endfor %}
12
+ {{ output }}{{ input_index_expr }} = {{ input0 }}{{ input_index_expr }};
13
+ {% for _ in input_shape %}
14
+ }
15
+ {% endfor %}
16
+ return;
17
+ {% else %}
18
+ for (idx_t i = 0; i < {{ input_shape | length }}; ++i) {
19
+ reduce_mask[i] = true;
20
+ }
21
+ {% endif %}
22
+ } else {
23
+ for (idx_t i = 0; i < axis_count; ++i) {
24
+ int axis = (int){{ axes_input }}[i];
25
+ if (axis < 0) {
26
+ axis += {{ input_shape | length }};
27
+ }
28
+ if (axis >= 0 && axis < {{ input_shape | length }}) {
29
+ reduce_mask[axis] = true;
30
+ }
31
+ }
32
+ }
33
+ idx_t reduce_count = 1;
34
+ {% for dim in input_shape %}
35
+ if (reduce_mask[{{ loop.index0 }}]) {
36
+ reduce_count *= {{ dim }};
37
+ }
38
+ {% endfor %}
39
+ {% for dim in output_shape %}
40
+ for (idx_t {{ output_loop_vars[loop.index0] }} = 0; {{ output_loop_vars[loop.index0] }} < {{ dim }}; ++{{ output_loop_vars[loop.index0] }}) {
41
+ {% endfor %}
42
+ {{ output }}{{ output_loop_index_expr }} = {{ init_literal }};
43
+ {% for _ in output_shape %}
44
+ }
45
+ {% endfor %}
46
+ {% for dim in input_shape %}
47
+ for (idx_t {{ input_loop_vars[loop.index0] }} = 0; {{ input_loop_vars[loop.index0] }} < {{ dim }}; ++{{ input_loop_vars[loop.index0] }}) {
48
+ {% endfor %}
49
+ idx_t out_indices[{{ output_shape | length }}];
50
+ {% if keepdims %}
51
+ {% for axis in range(input_shape | length) %}
52
+ out_indices[{{ axis }}] = reduce_mask[{{ axis }}] ? 0 : {{ input_loop_vars[axis] }};
53
+ {% endfor %}
54
+ {% else %}
55
+ idx_t out_pos = 0;
56
+ {% for axis in range(input_shape | length) %}
57
+ if (!reduce_mask[{{ axis }}]) {
58
+ out_indices[out_pos++] = {{ input_loop_vars[axis] }};
59
+ }
60
+ {% endfor %}
61
+ {% endif %}
62
+ {{ c_type }} *out_ptr = &{{ output }}{{ output_index_expr }};
63
+ {{ update_expr }}
64
+ {% for _ in input_shape %}
65
+ }
66
+ {% endfor %}
67
+ {% if post_expr %}
68
+ {% for dim in output_shape %}
69
+ for (idx_t {{ output_loop_vars[loop.index0] }} = 0; {{ output_loop_vars[loop.index0] }} < {{ dim }}; ++{{ output_loop_vars[loop.index0] }}) {
70
+ {% endfor %}
71
+ {{ c_type }} *out_ptr = &{{ output }}{{ output_loop_index_expr }};
72
+ {{ post_expr }}
73
+ {% for _ in output_shape %}
74
+ }
75
+ {% endfor %}
76
+ {% endif %}
77
+ }
@@ -0,0 +1,18 @@
1
+ static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
2
+ const {{ c_type }} *input0_data = (const {{ c_type }} *){{ input0 }};
3
+ {% for dim in output_shape %}
4
+ for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
5
+ {% endfor %}
6
+ {% if loop_vars %}
7
+ {% set ns = namespace(expr=loop_vars[0]) %}
8
+ {% for dim in output_shape[1:] %}
9
+ {% set ns.expr = "(" ~ ns.expr ~ " * " ~ dim ~ " + " ~ loop_vars[loop.index] ~ ")" %}
10
+ {% endfor %}
11
+ {{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = input0_data[{{ ns.expr }}];
12
+ {% else %}
13
+ {{ output }} = input0_data[0];
14
+ {% endif %}
15
+ {% for _ in output_shape %}
16
+ }
17
+ {% endfor %}
18
+ }