numkong 7.5.0 → 7.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/binding.gyp +18 -0
- package/c/dispatch_e5m2.c +23 -3
- package/include/numkong/capabilities.h +1 -1
- package/include/numkong/cast/README.md +3 -0
- package/include/numkong/cast/haswell.h +28 -64
- package/include/numkong/cast/serial.h +17 -0
- package/include/numkong/cast/skylake.h +67 -52
- package/include/numkong/cast.h +1 -0
- package/include/numkong/dot/README.md +1 -0
- package/include/numkong/dot/haswell.h +92 -13
- package/include/numkong/dot/serial.h +15 -0
- package/include/numkong/dot/skylake.h +61 -14
- package/include/numkong/dots/README.md +2 -0
- package/include/numkong/dots/graniteamx.h +434 -0
- package/include/numkong/dots/haswell.h +28 -28
- package/include/numkong/dots/sapphireamx.h +1 -1
- package/include/numkong/dots/serial.h +23 -8
- package/include/numkong/dots/skylake.h +28 -23
- package/include/numkong/dots.h +12 -0
- package/include/numkong/each/serial.h +18 -1
- package/include/numkong/geospatial/serial.h +14 -3
- package/include/numkong/maxsim/serial.h +15 -0
- package/include/numkong/mesh/README.md +50 -44
- package/include/numkong/mesh/genoa.h +462 -0
- package/include/numkong/mesh/haswell.h +806 -933
- package/include/numkong/mesh/neon.h +871 -943
- package/include/numkong/mesh/neonbfdot.h +382 -522
- package/include/numkong/mesh/neonfhm.h +676 -0
- package/include/numkong/mesh/rvv.h +404 -319
- package/include/numkong/mesh/serial.h +204 -162
- package/include/numkong/mesh/skylake.h +1029 -1585
- package/include/numkong/mesh/v128relaxed.h +403 -377
- package/include/numkong/mesh.h +38 -0
- package/include/numkong/reduce/serial.h +15 -1
- package/include/numkong/sparse/serial.h +17 -2
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +98 -56
- package/include/numkong/spatial/serial.h +15 -0
- package/include/numkong/spatial/skylake.h +114 -54
- package/include/numkong/spatial.h +0 -12
- package/include/numkong/spatials/graniteamx.h +128 -0
- package/include/numkong/spatials/serial.h +18 -1
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials.h +17 -0
- package/include/numkong/tensor.hpp +107 -23
- package/javascript/numkong.c +3 -2
- package/package.json +7 -7
- package/wasm/numkong.wasm +0 -0
|
@@ -8,24 +8,28 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section mesh_neonbfdot_instructions ARM NEON BF16 Instructions (ARMv8.6-BF16)
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic Instruction
|
|
12
|
-
*
|
|
13
|
-
*
|
|
14
|
-
*
|
|
15
|
-
*
|
|
16
|
-
*
|
|
17
|
-
*
|
|
18
|
-
*
|
|
19
|
-
*
|
|
11
|
+
* Intrinsic Instruction A76 M5
|
|
12
|
+
* vld3q_u16 LD3 (V.8H x 3) 4cy @ 1p 4cy @ 1p
|
|
13
|
+
* vld3_u16 LD3 (V.4H x 3) 4cy @ 1p 4cy @ 1p
|
|
14
|
+
* vbfdotq_f32 BFDOT (V.4S, V.8H, V.8H) 3cy @ 2p 2cy @ 1p
|
|
15
|
+
* vshll_n_u16 USHLL (V.4S, V.4H, #16) 2cy @ 2p 2cy @ 4p
|
|
16
|
+
* vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy @ 2p 3cy @ 4p
|
|
17
|
+
* vaddq_f32 FADD (V.4S, V.4S, V.4S) 2cy @ 2p 2cy @ 4p
|
|
18
|
+
* vsubq_f32 FSUB (V.4S, V.4S, V.4S) 2cy @ 2p 2cy @ 4p
|
|
19
|
+
* vmulq_f32 FMUL (V.4S, V.4S, V.4S) 3cy @ 2p 3cy @ 4p
|
|
20
|
+
* vdupq_n_f32 DUP (V.4S, scalar) 2cy @ 2p 2cy @ 4p
|
|
21
|
+
* vaddvq_f32 FADDP+FADDP (V.4S) 5cy @ 1p 8cy @ 1p
|
|
20
22
|
*
|
|
21
23
|
* The ARMv8.6-BF16 extension enables BF16 storage with F32 computation for 3D mesh alignment
|
|
22
24
|
* operations. BF16's wider exponent range (matching F32) prevents overflow in geometric calculations
|
|
23
25
|
* while halving memory bandwidth compared to F32.
|
|
24
26
|
*
|
|
25
|
-
* For point cloud registration (
|
|
26
|
-
* operations
|
|
27
|
-
*
|
|
28
|
-
*
|
|
27
|
+
* For point cloud registration (Kabsch, Umeyama), BF16 data is loaded using VLD3 de-interleave
|
|
28
|
+
* operations and processed directly with BFDOT (`vbfdotq_f32`), which computes two BF16 products
|
|
29
|
+
* per 32-bit lane with FP32 accumulation. This skips the explicit bf16→f32 widening that the
|
|
30
|
+
* prior vshll+fmaq approach required, halving the front-end pressure on the covariance/centroid
|
|
31
|
+
* stats pass. RMSD keeps the widen+subtract+fmaq pipeline because it needs the (a - b) difference
|
|
32
|
+
* before squaring, which BFDOT can't express directly.
|
|
29
33
|
*/
|
|
30
34
|
#ifndef NK_MESH_NEONBFDOT_H
|
|
31
35
|
#define NK_MESH_NEONBFDOT_H
|
|
@@ -34,6 +38,7 @@
|
|
|
34
38
|
#if NK_TARGET_NEONBFDOT
|
|
35
39
|
|
|
36
40
|
#include "numkong/types.h"
|
|
41
|
+
#include "numkong/cast/neon.h" // `nk_u16x8_splat_`
|
|
37
42
|
#include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`
|
|
38
43
|
|
|
39
44
|
#if defined(__cplusplus)
|
|
@@ -75,191 +80,6 @@ NK_INTERNAL void nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(nk_bf16_t cons
|
|
|
75
80
|
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_((nk_bf16_t const *)buf, x_out, y_out, z_out);
|
|
76
81
|
}
|
|
77
82
|
|
|
78
|
-
/* Compute sum of squared distances for bf16 data after applying rotation (and optional scale).
|
|
79
|
-
* Loads bf16 data, converts to f32 during processing.
|
|
80
|
-
* Note: rotation matrix r is f32 (from SVD), scale and computation done in f32.
|
|
81
|
-
*/
|
|
82
|
-
NK_INTERNAL nk_f32_t nk_transformed_ssd_bf16_neonbfdot_(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n,
|
|
83
|
-
nk_f32_t const *r, nk_f32_t scale, nk_f32_t centroid_a_x,
|
|
84
|
-
nk_f32_t centroid_a_y, nk_f32_t centroid_a_z,
|
|
85
|
-
nk_f32_t centroid_b_x, nk_f32_t centroid_b_y,
|
|
86
|
-
nk_f32_t centroid_b_z) {
|
|
87
|
-
// Broadcast scaled rotation matrix elements
|
|
88
|
-
float32x4_t scaled_rotation_x_x_f32x4 = vdupq_n_f32(scale * r[0]);
|
|
89
|
-
float32x4_t scaled_rotation_x_y_f32x4 = vdupq_n_f32(scale * r[1]);
|
|
90
|
-
float32x4_t scaled_rotation_x_z_f32x4 = vdupq_n_f32(scale * r[2]);
|
|
91
|
-
float32x4_t scaled_rotation_y_x_f32x4 = vdupq_n_f32(scale * r[3]);
|
|
92
|
-
float32x4_t scaled_rotation_y_y_f32x4 = vdupq_n_f32(scale * r[4]);
|
|
93
|
-
float32x4_t scaled_rotation_y_z_f32x4 = vdupq_n_f32(scale * r[5]);
|
|
94
|
-
float32x4_t scaled_rotation_z_x_f32x4 = vdupq_n_f32(scale * r[6]);
|
|
95
|
-
float32x4_t scaled_rotation_z_y_f32x4 = vdupq_n_f32(scale * r[7]);
|
|
96
|
-
float32x4_t scaled_rotation_z_z_f32x4 = vdupq_n_f32(scale * r[8]);
|
|
97
|
-
|
|
98
|
-
// Broadcast centroids
|
|
99
|
-
float32x4_t centroid_a_x_f32x4 = vdupq_n_f32(centroid_a_x);
|
|
100
|
-
float32x4_t centroid_a_y_f32x4 = vdupq_n_f32(centroid_a_y);
|
|
101
|
-
float32x4_t centroid_a_z_f32x4 = vdupq_n_f32(centroid_a_z);
|
|
102
|
-
float32x4_t centroid_b_x_f32x4 = vdupq_n_f32(centroid_b_x);
|
|
103
|
-
float32x4_t centroid_b_y_f32x4 = vdupq_n_f32(centroid_b_y);
|
|
104
|
-
float32x4_t centroid_b_z_f32x4 = vdupq_n_f32(centroid_b_z);
|
|
105
|
-
|
|
106
|
-
// Two independent accumulators to hide FMA latency
|
|
107
|
-
float32x4_t sum_squared_a_f32x4 = vdupq_n_f32(0);
|
|
108
|
-
float32x4_t sum_squared_b_f32x4 = vdupq_n_f32(0);
|
|
109
|
-
nk_size_t j = 0;
|
|
110
|
-
|
|
111
|
-
// Main loop: process 8 points per iteration (2x unrolled)
|
|
112
|
-
for (; j + 8 <= n; j += 8) {
|
|
113
|
-
// First batch of 4 points
|
|
114
|
-
float32x4_t a1_x_f32x4, a1_y_f32x4, a1_z_f32x4, b1_x_f32x4, b1_y_f32x4, b1_z_f32x4;
|
|
115
|
-
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + j * 3, &a1_x_f32x4, &a1_y_f32x4, &a1_z_f32x4);
|
|
116
|
-
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + j * 3, &b1_x_f32x4, &b1_y_f32x4, &b1_z_f32x4);
|
|
117
|
-
|
|
118
|
-
// Second batch of 4 points
|
|
119
|
-
float32x4_t a2_x_f32x4, a2_y_f32x4, a2_z_f32x4, b2_x_f32x4, b2_y_f32x4, b2_z_f32x4;
|
|
120
|
-
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + (j + 4) * 3, &a2_x_f32x4, &a2_y_f32x4, &a2_z_f32x4);
|
|
121
|
-
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + (j + 4) * 3, &b2_x_f32x4, &b2_y_f32x4, &b2_z_f32x4);
|
|
122
|
-
|
|
123
|
-
// Center first batch
|
|
124
|
-
float32x4_t pa1_x_f32x4 = vsubq_f32(a1_x_f32x4, centroid_a_x_f32x4);
|
|
125
|
-
float32x4_t pa1_y_f32x4 = vsubq_f32(a1_y_f32x4, centroid_a_y_f32x4);
|
|
126
|
-
float32x4_t pa1_z_f32x4 = vsubq_f32(a1_z_f32x4, centroid_a_z_f32x4);
|
|
127
|
-
float32x4_t pb1_x_f32x4 = vsubq_f32(b1_x_f32x4, centroid_b_x_f32x4);
|
|
128
|
-
float32x4_t pb1_y_f32x4 = vsubq_f32(b1_y_f32x4, centroid_b_y_f32x4);
|
|
129
|
-
float32x4_t pb1_z_f32x4 = vsubq_f32(b1_z_f32x4, centroid_b_z_f32x4);
|
|
130
|
-
|
|
131
|
-
// Center second batch
|
|
132
|
-
float32x4_t pa2_x_f32x4 = vsubq_f32(a2_x_f32x4, centroid_a_x_f32x4);
|
|
133
|
-
float32x4_t pa2_y_f32x4 = vsubq_f32(a2_y_f32x4, centroid_a_y_f32x4);
|
|
134
|
-
float32x4_t pa2_z_f32x4 = vsubq_f32(a2_z_f32x4, centroid_a_z_f32x4);
|
|
135
|
-
float32x4_t pb2_x_f32x4 = vsubq_f32(b2_x_f32x4, centroid_b_x_f32x4);
|
|
136
|
-
float32x4_t pb2_y_f32x4 = vsubq_f32(b2_y_f32x4, centroid_b_y_f32x4);
|
|
137
|
-
float32x4_t pb2_z_f32x4 = vsubq_f32(b2_z_f32x4, centroid_b_z_f32x4);
|
|
138
|
-
|
|
139
|
-
// Rotate and scale first batch: ra1 = scale * R * pa1
|
|
140
|
-
float32x4_t ra1_x_f32x4 = vfmaq_f32(
|
|
141
|
-
vfmaq_f32(vmulq_f32(scaled_rotation_x_x_f32x4, pa1_x_f32x4), scaled_rotation_x_y_f32x4, pa1_y_f32x4),
|
|
142
|
-
scaled_rotation_x_z_f32x4, pa1_z_f32x4);
|
|
143
|
-
float32x4_t ra1_y_f32x4 = vfmaq_f32(
|
|
144
|
-
vfmaq_f32(vmulq_f32(scaled_rotation_y_x_f32x4, pa1_x_f32x4), scaled_rotation_y_y_f32x4, pa1_y_f32x4),
|
|
145
|
-
scaled_rotation_y_z_f32x4, pa1_z_f32x4);
|
|
146
|
-
float32x4_t ra1_z_f32x4 = vfmaq_f32(
|
|
147
|
-
vfmaq_f32(vmulq_f32(scaled_rotation_z_x_f32x4, pa1_x_f32x4), scaled_rotation_z_y_f32x4, pa1_y_f32x4),
|
|
148
|
-
scaled_rotation_z_z_f32x4, pa1_z_f32x4);
|
|
149
|
-
|
|
150
|
-
// Rotate and scale second batch: ra2 = scale * R * pa2
|
|
151
|
-
float32x4_t ra2_x_f32x4 = vfmaq_f32(
|
|
152
|
-
vfmaq_f32(vmulq_f32(scaled_rotation_x_x_f32x4, pa2_x_f32x4), scaled_rotation_x_y_f32x4, pa2_y_f32x4),
|
|
153
|
-
scaled_rotation_x_z_f32x4, pa2_z_f32x4);
|
|
154
|
-
float32x4_t ra2_y_f32x4 = vfmaq_f32(
|
|
155
|
-
vfmaq_f32(vmulq_f32(scaled_rotation_y_x_f32x4, pa2_x_f32x4), scaled_rotation_y_y_f32x4, pa2_y_f32x4),
|
|
156
|
-
scaled_rotation_y_z_f32x4, pa2_z_f32x4);
|
|
157
|
-
float32x4_t ra2_z_f32x4 = vfmaq_f32(
|
|
158
|
-
vfmaq_f32(vmulq_f32(scaled_rotation_z_x_f32x4, pa2_x_f32x4), scaled_rotation_z_y_f32x4, pa2_y_f32x4),
|
|
159
|
-
scaled_rotation_z_z_f32x4, pa2_z_f32x4);
|
|
160
|
-
|
|
161
|
-
// Deltas
|
|
162
|
-
float32x4_t delta1_x_f32x4 = vsubq_f32(ra1_x_f32x4, pb1_x_f32x4);
|
|
163
|
-
float32x4_t delta1_y_f32x4 = vsubq_f32(ra1_y_f32x4, pb1_y_f32x4);
|
|
164
|
-
float32x4_t delta1_z_f32x4 = vsubq_f32(ra1_z_f32x4, pb1_z_f32x4);
|
|
165
|
-
float32x4_t delta2_x_f32x4 = vsubq_f32(ra2_x_f32x4, pb2_x_f32x4);
|
|
166
|
-
float32x4_t delta2_y_f32x4 = vsubq_f32(ra2_y_f32x4, pb2_y_f32x4);
|
|
167
|
-
float32x4_t delta2_z_f32x4 = vsubq_f32(ra2_z_f32x4, pb2_z_f32x4);
|
|
168
|
-
|
|
169
|
-
// Accumulate to independent accumulators
|
|
170
|
-
sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta1_x_f32x4, delta1_x_f32x4);
|
|
171
|
-
sum_squared_b_f32x4 = vfmaq_f32(sum_squared_b_f32x4, delta2_x_f32x4, delta2_x_f32x4);
|
|
172
|
-
sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta1_y_f32x4, delta1_y_f32x4);
|
|
173
|
-
sum_squared_b_f32x4 = vfmaq_f32(sum_squared_b_f32x4, delta2_y_f32x4, delta2_y_f32x4);
|
|
174
|
-
sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta1_z_f32x4, delta1_z_f32x4);
|
|
175
|
-
sum_squared_b_f32x4 = vfmaq_f32(sum_squared_b_f32x4, delta2_z_f32x4, delta2_z_f32x4);
|
|
176
|
-
}
|
|
177
|
-
|
|
178
|
-
// Handle remaining 4 points
|
|
179
|
-
if (j + 4 <= n) {
|
|
180
|
-
float32x4_t a_x_f32x4, a_y_f32x4, a_z_f32x4, b_x_f32x4, b_y_f32x4, b_z_f32x4;
|
|
181
|
-
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + j * 3, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
|
|
182
|
-
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + j * 3, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
|
|
183
|
-
|
|
184
|
-
float32x4_t pa_x_f32x4 = vsubq_f32(a_x_f32x4, centroid_a_x_f32x4);
|
|
185
|
-
float32x4_t pa_y_f32x4 = vsubq_f32(a_y_f32x4, centroid_a_y_f32x4);
|
|
186
|
-
float32x4_t pa_z_f32x4 = vsubq_f32(a_z_f32x4, centroid_a_z_f32x4);
|
|
187
|
-
float32x4_t pb_x_f32x4 = vsubq_f32(b_x_f32x4, centroid_b_x_f32x4);
|
|
188
|
-
float32x4_t pb_y_f32x4 = vsubq_f32(b_y_f32x4, centroid_b_y_f32x4);
|
|
189
|
-
float32x4_t pb_z_f32x4 = vsubq_f32(b_z_f32x4, centroid_b_z_f32x4);
|
|
190
|
-
|
|
191
|
-
float32x4_t ra_x_f32x4 = vfmaq_f32(
|
|
192
|
-
vfmaq_f32(vmulq_f32(scaled_rotation_x_x_f32x4, pa_x_f32x4), scaled_rotation_x_y_f32x4, pa_y_f32x4),
|
|
193
|
-
scaled_rotation_x_z_f32x4, pa_z_f32x4);
|
|
194
|
-
float32x4_t ra_y_f32x4 = vfmaq_f32(
|
|
195
|
-
vfmaq_f32(vmulq_f32(scaled_rotation_y_x_f32x4, pa_x_f32x4), scaled_rotation_y_y_f32x4, pa_y_f32x4),
|
|
196
|
-
scaled_rotation_y_z_f32x4, pa_z_f32x4);
|
|
197
|
-
float32x4_t ra_z_f32x4 = vfmaq_f32(
|
|
198
|
-
vfmaq_f32(vmulq_f32(scaled_rotation_z_x_f32x4, pa_x_f32x4), scaled_rotation_z_y_f32x4, pa_y_f32x4),
|
|
199
|
-
scaled_rotation_z_z_f32x4, pa_z_f32x4);
|
|
200
|
-
|
|
201
|
-
float32x4_t delta_x_f32x4 = vsubq_f32(ra_x_f32x4, pb_x_f32x4);
|
|
202
|
-
float32x4_t delta_y_f32x4 = vsubq_f32(ra_y_f32x4, pb_y_f32x4);
|
|
203
|
-
float32x4_t delta_z_f32x4 = vsubq_f32(ra_z_f32x4, pb_z_f32x4);
|
|
204
|
-
|
|
205
|
-
sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta_x_f32x4, delta_x_f32x4);
|
|
206
|
-
sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta_y_f32x4, delta_y_f32x4);
|
|
207
|
-
sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta_z_f32x4, delta_z_f32x4);
|
|
208
|
-
j += 4;
|
|
209
|
-
}
|
|
210
|
-
|
|
211
|
-
// Partial tail: handle remaining 1-3 points with vectorized partial deinterleave
|
|
212
|
-
if (j < n) {
|
|
213
|
-
float32x4_t a_x_f32x4, a_y_f32x4, a_z_f32x4, b_x_f32x4, b_y_f32x4, b_z_f32x4;
|
|
214
|
-
nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(a + j * 3, n - j, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
|
|
215
|
-
nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(b + j * 3, n - j, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
|
|
216
|
-
|
|
217
|
-
// Mask invalid lanes to zero BEFORE centering
|
|
218
|
-
uint32x4_t lane_u32x4 = vcombine_u32(vreinterpret_u32_u64(vcreate_u64(0x0000000100000000ULL)),
|
|
219
|
-
vreinterpret_u32_u64(vcreate_u64(0x0000000300000002ULL)));
|
|
220
|
-
uint32x4_t valid_u32x4 = vcltq_u32(lane_u32x4, vdupq_n_u32((nk_u32_t)(n - j)));
|
|
221
|
-
float32x4_t zero_f32x4 = vdupq_n_f32(0);
|
|
222
|
-
a_x_f32x4 = vbslq_f32(valid_u32x4, a_x_f32x4, zero_f32x4);
|
|
223
|
-
a_y_f32x4 = vbslq_f32(valid_u32x4, a_y_f32x4, zero_f32x4);
|
|
224
|
-
a_z_f32x4 = vbslq_f32(valid_u32x4, a_z_f32x4, zero_f32x4);
|
|
225
|
-
b_x_f32x4 = vbslq_f32(valid_u32x4, b_x_f32x4, zero_f32x4);
|
|
226
|
-
b_y_f32x4 = vbslq_f32(valid_u32x4, b_y_f32x4, zero_f32x4);
|
|
227
|
-
b_z_f32x4 = vbslq_f32(valid_u32x4, b_z_f32x4, zero_f32x4);
|
|
228
|
-
|
|
229
|
-
// Same centering + rotation + delta + FMA as body
|
|
230
|
-
float32x4_t pa_x_f32x4 = vsubq_f32(a_x_f32x4, centroid_a_x_f32x4);
|
|
231
|
-
float32x4_t pa_y_f32x4 = vsubq_f32(a_y_f32x4, centroid_a_y_f32x4);
|
|
232
|
-
float32x4_t pa_z_f32x4 = vsubq_f32(a_z_f32x4, centroid_a_z_f32x4);
|
|
233
|
-
float32x4_t pb_x_f32x4 = vsubq_f32(b_x_f32x4, centroid_b_x_f32x4);
|
|
234
|
-
float32x4_t pb_y_f32x4 = vsubq_f32(b_y_f32x4, centroid_b_y_f32x4);
|
|
235
|
-
float32x4_t pb_z_f32x4 = vsubq_f32(b_z_f32x4, centroid_b_z_f32x4);
|
|
236
|
-
|
|
237
|
-
float32x4_t ra_x_f32x4 = vfmaq_f32(
|
|
238
|
-
vfmaq_f32(vmulq_f32(scaled_rotation_x_x_f32x4, pa_x_f32x4), scaled_rotation_x_y_f32x4, pa_y_f32x4),
|
|
239
|
-
scaled_rotation_x_z_f32x4, pa_z_f32x4);
|
|
240
|
-
float32x4_t ra_y_f32x4 = vfmaq_f32(
|
|
241
|
-
vfmaq_f32(vmulq_f32(scaled_rotation_y_x_f32x4, pa_x_f32x4), scaled_rotation_y_y_f32x4, pa_y_f32x4),
|
|
242
|
-
scaled_rotation_y_z_f32x4, pa_z_f32x4);
|
|
243
|
-
float32x4_t ra_z_f32x4 = vfmaq_f32(
|
|
244
|
-
vfmaq_f32(vmulq_f32(scaled_rotation_z_x_f32x4, pa_x_f32x4), scaled_rotation_z_y_f32x4, pa_y_f32x4),
|
|
245
|
-
scaled_rotation_z_z_f32x4, pa_z_f32x4);
|
|
246
|
-
|
|
247
|
-
float32x4_t delta_x_f32x4 = vsubq_f32(ra_x_f32x4, pb_x_f32x4);
|
|
248
|
-
float32x4_t delta_y_f32x4 = vsubq_f32(ra_y_f32x4, pb_y_f32x4);
|
|
249
|
-
float32x4_t delta_z_f32x4 = vsubq_f32(ra_z_f32x4, pb_z_f32x4);
|
|
250
|
-
|
|
251
|
-
sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta_x_f32x4, delta_x_f32x4);
|
|
252
|
-
sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta_y_f32x4, delta_y_f32x4);
|
|
253
|
-
sum_squared_a_f32x4 = vfmaq_f32(sum_squared_a_f32x4, delta_z_f32x4, delta_z_f32x4);
|
|
254
|
-
}
|
|
255
|
-
|
|
256
|
-
// Combine accumulators and reduce
|
|
257
|
-
float32x4_t sum_squared_f32x4 = vaddq_f32(sum_squared_a_f32x4, sum_squared_b_f32x4);
|
|
258
|
-
nk_f32_t sum_squared = vaddvq_f32(sum_squared_f32x4);
|
|
259
|
-
|
|
260
|
-
return sum_squared;
|
|
261
|
-
}
|
|
262
|
-
|
|
263
83
|
NK_PUBLIC void nk_rmsd_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
264
84
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
265
85
|
// RMSD uses identity rotation and scale=1.0
|
|
@@ -318,124 +138,117 @@ NK_PUBLIC void nk_rmsd_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk
|
|
|
318
138
|
NK_PUBLIC void nk_kabsch_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
319
139
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
320
140
|
float32x4_t const zeros_f32x4 = vdupq_n_f32(0);
|
|
321
|
-
|
|
322
|
-
//
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
float32x4_t
|
|
329
|
-
float32x4_t
|
|
330
|
-
float32x4_t
|
|
331
|
-
float32x4_t
|
|
332
|
-
float32x4_t
|
|
333
|
-
float32x4_t cov_zx_b_f32x4 = zeros_f32x4, cov_zy_b_f32x4 = zeros_f32x4, cov_zz_b_f32x4 = zeros_f32x4;
|
|
141
|
+
// bf16 representation of 1.0 is 0x3F80; splatted across 8 lanes for BFDOT-based horizontal sums.
|
|
142
|
+
// The `nk_u16x8_splat_` wrapper prevents GCC from lowering this to `fmov v.8h, #1.0` (a FEAT_FP16
|
|
143
|
+
// encoding) that fails to assemble under a `+bf16`-only pragma.
|
|
144
|
+
bfloat16x8_t const ones_bf16x8 = vreinterpretq_bf16_u16(nk_u16x8_splat_(0x3F80));
|
|
145
|
+
|
|
146
|
+
// Centroid numerators, norm-squared, and 3x3 cross-covariance accumulators.
|
|
147
|
+
float32x4_t sum_a_x_f32x4 = zeros_f32x4, sum_a_y_f32x4 = zeros_f32x4, sum_a_z_f32x4 = zeros_f32x4;
|
|
148
|
+
float32x4_t sum_b_x_f32x4 = zeros_f32x4, sum_b_y_f32x4 = zeros_f32x4, sum_b_z_f32x4 = zeros_f32x4;
|
|
149
|
+
float32x4_t covariance_xx_f32x4 = zeros_f32x4, covariance_xy_f32x4 = zeros_f32x4, covariance_xz_f32x4 = zeros_f32x4;
|
|
150
|
+
float32x4_t covariance_yx_f32x4 = zeros_f32x4, covariance_yy_f32x4 = zeros_f32x4, covariance_yz_f32x4 = zeros_f32x4;
|
|
151
|
+
float32x4_t covariance_zx_f32x4 = zeros_f32x4, covariance_zy_f32x4 = zeros_f32x4, covariance_zz_f32x4 = zeros_f32x4;
|
|
152
|
+
float32x4_t norm_squared_a_f32x4 = zeros_f32x4, norm_squared_b_f32x4 = zeros_f32x4;
|
|
334
153
|
|
|
335
154
|
nk_size_t i = 0;
|
|
336
|
-
float32x4_t a1_x_f32x4, a1_y_f32x4, a1_z_f32x4, b1_x_f32x4, b1_y_f32x4, b1_z_f32x4;
|
|
337
|
-
float32x4_t a2_x_f32x4, a2_y_f32x4, a2_z_f32x4, b2_x_f32x4, b2_y_f32x4, b2_z_f32x4;
|
|
338
155
|
|
|
339
|
-
// Main loop: 8
|
|
156
|
+
// Main loop: 8 triplets per iteration via vld3q_u16 + vbfdotq_f32.
|
|
157
|
+
// Each vld3q_u16 de-interleaves 24 bf16 values into three 8-lane channel vectors.
|
|
158
|
+
// vbfdotq_f32(acc, p, q) computes per-32bit-lane: acc[l] += p[2l]*q[2l] + p[2l+1]*q[2l+1]
|
|
159
|
+
// with bf16 inputs and f32 accumulation. Summing the 4 lanes at the end yields the scalar.
|
|
340
160
|
for (; i + 8 <= n; i += 8) {
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
cov_zz_a_f32x4 = vfmaq_f32(cov_zz_a_f32x4, a1_z_f32x4, b1_z_f32x4);
|
|
377
|
-
cov_zz_b_f32x4 = vfmaq_f32(cov_zz_b_f32x4, a2_z_f32x4, b2_z_f32x4);
|
|
161
|
+
uint16x8x3_t a_xyz_u16x8x3 = vld3q_u16((nk_u16_t const *)(a + i * 3));
|
|
162
|
+
uint16x8x3_t b_xyz_u16x8x3 = vld3q_u16((nk_u16_t const *)(b + i * 3));
|
|
163
|
+
bfloat16x8_t a_x_bf16x8 = vreinterpretq_bf16_u16(a_xyz_u16x8x3.val[0]);
|
|
164
|
+
bfloat16x8_t a_y_bf16x8 = vreinterpretq_bf16_u16(a_xyz_u16x8x3.val[1]);
|
|
165
|
+
bfloat16x8_t a_z_bf16x8 = vreinterpretq_bf16_u16(a_xyz_u16x8x3.val[2]);
|
|
166
|
+
bfloat16x8_t b_x_bf16x8 = vreinterpretq_bf16_u16(b_xyz_u16x8x3.val[0]);
|
|
167
|
+
bfloat16x8_t b_y_bf16x8 = vreinterpretq_bf16_u16(b_xyz_u16x8x3.val[1]);
|
|
168
|
+
bfloat16x8_t b_z_bf16x8 = vreinterpretq_bf16_u16(b_xyz_u16x8x3.val[2]);
|
|
169
|
+
|
|
170
|
+
// Centroid numerators: Σ channel values (pairwise via BFDOT against bf16 1.0).
|
|
171
|
+
sum_a_x_f32x4 = vbfdotq_f32(sum_a_x_f32x4, a_x_bf16x8, ones_bf16x8);
|
|
172
|
+
sum_a_y_f32x4 = vbfdotq_f32(sum_a_y_f32x4, a_y_bf16x8, ones_bf16x8);
|
|
173
|
+
sum_a_z_f32x4 = vbfdotq_f32(sum_a_z_f32x4, a_z_bf16x8, ones_bf16x8);
|
|
174
|
+
sum_b_x_f32x4 = vbfdotq_f32(sum_b_x_f32x4, b_x_bf16x8, ones_bf16x8);
|
|
175
|
+
sum_b_y_f32x4 = vbfdotq_f32(sum_b_y_f32x4, b_y_bf16x8, ones_bf16x8);
|
|
176
|
+
sum_b_z_f32x4 = vbfdotq_f32(sum_b_z_f32x4, b_z_bf16x8, ones_bf16x8);
|
|
177
|
+
|
|
178
|
+
// 3x3 cross-covariance H cells: Σ a_j · b_k.
|
|
179
|
+
covariance_xx_f32x4 = vbfdotq_f32(covariance_xx_f32x4, a_x_bf16x8, b_x_bf16x8);
|
|
180
|
+
covariance_xy_f32x4 = vbfdotq_f32(covariance_xy_f32x4, a_x_bf16x8, b_y_bf16x8);
|
|
181
|
+
covariance_xz_f32x4 = vbfdotq_f32(covariance_xz_f32x4, a_x_bf16x8, b_z_bf16x8);
|
|
182
|
+
covariance_yx_f32x4 = vbfdotq_f32(covariance_yx_f32x4, a_y_bf16x8, b_x_bf16x8);
|
|
183
|
+
covariance_yy_f32x4 = vbfdotq_f32(covariance_yy_f32x4, a_y_bf16x8, b_y_bf16x8);
|
|
184
|
+
covariance_yz_f32x4 = vbfdotq_f32(covariance_yz_f32x4, a_y_bf16x8, b_z_bf16x8);
|
|
185
|
+
covariance_zx_f32x4 = vbfdotq_f32(covariance_zx_f32x4, a_z_bf16x8, b_x_bf16x8);
|
|
186
|
+
covariance_zy_f32x4 = vbfdotq_f32(covariance_zy_f32x4, a_z_bf16x8, b_y_bf16x8);
|
|
187
|
+
covariance_zz_f32x4 = vbfdotq_f32(covariance_zz_f32x4, a_z_bf16x8, b_z_bf16x8);
|
|
188
|
+
|
|
189
|
+
// Norm-squared per point set: Σ (x² + y² + z²).
|
|
190
|
+
norm_squared_a_f32x4 = vbfdotq_f32(norm_squared_a_f32x4, a_x_bf16x8, a_x_bf16x8);
|
|
191
|
+
norm_squared_a_f32x4 = vbfdotq_f32(norm_squared_a_f32x4, a_y_bf16x8, a_y_bf16x8);
|
|
192
|
+
norm_squared_a_f32x4 = vbfdotq_f32(norm_squared_a_f32x4, a_z_bf16x8, a_z_bf16x8);
|
|
193
|
+
norm_squared_b_f32x4 = vbfdotq_f32(norm_squared_b_f32x4, b_x_bf16x8, b_x_bf16x8);
|
|
194
|
+
norm_squared_b_f32x4 = vbfdotq_f32(norm_squared_b_f32x4, b_y_bf16x8, b_y_bf16x8);
|
|
195
|
+
norm_squared_b_f32x4 = vbfdotq_f32(norm_squared_b_f32x4, b_z_bf16x8, b_z_bf16x8);
|
|
378
196
|
}
|
|
379
197
|
|
|
380
|
-
// 4-point
|
|
198
|
+
// 4-point and partial (1-3) tails: keep the widen+fmaq path on f32x4 channel vectors.
|
|
199
|
+
// These branches run at most once each, so we skip another vbfdotq variant for them.
|
|
200
|
+
float32x4_t a_x_f32x4, a_y_f32x4, a_z_f32x4, b_x_f32x4, b_y_f32x4, b_z_f32x4;
|
|
381
201
|
for (; i + 4 <= n; i += 4) {
|
|
382
|
-
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + i * 3, &
|
|
383
|
-
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + i * 3, &
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
202
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + i * 3, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
|
|
203
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + i * 3, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
|
|
204
|
+
sum_a_x_f32x4 = vaddq_f32(sum_a_x_f32x4, a_x_f32x4);
|
|
205
|
+
sum_a_y_f32x4 = vaddq_f32(sum_a_y_f32x4, a_y_f32x4);
|
|
206
|
+
sum_a_z_f32x4 = vaddq_f32(sum_a_z_f32x4, a_z_f32x4);
|
|
207
|
+
sum_b_x_f32x4 = vaddq_f32(sum_b_x_f32x4, b_x_f32x4);
|
|
208
|
+
sum_b_y_f32x4 = vaddq_f32(sum_b_y_f32x4, b_y_f32x4);
|
|
209
|
+
sum_b_z_f32x4 = vaddq_f32(sum_b_z_f32x4, b_z_f32x4);
|
|
210
|
+
covariance_xx_f32x4 = vfmaq_f32(covariance_xx_f32x4, a_x_f32x4, b_x_f32x4);
|
|
211
|
+
covariance_xy_f32x4 = vfmaq_f32(covariance_xy_f32x4, a_x_f32x4, b_y_f32x4);
|
|
212
|
+
covariance_xz_f32x4 = vfmaq_f32(covariance_xz_f32x4, a_x_f32x4, b_z_f32x4);
|
|
213
|
+
covariance_yx_f32x4 = vfmaq_f32(covariance_yx_f32x4, a_y_f32x4, b_x_f32x4);
|
|
214
|
+
covariance_yy_f32x4 = vfmaq_f32(covariance_yy_f32x4, a_y_f32x4, b_y_f32x4);
|
|
215
|
+
covariance_yz_f32x4 = vfmaq_f32(covariance_yz_f32x4, a_y_f32x4, b_z_f32x4);
|
|
216
|
+
covariance_zx_f32x4 = vfmaq_f32(covariance_zx_f32x4, a_z_f32x4, b_x_f32x4);
|
|
217
|
+
covariance_zy_f32x4 = vfmaq_f32(covariance_zy_f32x4, a_z_f32x4, b_y_f32x4);
|
|
218
|
+
covariance_zz_f32x4 = vfmaq_f32(covariance_zz_f32x4, a_z_f32x4, b_z_f32x4);
|
|
219
|
+
norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_x_f32x4, a_x_f32x4);
|
|
220
|
+
norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_y_f32x4, a_y_f32x4);
|
|
221
|
+
norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_z_f32x4, a_z_f32x4);
|
|
222
|
+
norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_x_f32x4, b_x_f32x4);
|
|
223
|
+
norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_y_f32x4, b_y_f32x4);
|
|
224
|
+
norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_z_f32x4, b_z_f32x4);
|
|
399
225
|
}
|
|
400
|
-
|
|
401
|
-
// Partial tail: handle remaining 1-3 points with vectorized partial deinterleave
|
|
402
226
|
if (i < n) {
|
|
403
|
-
nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(a + i * 3, n - i, &
|
|
404
|
-
nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(b + i * 3, n - i, &
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
227
|
+
nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(a + i * 3, n - i, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
|
|
228
|
+
nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(b + i * 3, n - i, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
|
|
229
|
+
sum_a_x_f32x4 = vaddq_f32(sum_a_x_f32x4, a_x_f32x4);
|
|
230
|
+
sum_a_y_f32x4 = vaddq_f32(sum_a_y_f32x4, a_y_f32x4);
|
|
231
|
+
sum_a_z_f32x4 = vaddq_f32(sum_a_z_f32x4, a_z_f32x4);
|
|
232
|
+
sum_b_x_f32x4 = vaddq_f32(sum_b_x_f32x4, b_x_f32x4);
|
|
233
|
+
sum_b_y_f32x4 = vaddq_f32(sum_b_y_f32x4, b_y_f32x4);
|
|
234
|
+
sum_b_z_f32x4 = vaddq_f32(sum_b_z_f32x4, b_z_f32x4);
|
|
235
|
+
covariance_xx_f32x4 = vfmaq_f32(covariance_xx_f32x4, a_x_f32x4, b_x_f32x4);
|
|
236
|
+
covariance_xy_f32x4 = vfmaq_f32(covariance_xy_f32x4, a_x_f32x4, b_y_f32x4);
|
|
237
|
+
covariance_xz_f32x4 = vfmaq_f32(covariance_xz_f32x4, a_x_f32x4, b_z_f32x4);
|
|
238
|
+
covariance_yx_f32x4 = vfmaq_f32(covariance_yx_f32x4, a_y_f32x4, b_x_f32x4);
|
|
239
|
+
covariance_yy_f32x4 = vfmaq_f32(covariance_yy_f32x4, a_y_f32x4, b_y_f32x4);
|
|
240
|
+
covariance_yz_f32x4 = vfmaq_f32(covariance_yz_f32x4, a_y_f32x4, b_z_f32x4);
|
|
241
|
+
covariance_zx_f32x4 = vfmaq_f32(covariance_zx_f32x4, a_z_f32x4, b_x_f32x4);
|
|
242
|
+
covariance_zy_f32x4 = vfmaq_f32(covariance_zy_f32x4, a_z_f32x4, b_y_f32x4);
|
|
243
|
+
covariance_zz_f32x4 = vfmaq_f32(covariance_zz_f32x4, a_z_f32x4, b_z_f32x4);
|
|
244
|
+
norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_x_f32x4, a_x_f32x4);
|
|
245
|
+
norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_y_f32x4, a_y_f32x4);
|
|
246
|
+
norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_z_f32x4, a_z_f32x4);
|
|
247
|
+
norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_x_f32x4, b_x_f32x4);
|
|
248
|
+
norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_y_f32x4, b_y_f32x4);
|
|
249
|
+
norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_z_f32x4, b_z_f32x4);
|
|
420
250
|
}
|
|
421
251
|
|
|
422
|
-
// Combine dual accumulators
|
|
423
|
-
float32x4_t sum_a_x_f32x4 = vaddq_f32(sum_a_x_a_f32x4, sum_a_x_b_f32x4);
|
|
424
|
-
float32x4_t sum_a_y_f32x4 = vaddq_f32(sum_a_y_a_f32x4, sum_a_y_b_f32x4);
|
|
425
|
-
float32x4_t sum_a_z_f32x4 = vaddq_f32(sum_a_z_a_f32x4, sum_a_z_b_f32x4);
|
|
426
|
-
float32x4_t sum_b_x_f32x4 = vaddq_f32(sum_b_x_a_f32x4, sum_b_x_b_f32x4);
|
|
427
|
-
float32x4_t sum_b_y_f32x4 = vaddq_f32(sum_b_y_a_f32x4, sum_b_y_b_f32x4);
|
|
428
|
-
float32x4_t sum_b_z_f32x4 = vaddq_f32(sum_b_z_a_f32x4, sum_b_z_b_f32x4);
|
|
429
|
-
float32x4_t cov_xx_f32x4 = vaddq_f32(cov_xx_a_f32x4, cov_xx_b_f32x4);
|
|
430
|
-
float32x4_t cov_xy_f32x4 = vaddq_f32(cov_xy_a_f32x4, cov_xy_b_f32x4);
|
|
431
|
-
float32x4_t cov_xz_f32x4 = vaddq_f32(cov_xz_a_f32x4, cov_xz_b_f32x4);
|
|
432
|
-
float32x4_t cov_yx_f32x4 = vaddq_f32(cov_yx_a_f32x4, cov_yx_b_f32x4);
|
|
433
|
-
float32x4_t cov_yy_f32x4 = vaddq_f32(cov_yy_a_f32x4, cov_yy_b_f32x4);
|
|
434
|
-
float32x4_t cov_yz_f32x4 = vaddq_f32(cov_yz_a_f32x4, cov_yz_b_f32x4);
|
|
435
|
-
float32x4_t cov_zx_f32x4 = vaddq_f32(cov_zx_a_f32x4, cov_zx_b_f32x4);
|
|
436
|
-
float32x4_t cov_zy_f32x4 = vaddq_f32(cov_zy_a_f32x4, cov_zy_b_f32x4);
|
|
437
|
-
float32x4_t cov_zz_f32x4 = vaddq_f32(cov_zz_a_f32x4, cov_zz_b_f32x4);
|
|
438
|
-
|
|
439
252
|
// Reduce vector accumulators
|
|
440
253
|
nk_f32_t sum_a_x = vaddvq_f32(sum_a_x_f32x4);
|
|
441
254
|
nk_f32_t sum_a_y = vaddvq_f32(sum_a_y_f32x4);
|
|
@@ -444,15 +257,17 @@ NK_PUBLIC void nk_kabsch_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
|
|
|
444
257
|
nk_f32_t sum_b_y = vaddvq_f32(sum_b_y_f32x4);
|
|
445
258
|
nk_f32_t sum_b_z = vaddvq_f32(sum_b_z_f32x4);
|
|
446
259
|
|
|
447
|
-
nk_f32_t covariance_x_x = vaddvq_f32(
|
|
448
|
-
nk_f32_t covariance_x_y = vaddvq_f32(
|
|
449
|
-
nk_f32_t covariance_x_z = vaddvq_f32(
|
|
450
|
-
nk_f32_t covariance_y_x = vaddvq_f32(
|
|
451
|
-
nk_f32_t covariance_y_y = vaddvq_f32(
|
|
452
|
-
nk_f32_t covariance_y_z = vaddvq_f32(
|
|
453
|
-
nk_f32_t covariance_z_x = vaddvq_f32(
|
|
454
|
-
nk_f32_t covariance_z_y = vaddvq_f32(
|
|
455
|
-
nk_f32_t covariance_z_z = vaddvq_f32(
|
|
260
|
+
nk_f32_t covariance_x_x = vaddvq_f32(covariance_xx_f32x4);
|
|
261
|
+
nk_f32_t covariance_x_y = vaddvq_f32(covariance_xy_f32x4);
|
|
262
|
+
nk_f32_t covariance_x_z = vaddvq_f32(covariance_xz_f32x4);
|
|
263
|
+
nk_f32_t covariance_y_x = vaddvq_f32(covariance_yx_f32x4);
|
|
264
|
+
nk_f32_t covariance_y_y = vaddvq_f32(covariance_yy_f32x4);
|
|
265
|
+
nk_f32_t covariance_y_z = vaddvq_f32(covariance_yz_f32x4);
|
|
266
|
+
nk_f32_t covariance_z_x = vaddvq_f32(covariance_zx_f32x4);
|
|
267
|
+
nk_f32_t covariance_z_y = vaddvq_f32(covariance_zy_f32x4);
|
|
268
|
+
nk_f32_t covariance_z_z = vaddvq_f32(covariance_zz_f32x4);
|
|
269
|
+
nk_f32_t norm_squared_a = vaddvq_f32(norm_squared_a_f32x4);
|
|
270
|
+
nk_f32_t norm_squared_b = vaddvq_f32(norm_squared_b_f32x4);
|
|
456
271
|
|
|
457
272
|
// Compute centroids
|
|
458
273
|
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
@@ -466,6 +281,16 @@ NK_PUBLIC void nk_kabsch_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
|
|
|
466
281
|
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
467
282
|
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
468
283
|
|
|
284
|
+
// Centered norm-squared via parallel-axis identity; clamp at zero for numeric safety.
|
|
285
|
+
nk_f32_t centered_norm_squared_a = norm_squared_a -
|
|
286
|
+
(nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
|
|
287
|
+
centroid_a_z * centroid_a_z);
|
|
288
|
+
nk_f32_t centered_norm_squared_b = norm_squared_b -
|
|
289
|
+
(nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
|
|
290
|
+
centroid_b_z * centroid_b_z);
|
|
291
|
+
if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
|
|
292
|
+
if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
|
|
293
|
+
|
|
469
294
|
// Apply centering correction: H_centered = H - n * centroid_a * centroid_bᵀ
|
|
470
295
|
covariance_x_x -= n * centroid_a_x * centroid_b_x;
|
|
471
296
|
covariance_x_y -= n * centroid_a_x * centroid_b_y;
|
|
@@ -480,187 +305,186 @@ NK_PUBLIC void nk_kabsch_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
|
|
|
480
305
|
// Compute SVD and optimal rotation
|
|
481
306
|
nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
482
307
|
covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
308
|
+
|
|
309
|
+
// Identity-dominant short-circuit: if H ≈ diag(positive entries), R = I and trace(R·H) = trace(H).
|
|
310
|
+
nk_f32_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
|
|
311
|
+
cross_covariance[4] * cross_covariance[4] +
|
|
312
|
+
cross_covariance[8] * cross_covariance[8];
|
|
313
|
+
nk_f32_t covariance_offdiagonal_norm_squared =
|
|
314
|
+
cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
|
|
315
|
+
cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
|
|
316
|
+
cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
|
|
317
|
+
nk_f32_t optimal_rotation[9];
|
|
318
|
+
nk_f32_t trace_rotation_covariance;
|
|
319
|
+
if (covariance_offdiagonal_norm_squared < 1e-12f * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0f &&
|
|
320
|
+
cross_covariance[4] > 0.0f && cross_covariance[8] > 0.0f) {
|
|
321
|
+
optimal_rotation[0] = 1, optimal_rotation[1] = 0, optimal_rotation[2] = 0, optimal_rotation[3] = 0,
|
|
322
|
+
optimal_rotation[4] = 1, optimal_rotation[5] = 0, optimal_rotation[6] = 0, optimal_rotation[7] = 0,
|
|
323
|
+
optimal_rotation[8] = 1;
|
|
324
|
+
trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
|
|
325
|
+
}
|
|
326
|
+
else {
|
|
327
|
+
nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
|
|
328
|
+
nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
|
|
329
|
+
|
|
330
|
+
// R = V * Uᵀ
|
|
331
|
+
optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
|
|
332
|
+
optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
|
|
333
|
+
optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
|
|
334
|
+
optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
|
|
335
|
+
optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
|
|
336
|
+
optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
|
|
337
|
+
optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
|
|
338
|
+
optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
|
|
339
|
+
optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
|
|
340
|
+
|
|
341
|
+
// Handle reflection: if det(R) < 0, negate third column of V and recompute R
|
|
342
|
+
if (nk_det3x3_f32_(optimal_rotation) < 0) {
|
|
343
|
+
svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
|
|
344
|
+
optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
|
|
345
|
+
optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
|
|
346
|
+
optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
|
|
347
|
+
optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
|
|
348
|
+
optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
|
|
349
|
+
optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
|
|
350
|
+
optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
|
|
351
|
+
optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
|
|
352
|
+
optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
trace_rotation_covariance =
|
|
356
|
+
optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
|
|
357
|
+
optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
|
|
358
|
+
optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
|
|
359
|
+
optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
|
|
360
|
+
optimal_rotation[8] * cross_covariance[8];
|
|
510
361
|
}
|
|
511
362
|
|
|
512
363
|
// Output rotation matrix and scale=1.0
|
|
513
364
|
if (rotation)
|
|
514
|
-
for (int j = 0; j < 9; ++j) rotation[j] =
|
|
365
|
+
for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
|
|
515
366
|
if (scale) *scale = 1.0f;
|
|
516
367
|
|
|
517
|
-
//
|
|
518
|
-
nk_f32_t sum_squared =
|
|
519
|
-
|
|
368
|
+
// Folded SSD via trace identity: SSD = ‖a-ā‖² + ‖b-b̄‖² − 2·trace(R · H_centered).
|
|
369
|
+
nk_f32_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0f * trace_rotation_covariance;
|
|
370
|
+
if (sum_squared < 0.0f) sum_squared = 0.0f;
|
|
520
371
|
*result = nk_f32_sqrt_neon(sum_squared * inv_n);
|
|
521
372
|
}
|
|
522
373
|
|
|
523
374
|
NK_PUBLIC void nk_umeyama_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
524
375
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
525
376
|
float32x4_t const zeros_f32x4 = vdupq_n_f32(0);
|
|
526
|
-
|
|
527
|
-
//
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
float32x4_t
|
|
534
|
-
float32x4_t
|
|
535
|
-
float32x4_t
|
|
536
|
-
float32x4_t
|
|
537
|
-
float32x4_t
|
|
538
|
-
float32x4_t cov_zx_b_f32x4 = zeros_f32x4, cov_zy_b_f32x4 = zeros_f32x4, cov_zz_b_f32x4 = zeros_f32x4;
|
|
539
|
-
|
|
540
|
-
// Variance of A accumulators
|
|
541
|
-
float32x4_t variance_a_a_f32x4 = zeros_f32x4;
|
|
542
|
-
float32x4_t variance_a_b_f32x4 = zeros_f32x4;
|
|
377
|
+
// bf16 representation of 1.0 is 0x3F80; splatted across 8 lanes for BFDOT-based horizontal sums.
|
|
378
|
+
// The `nk_u16x8_splat_` wrapper prevents GCC from lowering this to `fmov v.8h, #1.0` (a FEAT_FP16
|
|
379
|
+
// encoding) that fails to assemble under a `+bf16`-only pragma.
|
|
380
|
+
bfloat16x8_t const ones_bf16x8 = vreinterpretq_bf16_u16(nk_u16x8_splat_(0x3F80));
|
|
381
|
+
|
|
382
|
+
// Centroid numerators, norm-squared, and 3x3 cross-covariance accumulators.
|
|
383
|
+
float32x4_t sum_a_x_f32x4 = zeros_f32x4, sum_a_y_f32x4 = zeros_f32x4, sum_a_z_f32x4 = zeros_f32x4;
|
|
384
|
+
float32x4_t sum_b_x_f32x4 = zeros_f32x4, sum_b_y_f32x4 = zeros_f32x4, sum_b_z_f32x4 = zeros_f32x4;
|
|
385
|
+
float32x4_t covariance_xx_f32x4 = zeros_f32x4, covariance_xy_f32x4 = zeros_f32x4, covariance_xz_f32x4 = zeros_f32x4;
|
|
386
|
+
float32x4_t covariance_yx_f32x4 = zeros_f32x4, covariance_yy_f32x4 = zeros_f32x4, covariance_yz_f32x4 = zeros_f32x4;
|
|
387
|
+
float32x4_t covariance_zx_f32x4 = zeros_f32x4, covariance_zy_f32x4 = zeros_f32x4, covariance_zz_f32x4 = zeros_f32x4;
|
|
388
|
+
float32x4_t norm_squared_a_f32x4 = zeros_f32x4, norm_squared_b_f32x4 = zeros_f32x4;
|
|
543
389
|
|
|
544
390
|
nk_size_t i = 0;
|
|
545
|
-
float32x4_t a1_x_f32x4, a1_y_f32x4, a1_z_f32x4, b1_x_f32x4, b1_y_f32x4, b1_z_f32x4;
|
|
546
|
-
float32x4_t a2_x_f32x4, a2_y_f32x4, a2_z_f32x4, b2_x_f32x4, b2_y_f32x4, b2_z_f32x4;
|
|
547
391
|
|
|
548
|
-
// Main loop: 8
|
|
392
|
+
// Main loop: 8 triplets per iteration via vld3q_u16 + vbfdotq_f32.
|
|
393
|
+
// Each vld3q_u16 de-interleaves 24 bf16 values into three 8-lane channel vectors.
|
|
394
|
+
// vbfdotq_f32(acc, p, q) computes per-32bit-lane: acc[l] += p[2l]*q[2l] + p[2l+1]*q[2l+1]
|
|
395
|
+
// with bf16 inputs and f32 accumulation. Summing the 4 lanes at the end yields the scalar.
|
|
549
396
|
for (; i + 8 <= n; i += 8) {
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
cov_zy_b_f32x4 = vfmaq_f32(cov_zy_b_f32x4, a2_z_f32x4, b2_y_f32x4);
|
|
586
|
-
cov_zz_a_f32x4 = vfmaq_f32(cov_zz_a_f32x4, a1_z_f32x4, b1_z_f32x4);
|
|
587
|
-
cov_zz_b_f32x4 = vfmaq_f32(cov_zz_b_f32x4, a2_z_f32x4, b2_z_f32x4);
|
|
588
|
-
|
|
589
|
-
// Variance of A
|
|
590
|
-
variance_a_a_f32x4 = vfmaq_f32(variance_a_a_f32x4, a1_x_f32x4, a1_x_f32x4);
|
|
591
|
-
variance_a_b_f32x4 = vfmaq_f32(variance_a_b_f32x4, a2_x_f32x4, a2_x_f32x4);
|
|
592
|
-
variance_a_a_f32x4 = vfmaq_f32(variance_a_a_f32x4, a1_y_f32x4, a1_y_f32x4);
|
|
593
|
-
variance_a_b_f32x4 = vfmaq_f32(variance_a_b_f32x4, a2_y_f32x4, a2_y_f32x4);
|
|
594
|
-
variance_a_a_f32x4 = vfmaq_f32(variance_a_a_f32x4, a1_z_f32x4, a1_z_f32x4);
|
|
595
|
-
variance_a_b_f32x4 = vfmaq_f32(variance_a_b_f32x4, a2_z_f32x4, a2_z_f32x4);
|
|
397
|
+
uint16x8x3_t a_xyz_u16x8x3 = vld3q_u16((nk_u16_t const *)(a + i * 3));
|
|
398
|
+
uint16x8x3_t b_xyz_u16x8x3 = vld3q_u16((nk_u16_t const *)(b + i * 3));
|
|
399
|
+
bfloat16x8_t a_x_bf16x8 = vreinterpretq_bf16_u16(a_xyz_u16x8x3.val[0]);
|
|
400
|
+
bfloat16x8_t a_y_bf16x8 = vreinterpretq_bf16_u16(a_xyz_u16x8x3.val[1]);
|
|
401
|
+
bfloat16x8_t a_z_bf16x8 = vreinterpretq_bf16_u16(a_xyz_u16x8x3.val[2]);
|
|
402
|
+
bfloat16x8_t b_x_bf16x8 = vreinterpretq_bf16_u16(b_xyz_u16x8x3.val[0]);
|
|
403
|
+
bfloat16x8_t b_y_bf16x8 = vreinterpretq_bf16_u16(b_xyz_u16x8x3.val[1]);
|
|
404
|
+
bfloat16x8_t b_z_bf16x8 = vreinterpretq_bf16_u16(b_xyz_u16x8x3.val[2]);
|
|
405
|
+
|
|
406
|
+
// Centroid numerators: Σ channel values (pairwise via BFDOT against bf16 1.0).
|
|
407
|
+
sum_a_x_f32x4 = vbfdotq_f32(sum_a_x_f32x4, a_x_bf16x8, ones_bf16x8);
|
|
408
|
+
sum_a_y_f32x4 = vbfdotq_f32(sum_a_y_f32x4, a_y_bf16x8, ones_bf16x8);
|
|
409
|
+
sum_a_z_f32x4 = vbfdotq_f32(sum_a_z_f32x4, a_z_bf16x8, ones_bf16x8);
|
|
410
|
+
sum_b_x_f32x4 = vbfdotq_f32(sum_b_x_f32x4, b_x_bf16x8, ones_bf16x8);
|
|
411
|
+
sum_b_y_f32x4 = vbfdotq_f32(sum_b_y_f32x4, b_y_bf16x8, ones_bf16x8);
|
|
412
|
+
sum_b_z_f32x4 = vbfdotq_f32(sum_b_z_f32x4, b_z_bf16x8, ones_bf16x8);
|
|
413
|
+
|
|
414
|
+
// 3x3 cross-covariance H cells: Σ a_j · b_k.
|
|
415
|
+
covariance_xx_f32x4 = vbfdotq_f32(covariance_xx_f32x4, a_x_bf16x8, b_x_bf16x8);
|
|
416
|
+
covariance_xy_f32x4 = vbfdotq_f32(covariance_xy_f32x4, a_x_bf16x8, b_y_bf16x8);
|
|
417
|
+
covariance_xz_f32x4 = vbfdotq_f32(covariance_xz_f32x4, a_x_bf16x8, b_z_bf16x8);
|
|
418
|
+
covariance_yx_f32x4 = vbfdotq_f32(covariance_yx_f32x4, a_y_bf16x8, b_x_bf16x8);
|
|
419
|
+
covariance_yy_f32x4 = vbfdotq_f32(covariance_yy_f32x4, a_y_bf16x8, b_y_bf16x8);
|
|
420
|
+
covariance_yz_f32x4 = vbfdotq_f32(covariance_yz_f32x4, a_y_bf16x8, b_z_bf16x8);
|
|
421
|
+
covariance_zx_f32x4 = vbfdotq_f32(covariance_zx_f32x4, a_z_bf16x8, b_x_bf16x8);
|
|
422
|
+
covariance_zy_f32x4 = vbfdotq_f32(covariance_zy_f32x4, a_z_bf16x8, b_y_bf16x8);
|
|
423
|
+
covariance_zz_f32x4 = vbfdotq_f32(covariance_zz_f32x4, a_z_bf16x8, b_z_bf16x8);
|
|
424
|
+
|
|
425
|
+
// Norm-squared per point set: Σ (x² + y² + z²).
|
|
426
|
+
norm_squared_a_f32x4 = vbfdotq_f32(norm_squared_a_f32x4, a_x_bf16x8, a_x_bf16x8);
|
|
427
|
+
norm_squared_a_f32x4 = vbfdotq_f32(norm_squared_a_f32x4, a_y_bf16x8, a_y_bf16x8);
|
|
428
|
+
norm_squared_a_f32x4 = vbfdotq_f32(norm_squared_a_f32x4, a_z_bf16x8, a_z_bf16x8);
|
|
429
|
+
norm_squared_b_f32x4 = vbfdotq_f32(norm_squared_b_f32x4, b_x_bf16x8, b_x_bf16x8);
|
|
430
|
+
norm_squared_b_f32x4 = vbfdotq_f32(norm_squared_b_f32x4, b_y_bf16x8, b_y_bf16x8);
|
|
431
|
+
norm_squared_b_f32x4 = vbfdotq_f32(norm_squared_b_f32x4, b_z_bf16x8, b_z_bf16x8);
|
|
596
432
|
}
|
|
597
433
|
|
|
598
|
-
// 4-point
|
|
434
|
+
// 4-point and partial (1-3) tails: keep the widen+fmaq path on f32x4 channel vectors.
|
|
435
|
+
// These branches run at most once each, so we skip another vbfdotq variant for them.
|
|
436
|
+
float32x4_t a_x_f32x4, a_y_f32x4, a_z_f32x4, b_x_f32x4, b_y_f32x4, b_z_f32x4;
|
|
599
437
|
for (; i + 4 <= n; i += 4) {
|
|
600
|
-
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + i * 3, &
|
|
601
|
-
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + i * 3, &
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
438
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + i * 3, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
|
|
439
|
+
nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + i * 3, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
|
|
440
|
+
sum_a_x_f32x4 = vaddq_f32(sum_a_x_f32x4, a_x_f32x4);
|
|
441
|
+
sum_a_y_f32x4 = vaddq_f32(sum_a_y_f32x4, a_y_f32x4);
|
|
442
|
+
sum_a_z_f32x4 = vaddq_f32(sum_a_z_f32x4, a_z_f32x4);
|
|
443
|
+
sum_b_x_f32x4 = vaddq_f32(sum_b_x_f32x4, b_x_f32x4);
|
|
444
|
+
sum_b_y_f32x4 = vaddq_f32(sum_b_y_f32x4, b_y_f32x4);
|
|
445
|
+
sum_b_z_f32x4 = vaddq_f32(sum_b_z_f32x4, b_z_f32x4);
|
|
446
|
+
covariance_xx_f32x4 = vfmaq_f32(covariance_xx_f32x4, a_x_f32x4, b_x_f32x4);
|
|
447
|
+
covariance_xy_f32x4 = vfmaq_f32(covariance_xy_f32x4, a_x_f32x4, b_y_f32x4);
|
|
448
|
+
covariance_xz_f32x4 = vfmaq_f32(covariance_xz_f32x4, a_x_f32x4, b_z_f32x4);
|
|
449
|
+
covariance_yx_f32x4 = vfmaq_f32(covariance_yx_f32x4, a_y_f32x4, b_x_f32x4);
|
|
450
|
+
covariance_yy_f32x4 = vfmaq_f32(covariance_yy_f32x4, a_y_f32x4, b_y_f32x4);
|
|
451
|
+
covariance_yz_f32x4 = vfmaq_f32(covariance_yz_f32x4, a_y_f32x4, b_z_f32x4);
|
|
452
|
+
covariance_zx_f32x4 = vfmaq_f32(covariance_zx_f32x4, a_z_f32x4, b_x_f32x4);
|
|
453
|
+
covariance_zy_f32x4 = vfmaq_f32(covariance_zy_f32x4, a_z_f32x4, b_y_f32x4);
|
|
454
|
+
covariance_zz_f32x4 = vfmaq_f32(covariance_zz_f32x4, a_z_f32x4, b_z_f32x4);
|
|
455
|
+
norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_x_f32x4, a_x_f32x4);
|
|
456
|
+
norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_y_f32x4, a_y_f32x4);
|
|
457
|
+
norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_z_f32x4, a_z_f32x4);
|
|
458
|
+
norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_x_f32x4, b_x_f32x4);
|
|
459
|
+
norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_y_f32x4, b_y_f32x4);
|
|
460
|
+
norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_z_f32x4, b_z_f32x4);
|
|
620
461
|
}
|
|
621
|
-
|
|
622
|
-
// Partial tail: handle remaining 1-3 points with vectorized partial deinterleave
|
|
623
462
|
if (i < n) {
|
|
624
|
-
nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(a + i * 3, n - i, &
|
|
625
|
-
nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(b + i * 3, n - i, &
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
463
|
+
nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(a + i * 3, n - i, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
|
|
464
|
+
nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(b + i * 3, n - i, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
|
|
465
|
+
sum_a_x_f32x4 = vaddq_f32(sum_a_x_f32x4, a_x_f32x4);
|
|
466
|
+
sum_a_y_f32x4 = vaddq_f32(sum_a_y_f32x4, a_y_f32x4);
|
|
467
|
+
sum_a_z_f32x4 = vaddq_f32(sum_a_z_f32x4, a_z_f32x4);
|
|
468
|
+
sum_b_x_f32x4 = vaddq_f32(sum_b_x_f32x4, b_x_f32x4);
|
|
469
|
+
sum_b_y_f32x4 = vaddq_f32(sum_b_y_f32x4, b_y_f32x4);
|
|
470
|
+
sum_b_z_f32x4 = vaddq_f32(sum_b_z_f32x4, b_z_f32x4);
|
|
471
|
+
covariance_xx_f32x4 = vfmaq_f32(covariance_xx_f32x4, a_x_f32x4, b_x_f32x4);
|
|
472
|
+
covariance_xy_f32x4 = vfmaq_f32(covariance_xy_f32x4, a_x_f32x4, b_y_f32x4);
|
|
473
|
+
covariance_xz_f32x4 = vfmaq_f32(covariance_xz_f32x4, a_x_f32x4, b_z_f32x4);
|
|
474
|
+
covariance_yx_f32x4 = vfmaq_f32(covariance_yx_f32x4, a_y_f32x4, b_x_f32x4);
|
|
475
|
+
covariance_yy_f32x4 = vfmaq_f32(covariance_yy_f32x4, a_y_f32x4, b_y_f32x4);
|
|
476
|
+
covariance_yz_f32x4 = vfmaq_f32(covariance_yz_f32x4, a_y_f32x4, b_z_f32x4);
|
|
477
|
+
covariance_zx_f32x4 = vfmaq_f32(covariance_zx_f32x4, a_z_f32x4, b_x_f32x4);
|
|
478
|
+
covariance_zy_f32x4 = vfmaq_f32(covariance_zy_f32x4, a_z_f32x4, b_y_f32x4);
|
|
479
|
+
covariance_zz_f32x4 = vfmaq_f32(covariance_zz_f32x4, a_z_f32x4, b_z_f32x4);
|
|
480
|
+
norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_x_f32x4, a_x_f32x4);
|
|
481
|
+
norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_y_f32x4, a_y_f32x4);
|
|
482
|
+
norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_z_f32x4, a_z_f32x4);
|
|
483
|
+
norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_x_f32x4, b_x_f32x4);
|
|
484
|
+
norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_y_f32x4, b_y_f32x4);
|
|
485
|
+
norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_z_f32x4, b_z_f32x4);
|
|
644
486
|
}
|
|
645
487
|
|
|
646
|
-
// Combine dual accumulators
|
|
647
|
-
float32x4_t sum_a_x_f32x4 = vaddq_f32(sum_a_x_a_f32x4, sum_a_x_b_f32x4);
|
|
648
|
-
float32x4_t sum_a_y_f32x4 = vaddq_f32(sum_a_y_a_f32x4, sum_a_y_b_f32x4);
|
|
649
|
-
float32x4_t sum_a_z_f32x4 = vaddq_f32(sum_a_z_a_f32x4, sum_a_z_b_f32x4);
|
|
650
|
-
float32x4_t sum_b_x_f32x4 = vaddq_f32(sum_b_x_a_f32x4, sum_b_x_b_f32x4);
|
|
651
|
-
float32x4_t sum_b_y_f32x4 = vaddq_f32(sum_b_y_a_f32x4, sum_b_y_b_f32x4);
|
|
652
|
-
float32x4_t sum_b_z_f32x4 = vaddq_f32(sum_b_z_a_f32x4, sum_b_z_b_f32x4);
|
|
653
|
-
float32x4_t cov_xx_f32x4 = vaddq_f32(cov_xx_a_f32x4, cov_xx_b_f32x4);
|
|
654
|
-
float32x4_t cov_xy_f32x4 = vaddq_f32(cov_xy_a_f32x4, cov_xy_b_f32x4);
|
|
655
|
-
float32x4_t cov_xz_f32x4 = vaddq_f32(cov_xz_a_f32x4, cov_xz_b_f32x4);
|
|
656
|
-
float32x4_t cov_yx_f32x4 = vaddq_f32(cov_yx_a_f32x4, cov_yx_b_f32x4);
|
|
657
|
-
float32x4_t cov_yy_f32x4 = vaddq_f32(cov_yy_a_f32x4, cov_yy_b_f32x4);
|
|
658
|
-
float32x4_t cov_yz_f32x4 = vaddq_f32(cov_yz_a_f32x4, cov_yz_b_f32x4);
|
|
659
|
-
float32x4_t cov_zx_f32x4 = vaddq_f32(cov_zx_a_f32x4, cov_zx_b_f32x4);
|
|
660
|
-
float32x4_t cov_zy_f32x4 = vaddq_f32(cov_zy_a_f32x4, cov_zy_b_f32x4);
|
|
661
|
-
float32x4_t cov_zz_f32x4 = vaddq_f32(cov_zz_a_f32x4, cov_zz_b_f32x4);
|
|
662
|
-
float32x4_t variance_a_f32x4 = vaddq_f32(variance_a_a_f32x4, variance_a_b_f32x4);
|
|
663
|
-
|
|
664
488
|
// Reduce vector accumulators
|
|
665
489
|
nk_f32_t sum_a_x = vaddvq_f32(sum_a_x_f32x4);
|
|
666
490
|
nk_f32_t sum_a_y = vaddvq_f32(sum_a_y_f32x4);
|
|
@@ -669,16 +493,17 @@ NK_PUBLIC void nk_umeyama_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
|
|
|
669
493
|
nk_f32_t sum_b_y = vaddvq_f32(sum_b_y_f32x4);
|
|
670
494
|
nk_f32_t sum_b_z = vaddvq_f32(sum_b_z_f32x4);
|
|
671
495
|
|
|
672
|
-
nk_f32_t covariance_x_x = vaddvq_f32(
|
|
673
|
-
nk_f32_t covariance_x_y = vaddvq_f32(
|
|
674
|
-
nk_f32_t covariance_x_z = vaddvq_f32(
|
|
675
|
-
nk_f32_t covariance_y_x = vaddvq_f32(
|
|
676
|
-
nk_f32_t covariance_y_y = vaddvq_f32(
|
|
677
|
-
nk_f32_t covariance_y_z = vaddvq_f32(
|
|
678
|
-
nk_f32_t covariance_z_x = vaddvq_f32(
|
|
679
|
-
nk_f32_t covariance_z_y = vaddvq_f32(
|
|
680
|
-
nk_f32_t covariance_z_z = vaddvq_f32(
|
|
681
|
-
nk_f32_t
|
|
496
|
+
nk_f32_t covariance_x_x = vaddvq_f32(covariance_xx_f32x4);
|
|
497
|
+
nk_f32_t covariance_x_y = vaddvq_f32(covariance_xy_f32x4);
|
|
498
|
+
nk_f32_t covariance_x_z = vaddvq_f32(covariance_xz_f32x4);
|
|
499
|
+
nk_f32_t covariance_y_x = vaddvq_f32(covariance_yx_f32x4);
|
|
500
|
+
nk_f32_t covariance_y_y = vaddvq_f32(covariance_yy_f32x4);
|
|
501
|
+
nk_f32_t covariance_y_z = vaddvq_f32(covariance_yz_f32x4);
|
|
502
|
+
nk_f32_t covariance_z_x = vaddvq_f32(covariance_zx_f32x4);
|
|
503
|
+
nk_f32_t covariance_z_y = vaddvq_f32(covariance_zy_f32x4);
|
|
504
|
+
nk_f32_t covariance_z_z = vaddvq_f32(covariance_zz_f32x4);
|
|
505
|
+
nk_f32_t norm_squared_a = vaddvq_f32(norm_squared_a_f32x4);
|
|
506
|
+
nk_f32_t norm_squared_b = vaddvq_f32(norm_squared_b_f32x4);
|
|
682
507
|
|
|
683
508
|
// Compute centroids
|
|
684
509
|
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
@@ -692,9 +517,15 @@ NK_PUBLIC void nk_umeyama_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
|
|
|
692
517
|
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
693
518
|
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
694
519
|
|
|
695
|
-
//
|
|
696
|
-
nk_f32_t
|
|
697
|
-
|
|
520
|
+
// Centered norm-squared via parallel-axis identity; clamp at zero for numeric safety.
|
|
521
|
+
nk_f32_t centered_norm_squared_a = norm_squared_a -
|
|
522
|
+
(nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
|
|
523
|
+
centroid_a_z * centroid_a_z);
|
|
524
|
+
nk_f32_t centered_norm_squared_b = norm_squared_b -
|
|
525
|
+
(nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
|
|
526
|
+
centroid_b_z * centroid_b_z);
|
|
527
|
+
if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
|
|
528
|
+
if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
|
|
698
529
|
|
|
699
530
|
// Apply centering correction: H_centered = H - n * centroid_a * centroid_bᵀ
|
|
700
531
|
covariance_x_x -= n * centroid_a_x * centroid_b_x;
|
|
@@ -710,49 +541,78 @@ NK_PUBLIC void nk_umeyama_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
|
|
|
710
541
|
// Compute SVD
|
|
711
542
|
nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
712
543
|
covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
|
|
713
|
-
nk_f32_t svd_u[9], svd_s[9], svd_v[9];
|
|
714
|
-
nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
|
|
715
|
-
|
|
716
|
-
// R = V * Uᵀ
|
|
717
|
-
nk_f32_t r[9];
|
|
718
|
-
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
719
|
-
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
720
|
-
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
721
|
-
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
722
|
-
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
723
|
-
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
724
|
-
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
725
|
-
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
726
|
-
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
727
|
-
|
|
728
|
-
// Handle reflection and compute scale: c = trace(D × S) / variance(a)
|
|
729
|
-
// D = diag(1, 1, det(R)), svd_s contains proper positive singular values on diagonal
|
|
730
|
-
nk_f32_t rotation_det = nk_det3x3_f32_(r);
|
|
731
|
-
nk_f32_t sign_det = rotation_det < 0 ? -1.0f : 1.0f;
|
|
732
|
-
nk_f32_t trace_scaled_s = svd_s[0] + svd_s[4] + sign_det * svd_s[8];
|
|
733
|
-
nk_f32_t c = trace_scaled_s / ((nk_f32_t)n * variance_a);
|
|
734
|
-
if (scale) *scale = c;
|
|
735
544
|
|
|
736
|
-
if (
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
545
|
+
// Identity-dominant short-circuit: if H ≈ diag(positive entries), R = I and trace(R·H) = trace(H).
|
|
546
|
+
nk_f32_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
|
|
547
|
+
cross_covariance[4] * cross_covariance[4] +
|
|
548
|
+
cross_covariance[8] * cross_covariance[8];
|
|
549
|
+
nk_f32_t covariance_offdiagonal_norm_squared =
|
|
550
|
+
cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
|
|
551
|
+
cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
|
|
552
|
+
cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
|
|
553
|
+
nk_f32_t optimal_rotation[9];
|
|
554
|
+
nk_f32_t trace_rotation_covariance;
|
|
555
|
+
nk_f32_t c;
|
|
556
|
+
if (covariance_offdiagonal_norm_squared < 1e-12f * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0f &&
|
|
557
|
+
cross_covariance[4] > 0.0f && cross_covariance[8] > 0.0f) {
|
|
558
|
+
optimal_rotation[0] = 1, optimal_rotation[1] = 0, optimal_rotation[2] = 0, optimal_rotation[3] = 0,
|
|
559
|
+
optimal_rotation[4] = 1, optimal_rotation[5] = 0, optimal_rotation[6] = 0, optimal_rotation[7] = 0,
|
|
560
|
+
optimal_rotation[8] = 1;
|
|
561
|
+
trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
|
|
562
|
+
c = centered_norm_squared_a > 0.0f ? trace_rotation_covariance / centered_norm_squared_a : 0.0f;
|
|
747
563
|
}
|
|
564
|
+
else {
|
|
565
|
+
nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
|
|
566
|
+
nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
|
|
567
|
+
|
|
568
|
+
// R = V * Uᵀ
|
|
569
|
+
optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
|
|
570
|
+
optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
|
|
571
|
+
optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
|
|
572
|
+
optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
|
|
573
|
+
optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
|
|
574
|
+
optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
|
|
575
|
+
optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
|
|
576
|
+
optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
|
|
577
|
+
optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
|
|
578
|
+
|
|
579
|
+
// Handle reflection and compute scale: c = trace(D · S) / ‖a-ā‖²
|
|
580
|
+
// D = diag(1, 1, det(R)), svd_diagonal contains proper positive singular values on diagonal
|
|
581
|
+
nk_f32_t rotation_det = nk_det3x3_f32_(optimal_rotation);
|
|
582
|
+
nk_f32_t sign_det = rotation_det < 0 ? -1.0f : 1.0f;
|
|
583
|
+
nk_f32_t trace_scaled_s = svd_diagonal[0] + svd_diagonal[4] + sign_det * svd_diagonal[8];
|
|
584
|
+
c = centered_norm_squared_a > 0.0f ? trace_scaled_s / centered_norm_squared_a : 0.0f;
|
|
585
|
+
|
|
586
|
+
if (rotation_det < 0) {
|
|
587
|
+
svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
|
|
588
|
+
optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
|
|
589
|
+
optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
|
|
590
|
+
optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
|
|
591
|
+
optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
|
|
592
|
+
optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
|
|
593
|
+
optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
|
|
594
|
+
optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
|
|
595
|
+
optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
|
|
596
|
+
optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
|
|
597
|
+
}
|
|
598
|
+
|
|
599
|
+
trace_rotation_covariance =
|
|
600
|
+
optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
|
|
601
|
+
optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
|
|
602
|
+
optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
|
|
603
|
+
optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
|
|
604
|
+
optimal_rotation[8] * cross_covariance[8];
|
|
605
|
+
}
|
|
606
|
+
if (scale) *scale = c;
|
|
748
607
|
|
|
749
608
|
// Output rotation matrix
|
|
750
609
|
if (rotation)
|
|
751
|
-
for (int j = 0; j < 9; ++j) rotation[j] =
|
|
610
|
+
for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
|
|
752
611
|
|
|
753
|
-
//
|
|
754
|
-
nk_f32_t sum_squared =
|
|
755
|
-
|
|
612
|
+
// Folded SSD with scale: c²·‖a-ā‖² + ‖b-b̄‖² − 2c·trace(R · H_centered).
|
|
613
|
+
nk_f32_t sum_squared = c * c * centered_norm_squared_a + centered_norm_squared_b -
|
|
614
|
+
2.0f * c * trace_rotation_covariance;
|
|
615
|
+
if (sum_squared < 0.0f) sum_squared = 0.0f;
|
|
756
616
|
*result = nk_f32_sqrt_neon(sum_squared * inv_n);
|
|
757
617
|
}
|
|
758
618
|
|