emx-onnx-cgen 0.3.7__py3-none-any.whl → 0.4.1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- emx_onnx_cgen/_build_info.py +1 -1
- emx_onnx_cgen/_version.py +2 -2
- emx_onnx_cgen/cli.py +1025 -162
- emx_onnx_cgen/codegen/__init__.py +2 -0
- emx_onnx_cgen/codegen/c_emitter.py +2081 -458
- emx_onnx_cgen/compiler.py +157 -75
- emx_onnx_cgen/determinism.py +39 -0
- emx_onnx_cgen/ir/context.py +25 -15
- emx_onnx_cgen/ir/model.py +1 -0
- emx_onnx_cgen/ir/op_base.py +32 -7
- emx_onnx_cgen/ir/ops/__init__.py +20 -0
- emx_onnx_cgen/ir/ops/elementwise.py +138 -22
- emx_onnx_cgen/ir/ops/misc.py +95 -0
- emx_onnx_cgen/ir/ops/nn.py +361 -38
- emx_onnx_cgen/ir/ops/reduce.py +1 -16
- emx_onnx_cgen/lowering/__init__.py +9 -0
- emx_onnx_cgen/lowering/arg_reduce.py +0 -4
- emx_onnx_cgen/lowering/average_pool.py +157 -27
- emx_onnx_cgen/lowering/bernoulli.py +73 -0
- emx_onnx_cgen/lowering/common.py +48 -0
- emx_onnx_cgen/lowering/concat.py +41 -7
- emx_onnx_cgen/lowering/conv.py +19 -8
- emx_onnx_cgen/lowering/conv_integer.py +103 -0
- emx_onnx_cgen/lowering/dequantize_linear.py +128 -0
- emx_onnx_cgen/lowering/elementwise.py +140 -43
- emx_onnx_cgen/lowering/gather.py +11 -2
- emx_onnx_cgen/lowering/gemm.py +7 -124
- emx_onnx_cgen/lowering/global_max_pool.py +0 -5
- emx_onnx_cgen/lowering/gru.py +323 -0
- emx_onnx_cgen/lowering/hamming_window.py +104 -0
- emx_onnx_cgen/lowering/hardmax.py +1 -37
- emx_onnx_cgen/lowering/identity.py +7 -6
- emx_onnx_cgen/lowering/logsoftmax.py +1 -35
- emx_onnx_cgen/lowering/lp_pool.py +15 -4
- emx_onnx_cgen/lowering/matmul.py +3 -105
- emx_onnx_cgen/lowering/optional_has_element.py +28 -0
- emx_onnx_cgen/lowering/qlinear_mul.py +116 -0
- emx_onnx_cgen/lowering/reduce.py +0 -5
- emx_onnx_cgen/lowering/reshape.py +7 -16
- emx_onnx_cgen/lowering/shape.py +14 -8
- emx_onnx_cgen/lowering/slice.py +14 -4
- emx_onnx_cgen/lowering/softmax.py +1 -35
- emx_onnx_cgen/lowering/split.py +37 -3
- emx_onnx_cgen/lowering/tfidf_vectorizer.py +199 -0
- emx_onnx_cgen/lowering/tile.py +38 -1
- emx_onnx_cgen/lowering/topk.py +1 -5
- emx_onnx_cgen/lowering/transpose.py +9 -3
- emx_onnx_cgen/lowering/unsqueeze.py +11 -16
- emx_onnx_cgen/lowering/upsample.py +151 -0
- emx_onnx_cgen/lowering/variadic.py +1 -1
- emx_onnx_cgen/lowering/where.py +0 -5
- emx_onnx_cgen/onnx_import.py +578 -14
- emx_onnx_cgen/ops.py +3 -0
- emx_onnx_cgen/templates/adagrad_op.c.j2 +16 -0
- emx_onnx_cgen/templates/arg_reduce_op.c.j2 +18 -0
- emx_onnx_cgen/templates/attention_op.c.j2 +189 -0
- emx_onnx_cgen/templates/average_pool_op.c.j2 +126 -0
- emx_onnx_cgen/templates/batch_norm_op.c.j2 +11 -0
- emx_onnx_cgen/templates/bernoulli_op.c.j2 +34 -0
- emx_onnx_cgen/templates/binary_op.c.j2 +9 -0
- emx_onnx_cgen/templates/cast_op.c.j2 +9 -0
- emx_onnx_cgen/templates/clip_op.c.j2 +14 -0
- emx_onnx_cgen/templates/concat_op.c.j2 +28 -0
- emx_onnx_cgen/templates/constant_of_shape_op.c.j2 +10 -0
- emx_onnx_cgen/templates/conv_integer_op.c.j2 +34 -0
- emx_onnx_cgen/templates/conv_op.c.j2 +32 -0
- emx_onnx_cgen/templates/conv_transpose_op.c.j2 +43 -0
- emx_onnx_cgen/templates/cumsum_op.c.j2 +51 -0
- emx_onnx_cgen/templates/depth_to_space_op.c.j2 +26 -0
- emx_onnx_cgen/templates/dequantize_linear_op.c.j2 +10 -0
- emx_onnx_cgen/templates/einsum_op.c.j2 +55 -0
- emx_onnx_cgen/templates/expand_op.c.j2 +14 -0
- emx_onnx_cgen/templates/eye_like_op.c.j2 +27 -0
- emx_onnx_cgen/templates/gather_elements_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gather_nd_op.c.j2 +29 -0
- emx_onnx_cgen/templates/gather_op.c.j2 +13 -0
- emx_onnx_cgen/templates/gemm_op.c.j2 +35 -0
- emx_onnx_cgen/templates/grid_sample_op.c.j2 +184 -0
- emx_onnx_cgen/templates/group_normalization_op.c.j2 +46 -0
- emx_onnx_cgen/templates/gru_op.c.j2 +152 -0
- emx_onnx_cgen/templates/hamming_window_op.c.j2 +12 -0
- emx_onnx_cgen/templates/hardmax_op.c.j2 +24 -0
- emx_onnx_cgen/templates/identity_op.c.j2 +9 -0
- emx_onnx_cgen/templates/instance_normalization_op.c.j2 +35 -0
- emx_onnx_cgen/templates/layer_normalization_op.c.j2 +65 -0
- emx_onnx_cgen/templates/logsoftmax_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_normalization_op.c.j2 +27 -0
- emx_onnx_cgen/templates/lp_pool_op.c.j2 +24 -0
- emx_onnx_cgen/templates/lrn_op.c.j2 +20 -0
- emx_onnx_cgen/templates/lstm_op.c.j2 +175 -0
- emx_onnx_cgen/templates/matmul_op.c.j2 +13 -0
- emx_onnx_cgen/templates/maxpool_op.c.j2 +118 -0
- emx_onnx_cgen/templates/mean_variance_normalization_op.c.j2 +34 -0
- emx_onnx_cgen/templates/multi_input_op.c.j2 +15 -0
- emx_onnx_cgen/templates/negative_log_likelihood_loss_op.c.j2 +54 -0
- emx_onnx_cgen/templates/nonmax_suppression_op.c.j2 +179 -0
- emx_onnx_cgen/templates/nonzero_op.c.j2 +15 -0
- emx_onnx_cgen/templates/one_hot_op.c.j2 +25 -0
- emx_onnx_cgen/templates/optional_has_element_op.c.j2 +4 -0
- emx_onnx_cgen/templates/pad_op.c.j2 +80 -0
- emx_onnx_cgen/templates/qlinear_matmul_op.c.j2 +33 -0
- emx_onnx_cgen/templates/qlinear_mul_op.c.j2 +18 -0
- emx_onnx_cgen/templates/quantize_linear_op.c.j2 +13 -0
- emx_onnx_cgen/templates/range_op.c.j2 +8 -0
- emx_onnx_cgen/templates/reduce_op.c.j2 +28 -0
- emx_onnx_cgen/templates/reduce_op_dynamic.c.j2 +77 -0
- emx_onnx_cgen/templates/reshape_op.c.j2 +18 -0
- emx_onnx_cgen/templates/resize_op.c.j2 +277 -0
- emx_onnx_cgen/templates/rms_normalization_op.c.j2 +28 -0
- emx_onnx_cgen/templates/rotary_embedding_op.c.j2 +66 -0
- emx_onnx_cgen/templates/scatter_nd_op.c.j2 +52 -0
- emx_onnx_cgen/templates/shape_op.c.j2 +6 -0
- emx_onnx_cgen/templates/size_op.c.j2 +4 -0
- emx_onnx_cgen/templates/slice_op.c.j2 +9 -0
- emx_onnx_cgen/templates/slice_op_dynamic.c.j2 +70 -0
- emx_onnx_cgen/templates/softmax_cross_entropy_loss_op.c.j2 +105 -0
- emx_onnx_cgen/templates/softmax_op.c.j2 +26 -0
- emx_onnx_cgen/templates/space_to_depth_op.c.j2 +22 -0
- emx_onnx_cgen/templates/split_op.c.j2 +18 -0
- emx_onnx_cgen/templates/tensor_scatter_op.c.j2 +44 -0
- emx_onnx_cgen/templates/testbench.c.j2 +161 -0
- emx_onnx_cgen/templates/tfidf_vectorizer_op.c.j2 +144 -0
- emx_onnx_cgen/templates/tile_op.c.j2 +14 -0
- emx_onnx_cgen/templates/topk_op.c.j2 +50 -0
- emx_onnx_cgen/templates/transpose_op.c.j2 +9 -0
- emx_onnx_cgen/templates/trilu_op.c.j2 +33 -0
- emx_onnx_cgen/templates/unary_op.c.j2 +23 -0
- emx_onnx_cgen/templates/where_op.c.j2 +9 -0
- emx_onnx_cgen/verification.py +45 -5
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/METADATA +33 -15
- emx_onnx_cgen-0.4.1.dev0.dist-info/RECORD +190 -0
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/WHEEL +1 -1
- emx_onnx_cgen/runtime/__init__.py +0 -1
- emx_onnx_cgen/runtime/evaluator.py +0 -2955
- emx_onnx_cgen-0.3.7.dist-info/RECORD +0 -107
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/entry_points.txt +0 -0
- {emx_onnx_cgen-0.3.7.dist-info → emx_onnx_cgen-0.4.1.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
const {{ c_type }} *input_data = (const {{ c_type }} *){{ input0 }};
|
|
3
|
+
{{ c_type }} *output_ptrs[] = { {% for output in outputs %}({{ c_type }} *){{ output }}{% if not loop.last %}, {% endif %}{% endfor %} };
|
|
4
|
+
const idx_t axis_sizes[] = { {% for axis in axis_sizes %}{{ axis }}{% if not loop.last %}, {% endif %}{% endfor %} };
|
|
5
|
+
for (idx_t outer_idx = 0; outer_idx < {{ outer }}; ++outer_idx) {
|
|
6
|
+
idx_t input_base = outer_idx * {{ axis_total }} * {{ inner }};
|
|
7
|
+
idx_t axis_offset = 0;
|
|
8
|
+
for (idx_t output_idx = 0; output_idx < {{ output_count }}; ++output_idx) {
|
|
9
|
+
idx_t copy_elems = axis_sizes[output_idx] * {{ inner }};
|
|
10
|
+
memcpy(
|
|
11
|
+
output_ptrs[output_idx] + outer_idx * copy_elems,
|
|
12
|
+
input_data + input_base + axis_offset,
|
|
13
|
+
copy_elems * sizeof({{ c_type }})
|
|
14
|
+
);
|
|
15
|
+
axis_offset += copy_elems;
|
|
16
|
+
}
|
|
17
|
+
}
|
|
18
|
+
}
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in output_shape %}
|
|
3
|
+
for (idx_t {{ output_loop_vars[loop.index0] }} = 0; {{ output_loop_vars[loop.index0] }} < {{ dim }}; ++{{ output_loop_vars[loop.index0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
{{ output }}{% for var in output_loop_vars %}[{{ var }}]{% endfor %} = {{ past_cache }}{% for var in output_loop_vars %}[{{ var }}]{% endfor %};
|
|
6
|
+
{% for _ in output_shape %}
|
|
7
|
+
}
|
|
8
|
+
{% endfor %}
|
|
9
|
+
|
|
10
|
+
{% if prefix_shape %}
|
|
11
|
+
{% for dim in prefix_shape %}
|
|
12
|
+
for (idx_t {{ prefix_loop_vars[loop.index0] }} = 0; {{ prefix_loop_vars[loop.index0] }} < {{ dim }}; ++{{ prefix_loop_vars[loop.index0] }}) {
|
|
13
|
+
{% endfor %}
|
|
14
|
+
{% endif %}
|
|
15
|
+
idx_t {{ write_index_var }} = 0;
|
|
16
|
+
{% if write_indices_present %}
|
|
17
|
+
{{ write_index_var }} = (idx_t){{ write_indices }}[{{ batch_index_var }}];
|
|
18
|
+
{% endif %}
|
|
19
|
+
for (idx_t {{ sequence_loop_var }} = 0; {{ sequence_loop_var }} < {{ sequence_dim }}; ++{{ sequence_loop_var }}) {
|
|
20
|
+
idx_t {{ cache_index_var }} = {{ write_index_var }} + {{ sequence_loop_var }};
|
|
21
|
+
{% if circular %}
|
|
22
|
+
{{ cache_index_var }} = {{ cache_index_var }} % {{ max_sequence_length }};
|
|
23
|
+
if ({{ cache_index_var }} < 0) {
|
|
24
|
+
{{ cache_index_var }} += {{ max_sequence_length }};
|
|
25
|
+
}
|
|
26
|
+
{% endif %}
|
|
27
|
+
{% if tail_shape %}
|
|
28
|
+
{% for dim in tail_shape %}
|
|
29
|
+
for (idx_t {{ tail_loop_vars[loop.index0] }} = 0; {{ tail_loop_vars[loop.index0] }} < {{ dim }}; ++{{ tail_loop_vars[loop.index0] }}) {
|
|
30
|
+
{% endfor %}
|
|
31
|
+
{% endif %}
|
|
32
|
+
{{ output_index_expr }} = {{ update_index_expr }};
|
|
33
|
+
{% if tail_shape %}
|
|
34
|
+
{% for _ in tail_shape %}
|
|
35
|
+
}
|
|
36
|
+
{% endfor %}
|
|
37
|
+
{% endif %}
|
|
38
|
+
}
|
|
39
|
+
{% if prefix_shape %}
|
|
40
|
+
{% for _ in prefix_shape %}
|
|
41
|
+
}
|
|
42
|
+
{% endfor %}
|
|
43
|
+
{% endif %}
|
|
44
|
+
}
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
{% if rng_requires_u64 %}
|
|
2
|
+
static uint64_t rng_state = 0x243f6a8885a308d3ull;
|
|
3
|
+
|
|
4
|
+
static uint64_t rng_next_u64(void) {
|
|
5
|
+
uint64_t x = rng_state;
|
|
6
|
+
x ^= x >> 12;
|
|
7
|
+
x ^= x << 25;
|
|
8
|
+
x ^= x >> 27;
|
|
9
|
+
rng_state = x;
|
|
10
|
+
return x * 0x2545f4914f6cdd1dull;
|
|
11
|
+
}
|
|
12
|
+
{% endif %}
|
|
13
|
+
|
|
14
|
+
{% if rng_requires_float %}
|
|
15
|
+
static float rng_next_float(void) {
|
|
16
|
+
return (float)((double)rng_next_u64() * (1.0 / 18446744073709551616.0));
|
|
17
|
+
}
|
|
18
|
+
{% endif %}
|
|
19
|
+
|
|
20
|
+
{% if rng_requires_double %}
|
|
21
|
+
static double rng_next_double(void) {
|
|
22
|
+
return (double)rng_next_u64() * (1.0 / 18446744073709551616.0);
|
|
23
|
+
}
|
|
24
|
+
{% endif %}
|
|
25
|
+
|
|
26
|
+
{% if rng_requires_i64 %}
|
|
27
|
+
static int64_t rng_next_i64(void) {
|
|
28
|
+
return (int64_t)rng_next_u64();
|
|
29
|
+
}
|
|
30
|
+
{% endif %}
|
|
31
|
+
|
|
32
|
+
{% for input in inputs %}
|
|
33
|
+
{% if input.constant_lines %}
|
|
34
|
+
static const {{ input.c_type }} {{ input.constant_name }}[] = {
|
|
35
|
+
{% for line in input.constant_lines %}
|
|
36
|
+
{{ line }}{% if not loop.last %},{% endif %}
|
|
37
|
+
{% endfor %}
|
|
38
|
+
};
|
|
39
|
+
{% endif %}
|
|
40
|
+
{% endfor %}
|
|
41
|
+
|
|
42
|
+
int main(int argc, char **argv) {
|
|
43
|
+
FILE *input_file = NULL;
|
|
44
|
+
if (argc > 1) {
|
|
45
|
+
input_file = fopen(argv[1], "rb");
|
|
46
|
+
if (!input_file) {
|
|
47
|
+
fprintf(stderr, "Failed to open input file: %s\n", argv[1]);
|
|
48
|
+
return 1;
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
{% for dim in dim_args %}
|
|
52
|
+
int {{ dim.name }} = {{ dim.value }};
|
|
53
|
+
{% endfor %}
|
|
54
|
+
|
|
55
|
+
{% for input in inputs %}
|
|
56
|
+
{% if input.optional_flag_name %}
|
|
57
|
+
_Bool {{ input.optional_flag_name }} = {{ "true" if input.optional_present else "false" }};
|
|
58
|
+
{% endif %}
|
|
59
|
+
{{ input.c_type }} {{ input.name }}{{ input.array_suffix }};
|
|
60
|
+
{% if input.constant_name %}
|
|
61
|
+
{% if input.rank == 0 %}
|
|
62
|
+
{{ input.name }} = {{ input.constant_name }}[0];
|
|
63
|
+
{% else %}
|
|
64
|
+
{% for depth in range(input.rank) %}
|
|
65
|
+
for (idx_t {{ input.loop_vars[depth] }} = 0; {{ input.loop_vars[depth] }} < {{ input.shape[depth] }}; ++{{ input.loop_vars[depth] }}) {
|
|
66
|
+
{% endfor %}
|
|
67
|
+
{{ input.name }}{{ input.array_index_expr }} = {{ input.constant_name }}[{{ input.index_expr }}];
|
|
68
|
+
{% for depth in range(input.rank - 1, -1, -1) %}
|
|
69
|
+
}
|
|
70
|
+
{% endfor %}
|
|
71
|
+
{% endif %}
|
|
72
|
+
{% else %}
|
|
73
|
+
if (input_file) {
|
|
74
|
+
{% for depth in range(input.rank) %}
|
|
75
|
+
for (idx_t {{ input.loop_vars[depth] }} = 0; {{ input.loop_vars[depth] }} < {{ input.shape[depth] }}; ++{{ input.loop_vars[depth] }}) {
|
|
76
|
+
{% endfor %}
|
|
77
|
+
if (fread(&{{ input.name }}{{ input.array_index_expr }}, sizeof({{ input.c_type }}), 1, input_file) != 1) {
|
|
78
|
+
fprintf(stderr, "Failed to read input {{ input.json_name }}\n");
|
|
79
|
+
return 1;
|
|
80
|
+
}
|
|
81
|
+
{% for depth in range(input.rank - 1, -1, -1) %}
|
|
82
|
+
}
|
|
83
|
+
{% endfor %}
|
|
84
|
+
} else {
|
|
85
|
+
{% for depth in range(input.rank) %}
|
|
86
|
+
for (idx_t {{ input.loop_vars[depth] }} = 0; {{ input.loop_vars[depth] }} < {{ input.shape[depth] }}; ++{{ input.loop_vars[depth] }}) {
|
|
87
|
+
{% endfor %}
|
|
88
|
+
{{ input.name }}{{ input.array_index_expr }} = {{ input.random_expr }};
|
|
89
|
+
{% for depth in range(input.rank - 1, -1, -1) %}
|
|
90
|
+
}
|
|
91
|
+
{% endfor %}
|
|
92
|
+
}
|
|
93
|
+
{% endif %}
|
|
94
|
+
{% endfor %}
|
|
95
|
+
if (input_file) {
|
|
96
|
+
fclose(input_file);
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
{% for output in outputs %}
|
|
100
|
+
{{ output.c_type }} {{ output.name }}{{ output.array_suffix }};
|
|
101
|
+
{% endfor %}
|
|
102
|
+
|
|
103
|
+
if (!{{ model_name }}_load("{{ weight_data_filename }}")) {
|
|
104
|
+
return 1;
|
|
105
|
+
}
|
|
106
|
+
{{ model_name }}({% for dim in dim_args %}{{ dim.name }}, {% endfor %}{% for input in inputs %}{{ input.name }}{% if input.optional_flag_name %}, {{ input.optional_flag_name }}{% endif %}, {% endfor %}{% for output in outputs %}{{ output.name }}{% if not loop.last %}, {% endif %}{% endfor %});
|
|
107
|
+
|
|
108
|
+
printf("{\"inputs\":{");
|
|
109
|
+
{% for input in inputs %}
|
|
110
|
+
{% if not loop.first %}
|
|
111
|
+
printf(",");
|
|
112
|
+
{% endif %}
|
|
113
|
+
printf("\"{{ input.json_name }}\":{\"shape\":[{{ input.shape_literal }}],\"data\":");
|
|
114
|
+
printf("[");
|
|
115
|
+
{% for depth in range(input.rank) %}
|
|
116
|
+
for (idx_t {{ input.loop_vars[depth] }} = 0; {{ input.loop_vars[depth] }} < {{ input.shape[depth] }}; ++{{ input.loop_vars[depth] }}) {
|
|
117
|
+
if ({{ input.loop_vars[depth] }}) {
|
|
118
|
+
printf(",");
|
|
119
|
+
}
|
|
120
|
+
{% if depth < input.rank - 1 %}
|
|
121
|
+
printf("[");
|
|
122
|
+
{% endif %}
|
|
123
|
+
{% endfor %}
|
|
124
|
+
printf("{{ input.print_format }}", {{ input.print_cast }}{{ input.name }}{{ input.array_index_expr }});
|
|
125
|
+
{% for depth in range(input.rank - 1, -1, -1) %}
|
|
126
|
+
{% if depth < input.rank - 1 %}
|
|
127
|
+
printf("]");
|
|
128
|
+
{% endif %}
|
|
129
|
+
}
|
|
130
|
+
{% endfor %}
|
|
131
|
+
printf("]}");
|
|
132
|
+
{% endfor %}
|
|
133
|
+
|
|
134
|
+
printf("},\"outputs\":{");
|
|
135
|
+
{% for output in outputs %}
|
|
136
|
+
{% if not loop.first %}
|
|
137
|
+
printf(",");
|
|
138
|
+
{% endif %}
|
|
139
|
+
printf("\"{{ output.json_name }}\":{\"shape\":[{{ output.shape_literal }}],\"data\":");
|
|
140
|
+
printf("[");
|
|
141
|
+
{% for depth in range(output.rank) %}
|
|
142
|
+
for (idx_t {{ output.loop_vars[depth] }} = 0; {{ output.loop_vars[depth] }} < {{ output.shape[depth] }}; ++{{ output.loop_vars[depth] }}) {
|
|
143
|
+
if ({{ output.loop_vars[depth] }}) {
|
|
144
|
+
printf(",");
|
|
145
|
+
}
|
|
146
|
+
{% if depth < output.rank - 1 %}
|
|
147
|
+
printf("[");
|
|
148
|
+
{% endif %}
|
|
149
|
+
{% endfor %}
|
|
150
|
+
printf("{{ output.print_format }}", {{ output.print_cast }}{{ output.name }}{{ output.array_index_expr }});
|
|
151
|
+
{% for depth in range(output.rank - 1, -1, -1) %}
|
|
152
|
+
{% if depth < output.rank - 1 %}
|
|
153
|
+
printf("]");
|
|
154
|
+
{% endif %}
|
|
155
|
+
}
|
|
156
|
+
{% endfor %}
|
|
157
|
+
printf("]}");
|
|
158
|
+
{% endfor %}
|
|
159
|
+
printf("}}\n");
|
|
160
|
+
return 0;
|
|
161
|
+
}
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
const int64_t pool[{{ pool_size if pool_size > 0 else 1 }}] = { {% if pool_values %}{{ pool_values | join(', ') }}{% else %}0{% endif %} };
|
|
3
|
+
const int64_t ngram_counts[{{ ngram_counts_len if ngram_counts_len > 0 else 1 }}] = { {% if ngram_counts_values %}{{ ngram_counts_values | join(', ') }}{% else %}0{% endif %} };
|
|
4
|
+
const int64_t ngram_indexes[{{ ngram_index_len if ngram_index_len > 0 else 1 }}] = { {% if ngram_indexes_values %}{{ ngram_indexes_values | join(', ') }}{% else %}0{% endif %} };
|
|
5
|
+
{% if weights_values %}
|
|
6
|
+
const {{ c_type }} weights[{{ weights_values | length if weights_values | length > 0 else 1 }}] = { {% if weights_values %}{{ weights_values | join(', ') }}{% else %}{{ one_literal }}{% endif %} };
|
|
7
|
+
{% endif %}
|
|
8
|
+
const idx_t output_dim = {{ output_dim }};
|
|
9
|
+
const idx_t pool_size = {{ pool_size }};
|
|
10
|
+
const idx_t ngram_counts_len = {{ ngram_counts_len }};
|
|
11
|
+
const idx_t max_skip = {{ max_skip_count }};
|
|
12
|
+
const idx_t min_gram = {{ min_gram_length }};
|
|
13
|
+
const idx_t max_gram = {{ max_gram_length }};
|
|
14
|
+
{% if input_rank == 1 %}
|
|
15
|
+
const idx_t seq_len = {{ input_shape[0] }};
|
|
16
|
+
for (idx_t o = 0; o < output_dim; ++o) {
|
|
17
|
+
{{ output }}[o] = {{ zero_literal }};
|
|
18
|
+
}
|
|
19
|
+
idx_t ngram_index_offset = 0;
|
|
20
|
+
for (idx_t gram_len = 1; gram_len < min_gram; ++gram_len) {
|
|
21
|
+
const idx_t count_start = (idx_t)ngram_counts[gram_len - 1];
|
|
22
|
+
const idx_t count_end =
|
|
23
|
+
gram_len < ngram_counts_len ? (idx_t)ngram_counts[gram_len] : pool_size;
|
|
24
|
+
const idx_t num_ngrams = (count_end - count_start) / gram_len;
|
|
25
|
+
ngram_index_offset += num_ngrams;
|
|
26
|
+
}
|
|
27
|
+
for (idx_t gram_len = min_gram; gram_len <= max_gram; ++gram_len) {
|
|
28
|
+
const idx_t count_start = (idx_t)ngram_counts[gram_len - 1];
|
|
29
|
+
const idx_t count_end =
|
|
30
|
+
gram_len < ngram_counts_len ? (idx_t)ngram_counts[gram_len] : pool_size;
|
|
31
|
+
const idx_t num_ngrams = (count_end - count_start) / gram_len;
|
|
32
|
+
if (num_ngrams == 0) {
|
|
33
|
+
continue;
|
|
34
|
+
}
|
|
35
|
+
const idx_t skip_limit = gram_len == 1 ? 0 : max_skip;
|
|
36
|
+
for (idx_t skip = 0; skip <= skip_limit; ++skip) {
|
|
37
|
+
const idx_t stride = skip + 1;
|
|
38
|
+
if (seq_len < (gram_len - 1) * stride + 1) {
|
|
39
|
+
continue;
|
|
40
|
+
}
|
|
41
|
+
const idx_t max_start = seq_len - (gram_len - 1) * stride;
|
|
42
|
+
for (idx_t start = 0; start < max_start; ++start) {
|
|
43
|
+
for (idx_t ngram_idx = 0; ngram_idx < num_ngrams; ++ngram_idx) {
|
|
44
|
+
const idx_t pool_offset = count_start + ngram_idx * gram_len;
|
|
45
|
+
int match = 1;
|
|
46
|
+
for (idx_t pos = 0; pos < gram_len; ++pos) {
|
|
47
|
+
const int64_t token = (int64_t){{ input0 }}[start + pos * stride];
|
|
48
|
+
if (token != pool[pool_offset + pos]) {
|
|
49
|
+
match = 0;
|
|
50
|
+
break;
|
|
51
|
+
}
|
|
52
|
+
}
|
|
53
|
+
if (match) {
|
|
54
|
+
const idx_t out_index =
|
|
55
|
+
(idx_t)ngram_indexes[ngram_index_offset + ngram_idx];
|
|
56
|
+
{{ output }}[out_index] += ({{ c_type }})1;
|
|
57
|
+
}
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
ngram_index_offset += num_ngrams;
|
|
62
|
+
}
|
|
63
|
+
{% if mode_id != 0 %}
|
|
64
|
+
for (idx_t o = 0; o < output_dim; ++o) {
|
|
65
|
+
{{ c_type }} value = {{ output }}[o];
|
|
66
|
+
{% if mode_id == 1 %}
|
|
67
|
+
value = value > {{ zero_literal }} ? {{ one_literal }} : {{ zero_literal }};
|
|
68
|
+
{% endif %}
|
|
69
|
+
{% if weights_values %}
|
|
70
|
+
value *= weights[o];
|
|
71
|
+
{% else %}
|
|
72
|
+
value *= {{ one_literal }};
|
|
73
|
+
{% endif %}
|
|
74
|
+
{{ output }}[o] = value;
|
|
75
|
+
}
|
|
76
|
+
{% endif %}
|
|
77
|
+
{% else %}
|
|
78
|
+
const idx_t batch = {{ input_shape[0] }};
|
|
79
|
+
const idx_t seq_len = {{ input_shape[1] }};
|
|
80
|
+
for (idx_t b = 0; b < batch; ++b) {
|
|
81
|
+
for (idx_t o = 0; o < output_dim; ++o) {
|
|
82
|
+
{{ output }}[b][o] = {{ zero_literal }};
|
|
83
|
+
}
|
|
84
|
+
idx_t ngram_index_offset = 0;
|
|
85
|
+
for (idx_t gram_len = 1; gram_len < min_gram; ++gram_len) {
|
|
86
|
+
const idx_t count_start = (idx_t)ngram_counts[gram_len - 1];
|
|
87
|
+
const idx_t count_end =
|
|
88
|
+
gram_len < ngram_counts_len ? (idx_t)ngram_counts[gram_len] : pool_size;
|
|
89
|
+
const idx_t num_ngrams = (count_end - count_start) / gram_len;
|
|
90
|
+
ngram_index_offset += num_ngrams;
|
|
91
|
+
}
|
|
92
|
+
for (idx_t gram_len = min_gram; gram_len <= max_gram; ++gram_len) {
|
|
93
|
+
const idx_t count_start = (idx_t)ngram_counts[gram_len - 1];
|
|
94
|
+
const idx_t count_end =
|
|
95
|
+
gram_len < ngram_counts_len ? (idx_t)ngram_counts[gram_len] : pool_size;
|
|
96
|
+
const idx_t num_ngrams = (count_end - count_start) / gram_len;
|
|
97
|
+
if (num_ngrams == 0) {
|
|
98
|
+
continue;
|
|
99
|
+
}
|
|
100
|
+
const idx_t skip_limit = gram_len == 1 ? 0 : max_skip;
|
|
101
|
+
for (idx_t skip = 0; skip <= skip_limit; ++skip) {
|
|
102
|
+
const idx_t stride = skip + 1;
|
|
103
|
+
if (seq_len < (gram_len - 1) * stride + 1) {
|
|
104
|
+
continue;
|
|
105
|
+
}
|
|
106
|
+
const idx_t max_start = seq_len - (gram_len - 1) * stride;
|
|
107
|
+
for (idx_t start = 0; start < max_start; ++start) {
|
|
108
|
+
for (idx_t ngram_idx = 0; ngram_idx < num_ngrams; ++ngram_idx) {
|
|
109
|
+
const idx_t pool_offset = count_start + ngram_idx * gram_len;
|
|
110
|
+
int match = 1;
|
|
111
|
+
for (idx_t pos = 0; pos < gram_len; ++pos) {
|
|
112
|
+
const int64_t token = (int64_t){{ input0 }}[b][start + pos * stride];
|
|
113
|
+
if (token != pool[pool_offset + pos]) {
|
|
114
|
+
match = 0;
|
|
115
|
+
break;
|
|
116
|
+
}
|
|
117
|
+
}
|
|
118
|
+
if (match) {
|
|
119
|
+
const idx_t out_index =
|
|
120
|
+
(idx_t)ngram_indexes[ngram_index_offset + ngram_idx];
|
|
121
|
+
{{ output }}[b][out_index] += ({{ c_type }})1;
|
|
122
|
+
}
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
ngram_index_offset += num_ngrams;
|
|
127
|
+
}
|
|
128
|
+
{% if mode_id != 0 %}
|
|
129
|
+
for (idx_t o = 0; o < output_dim; ++o) {
|
|
130
|
+
{{ c_type }} value = {{ output }}[b][o];
|
|
131
|
+
{% if mode_id == 1 %}
|
|
132
|
+
value = value > {{ zero_literal }} ? {{ one_literal }} : {{ zero_literal }};
|
|
133
|
+
{% endif %}
|
|
134
|
+
{% if weights_values %}
|
|
135
|
+
value *= weights[o];
|
|
136
|
+
{% else %}
|
|
137
|
+
value *= {{ one_literal }};
|
|
138
|
+
{% endif %}
|
|
139
|
+
{{ output }}[b][o] = value;
|
|
140
|
+
}
|
|
141
|
+
{% endif %}
|
|
142
|
+
}
|
|
143
|
+
{% endif %}
|
|
144
|
+
}
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
const {{ c_type }} *input_data = (const {{ c_type }} *){{ input0 }};
|
|
3
|
+
{{ c_type }} *output_data = ({{ c_type }} *){{ output }};
|
|
4
|
+
idx_t output_index = 0;
|
|
5
|
+
{% for dim in output_shape %}
|
|
6
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
7
|
+
{% endfor %}
|
|
8
|
+
idx_t input_index = {{ input_index_expr }};
|
|
9
|
+
output_data[output_index] = input_data[input_index];
|
|
10
|
+
output_index++;
|
|
11
|
+
{% for _ in output_shape %}
|
|
12
|
+
}
|
|
13
|
+
{% endfor %}
|
|
14
|
+
}
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
static inline int {{ op_name }}_better({{ input_c_type }} a, {{ output_indices_c_type }} ai, {{ input_c_type }} b, {{ output_indices_c_type }} bi) {
|
|
2
|
+
return {{ compare_expr }};
|
|
3
|
+
}
|
|
4
|
+
|
|
5
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
6
|
+
{% for dim in outer_shape %}
|
|
7
|
+
for (idx_t {{ outer_loop_vars[loop.index0] }} = 0; {{ outer_loop_vars[loop.index0] }} < {{ dim }}; ++{{ outer_loop_vars[loop.index0] }}) {
|
|
8
|
+
{% endfor %}
|
|
9
|
+
{{ input_c_type }} best_values[{{ k }}];
|
|
10
|
+
{{ output_indices_c_type }} best_indices[{{ k }}];
|
|
11
|
+
for (idx_t {{ reduce_var }} = 0; {{ reduce_var }} < {{ k }}; ++{{ reduce_var }}) {
|
|
12
|
+
best_values[{{ reduce_var }}] = {{ input0 }}{{ input_index_expr }};
|
|
13
|
+
best_indices[{{ reduce_var }}] = ({{ output_indices_c_type }}){{ reduce_var }};
|
|
14
|
+
}
|
|
15
|
+
for (idx_t i = 1; i < {{ k }}; ++i) {
|
|
16
|
+
idx_t j = i;
|
|
17
|
+
while (j > 0
|
|
18
|
+
&& {{ op_name }}_better(best_values[j], best_indices[j], best_values[j - 1], best_indices[j - 1])) {
|
|
19
|
+
{{ input_c_type }} temp_value = best_values[j - 1];
|
|
20
|
+
{{ output_indices_c_type }} temp_index = best_indices[j - 1];
|
|
21
|
+
best_values[j - 1] = best_values[j];
|
|
22
|
+
best_indices[j - 1] = best_indices[j];
|
|
23
|
+
best_values[j] = temp_value;
|
|
24
|
+
best_indices[j] = temp_index;
|
|
25
|
+
--j;
|
|
26
|
+
}
|
|
27
|
+
}
|
|
28
|
+
for (idx_t {{ reduce_var }} = {{ k }}; {{ reduce_var }} < {{ axis_dim }}; ++{{ reduce_var }}) {
|
|
29
|
+
{{ input_c_type }} candidate = {{ input0 }}{{ input_index_expr }};
|
|
30
|
+
{{ output_indices_c_type }} candidate_index = ({{ output_indices_c_type }}){{ reduce_var }};
|
|
31
|
+
if ({{ op_name }}_better(candidate, candidate_index, best_values[{{ k - 1 }}], best_indices[{{ k - 1 }}])) {
|
|
32
|
+
idx_t pos = {{ k - 1 }};
|
|
33
|
+
while (pos > 0
|
|
34
|
+
&& {{ op_name }}_better(candidate, candidate_index, best_values[pos - 1], best_indices[pos - 1])) {
|
|
35
|
+
best_values[pos] = best_values[pos - 1];
|
|
36
|
+
best_indices[pos] = best_indices[pos - 1];
|
|
37
|
+
--pos;
|
|
38
|
+
}
|
|
39
|
+
best_values[pos] = candidate;
|
|
40
|
+
best_indices[pos] = candidate_index;
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
for (idx_t {{ k_var }} = 0; {{ k_var }} < {{ k }}; ++{{ k_var }}) {
|
|
44
|
+
{{ output_values }}{{ output_index_expr }} = best_values[{{ k_var }}];
|
|
45
|
+
{{ output_indices }}{{ output_index_expr }} = best_indices[{{ k_var }}];
|
|
46
|
+
}
|
|
47
|
+
{% for _ in outer_shape %}
|
|
48
|
+
}
|
|
49
|
+
{% endfor %}
|
|
50
|
+
}
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in output_shape %}
|
|
3
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = {{ input0 }}{% for var in input_indices %}[{{ var }}]{% endfor %};
|
|
6
|
+
{% for _ in output_shape %}
|
|
7
|
+
}
|
|
8
|
+
{% endfor %}
|
|
9
|
+
}
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
const {{ c_type }} *input_data = (const {{ c_type }} *){{ input0 }};
|
|
3
|
+
{{ c_type }} *output_data = ({{ c_type }} *){{ output }};
|
|
4
|
+
int64_t k = {{ k_value }};
|
|
5
|
+
{% if k_input %}
|
|
6
|
+
const {{ k_c_type }} *k_data = (const {{ k_c_type }} *){{ k_input }};
|
|
7
|
+
k = (int64_t)k_data[0];
|
|
8
|
+
{% endif %}
|
|
9
|
+
idx_t rows = {{ rows }};
|
|
10
|
+
idx_t cols = {{ cols }};
|
|
11
|
+
idx_t batch_size = {{ batch_size }};
|
|
12
|
+
for (idx_t batch = 0; batch < batch_size; ++batch) {
|
|
13
|
+
idx_t base = batch * rows * cols;
|
|
14
|
+
for (idx_t row = 0; row < rows; ++row) {
|
|
15
|
+
for (idx_t col = 0; col < cols; ++col) {
|
|
16
|
+
idx_t offset = base + row * cols + col;
|
|
17
|
+
{% if upper %}
|
|
18
|
+
if ((int64_t)col - (int64_t)row >= k) {
|
|
19
|
+
output_data[offset] = input_data[offset];
|
|
20
|
+
} else {
|
|
21
|
+
output_data[offset] = {{ zero_literal }};
|
|
22
|
+
}
|
|
23
|
+
{% else %}
|
|
24
|
+
if ((int64_t)row - (int64_t)col >= -k) {
|
|
25
|
+
output_data[offset] = input_data[offset];
|
|
26
|
+
} else {
|
|
27
|
+
output_data[offset] = {{ zero_literal }};
|
|
28
|
+
}
|
|
29
|
+
{% endif %}
|
|
30
|
+
}
|
|
31
|
+
}
|
|
32
|
+
}
|
|
33
|
+
}
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in shape %}
|
|
3
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
{% if operator == "relu" %}
|
|
6
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = {{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %} > {{ zero_literal }} ? {{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %} : {{ zero_literal }};
|
|
7
|
+
{% elif operator == "neg" %}
|
|
8
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = -{{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %};
|
|
9
|
+
{% elif operator == "identity" %}
|
|
10
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = {{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %};
|
|
11
|
+
{% elif operator == "zero" %}
|
|
12
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = {{ zero_literal }};
|
|
13
|
+
{% elif operator == "isneginf" %}
|
|
14
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = isinf({{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %}) && signbit({{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %});
|
|
15
|
+
{% elif operator == "isposinf" %}
|
|
16
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = isinf({{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %}) && !signbit({{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %});
|
|
17
|
+
{% else %}
|
|
18
|
+
{{ output }}{% for var in loop_vars %}[{{ var }}]{% endfor %} = {{ operator }}({{ input0 }}{% for var in loop_vars %}[{{ var }}]{% endfor %});
|
|
19
|
+
{% endif %}
|
|
20
|
+
{% for _ in shape %}
|
|
21
|
+
}
|
|
22
|
+
{% endfor %}
|
|
23
|
+
}
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
static inline void {{ op_name }}({{ dim_args }}{{ params | join(', ') }}) {
|
|
2
|
+
{% for dim in output_shape %}
|
|
3
|
+
for (idx_t {{ loop_vars[loop.index0] }} = 0; {{ loop_vars[loop.index0] }} < {{ dim }}; ++{{ loop_vars[loop.index0] }}) {
|
|
4
|
+
{% endfor %}
|
|
5
|
+
{{ output_expr }} = {{ condition_expr }} ? {{ x_expr }} : {{ y_expr }};
|
|
6
|
+
{% for _ in output_shape %}
|
|
7
|
+
}
|
|
8
|
+
{% endfor %}
|
|
9
|
+
}
|
emx_onnx_cgen/verification.py
CHANGED
|
@@ -1,29 +1,69 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from typing import TypeAlias
|
|
4
|
+
|
|
3
5
|
import numpy as np
|
|
4
6
|
|
|
5
7
|
from shared.ulp import ulp_intdiff_float
|
|
6
8
|
|
|
9
|
+
WorstUlpDiff: TypeAlias = tuple[tuple[int, ...], float, float]
|
|
10
|
+
|
|
7
11
|
|
|
8
|
-
def
|
|
12
|
+
def _validate_ulp_inputs(
|
|
13
|
+
actual: np.ndarray, expected: np.ndarray
|
|
14
|
+
) -> np.dtype | None:
|
|
9
15
|
if actual.shape != expected.shape:
|
|
10
16
|
raise ValueError(
|
|
11
17
|
f"Shape mismatch for ULP calculation: {actual.shape} vs {expected.shape}"
|
|
12
18
|
)
|
|
13
19
|
if not np.issubdtype(expected.dtype, np.floating):
|
|
14
|
-
return
|
|
20
|
+
return None
|
|
15
21
|
dtype = expected.dtype
|
|
16
22
|
if dtype not in (np.float16, np.float32, np.float64):
|
|
17
23
|
raise ValueError(f"Unsupported floating dtype for ULP calculation: {dtype}")
|
|
24
|
+
return dtype
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def worst_ulp_diff(
|
|
28
|
+
actual: np.ndarray,
|
|
29
|
+
expected: np.ndarray,
|
|
30
|
+
*,
|
|
31
|
+
atol_eps: float = 1.0,
|
|
32
|
+
) -> tuple[int, WorstUlpDiff | None]:
|
|
33
|
+
dtype = _validate_ulp_inputs(actual, expected)
|
|
34
|
+
if dtype is None:
|
|
35
|
+
return 0, None
|
|
36
|
+
if actual.size == 0:
|
|
37
|
+
return 0, None
|
|
18
38
|
actual_cast = actual.astype(dtype, copy=False)
|
|
19
39
|
expected_cast = expected.astype(dtype, copy=False)
|
|
40
|
+
abs_tol = np.finfo(dtype).eps * atol_eps
|
|
20
41
|
max_diff = 0
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
42
|
+
worst: WorstUlpDiff | None = None
|
|
43
|
+
iterator = np.nditer(
|
|
44
|
+
[actual_cast, expected_cast], flags=["refs_ok", "multi_index"]
|
|
45
|
+
)
|
|
46
|
+
for actual_value, expected_value in iterator:
|
|
47
|
+
if (
|
|
48
|
+
abs(float(actual_value[()]) - float(expected_value[()]))
|
|
49
|
+
<= abs_tol
|
|
50
|
+
):
|
|
51
|
+
continue
|
|
24
52
|
diff = ulp_intdiff_float(actual_value[()], expected_value[()])
|
|
25
53
|
if diff > max_diff:
|
|
26
54
|
max_diff = diff
|
|
55
|
+
worst = (
|
|
56
|
+
iterator.multi_index,
|
|
57
|
+
float(actual_value[()]),
|
|
58
|
+
float(expected_value[()]),
|
|
59
|
+
)
|
|
60
|
+
return max_diff, worst
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def max_ulp_diff(
|
|
64
|
+
actual: np.ndarray, expected: np.ndarray, *, atol_eps: float = 1.0
|
|
65
|
+
) -> int:
|
|
66
|
+
max_diff, _ = worst_ulp_diff(actual, expected, atol_eps=atol_eps)
|
|
27
67
|
return max_diff
|
|
28
68
|
|
|
29
69
|
|