numkong 7.5.0 → 7.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/binding.gyp +18 -0
- package/c/dispatch_e5m2.c +23 -3
- package/include/numkong/capabilities.h +1 -1
- package/include/numkong/cast/README.md +3 -0
- package/include/numkong/cast/haswell.h +28 -64
- package/include/numkong/cast/serial.h +17 -0
- package/include/numkong/cast/skylake.h +67 -52
- package/include/numkong/cast.h +1 -0
- package/include/numkong/dot/README.md +1 -0
- package/include/numkong/dot/haswell.h +92 -13
- package/include/numkong/dot/serial.h +15 -0
- package/include/numkong/dot/skylake.h +61 -14
- package/include/numkong/dots/README.md +2 -0
- package/include/numkong/dots/graniteamx.h +434 -0
- package/include/numkong/dots/haswell.h +28 -28
- package/include/numkong/dots/sapphireamx.h +1 -1
- package/include/numkong/dots/serial.h +23 -8
- package/include/numkong/dots/skylake.h +28 -23
- package/include/numkong/dots.h +12 -0
- package/include/numkong/each/serial.h +18 -1
- package/include/numkong/geospatial/serial.h +14 -3
- package/include/numkong/maxsim/serial.h +15 -0
- package/include/numkong/mesh/README.md +50 -44
- package/include/numkong/mesh/genoa.h +462 -0
- package/include/numkong/mesh/haswell.h +806 -933
- package/include/numkong/mesh/neon.h +871 -943
- package/include/numkong/mesh/neonbfdot.h +382 -522
- package/include/numkong/mesh/neonfhm.h +676 -0
- package/include/numkong/mesh/rvv.h +404 -319
- package/include/numkong/mesh/serial.h +204 -162
- package/include/numkong/mesh/skylake.h +1029 -1585
- package/include/numkong/mesh/v128relaxed.h +403 -377
- package/include/numkong/mesh.h +38 -0
- package/include/numkong/reduce/serial.h +15 -1
- package/include/numkong/sparse/serial.h +17 -2
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +98 -56
- package/include/numkong/spatial/serial.h +15 -0
- package/include/numkong/spatial/skylake.h +114 -54
- package/include/numkong/spatial.h +0 -12
- package/include/numkong/spatials/graniteamx.h +128 -0
- package/include/numkong/spatials/serial.h +18 -1
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials.h +17 -0
- package/include/numkong/tensor.hpp +107 -23
- package/javascript/numkong.c +3 -2
- package/package.json +7 -7
- package/wasm/numkong.wasm +0 -0
|
@@ -0,0 +1,462 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Point Cloud Alignment for Genoa (AVX-512-BF16).
|
|
3
|
+
* @file include/numkong/mesh/genoa.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 28, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/mesh.h
|
|
8
|
+
*
|
|
9
|
+
* @section genoa_mesh_instructions Key AVX-512 BF16 Mesh Instructions
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Genoa Sapphire
|
|
12
|
+
* _mm512_dpbf16_ps VDPBF16PS (ZMM, ZMM, ZMM) 6cy @ p01 6cy @ p05
|
|
13
|
+
* _mm512_permutexvar_epi16 VPERMW (ZMM, ZMM, ZMM) 3cy @ p5 6cy @ p5
|
|
14
|
+
* _mm512_maskz_loadu_epi16 VMOVDQU16 (ZMM{k}, M) 9cy @ L1 9cy @ L1
|
|
15
|
+
*
|
|
16
|
+
* The bf16 mesh kernels use a 15-lane channel-grouped layout: 10 xyz triplets per ZMM (30 bf16
|
|
17
|
+
* values laid out as [x0..x9, y0..y9, z0..z9, _, _] after a single VPERMW). That maps cleanly
|
|
18
|
+
* onto VDPBF16PS, which pairs adjacent bf16 values per fp32 lane; 5 channel-consecutive pairs
|
|
19
|
+
* give a single H-cell per lane-range. Three product accumulators (a*b, a*rot1(b), a*rot2(b))
|
|
20
|
+
* cover the 9 cross-covariance cells, matching the Skylake structure.
|
|
21
|
+
*/
|
|
22
|
+
#ifndef NK_MESH_GENOA_H
|
|
23
|
+
#define NK_MESH_GENOA_H
|
|
24
|
+
|
|
25
|
+
#if NK_TARGET_X8664_
|
|
26
|
+
#if NK_TARGET_GENOA
|
|
27
|
+
|
|
28
|
+
#include "numkong/types.h"
|
|
29
|
+
#include "numkong/mesh/serial.h"
|
|
30
|
+
#include "numkong/spatial/haswell.h" // `nk_f32_sqrt_haswell`
|
|
31
|
+
|
|
32
|
+
#if defined(__cplusplus)
|
|
33
|
+
extern "C" {
|
|
34
|
+
#endif
|
|
35
|
+
|
|
36
|
+
#if defined(__clang__)
|
|
37
|
+
#pragma clang attribute push( \
|
|
38
|
+
__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512bf16,f16c,fma,bmi,bmi2"))), \
|
|
39
|
+
apply_to = function)
|
|
40
|
+
#elif defined(__GNUC__)
|
|
41
|
+
#pragma GCC push_options
|
|
42
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512bf16", "f16c", "fma", "bmi", "bmi2")
|
|
43
|
+
#endif
|
|
44
|
+
|
|
45
|
+
NK_PUBLIC void nk_rmsd_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
46
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
47
|
+
if (rotation)
|
|
48
|
+
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
|
|
49
|
+
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
50
|
+
if (scale) *scale = 1.0f;
|
|
51
|
+
if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
|
|
52
|
+
if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
|
|
53
|
+
|
|
54
|
+
// 32-lane bf16 chunks = 10 triplets + 2 padding bf16 per register.
|
|
55
|
+
// VDPBF16PS pairs adjacent bf16 per fp32 lane: lane[i] += a[2i]*b[2i] + a[2i+1]*b[2i+1].
|
|
56
|
+
// For RMSD we need Σ(a-b)², computed via Σ a² + Σ b² - 2 Σ a·b.
|
|
57
|
+
__m512 const zeros_f32x16 = _mm512_setzero_ps();
|
|
58
|
+
__m512 norm_squared_a_f32x16 = zeros_f32x16;
|
|
59
|
+
__m512 norm_squared_b_f32x16 = zeros_f32x16;
|
|
60
|
+
__m512 cross_product_f32x16 = zeros_f32x16;
|
|
61
|
+
nk_size_t index = 0;
|
|
62
|
+
|
|
63
|
+
__mmask32 const full_mask_bf16 = (__mmask32)0x3FFFFFFF; // 30 bf16 valid, 2 bf16 padding zeros
|
|
64
|
+
|
|
65
|
+
for (; index + 10 <= n; index += 10) {
|
|
66
|
+
__m512i a_bf16x32 = _mm512_maskz_loadu_epi16(full_mask_bf16, (__m512i const *)(a + index * 3));
|
|
67
|
+
__m512i b_bf16x32 = _mm512_maskz_loadu_epi16(full_mask_bf16, (__m512i const *)(b + index * 3));
|
|
68
|
+
norm_squared_a_f32x16 = _mm512_dpbf16_ps(norm_squared_a_f32x16, nk_m512bh_from_m512i_(a_bf16x32),
|
|
69
|
+
nk_m512bh_from_m512i_(a_bf16x32));
|
|
70
|
+
norm_squared_b_f32x16 = _mm512_dpbf16_ps(norm_squared_b_f32x16, nk_m512bh_from_m512i_(b_bf16x32),
|
|
71
|
+
nk_m512bh_from_m512i_(b_bf16x32));
|
|
72
|
+
cross_product_f32x16 = _mm512_dpbf16_ps(cross_product_f32x16, nk_m512bh_from_m512i_(a_bf16x32),
|
|
73
|
+
nk_m512bh_from_m512i_(b_bf16x32));
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
if (index < n) {
|
|
77
|
+
__mmask32 tail_mask = (__mmask32)_bzhi_u32(0x3FFFFFFF, (nk_u32_t)((n - index) * 3));
|
|
78
|
+
__m512i a_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, (__m512i const *)(a + index * 3));
|
|
79
|
+
__m512i b_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, (__m512i const *)(b + index * 3));
|
|
80
|
+
norm_squared_a_f32x16 = _mm512_dpbf16_ps(norm_squared_a_f32x16, nk_m512bh_from_m512i_(a_bf16x32),
|
|
81
|
+
nk_m512bh_from_m512i_(a_bf16x32));
|
|
82
|
+
norm_squared_b_f32x16 = _mm512_dpbf16_ps(norm_squared_b_f32x16, nk_m512bh_from_m512i_(b_bf16x32),
|
|
83
|
+
nk_m512bh_from_m512i_(b_bf16x32));
|
|
84
|
+
cross_product_f32x16 = _mm512_dpbf16_ps(cross_product_f32x16, nk_m512bh_from_m512i_(a_bf16x32),
|
|
85
|
+
nk_m512bh_from_m512i_(b_bf16x32));
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
nk_f32_t norm_squared_a = _mm512_reduce_add_ps(norm_squared_a_f32x16);
|
|
89
|
+
nk_f32_t norm_squared_b = _mm512_reduce_add_ps(norm_squared_b_f32x16);
|
|
90
|
+
nk_f32_t cross_product = _mm512_reduce_add_ps(cross_product_f32x16);
|
|
91
|
+
nk_f32_t sum_squared = norm_squared_a + norm_squared_b - 2.0f * cross_product;
|
|
92
|
+
if (sum_squared < 0.0f) sum_squared = 0.0f;
|
|
93
|
+
*result = nk_f32_sqrt_haswell(sum_squared / (nk_f32_t)n);
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
// Channel-grouping permute: 10 xyz triplets + 2 padding bf16 → [x0..x9, y0..y9, z0..z9, _, _].
|
|
97
|
+
// After VPERMW lanes 0..4 carry the x-channel (2 bf16 per fp32 lane), 5..9 carry y, 10..14 carry z.
|
|
98
|
+
#define NK_MESH_GENOA_CHANNEL_GROUP_INDICES_ \
|
|
99
|
+
_mm512_set_epi16(31, 30, 29, 26, 23, 20, 17, 14, 11, 8, 5, 2, 28, 25, 22, 19, 16, 13, 10, 7, 4, 1, 27, 24, 21, 18, \
|
|
100
|
+
15, 12, 9, 6, 3, 0)
|
|
101
|
+
|
|
102
|
+
// Rotation-1 applied during channel-grouping: each channel slot carries the *next* channel of b.
|
|
103
|
+
// x-slot gets b.y, y-slot gets b.z, z-slot gets b.x. Pairs covariance cells (xy, yz, zx).
|
|
104
|
+
#define NK_MESH_GENOA_ROTATION_1_INDICES_ \
|
|
105
|
+
_mm512_set_epi16(31, 30, 27, 24, 21, 18, 15, 12, 9, 6, 3, 0, 29, 26, 23, 20, 17, 14, 11, 8, 5, 2, 28, 25, 22, 19, \
|
|
106
|
+
16, 13, 10, 7, 4, 1)
|
|
107
|
+
|
|
108
|
+
// Rotation-2: x-slot gets b.z, y-slot gets b.x, z-slot gets b.y. Pairs covariance cells (xz, yx, zy).
|
|
109
|
+
#define NK_MESH_GENOA_ROTATION_2_INDICES_ \
|
|
110
|
+
_mm512_set_epi16(31, 30, 28, 25, 22, 19, 16, 13, 10, 7, 4, 1, 27, 24, 21, 18, 15, 12, 9, 6, 3, 0, 29, 26, 23, 20, \
|
|
111
|
+
17, 14, 11, 8, 5, 2)
|
|
112
|
+
|
|
113
|
+
NK_PUBLIC void nk_kabsch_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
114
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
115
|
+
__m512i const idx_channel_group_i16x32 = NK_MESH_GENOA_CHANNEL_GROUP_INDICES_;
|
|
116
|
+
__m512i const idx_rotation_1_i16x32 = NK_MESH_GENOA_ROTATION_1_INDICES_;
|
|
117
|
+
__m512i const idx_rotation_2_i16x32 = NK_MESH_GENOA_ROTATION_2_INDICES_;
|
|
118
|
+
__m512i const ones_bf16x32 = _mm512_set1_epi16(0x3F80); // bf16 representation of 1.0
|
|
119
|
+
|
|
120
|
+
__m512 const zeros_f32x16 = _mm512_setzero_ps();
|
|
121
|
+
__m512 sum_a_f32x16 = zeros_f32x16, sum_b_f32x16 = zeros_f32x16;
|
|
122
|
+
__m512 norm_squared_a_f32x16 = zeros_f32x16, norm_squared_b_f32x16 = zeros_f32x16;
|
|
123
|
+
__m512 product_diagonal_f32x16 = zeros_f32x16;
|
|
124
|
+
__m512 product_rotation_1_f32x16 = zeros_f32x16;
|
|
125
|
+
__m512 product_rotation_2_f32x16 = zeros_f32x16;
|
|
126
|
+
|
|
127
|
+
__mmask32 const full_mask_bf16 = (__mmask32)0x3FFFFFFF;
|
|
128
|
+
|
|
129
|
+
nk_size_t index = 0;
|
|
130
|
+
for (; index + 10 <= n; index += 10) {
|
|
131
|
+
__m512i a_raw_bf16x32 = _mm512_maskz_loadu_epi16(full_mask_bf16, (__m512i const *)(a + index * 3));
|
|
132
|
+
__m512i b_raw_bf16x32 = _mm512_maskz_loadu_epi16(full_mask_bf16, (__m512i const *)(b + index * 3));
|
|
133
|
+
__m512i a_grouped_bf16x32 = _mm512_permutexvar_epi16(idx_channel_group_i16x32, a_raw_bf16x32);
|
|
134
|
+
__m512i b_grouped_bf16x32 = _mm512_permutexvar_epi16(idx_channel_group_i16x32, b_raw_bf16x32);
|
|
135
|
+
__m512i b_rotation_1_bf16x32 = _mm512_permutexvar_epi16(idx_rotation_1_i16x32, b_raw_bf16x32);
|
|
136
|
+
__m512i b_rotation_2_bf16x32 = _mm512_permutexvar_epi16(idx_rotation_2_i16x32, b_raw_bf16x32);
|
|
137
|
+
|
|
138
|
+
sum_a_f32x16 = _mm512_dpbf16_ps(sum_a_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
|
|
139
|
+
nk_m512bh_from_m512i_(ones_bf16x32));
|
|
140
|
+
sum_b_f32x16 = _mm512_dpbf16_ps(sum_b_f32x16, nk_m512bh_from_m512i_(b_grouped_bf16x32),
|
|
141
|
+
nk_m512bh_from_m512i_(ones_bf16x32));
|
|
142
|
+
norm_squared_a_f32x16 = _mm512_dpbf16_ps(norm_squared_a_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
|
|
143
|
+
nk_m512bh_from_m512i_(a_grouped_bf16x32));
|
|
144
|
+
norm_squared_b_f32x16 = _mm512_dpbf16_ps(norm_squared_b_f32x16, nk_m512bh_from_m512i_(b_grouped_bf16x32),
|
|
145
|
+
nk_m512bh_from_m512i_(b_grouped_bf16x32));
|
|
146
|
+
product_diagonal_f32x16 = _mm512_dpbf16_ps(product_diagonal_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
|
|
147
|
+
nk_m512bh_from_m512i_(b_grouped_bf16x32));
|
|
148
|
+
product_rotation_1_f32x16 = _mm512_dpbf16_ps(product_rotation_1_f32x16,
|
|
149
|
+
nk_m512bh_from_m512i_(a_grouped_bf16x32),
|
|
150
|
+
nk_m512bh_from_m512i_(b_rotation_1_bf16x32));
|
|
151
|
+
product_rotation_2_f32x16 = _mm512_dpbf16_ps(product_rotation_2_f32x16,
|
|
152
|
+
nk_m512bh_from_m512i_(a_grouped_bf16x32),
|
|
153
|
+
nk_m512bh_from_m512i_(b_rotation_2_bf16x32));
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
if (index < n) {
|
|
157
|
+
__mmask32 tail_mask = (__mmask32)_bzhi_u32(0x3FFFFFFF, (nk_u32_t)((n - index) * 3));
|
|
158
|
+
__m512i a_raw_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, (__m512i const *)(a + index * 3));
|
|
159
|
+
__m512i b_raw_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, (__m512i const *)(b + index * 3));
|
|
160
|
+
__m512i a_grouped_bf16x32 = _mm512_permutexvar_epi16(idx_channel_group_i16x32, a_raw_bf16x32);
|
|
161
|
+
__m512i b_grouped_bf16x32 = _mm512_permutexvar_epi16(idx_channel_group_i16x32, b_raw_bf16x32);
|
|
162
|
+
__m512i b_rotation_1_bf16x32 = _mm512_permutexvar_epi16(idx_rotation_1_i16x32, b_raw_bf16x32);
|
|
163
|
+
__m512i b_rotation_2_bf16x32 = _mm512_permutexvar_epi16(idx_rotation_2_i16x32, b_raw_bf16x32);
|
|
164
|
+
|
|
165
|
+
sum_a_f32x16 = _mm512_dpbf16_ps(sum_a_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
|
|
166
|
+
nk_m512bh_from_m512i_(ones_bf16x32));
|
|
167
|
+
sum_b_f32x16 = _mm512_dpbf16_ps(sum_b_f32x16, nk_m512bh_from_m512i_(b_grouped_bf16x32),
|
|
168
|
+
nk_m512bh_from_m512i_(ones_bf16x32));
|
|
169
|
+
norm_squared_a_f32x16 = _mm512_dpbf16_ps(norm_squared_a_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
|
|
170
|
+
nk_m512bh_from_m512i_(a_grouped_bf16x32));
|
|
171
|
+
norm_squared_b_f32x16 = _mm512_dpbf16_ps(norm_squared_b_f32x16, nk_m512bh_from_m512i_(b_grouped_bf16x32),
|
|
172
|
+
nk_m512bh_from_m512i_(b_grouped_bf16x32));
|
|
173
|
+
product_diagonal_f32x16 = _mm512_dpbf16_ps(product_diagonal_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
|
|
174
|
+
nk_m512bh_from_m512i_(b_grouped_bf16x32));
|
|
175
|
+
product_rotation_1_f32x16 = _mm512_dpbf16_ps(product_rotation_1_f32x16,
|
|
176
|
+
nk_m512bh_from_m512i_(a_grouped_bf16x32),
|
|
177
|
+
nk_m512bh_from_m512i_(b_rotation_1_bf16x32));
|
|
178
|
+
product_rotation_2_f32x16 = _mm512_dpbf16_ps(product_rotation_2_f32x16,
|
|
179
|
+
nk_m512bh_from_m512i_(a_grouped_bf16x32),
|
|
180
|
+
nk_m512bh_from_m512i_(b_rotation_2_bf16x32));
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
// Channel demux by lane range (x=0..4, y=5..9, z=10..14, lane 15 padding).
|
|
184
|
+
__mmask16 const mask_channel_x_f32 = 0x001F;
|
|
185
|
+
__mmask16 const mask_channel_y_f32 = 0x03E0;
|
|
186
|
+
__mmask16 const mask_channel_z_f32 = 0x7C00;
|
|
187
|
+
|
|
188
|
+
nk_f32_t sum_a_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_a_f32x16);
|
|
189
|
+
nk_f32_t sum_a_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_a_f32x16);
|
|
190
|
+
nk_f32_t sum_a_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_a_f32x16);
|
|
191
|
+
nk_f32_t sum_b_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_b_f32x16);
|
|
192
|
+
nk_f32_t sum_b_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_b_f32x16);
|
|
193
|
+
nk_f32_t sum_b_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_b_f32x16);
|
|
194
|
+
nk_f32_t norm_squared_a = _mm512_reduce_add_ps(norm_squared_a_f32x16);
|
|
195
|
+
nk_f32_t norm_squared_b = _mm512_reduce_add_ps(norm_squared_b_f32x16);
|
|
196
|
+
|
|
197
|
+
nk_f32_t covariance_x_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_diagonal_f32x16);
|
|
198
|
+
nk_f32_t covariance_y_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_diagonal_f32x16);
|
|
199
|
+
nk_f32_t covariance_z_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_diagonal_f32x16);
|
|
200
|
+
nk_f32_t covariance_x_y = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_1_f32x16);
|
|
201
|
+
nk_f32_t covariance_y_z = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_1_f32x16);
|
|
202
|
+
nk_f32_t covariance_z_x = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_1_f32x16);
|
|
203
|
+
nk_f32_t covariance_x_z = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_2_f32x16);
|
|
204
|
+
nk_f32_t covariance_y_x = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_2_f32x16);
|
|
205
|
+
nk_f32_t covariance_z_y = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_2_f32x16);
|
|
206
|
+
|
|
207
|
+
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
208
|
+
nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
|
|
209
|
+
nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
|
|
210
|
+
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
211
|
+
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
212
|
+
|
|
213
|
+
// Parallel-axis correction.
|
|
214
|
+
nk_f32_t cross_covariance[9];
|
|
215
|
+
cross_covariance[0] = covariance_x_x - (nk_f32_t)n * centroid_a_x * centroid_b_x;
|
|
216
|
+
cross_covariance[1] = covariance_x_y - (nk_f32_t)n * centroid_a_x * centroid_b_y;
|
|
217
|
+
cross_covariance[2] = covariance_x_z - (nk_f32_t)n * centroid_a_x * centroid_b_z;
|
|
218
|
+
cross_covariance[3] = covariance_y_x - (nk_f32_t)n * centroid_a_y * centroid_b_x;
|
|
219
|
+
cross_covariance[4] = covariance_y_y - (nk_f32_t)n * centroid_a_y * centroid_b_y;
|
|
220
|
+
cross_covariance[5] = covariance_y_z - (nk_f32_t)n * centroid_a_y * centroid_b_z;
|
|
221
|
+
cross_covariance[6] = covariance_z_x - (nk_f32_t)n * centroid_a_z * centroid_b_x;
|
|
222
|
+
cross_covariance[7] = covariance_z_y - (nk_f32_t)n * centroid_a_z * centroid_b_y;
|
|
223
|
+
cross_covariance[8] = covariance_z_z - (nk_f32_t)n * centroid_a_z * centroid_b_z;
|
|
224
|
+
|
|
225
|
+
nk_f32_t centered_norm_squared_a = norm_squared_a -
|
|
226
|
+
(nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
|
|
227
|
+
centroid_a_z * centroid_a_z);
|
|
228
|
+
nk_f32_t centered_norm_squared_b = norm_squared_b -
|
|
229
|
+
(nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
|
|
230
|
+
centroid_b_z * centroid_b_z);
|
|
231
|
+
if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
|
|
232
|
+
if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
|
|
233
|
+
|
|
234
|
+
// Identity-dominant short-circuit.
|
|
235
|
+
nk_f32_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
|
|
236
|
+
cross_covariance[4] * cross_covariance[4] +
|
|
237
|
+
cross_covariance[8] * cross_covariance[8];
|
|
238
|
+
nk_f32_t covariance_offdiagonal_norm_squared =
|
|
239
|
+
cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
|
|
240
|
+
cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
|
|
241
|
+
cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
|
|
242
|
+
nk_f32_t optimal_rotation[9];
|
|
243
|
+
nk_f32_t trace_rotation_covariance;
|
|
244
|
+
if (covariance_offdiagonal_norm_squared < 1e-12f * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0f &&
|
|
245
|
+
cross_covariance[4] > 0.0f && cross_covariance[8] > 0.0f) {
|
|
246
|
+
optimal_rotation[0] = 1.0f, optimal_rotation[1] = 0.0f, optimal_rotation[2] = 0.0f;
|
|
247
|
+
optimal_rotation[3] = 0.0f, optimal_rotation[4] = 1.0f, optimal_rotation[5] = 0.0f;
|
|
248
|
+
optimal_rotation[6] = 0.0f, optimal_rotation[7] = 0.0f, optimal_rotation[8] = 1.0f;
|
|
249
|
+
trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
|
|
250
|
+
}
|
|
251
|
+
else {
|
|
252
|
+
nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
|
|
253
|
+
nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
|
|
254
|
+
nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
|
|
255
|
+
if (nk_det3x3_f32_(optimal_rotation) < 0) {
|
|
256
|
+
svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
|
|
257
|
+
nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
|
|
258
|
+
}
|
|
259
|
+
trace_rotation_covariance =
|
|
260
|
+
optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
|
|
261
|
+
optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
|
|
262
|
+
optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
|
|
263
|
+
optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
|
|
264
|
+
optimal_rotation[8] * cross_covariance[8];
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
if (rotation)
|
|
268
|
+
for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
|
|
269
|
+
if (scale) *scale = 1.0f;
|
|
270
|
+
|
|
271
|
+
// Folded SSD via trace identity: SSD = ‖a-ā‖² + ‖b-b̄‖² − 2·trace(R · H_centered).
|
|
272
|
+
nk_f32_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0f * trace_rotation_covariance;
|
|
273
|
+
if (sum_squared < 0.0f) sum_squared = 0.0f;
|
|
274
|
+
*result = nk_f32_sqrt_haswell(sum_squared * inv_n);
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
NK_PUBLIC void nk_umeyama_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
278
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
279
|
+
__m512i const idx_channel_group_i16x32 = NK_MESH_GENOA_CHANNEL_GROUP_INDICES_;
|
|
280
|
+
__m512i const idx_rotation_1_i16x32 = NK_MESH_GENOA_ROTATION_1_INDICES_;
|
|
281
|
+
__m512i const idx_rotation_2_i16x32 = NK_MESH_GENOA_ROTATION_2_INDICES_;
|
|
282
|
+
__m512i const ones_bf16x32 = _mm512_set1_epi16(0x3F80);
|
|
283
|
+
|
|
284
|
+
__m512 const zeros_f32x16 = _mm512_setzero_ps();
|
|
285
|
+
__m512 sum_a_f32x16 = zeros_f32x16, sum_b_f32x16 = zeros_f32x16;
|
|
286
|
+
__m512 norm_squared_a_f32x16 = zeros_f32x16, norm_squared_b_f32x16 = zeros_f32x16;
|
|
287
|
+
__m512 product_diagonal_f32x16 = zeros_f32x16;
|
|
288
|
+
__m512 product_rotation_1_f32x16 = zeros_f32x16;
|
|
289
|
+
__m512 product_rotation_2_f32x16 = zeros_f32x16;
|
|
290
|
+
|
|
291
|
+
__mmask32 const full_mask_bf16 = (__mmask32)0x3FFFFFFF;
|
|
292
|
+
|
|
293
|
+
nk_size_t index = 0;
|
|
294
|
+
for (; index + 10 <= n; index += 10) {
|
|
295
|
+
__m512i a_raw_bf16x32 = _mm512_maskz_loadu_epi16(full_mask_bf16, (__m512i const *)(a + index * 3));
|
|
296
|
+
__m512i b_raw_bf16x32 = _mm512_maskz_loadu_epi16(full_mask_bf16, (__m512i const *)(b + index * 3));
|
|
297
|
+
__m512i a_grouped_bf16x32 = _mm512_permutexvar_epi16(idx_channel_group_i16x32, a_raw_bf16x32);
|
|
298
|
+
__m512i b_grouped_bf16x32 = _mm512_permutexvar_epi16(idx_channel_group_i16x32, b_raw_bf16x32);
|
|
299
|
+
__m512i b_rotation_1_bf16x32 = _mm512_permutexvar_epi16(idx_rotation_1_i16x32, b_raw_bf16x32);
|
|
300
|
+
__m512i b_rotation_2_bf16x32 = _mm512_permutexvar_epi16(idx_rotation_2_i16x32, b_raw_bf16x32);
|
|
301
|
+
|
|
302
|
+
sum_a_f32x16 = _mm512_dpbf16_ps(sum_a_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
|
|
303
|
+
nk_m512bh_from_m512i_(ones_bf16x32));
|
|
304
|
+
sum_b_f32x16 = _mm512_dpbf16_ps(sum_b_f32x16, nk_m512bh_from_m512i_(b_grouped_bf16x32),
|
|
305
|
+
nk_m512bh_from_m512i_(ones_bf16x32));
|
|
306
|
+
norm_squared_a_f32x16 = _mm512_dpbf16_ps(norm_squared_a_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
|
|
307
|
+
nk_m512bh_from_m512i_(a_grouped_bf16x32));
|
|
308
|
+
norm_squared_b_f32x16 = _mm512_dpbf16_ps(norm_squared_b_f32x16, nk_m512bh_from_m512i_(b_grouped_bf16x32),
|
|
309
|
+
nk_m512bh_from_m512i_(b_grouped_bf16x32));
|
|
310
|
+
product_diagonal_f32x16 = _mm512_dpbf16_ps(product_diagonal_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
|
|
311
|
+
nk_m512bh_from_m512i_(b_grouped_bf16x32));
|
|
312
|
+
product_rotation_1_f32x16 = _mm512_dpbf16_ps(product_rotation_1_f32x16,
|
|
313
|
+
nk_m512bh_from_m512i_(a_grouped_bf16x32),
|
|
314
|
+
nk_m512bh_from_m512i_(b_rotation_1_bf16x32));
|
|
315
|
+
product_rotation_2_f32x16 = _mm512_dpbf16_ps(product_rotation_2_f32x16,
|
|
316
|
+
nk_m512bh_from_m512i_(a_grouped_bf16x32),
|
|
317
|
+
nk_m512bh_from_m512i_(b_rotation_2_bf16x32));
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
if (index < n) {
|
|
321
|
+
__mmask32 tail_mask = (__mmask32)_bzhi_u32(0x3FFFFFFF, (nk_u32_t)((n - index) * 3));
|
|
322
|
+
__m512i a_raw_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, (__m512i const *)(a + index * 3));
|
|
323
|
+
__m512i b_raw_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, (__m512i const *)(b + index * 3));
|
|
324
|
+
__m512i a_grouped_bf16x32 = _mm512_permutexvar_epi16(idx_channel_group_i16x32, a_raw_bf16x32);
|
|
325
|
+
__m512i b_grouped_bf16x32 = _mm512_permutexvar_epi16(idx_channel_group_i16x32, b_raw_bf16x32);
|
|
326
|
+
__m512i b_rotation_1_bf16x32 = _mm512_permutexvar_epi16(idx_rotation_1_i16x32, b_raw_bf16x32);
|
|
327
|
+
__m512i b_rotation_2_bf16x32 = _mm512_permutexvar_epi16(idx_rotation_2_i16x32, b_raw_bf16x32);
|
|
328
|
+
|
|
329
|
+
sum_a_f32x16 = _mm512_dpbf16_ps(sum_a_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
|
|
330
|
+
nk_m512bh_from_m512i_(ones_bf16x32));
|
|
331
|
+
sum_b_f32x16 = _mm512_dpbf16_ps(sum_b_f32x16, nk_m512bh_from_m512i_(b_grouped_bf16x32),
|
|
332
|
+
nk_m512bh_from_m512i_(ones_bf16x32));
|
|
333
|
+
norm_squared_a_f32x16 = _mm512_dpbf16_ps(norm_squared_a_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
|
|
334
|
+
nk_m512bh_from_m512i_(a_grouped_bf16x32));
|
|
335
|
+
norm_squared_b_f32x16 = _mm512_dpbf16_ps(norm_squared_b_f32x16, nk_m512bh_from_m512i_(b_grouped_bf16x32),
|
|
336
|
+
nk_m512bh_from_m512i_(b_grouped_bf16x32));
|
|
337
|
+
product_diagonal_f32x16 = _mm512_dpbf16_ps(product_diagonal_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
|
|
338
|
+
nk_m512bh_from_m512i_(b_grouped_bf16x32));
|
|
339
|
+
product_rotation_1_f32x16 = _mm512_dpbf16_ps(product_rotation_1_f32x16,
|
|
340
|
+
nk_m512bh_from_m512i_(a_grouped_bf16x32),
|
|
341
|
+
nk_m512bh_from_m512i_(b_rotation_1_bf16x32));
|
|
342
|
+
product_rotation_2_f32x16 = _mm512_dpbf16_ps(product_rotation_2_f32x16,
|
|
343
|
+
nk_m512bh_from_m512i_(a_grouped_bf16x32),
|
|
344
|
+
nk_m512bh_from_m512i_(b_rotation_2_bf16x32));
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
__mmask16 const mask_channel_x_f32 = 0x001F;
|
|
348
|
+
__mmask16 const mask_channel_y_f32 = 0x03E0;
|
|
349
|
+
__mmask16 const mask_channel_z_f32 = 0x7C00;
|
|
350
|
+
|
|
351
|
+
nk_f32_t sum_a_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_a_f32x16);
|
|
352
|
+
nk_f32_t sum_a_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_a_f32x16);
|
|
353
|
+
nk_f32_t sum_a_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_a_f32x16);
|
|
354
|
+
nk_f32_t sum_b_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_b_f32x16);
|
|
355
|
+
nk_f32_t sum_b_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_b_f32x16);
|
|
356
|
+
nk_f32_t sum_b_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_b_f32x16);
|
|
357
|
+
nk_f32_t norm_squared_a = _mm512_reduce_add_ps(norm_squared_a_f32x16);
|
|
358
|
+
nk_f32_t norm_squared_b = _mm512_reduce_add_ps(norm_squared_b_f32x16);
|
|
359
|
+
|
|
360
|
+
nk_f32_t covariance_x_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_diagonal_f32x16);
|
|
361
|
+
nk_f32_t covariance_y_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_diagonal_f32x16);
|
|
362
|
+
nk_f32_t covariance_z_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_diagonal_f32x16);
|
|
363
|
+
nk_f32_t covariance_x_y = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_1_f32x16);
|
|
364
|
+
nk_f32_t covariance_y_z = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_1_f32x16);
|
|
365
|
+
nk_f32_t covariance_z_x = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_1_f32x16);
|
|
366
|
+
nk_f32_t covariance_x_z = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_2_f32x16);
|
|
367
|
+
nk_f32_t covariance_y_x = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_2_f32x16);
|
|
368
|
+
nk_f32_t covariance_z_y = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_2_f32x16);
|
|
369
|
+
|
|
370
|
+
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
371
|
+
nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
|
|
372
|
+
nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
|
|
373
|
+
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
374
|
+
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
375
|
+
|
|
376
|
+
nk_f32_t cross_covariance[9];
|
|
377
|
+
cross_covariance[0] = covariance_x_x - (nk_f32_t)n * centroid_a_x * centroid_b_x;
|
|
378
|
+
cross_covariance[1] = covariance_x_y - (nk_f32_t)n * centroid_a_x * centroid_b_y;
|
|
379
|
+
cross_covariance[2] = covariance_x_z - (nk_f32_t)n * centroid_a_x * centroid_b_z;
|
|
380
|
+
cross_covariance[3] = covariance_y_x - (nk_f32_t)n * centroid_a_y * centroid_b_x;
|
|
381
|
+
cross_covariance[4] = covariance_y_y - (nk_f32_t)n * centroid_a_y * centroid_b_y;
|
|
382
|
+
cross_covariance[5] = covariance_y_z - (nk_f32_t)n * centroid_a_y * centroid_b_z;
|
|
383
|
+
cross_covariance[6] = covariance_z_x - (nk_f32_t)n * centroid_a_z * centroid_b_x;
|
|
384
|
+
cross_covariance[7] = covariance_z_y - (nk_f32_t)n * centroid_a_z * centroid_b_y;
|
|
385
|
+
cross_covariance[8] = covariance_z_z - (nk_f32_t)n * centroid_a_z * centroid_b_z;
|
|
386
|
+
|
|
387
|
+
nk_f32_t centered_norm_squared_a = norm_squared_a -
|
|
388
|
+
(nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
|
|
389
|
+
centroid_a_z * centroid_a_z);
|
|
390
|
+
nk_f32_t centered_norm_squared_b = norm_squared_b -
|
|
391
|
+
(nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
|
|
392
|
+
centroid_b_z * centroid_b_z);
|
|
393
|
+
if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
|
|
394
|
+
if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
|
|
395
|
+
|
|
396
|
+
// Identity-dominant short-circuit.
|
|
397
|
+
nk_f32_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
|
|
398
|
+
cross_covariance[4] * cross_covariance[4] +
|
|
399
|
+
cross_covariance[8] * cross_covariance[8];
|
|
400
|
+
nk_f32_t covariance_offdiagonal_norm_squared =
|
|
401
|
+
cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
|
|
402
|
+
cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
|
|
403
|
+
cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
|
|
404
|
+
nk_f32_t optimal_rotation[9];
|
|
405
|
+
nk_f32_t c;
|
|
406
|
+
nk_f32_t trace_rotation_covariance;
|
|
407
|
+
if (covariance_offdiagonal_norm_squared < 1e-12f * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0f &&
|
|
408
|
+
cross_covariance[4] > 0.0f && cross_covariance[8] > 0.0f) {
|
|
409
|
+
optimal_rotation[0] = 1.0f, optimal_rotation[1] = 0.0f, optimal_rotation[2] = 0.0f;
|
|
410
|
+
optimal_rotation[3] = 0.0f, optimal_rotation[4] = 1.0f, optimal_rotation[5] = 0.0f;
|
|
411
|
+
optimal_rotation[6] = 0.0f, optimal_rotation[7] = 0.0f, optimal_rotation[8] = 1.0f;
|
|
412
|
+
trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
|
|
413
|
+
c = centered_norm_squared_a > 0.0f ? trace_rotation_covariance / centered_norm_squared_a : 0.0f;
|
|
414
|
+
}
|
|
415
|
+
else {
|
|
416
|
+
nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
|
|
417
|
+
nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
|
|
418
|
+
nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
|
|
419
|
+
|
|
420
|
+
// Scale factor: c = trace(D · S) / ‖a-ā‖², with reflection sign via d3.
|
|
421
|
+
nk_f32_t det = nk_det3x3_f32_(optimal_rotation);
|
|
422
|
+
nk_f32_t d3 = det < 0.0f ? -1.0f : 1.0f;
|
|
423
|
+
nk_f32_t trace_ds = nk_sum_three_products_f32_(svd_diagonal[0], 1.0f, svd_diagonal[4], 1.0f, svd_diagonal[8],
|
|
424
|
+
d3);
|
|
425
|
+
c = centered_norm_squared_a > 0.0f ? trace_ds / centered_norm_squared_a : 0.0f;
|
|
426
|
+
|
|
427
|
+
if (det < 0.0f) {
|
|
428
|
+
svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
|
|
429
|
+
nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
|
|
430
|
+
}
|
|
431
|
+
trace_rotation_covariance =
|
|
432
|
+
optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
|
|
433
|
+
optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
|
|
434
|
+
optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
|
|
435
|
+
optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
|
|
436
|
+
optimal_rotation[8] * cross_covariance[8];
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
if (scale) *scale = c;
|
|
440
|
+
if (rotation)
|
|
441
|
+
for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
|
|
442
|
+
|
|
443
|
+
// Folded SSD with scale: c²·‖a-ā‖² + ‖b-b̄‖² − 2c·trace(R · H_centered).
|
|
444
|
+
nk_f32_t sum_squared = c * c * centered_norm_squared_a + centered_norm_squared_b -
|
|
445
|
+
2.0f * c * trace_rotation_covariance;
|
|
446
|
+
if (sum_squared < 0.0f) sum_squared = 0.0f;
|
|
447
|
+
*result = nk_f32_sqrt_haswell(sum_squared * inv_n);
|
|
448
|
+
}
|
|
449
|
+
|
|
450
|
+
#if defined(__clang__)
|
|
451
|
+
#pragma clang attribute pop
|
|
452
|
+
#elif defined(__GNUC__)
|
|
453
|
+
#pragma GCC pop_options
|
|
454
|
+
#endif
|
|
455
|
+
|
|
456
|
+
#if defined(__cplusplus)
|
|
457
|
+
} // extern "C"
|
|
458
|
+
#endif
|
|
459
|
+
|
|
460
|
+
#endif // NK_TARGET_GENOA
|
|
461
|
+
#endif // NK_TARGET_X8664_
|
|
462
|
+
#endif // NK_MESH_GENOA_H
|