numkong 7.5.0 → 7.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/binding.gyp +18 -0
  2. package/c/dispatch_e5m2.c +23 -3
  3. package/include/numkong/capabilities.h +1 -1
  4. package/include/numkong/cast/README.md +3 -0
  5. package/include/numkong/cast/haswell.h +28 -64
  6. package/include/numkong/cast/serial.h +17 -0
  7. package/include/numkong/cast/skylake.h +67 -52
  8. package/include/numkong/cast.h +1 -0
  9. package/include/numkong/dot/README.md +1 -0
  10. package/include/numkong/dot/haswell.h +92 -13
  11. package/include/numkong/dot/serial.h +15 -0
  12. package/include/numkong/dot/skylake.h +61 -14
  13. package/include/numkong/dots/README.md +2 -0
  14. package/include/numkong/dots/graniteamx.h +434 -0
  15. package/include/numkong/dots/haswell.h +28 -28
  16. package/include/numkong/dots/sapphireamx.h +1 -1
  17. package/include/numkong/dots/serial.h +23 -8
  18. package/include/numkong/dots/skylake.h +28 -23
  19. package/include/numkong/dots.h +12 -0
  20. package/include/numkong/each/serial.h +18 -1
  21. package/include/numkong/geospatial/serial.h +14 -3
  22. package/include/numkong/maxsim/serial.h +15 -0
  23. package/include/numkong/mesh/README.md +50 -44
  24. package/include/numkong/mesh/genoa.h +462 -0
  25. package/include/numkong/mesh/haswell.h +806 -933
  26. package/include/numkong/mesh/neon.h +871 -943
  27. package/include/numkong/mesh/neonbfdot.h +382 -522
  28. package/include/numkong/mesh/neonfhm.h +676 -0
  29. package/include/numkong/mesh/rvv.h +404 -319
  30. package/include/numkong/mesh/serial.h +204 -162
  31. package/include/numkong/mesh/skylake.h +1029 -1585
  32. package/include/numkong/mesh/v128relaxed.h +403 -377
  33. package/include/numkong/mesh.h +38 -0
  34. package/include/numkong/reduce/serial.h +15 -1
  35. package/include/numkong/sparse/serial.h +17 -2
  36. package/include/numkong/spatial/genoa.h +0 -68
  37. package/include/numkong/spatial/haswell.h +98 -56
  38. package/include/numkong/spatial/serial.h +15 -0
  39. package/include/numkong/spatial/skylake.h +114 -54
  40. package/include/numkong/spatial.h +0 -12
  41. package/include/numkong/spatials/graniteamx.h +128 -0
  42. package/include/numkong/spatials/serial.h +18 -1
  43. package/include/numkong/spatials/skylake.h +2 -2
  44. package/include/numkong/spatials.h +17 -0
  45. package/include/numkong/tensor.hpp +107 -23
  46. package/javascript/numkong.c +3 -2
  47. package/package.json +7 -7
  48. package/wasm/numkong.wasm +0 -0
@@ -81,10 +81,6 @@ NK_INTERNAL nk_f64_t nk_reduce_stable_f64x2_neon_(float64x2_t values_f64x2) {
81
81
  return sum + compensation;
82
82
  }
83
83
 
84
- NK_INTERNAL void nk_rotation_from_svd_f64_neon_(nk_f64_t const *svd_u, nk_f64_t const *svd_v, nk_f64_t *rotation) {
85
- nk_rotation_from_svd_f64_serial_(svd_u, svd_v, rotation);
86
- }
87
-
88
84
  NK_INTERNAL void nk_accumulate_square_f64x2_neon_(float64x2_t *sum_f64x2, float64x2_t *compensation_f64x2,
89
85
  float64x2_t values_f64x2) {
90
86
  float64x2_t product_f64x2 = vmulq_f64(values_f64x2, values_f64x2);
@@ -97,260 +93,6 @@ NK_INTERNAL void nk_accumulate_square_f64x2_neon_(float64x2_t *sum_f64x2, float6
97
93
  *compensation_f64x2 = vaddq_f64(*compensation_f64x2, vaddq_f64(sum_error_f64x2, product_error_f64x2));
98
94
  }
99
95
 
100
- NK_INTERNAL nk_f64_t nk_transformed_ssd_f32_neon_( //
101
- nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t const *r, nk_f64_t scale, nk_f64_t centroid_a_x,
102
- nk_f64_t centroid_a_y, nk_f64_t centroid_a_z, nk_f64_t centroid_b_x, nk_f64_t centroid_b_y, nk_f64_t centroid_b_z) {
103
- float64x2_t scaled_rotation_x_x_f64x2 = vdupq_n_f64(scale * r[0]);
104
- float64x2_t scaled_rotation_x_y_f64x2 = vdupq_n_f64(scale * r[1]);
105
- float64x2_t scaled_rotation_x_z_f64x2 = vdupq_n_f64(scale * r[2]);
106
- float64x2_t scaled_rotation_y_x_f64x2 = vdupq_n_f64(scale * r[3]);
107
- float64x2_t scaled_rotation_y_y_f64x2 = vdupq_n_f64(scale * r[4]);
108
- float64x2_t scaled_rotation_y_z_f64x2 = vdupq_n_f64(scale * r[5]);
109
- float64x2_t scaled_rotation_z_x_f64x2 = vdupq_n_f64(scale * r[6]);
110
- float64x2_t scaled_rotation_z_y_f64x2 = vdupq_n_f64(scale * r[7]);
111
- float64x2_t scaled_rotation_z_z_f64x2 = vdupq_n_f64(scale * r[8]);
112
- float64x2_t centroid_a_x_f64x2 = vdupq_n_f64(centroid_a_x), centroid_a_y_f64x2 = vdupq_n_f64(centroid_a_y);
113
- float64x2_t centroid_a_z_f64x2 = vdupq_n_f64(centroid_a_z), centroid_b_x_f64x2 = vdupq_n_f64(centroid_b_x);
114
- float64x2_t centroid_b_y_f64x2 = vdupq_n_f64(centroid_b_y), centroid_b_z_f64x2 = vdupq_n_f64(centroid_b_z);
115
- float64x2_t sum_squared_low_f64x2 = vdupq_n_f64(0.0), sum_squared_high_f64x2 = vdupq_n_f64(0.0);
116
- nk_size_t index = 0;
117
-
118
- for (; index + 4 <= n; index += 4) {
119
- float32x4_t a_x_f32x4, a_y_f32x4, a_z_f32x4, b_x_f32x4, b_y_f32x4, b_z_f32x4;
120
- nk_deinterleave_f32x4_neon_(a + index * 3, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4),
121
- nk_deinterleave_f32x4_neon_(b + index * 3, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
122
-
123
- float64x2_t centered_a_x_low_f64x2 = vsubq_f64(vcvt_f64_f32(vget_low_f32(a_x_f32x4)), centroid_a_x_f64x2);
124
- float64x2_t centered_a_x_high_f64x2 = vsubq_f64(vcvt_high_f64_f32(a_x_f32x4), centroid_a_x_f64x2);
125
- float64x2_t centered_a_y_low_f64x2 = vsubq_f64(vcvt_f64_f32(vget_low_f32(a_y_f32x4)), centroid_a_y_f64x2);
126
- float64x2_t centered_a_y_high_f64x2 = vsubq_f64(vcvt_high_f64_f32(a_y_f32x4), centroid_a_y_f64x2);
127
- float64x2_t centered_a_z_low_f64x2 = vsubq_f64(vcvt_f64_f32(vget_low_f32(a_z_f32x4)), centroid_a_z_f64x2);
128
- float64x2_t centered_a_z_high_f64x2 = vsubq_f64(vcvt_high_f64_f32(a_z_f32x4), centroid_a_z_f64x2);
129
- float64x2_t centered_b_x_low_f64x2 = vsubq_f64(vcvt_f64_f32(vget_low_f32(b_x_f32x4)), centroid_b_x_f64x2);
130
- float64x2_t centered_b_x_high_f64x2 = vsubq_f64(vcvt_high_f64_f32(b_x_f32x4), centroid_b_x_f64x2);
131
- float64x2_t centered_b_y_low_f64x2 = vsubq_f64(vcvt_f64_f32(vget_low_f32(b_y_f32x4)), centroid_b_y_f64x2);
132
- float64x2_t centered_b_y_high_f64x2 = vsubq_f64(vcvt_high_f64_f32(b_y_f32x4), centroid_b_y_f64x2);
133
- float64x2_t centered_b_z_low_f64x2 = vsubq_f64(vcvt_f64_f32(vget_low_f32(b_z_f32x4)), centroid_b_z_f64x2);
134
- float64x2_t centered_b_z_high_f64x2 = vsubq_f64(vcvt_high_f64_f32(b_z_f32x4), centroid_b_z_f64x2);
135
-
136
- float64x2_t rotated_a_x_low_f64x2 = vfmaq_f64(
137
- vfmaq_f64(vmulq_f64(scaled_rotation_x_x_f64x2, centered_a_x_low_f64x2), scaled_rotation_x_y_f64x2,
138
- centered_a_y_low_f64x2),
139
- scaled_rotation_x_z_f64x2, centered_a_z_low_f64x2);
140
- float64x2_t rotated_a_x_high_f64x2 = vfmaq_f64(
141
- vfmaq_f64(vmulq_f64(scaled_rotation_x_x_f64x2, centered_a_x_high_f64x2), scaled_rotation_x_y_f64x2,
142
- centered_a_y_high_f64x2),
143
- scaled_rotation_x_z_f64x2, centered_a_z_high_f64x2);
144
- float64x2_t rotated_a_y_low_f64x2 = vfmaq_f64(
145
- vfmaq_f64(vmulq_f64(scaled_rotation_y_x_f64x2, centered_a_x_low_f64x2), scaled_rotation_y_y_f64x2,
146
- centered_a_y_low_f64x2),
147
- scaled_rotation_y_z_f64x2, centered_a_z_low_f64x2);
148
- float64x2_t rotated_a_y_high_f64x2 = vfmaq_f64(
149
- vfmaq_f64(vmulq_f64(scaled_rotation_y_x_f64x2, centered_a_x_high_f64x2), scaled_rotation_y_y_f64x2,
150
- centered_a_y_high_f64x2),
151
- scaled_rotation_y_z_f64x2, centered_a_z_high_f64x2);
152
- float64x2_t rotated_a_z_low_f64x2 = vfmaq_f64(
153
- vfmaq_f64(vmulq_f64(scaled_rotation_z_x_f64x2, centered_a_x_low_f64x2), scaled_rotation_z_y_f64x2,
154
- centered_a_y_low_f64x2),
155
- scaled_rotation_z_z_f64x2, centered_a_z_low_f64x2);
156
- float64x2_t rotated_a_z_high_f64x2 = vfmaq_f64(
157
- vfmaq_f64(vmulq_f64(scaled_rotation_z_x_f64x2, centered_a_x_high_f64x2), scaled_rotation_z_y_f64x2,
158
- centered_a_y_high_f64x2),
159
- scaled_rotation_z_z_f64x2, centered_a_z_high_f64x2);
160
-
161
- float64x2_t delta_x_low_f64x2 = vsubq_f64(rotated_a_x_low_f64x2, centered_b_x_low_f64x2);
162
- float64x2_t delta_x_high_f64x2 = vsubq_f64(rotated_a_x_high_f64x2, centered_b_x_high_f64x2);
163
- float64x2_t delta_y_low_f64x2 = vsubq_f64(rotated_a_y_low_f64x2, centered_b_y_low_f64x2);
164
- float64x2_t delta_y_high_f64x2 = vsubq_f64(rotated_a_y_high_f64x2, centered_b_y_high_f64x2);
165
- float64x2_t delta_z_low_f64x2 = vsubq_f64(rotated_a_z_low_f64x2, centered_b_z_low_f64x2);
166
- float64x2_t delta_z_high_f64x2 = vsubq_f64(rotated_a_z_high_f64x2, centered_b_z_high_f64x2);
167
-
168
- sum_squared_low_f64x2 = vfmaq_f64(sum_squared_low_f64x2, delta_x_low_f64x2, delta_x_low_f64x2),
169
- sum_squared_high_f64x2 = vfmaq_f64(sum_squared_high_f64x2, delta_x_high_f64x2, delta_x_high_f64x2);
170
- sum_squared_low_f64x2 = vfmaq_f64(sum_squared_low_f64x2, delta_y_low_f64x2, delta_y_low_f64x2),
171
- sum_squared_high_f64x2 = vfmaq_f64(sum_squared_high_f64x2, delta_y_high_f64x2, delta_y_high_f64x2);
172
- sum_squared_low_f64x2 = vfmaq_f64(sum_squared_low_f64x2, delta_z_low_f64x2, delta_z_low_f64x2),
173
- sum_squared_high_f64x2 = vfmaq_f64(sum_squared_high_f64x2, delta_z_high_f64x2, delta_z_high_f64x2);
174
- }
175
-
176
- nk_f64_t sum_squared = vaddvq_f64(vaddq_f64(sum_squared_low_f64x2, sum_squared_high_f64x2));
177
- for (; index < n; ++index) {
178
- nk_f64_t centered_a_x = (nk_f64_t)a[index * 3 + 0] - centroid_a_x,
179
- centered_a_y = (nk_f64_t)a[index * 3 + 1] - centroid_a_y,
180
- centered_a_z = (nk_f64_t)a[index * 3 + 2] - centroid_a_z;
181
- nk_f64_t centered_b_x = (nk_f64_t)b[index * 3 + 0] - centroid_b_x,
182
- centered_b_y = (nk_f64_t)b[index * 3 + 1] - centroid_b_y,
183
- centered_b_z = (nk_f64_t)b[index * 3 + 2] - centroid_b_z;
184
- nk_f64_t rotated_a_x = scale * (r[0] * centered_a_x + r[1] * centered_a_y + r[2] * centered_a_z),
185
- rotated_a_y = scale * (r[3] * centered_a_x + r[4] * centered_a_y + r[5] * centered_a_z),
186
- rotated_a_z = scale * (r[6] * centered_a_x + r[7] * centered_a_y + r[8] * centered_a_z);
187
- nk_f64_t delta_x = rotated_a_x - centered_b_x, delta_y = rotated_a_y - centered_b_y,
188
- delta_z = rotated_a_z - centered_b_z;
189
- sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
190
- }
191
-
192
- return sum_squared;
193
- }
194
-
195
- /* Compute sum of squared distances for f64 after applying rotation (and optional scale).
196
- *
197
- * Optimization: 2x loop unrolling with multiple accumulators hides FMA latency (3-7 cycles).
198
- */
199
- NK_INTERNAL nk_f64_t nk_transformed_ssd_f64_neon_(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t const *r,
200
- nk_f64_t scale, nk_f64_t centroid_a_x, nk_f64_t centroid_a_y,
201
- nk_f64_t centroid_a_z, nk_f64_t centroid_b_x, nk_f64_t centroid_b_y,
202
- nk_f64_t centroid_b_z) {
203
- // Broadcast scaled rotation matrix elements
204
- float64x2_t scaled_rotation_x_x_f64x2 = vdupq_n_f64(scale * r[0]);
205
- float64x2_t scaled_rotation_x_y_f64x2 = vdupq_n_f64(scale * r[1]);
206
- float64x2_t scaled_rotation_x_z_f64x2 = vdupq_n_f64(scale * r[2]);
207
- float64x2_t scaled_rotation_y_x_f64x2 = vdupq_n_f64(scale * r[3]);
208
- float64x2_t scaled_rotation_y_y_f64x2 = vdupq_n_f64(scale * r[4]);
209
- float64x2_t scaled_rotation_y_z_f64x2 = vdupq_n_f64(scale * r[5]);
210
- float64x2_t scaled_rotation_z_x_f64x2 = vdupq_n_f64(scale * r[6]);
211
- float64x2_t scaled_rotation_z_y_f64x2 = vdupq_n_f64(scale * r[7]);
212
- float64x2_t scaled_rotation_z_z_f64x2 = vdupq_n_f64(scale * r[8]);
213
-
214
- // Broadcast centroids
215
- float64x2_t centroid_a_x_f64x2 = vdupq_n_f64(centroid_a_x);
216
- float64x2_t centroid_a_y_f64x2 = vdupq_n_f64(centroid_a_y);
217
- float64x2_t centroid_a_z_f64x2 = vdupq_n_f64(centroid_a_z);
218
- float64x2_t centroid_b_x_f64x2 = vdupq_n_f64(centroid_b_x);
219
- float64x2_t centroid_b_y_f64x2 = vdupq_n_f64(centroid_b_y);
220
- float64x2_t centroid_b_z_f64x2 = vdupq_n_f64(centroid_b_z);
221
-
222
- // Two independent accumulators to hide FMA latency
223
- float64x2_t sum_squared_a_f64x2 = vdupq_n_f64(0), sum_squared_a_compensation_f64x2 = vdupq_n_f64(0);
224
- float64x2_t sum_squared_b_f64x2 = vdupq_n_f64(0), sum_squared_b_compensation_f64x2 = vdupq_n_f64(0);
225
- nk_size_t j = 0;
226
-
227
- // Main loop: process 4 points per iteration (2x unrolled, 2 points per batch)
228
- for (; j + 4 <= n; j += 4) {
229
- // First batch of 2 points
230
- float64x2_t a1_x_f64x2, a1_y_f64x2, a1_z_f64x2, b1_x_f64x2, b1_y_f64x2, b1_z_f64x2;
231
- nk_deinterleave_f64x2_neon_(a + j * 3, &a1_x_f64x2, &a1_y_f64x2, &a1_z_f64x2);
232
- nk_deinterleave_f64x2_neon_(b + j * 3, &b1_x_f64x2, &b1_y_f64x2, &b1_z_f64x2);
233
-
234
- // Second batch of 2 points
235
- float64x2_t a2_x_f64x2, a2_y_f64x2, a2_z_f64x2, b2_x_f64x2, b2_y_f64x2, b2_z_f64x2;
236
- nk_deinterleave_f64x2_neon_(a + (j + 2) * 3, &a2_x_f64x2, &a2_y_f64x2, &a2_z_f64x2);
237
- nk_deinterleave_f64x2_neon_(b + (j + 2) * 3, &b2_x_f64x2, &b2_y_f64x2, &b2_z_f64x2);
238
-
239
- // Center first batch
240
- float64x2_t centered_a1_x_f64x2 = vsubq_f64(a1_x_f64x2, centroid_a_x_f64x2);
241
- float64x2_t centered_a1_y_f64x2 = vsubq_f64(a1_y_f64x2, centroid_a_y_f64x2);
242
- float64x2_t centered_a1_z_f64x2 = vsubq_f64(a1_z_f64x2, centroid_a_z_f64x2);
243
- float64x2_t centered_b1_x_f64x2 = vsubq_f64(b1_x_f64x2, centroid_b_x_f64x2);
244
- float64x2_t centered_b1_y_f64x2 = vsubq_f64(b1_y_f64x2, centroid_b_y_f64x2);
245
- float64x2_t centered_b1_z_f64x2 = vsubq_f64(b1_z_f64x2, centroid_b_z_f64x2);
246
-
247
- // Center second batch
248
- float64x2_t centered_a2_x_f64x2 = vsubq_f64(a2_x_f64x2, centroid_a_x_f64x2);
249
- float64x2_t centered_a2_y_f64x2 = vsubq_f64(a2_y_f64x2, centroid_a_y_f64x2);
250
- float64x2_t centered_a2_z_f64x2 = vsubq_f64(a2_z_f64x2, centroid_a_z_f64x2);
251
- float64x2_t centered_b2_x_f64x2 = vsubq_f64(b2_x_f64x2, centroid_b_x_f64x2);
252
- float64x2_t centered_b2_y_f64x2 = vsubq_f64(b2_y_f64x2, centroid_b_y_f64x2);
253
- float64x2_t centered_b2_z_f64x2 = vsubq_f64(b2_z_f64x2, centroid_b_z_f64x2);
254
-
255
- // Rotate and scale first batch
256
- float64x2_t rotated_a1_x_f64x2 = vfmaq_f64(vfmaq_f64(vmulq_f64(scaled_rotation_x_x_f64x2, centered_a1_x_f64x2),
257
- scaled_rotation_x_y_f64x2, centered_a1_y_f64x2),
258
- scaled_rotation_x_z_f64x2, centered_a1_z_f64x2);
259
- float64x2_t rotated_a1_y_f64x2 = vfmaq_f64(vfmaq_f64(vmulq_f64(scaled_rotation_y_x_f64x2, centered_a1_x_f64x2),
260
- scaled_rotation_y_y_f64x2, centered_a1_y_f64x2),
261
- scaled_rotation_y_z_f64x2, centered_a1_z_f64x2);
262
- float64x2_t rotated_a1_z_f64x2 = vfmaq_f64(vfmaq_f64(vmulq_f64(scaled_rotation_z_x_f64x2, centered_a1_x_f64x2),
263
- scaled_rotation_z_y_f64x2, centered_a1_y_f64x2),
264
- scaled_rotation_z_z_f64x2, centered_a1_z_f64x2);
265
-
266
- // Rotate and scale second batch
267
- float64x2_t rotated_a2_x_f64x2 = vfmaq_f64(vfmaq_f64(vmulq_f64(scaled_rotation_x_x_f64x2, centered_a2_x_f64x2),
268
- scaled_rotation_x_y_f64x2, centered_a2_y_f64x2),
269
- scaled_rotation_x_z_f64x2, centered_a2_z_f64x2);
270
- float64x2_t rotated_a2_y_f64x2 = vfmaq_f64(vfmaq_f64(vmulq_f64(scaled_rotation_y_x_f64x2, centered_a2_x_f64x2),
271
- scaled_rotation_y_y_f64x2, centered_a2_y_f64x2),
272
- scaled_rotation_y_z_f64x2, centered_a2_z_f64x2);
273
- float64x2_t rotated_a2_z_f64x2 = vfmaq_f64(vfmaq_f64(vmulq_f64(scaled_rotation_z_x_f64x2, centered_a2_x_f64x2),
274
- scaled_rotation_z_y_f64x2, centered_a2_y_f64x2),
275
- scaled_rotation_z_z_f64x2, centered_a2_z_f64x2);
276
-
277
- // Deltas
278
- float64x2_t delta1_x_f64x2 = vsubq_f64(rotated_a1_x_f64x2, centered_b1_x_f64x2);
279
- float64x2_t delta1_y_f64x2 = vsubq_f64(rotated_a1_y_f64x2, centered_b1_y_f64x2);
280
- float64x2_t delta1_z_f64x2 = vsubq_f64(rotated_a1_z_f64x2, centered_b1_z_f64x2);
281
- float64x2_t delta2_x_f64x2 = vsubq_f64(rotated_a2_x_f64x2, centered_b2_x_f64x2);
282
- float64x2_t delta2_y_f64x2 = vsubq_f64(rotated_a2_y_f64x2, centered_b2_y_f64x2);
283
- float64x2_t delta2_z_f64x2 = vsubq_f64(rotated_a2_z_f64x2, centered_b2_z_f64x2);
284
-
285
- // Accumulate to independent accumulators (interleaved for latency hiding)
286
- nk_accumulate_square_f64x2_neon_(&sum_squared_a_f64x2, &sum_squared_a_compensation_f64x2, delta1_x_f64x2);
287
- nk_accumulate_square_f64x2_neon_(&sum_squared_b_f64x2, &sum_squared_b_compensation_f64x2, delta2_x_f64x2);
288
- nk_accumulate_square_f64x2_neon_(&sum_squared_a_f64x2, &sum_squared_a_compensation_f64x2, delta1_y_f64x2);
289
- nk_accumulate_square_f64x2_neon_(&sum_squared_b_f64x2, &sum_squared_b_compensation_f64x2, delta2_y_f64x2);
290
- nk_accumulate_square_f64x2_neon_(&sum_squared_a_f64x2, &sum_squared_a_compensation_f64x2, delta1_z_f64x2);
291
- nk_accumulate_square_f64x2_neon_(&sum_squared_b_f64x2, &sum_squared_b_compensation_f64x2, delta2_z_f64x2);
292
- }
293
-
294
- // Handle remaining 2 points
295
- if (j + 2 <= n) {
296
- float64x2_t a_x_f64x2, a_y_f64x2, a_z_f64x2, b_x_f64x2, b_y_f64x2, b_z_f64x2;
297
- nk_deinterleave_f64x2_neon_(a + j * 3, &a_x_f64x2, &a_y_f64x2, &a_z_f64x2);
298
- nk_deinterleave_f64x2_neon_(b + j * 3, &b_x_f64x2, &b_y_f64x2, &b_z_f64x2);
299
-
300
- float64x2_t centered_a_x_f64x2 = vsubq_f64(a_x_f64x2, centroid_a_x_f64x2);
301
- float64x2_t centered_a_y_f64x2 = vsubq_f64(a_y_f64x2, centroid_a_y_f64x2);
302
- float64x2_t centered_a_z_f64x2 = vsubq_f64(a_z_f64x2, centroid_a_z_f64x2);
303
- float64x2_t centered_b_x_f64x2 = vsubq_f64(b_x_f64x2, centroid_b_x_f64x2);
304
- float64x2_t centered_b_y_f64x2 = vsubq_f64(b_y_f64x2, centroid_b_y_f64x2);
305
- float64x2_t centered_b_z_f64x2 = vsubq_f64(b_z_f64x2, centroid_b_z_f64x2);
306
-
307
- float64x2_t rotated_a_x_f64x2 = vfmaq_f64(vfmaq_f64(vmulq_f64(scaled_rotation_x_x_f64x2, centered_a_x_f64x2),
308
- scaled_rotation_x_y_f64x2, centered_a_y_f64x2),
309
- scaled_rotation_x_z_f64x2, centered_a_z_f64x2);
310
- float64x2_t rotated_a_y_f64x2 = vfmaq_f64(vfmaq_f64(vmulq_f64(scaled_rotation_y_x_f64x2, centered_a_x_f64x2),
311
- scaled_rotation_y_y_f64x2, centered_a_y_f64x2),
312
- scaled_rotation_y_z_f64x2, centered_a_z_f64x2);
313
- float64x2_t rotated_a_z_f64x2 = vfmaq_f64(vfmaq_f64(vmulq_f64(scaled_rotation_z_x_f64x2, centered_a_x_f64x2),
314
- scaled_rotation_z_y_f64x2, centered_a_y_f64x2),
315
- scaled_rotation_z_z_f64x2, centered_a_z_f64x2);
316
-
317
- float64x2_t delta_x_f64x2 = vsubq_f64(rotated_a_x_f64x2, centered_b_x_f64x2);
318
- float64x2_t delta_y_f64x2 = vsubq_f64(rotated_a_y_f64x2, centered_b_y_f64x2);
319
- float64x2_t delta_z_f64x2 = vsubq_f64(rotated_a_z_f64x2, centered_b_z_f64x2);
320
-
321
- nk_accumulate_square_f64x2_neon_(&sum_squared_a_f64x2, &sum_squared_a_compensation_f64x2, delta_x_f64x2);
322
- nk_accumulate_square_f64x2_neon_(&sum_squared_a_f64x2, &sum_squared_a_compensation_f64x2, delta_y_f64x2);
323
- nk_accumulate_square_f64x2_neon_(&sum_squared_a_f64x2, &sum_squared_a_compensation_f64x2, delta_z_f64x2);
324
- j += 2;
325
- }
326
-
327
- // Combine accumulators and reduce
328
- float64x2_t sum_squared_f64x2 = vaddq_f64(sum_squared_a_f64x2, sum_squared_b_f64x2);
329
- float64x2_t sum_squared_compensation_f64x2 = vaddq_f64(sum_squared_a_compensation_f64x2,
330
- sum_squared_b_compensation_f64x2);
331
- nk_f64_t sum_squared = nk_dot_stable_sum_f64x2_neon_(sum_squared_f64x2, sum_squared_compensation_f64x2);
332
- nk_f64_t sum_squared_compensation = 0.0;
333
-
334
- // Scalar tail
335
- for (; j < n; ++j) {
336
- nk_f64_t pa_x = a[j * 3 + 0] - centroid_a_x, pa_y = a[j * 3 + 1] - centroid_a_y,
337
- pa_z = a[j * 3 + 2] - centroid_a_z;
338
- nk_f64_t pb_x = b[j * 3 + 0] - centroid_b_x, pb_y = b[j * 3 + 1] - centroid_b_y,
339
- pb_z = b[j * 3 + 2] - centroid_b_z;
340
-
341
- nk_f64_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z),
342
- ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z),
343
- ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
344
-
345
- nk_f64_t delta_x = ra_x - pb_x, delta_y = ra_y - pb_y, delta_z = ra_z - pb_z;
346
- nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_x);
347
- nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_y);
348
- nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_z);
349
- }
350
-
351
- return sum_squared + sum_squared_compensation;
352
- }
353
-
354
96
  NK_PUBLIC void nk_rmsd_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
355
97
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
356
98
  if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
@@ -463,15 +205,17 @@ NK_PUBLIC void nk_kabsch_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_
463
205
  float64x2_t sum_b_z_low_f64x2 = zero_f64x2, sum_b_z_high_f64x2 = zero_f64x2;
464
206
 
465
207
  // Covariance accumulators (f64, lower/upper halves)
466
- float64x2_t cov_xx_low_f64x2 = zero_f64x2, cov_xx_high_f64x2 = zero_f64x2;
467
- float64x2_t cov_xy_low_f64x2 = zero_f64x2, cov_xy_high_f64x2 = zero_f64x2;
468
- float64x2_t cov_xz_low_f64x2 = zero_f64x2, cov_xz_high_f64x2 = zero_f64x2;
469
- float64x2_t cov_yx_low_f64x2 = zero_f64x2, cov_yx_high_f64x2 = zero_f64x2;
470
- float64x2_t cov_yy_low_f64x2 = zero_f64x2, cov_yy_high_f64x2 = zero_f64x2;
471
- float64x2_t cov_yz_low_f64x2 = zero_f64x2, cov_yz_high_f64x2 = zero_f64x2;
472
- float64x2_t cov_zx_low_f64x2 = zero_f64x2, cov_zx_high_f64x2 = zero_f64x2;
473
- float64x2_t cov_zy_low_f64x2 = zero_f64x2, cov_zy_high_f64x2 = zero_f64x2;
474
- float64x2_t cov_zz_low_f64x2 = zero_f64x2, cov_zz_high_f64x2 = zero_f64x2;
208
+ float64x2_t covariance_xx_low_f64x2 = zero_f64x2, covariance_xx_high_f64x2 = zero_f64x2;
209
+ float64x2_t covariance_xy_low_f64x2 = zero_f64x2, covariance_xy_high_f64x2 = zero_f64x2;
210
+ float64x2_t covariance_xz_low_f64x2 = zero_f64x2, covariance_xz_high_f64x2 = zero_f64x2;
211
+ float64x2_t covariance_yx_low_f64x2 = zero_f64x2, covariance_yx_high_f64x2 = zero_f64x2;
212
+ float64x2_t covariance_yy_low_f64x2 = zero_f64x2, covariance_yy_high_f64x2 = zero_f64x2;
213
+ float64x2_t covariance_yz_low_f64x2 = zero_f64x2, covariance_yz_high_f64x2 = zero_f64x2;
214
+ float64x2_t covariance_zx_low_f64x2 = zero_f64x2, covariance_zx_high_f64x2 = zero_f64x2;
215
+ float64x2_t covariance_zy_low_f64x2 = zero_f64x2, covariance_zy_high_f64x2 = zero_f64x2;
216
+ float64x2_t covariance_zz_low_f64x2 = zero_f64x2, covariance_zz_high_f64x2 = zero_f64x2;
217
+ float64x2_t norm_squared_a_low_f64x2 = zero_f64x2, norm_squared_a_high_f64x2 = zero_f64x2;
218
+ float64x2_t norm_squared_b_low_f64x2 = zero_f64x2, norm_squared_b_high_f64x2 = zero_f64x2;
475
219
 
476
220
  nk_size_t index = 0;
477
221
  for (; index + 4 <= n; index += 4) {
@@ -507,24 +251,36 @@ NK_PUBLIC void nk_kabsch_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_
507
251
  sum_b_z_high_f64x2 = vaddq_f64(sum_b_z_high_f64x2, b_z_high_f64x2);
508
252
 
509
253
  // Accumulate raw outer products (uncentered)
510
- cov_xx_low_f64x2 = vfmaq_f64(cov_xx_low_f64x2, a_x_low_f64x2, b_x_low_f64x2),
511
- cov_xx_high_f64x2 = vfmaq_f64(cov_xx_high_f64x2, a_x_high_f64x2, b_x_high_f64x2);
512
- cov_xy_low_f64x2 = vfmaq_f64(cov_xy_low_f64x2, a_x_low_f64x2, b_y_low_f64x2),
513
- cov_xy_high_f64x2 = vfmaq_f64(cov_xy_high_f64x2, a_x_high_f64x2, b_y_high_f64x2);
514
- cov_xz_low_f64x2 = vfmaq_f64(cov_xz_low_f64x2, a_x_low_f64x2, b_z_low_f64x2),
515
- cov_xz_high_f64x2 = vfmaq_f64(cov_xz_high_f64x2, a_x_high_f64x2, b_z_high_f64x2);
516
- cov_yx_low_f64x2 = vfmaq_f64(cov_yx_low_f64x2, a_y_low_f64x2, b_x_low_f64x2),
517
- cov_yx_high_f64x2 = vfmaq_f64(cov_yx_high_f64x2, a_y_high_f64x2, b_x_high_f64x2);
518
- cov_yy_low_f64x2 = vfmaq_f64(cov_yy_low_f64x2, a_y_low_f64x2, b_y_low_f64x2),
519
- cov_yy_high_f64x2 = vfmaq_f64(cov_yy_high_f64x2, a_y_high_f64x2, b_y_high_f64x2);
520
- cov_yz_low_f64x2 = vfmaq_f64(cov_yz_low_f64x2, a_y_low_f64x2, b_z_low_f64x2),
521
- cov_yz_high_f64x2 = vfmaq_f64(cov_yz_high_f64x2, a_y_high_f64x2, b_z_high_f64x2);
522
- cov_zx_low_f64x2 = vfmaq_f64(cov_zx_low_f64x2, a_z_low_f64x2, b_x_low_f64x2),
523
- cov_zx_high_f64x2 = vfmaq_f64(cov_zx_high_f64x2, a_z_high_f64x2, b_x_high_f64x2);
524
- cov_zy_low_f64x2 = vfmaq_f64(cov_zy_low_f64x2, a_z_low_f64x2, b_y_low_f64x2),
525
- cov_zy_high_f64x2 = vfmaq_f64(cov_zy_high_f64x2, a_z_high_f64x2, b_y_high_f64x2);
526
- cov_zz_low_f64x2 = vfmaq_f64(cov_zz_low_f64x2, a_z_low_f64x2, b_z_low_f64x2),
527
- cov_zz_high_f64x2 = vfmaq_f64(cov_zz_high_f64x2, a_z_high_f64x2, b_z_high_f64x2);
254
+ covariance_xx_low_f64x2 = vfmaq_f64(covariance_xx_low_f64x2, a_x_low_f64x2, b_x_low_f64x2),
255
+ covariance_xx_high_f64x2 = vfmaq_f64(covariance_xx_high_f64x2, a_x_high_f64x2, b_x_high_f64x2);
256
+ covariance_xy_low_f64x2 = vfmaq_f64(covariance_xy_low_f64x2, a_x_low_f64x2, b_y_low_f64x2),
257
+ covariance_xy_high_f64x2 = vfmaq_f64(covariance_xy_high_f64x2, a_x_high_f64x2, b_y_high_f64x2);
258
+ covariance_xz_low_f64x2 = vfmaq_f64(covariance_xz_low_f64x2, a_x_low_f64x2, b_z_low_f64x2),
259
+ covariance_xz_high_f64x2 = vfmaq_f64(covariance_xz_high_f64x2, a_x_high_f64x2, b_z_high_f64x2);
260
+ covariance_yx_low_f64x2 = vfmaq_f64(covariance_yx_low_f64x2, a_y_low_f64x2, b_x_low_f64x2),
261
+ covariance_yx_high_f64x2 = vfmaq_f64(covariance_yx_high_f64x2, a_y_high_f64x2, b_x_high_f64x2);
262
+ covariance_yy_low_f64x2 = vfmaq_f64(covariance_yy_low_f64x2, a_y_low_f64x2, b_y_low_f64x2),
263
+ covariance_yy_high_f64x2 = vfmaq_f64(covariance_yy_high_f64x2, a_y_high_f64x2, b_y_high_f64x2);
264
+ covariance_yz_low_f64x2 = vfmaq_f64(covariance_yz_low_f64x2, a_y_low_f64x2, b_z_low_f64x2),
265
+ covariance_yz_high_f64x2 = vfmaq_f64(covariance_yz_high_f64x2, a_y_high_f64x2, b_z_high_f64x2);
266
+ covariance_zx_low_f64x2 = vfmaq_f64(covariance_zx_low_f64x2, a_z_low_f64x2, b_x_low_f64x2),
267
+ covariance_zx_high_f64x2 = vfmaq_f64(covariance_zx_high_f64x2, a_z_high_f64x2, b_x_high_f64x2);
268
+ covariance_zy_low_f64x2 = vfmaq_f64(covariance_zy_low_f64x2, a_z_low_f64x2, b_y_low_f64x2),
269
+ covariance_zy_high_f64x2 = vfmaq_f64(covariance_zy_high_f64x2, a_z_high_f64x2, b_y_high_f64x2);
270
+ covariance_zz_low_f64x2 = vfmaq_f64(covariance_zz_low_f64x2, a_z_low_f64x2, b_z_low_f64x2),
271
+ covariance_zz_high_f64x2 = vfmaq_f64(covariance_zz_high_f64x2, a_z_high_f64x2, b_z_high_f64x2);
272
+ norm_squared_a_low_f64x2 = vfmaq_f64(norm_squared_a_low_f64x2, a_x_low_f64x2, a_x_low_f64x2);
273
+ norm_squared_a_high_f64x2 = vfmaq_f64(norm_squared_a_high_f64x2, a_x_high_f64x2, a_x_high_f64x2);
274
+ norm_squared_a_low_f64x2 = vfmaq_f64(norm_squared_a_low_f64x2, a_y_low_f64x2, a_y_low_f64x2);
275
+ norm_squared_a_high_f64x2 = vfmaq_f64(norm_squared_a_high_f64x2, a_y_high_f64x2, a_y_high_f64x2);
276
+ norm_squared_a_low_f64x2 = vfmaq_f64(norm_squared_a_low_f64x2, a_z_low_f64x2, a_z_low_f64x2);
277
+ norm_squared_a_high_f64x2 = vfmaq_f64(norm_squared_a_high_f64x2, a_z_high_f64x2, a_z_high_f64x2);
278
+ norm_squared_b_low_f64x2 = vfmaq_f64(norm_squared_b_low_f64x2, b_x_low_f64x2, b_x_low_f64x2);
279
+ norm_squared_b_high_f64x2 = vfmaq_f64(norm_squared_b_high_f64x2, b_x_high_f64x2, b_x_high_f64x2);
280
+ norm_squared_b_low_f64x2 = vfmaq_f64(norm_squared_b_low_f64x2, b_y_low_f64x2, b_y_low_f64x2);
281
+ norm_squared_b_high_f64x2 = vfmaq_f64(norm_squared_b_high_f64x2, b_y_high_f64x2, b_y_high_f64x2);
282
+ norm_squared_b_low_f64x2 = vfmaq_f64(norm_squared_b_low_f64x2, b_z_low_f64x2, b_z_low_f64x2);
283
+ norm_squared_b_high_f64x2 = vfmaq_f64(norm_squared_b_high_f64x2, b_z_high_f64x2, b_z_high_f64x2);
528
284
  }
529
285
 
530
286
  // Reduce centroid accumulators
@@ -536,15 +292,17 @@ NK_PUBLIC void nk_kabsch_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_
536
292
  nk_f64_t sum_b_z = vaddvq_f64(vaddq_f64(sum_b_z_low_f64x2, sum_b_z_high_f64x2));
537
293
 
538
294
  // Reduce covariance accumulators
539
- nk_f64_t covariance_x_x = vaddvq_f64(vaddq_f64(cov_xx_low_f64x2, cov_xx_high_f64x2));
540
- nk_f64_t covariance_x_y = vaddvq_f64(vaddq_f64(cov_xy_low_f64x2, cov_xy_high_f64x2));
541
- nk_f64_t covariance_x_z = vaddvq_f64(vaddq_f64(cov_xz_low_f64x2, cov_xz_high_f64x2));
542
- nk_f64_t covariance_y_x = vaddvq_f64(vaddq_f64(cov_yx_low_f64x2, cov_yx_high_f64x2));
543
- nk_f64_t covariance_y_y = vaddvq_f64(vaddq_f64(cov_yy_low_f64x2, cov_yy_high_f64x2));
544
- nk_f64_t covariance_y_z = vaddvq_f64(vaddq_f64(cov_yz_low_f64x2, cov_yz_high_f64x2));
545
- nk_f64_t covariance_z_x = vaddvq_f64(vaddq_f64(cov_zx_low_f64x2, cov_zx_high_f64x2));
546
- nk_f64_t covariance_z_y = vaddvq_f64(vaddq_f64(cov_zy_low_f64x2, cov_zy_high_f64x2));
547
- nk_f64_t covariance_z_z = vaddvq_f64(vaddq_f64(cov_zz_low_f64x2, cov_zz_high_f64x2));
295
+ nk_f64_t covariance_x_x = vaddvq_f64(vaddq_f64(covariance_xx_low_f64x2, covariance_xx_high_f64x2));
296
+ nk_f64_t covariance_x_y = vaddvq_f64(vaddq_f64(covariance_xy_low_f64x2, covariance_xy_high_f64x2));
297
+ nk_f64_t covariance_x_z = vaddvq_f64(vaddq_f64(covariance_xz_low_f64x2, covariance_xz_high_f64x2));
298
+ nk_f64_t covariance_y_x = vaddvq_f64(vaddq_f64(covariance_yx_low_f64x2, covariance_yx_high_f64x2));
299
+ nk_f64_t covariance_y_y = vaddvq_f64(vaddq_f64(covariance_yy_low_f64x2, covariance_yy_high_f64x2));
300
+ nk_f64_t covariance_y_z = vaddvq_f64(vaddq_f64(covariance_yz_low_f64x2, covariance_yz_high_f64x2));
301
+ nk_f64_t covariance_z_x = vaddvq_f64(vaddq_f64(covariance_zx_low_f64x2, covariance_zx_high_f64x2));
302
+ nk_f64_t covariance_z_y = vaddvq_f64(vaddq_f64(covariance_zy_low_f64x2, covariance_zy_high_f64x2));
303
+ nk_f64_t covariance_z_z = vaddvq_f64(vaddq_f64(covariance_zz_low_f64x2, covariance_zz_high_f64x2));
304
+ nk_f64_t norm_squared_a = vaddvq_f64(vaddq_f64(norm_squared_a_low_f64x2, norm_squared_a_high_f64x2));
305
+ nk_f64_t norm_squared_b = vaddvq_f64(vaddq_f64(norm_squared_b_low_f64x2, norm_squared_b_high_f64x2));
548
306
 
549
307
  // Scalar tail
550
308
  for (; index < n; ++index) {
@@ -555,6 +313,8 @@ NK_PUBLIC void nk_kabsch_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_
555
313
  covariance_x_x += ax * bx, covariance_x_y += ax * by, covariance_x_z += ax * bz;
556
314
  covariance_y_x += ay * bx, covariance_y_y += ay * by, covariance_y_z += ay * bz;
557
315
  covariance_z_x += az * bx, covariance_z_y += az * by, covariance_z_z += az * bz;
316
+ norm_squared_a += ax * ax + ay * ay + az * az;
317
+ norm_squared_b += bx * bx + by * by + bz * bz;
558
318
  }
559
319
 
560
320
  // Compute centroids
@@ -569,50 +329,85 @@ NK_PUBLIC void nk_kabsch_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_
569
329
  b_centroid[2] = (nk_f32_t)centroid_b_z;
570
330
 
571
331
  // Apply centering correction: H_centered = sum(a * bᵀ) - n * centroid_a * centroid_bᵀ
572
- nk_f64_t h[9];
573
- h[0] = covariance_x_x - (nk_f64_t)n * centroid_a_x * centroid_b_x;
574
- h[1] = covariance_x_y - (nk_f64_t)n * centroid_a_x * centroid_b_y;
575
- h[2] = covariance_x_z - (nk_f64_t)n * centroid_a_x * centroid_b_z;
576
- h[3] = covariance_y_x - (nk_f64_t)n * centroid_a_y * centroid_b_x;
577
- h[4] = covariance_y_y - (nk_f64_t)n * centroid_a_y * centroid_b_y;
578
- h[5] = covariance_y_z - (nk_f64_t)n * centroid_a_y * centroid_b_z;
579
- h[6] = covariance_z_x - (nk_f64_t)n * centroid_a_z * centroid_b_x;
580
- h[7] = covariance_z_y - (nk_f64_t)n * centroid_a_z * centroid_b_y;
581
- h[8] = covariance_z_z - (nk_f64_t)n * centroid_a_z * centroid_b_z;
582
-
583
- nk_f64_t svd_u[9], svd_s[9], svd_v[9];
584
- nk_svd3x3_f64_(h, svd_u, svd_s, svd_v);
585
-
586
- nk_f64_t r[9];
587
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
588
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
589
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
590
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
591
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
592
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
593
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
594
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
595
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
596
-
597
- if (nk_det3x3_f64_(r) < 0) {
598
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
599
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
600
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
601
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
602
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
603
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
604
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
605
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
606
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
607
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
332
+ nk_f64_t cross_covariance[9];
333
+ cross_covariance[0] = covariance_x_x - (nk_f64_t)n * centroid_a_x * centroid_b_x;
334
+ cross_covariance[1] = covariance_x_y - (nk_f64_t)n * centroid_a_x * centroid_b_y;
335
+ cross_covariance[2] = covariance_x_z - (nk_f64_t)n * centroid_a_x * centroid_b_z;
336
+ cross_covariance[3] = covariance_y_x - (nk_f64_t)n * centroid_a_y * centroid_b_x;
337
+ cross_covariance[4] = covariance_y_y - (nk_f64_t)n * centroid_a_y * centroid_b_y;
338
+ cross_covariance[5] = covariance_y_z - (nk_f64_t)n * centroid_a_y * centroid_b_z;
339
+ cross_covariance[6] = covariance_z_x - (nk_f64_t)n * centroid_a_z * centroid_b_x;
340
+ cross_covariance[7] = covariance_z_y - (nk_f64_t)n * centroid_a_z * centroid_b_y;
341
+ cross_covariance[8] = covariance_z_z - (nk_f64_t)n * centroid_a_z * centroid_b_z;
342
+
343
+ // Identity-dominant short-circuit: if H ≈ diag(positive entries), R = I and trace(R·H) = trace(H).
344
+ nk_f64_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
345
+ cross_covariance[4] * cross_covariance[4] +
346
+ cross_covariance[8] * cross_covariance[8];
347
+ nk_f64_t covariance_offdiagonal_norm_squared =
348
+ cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
349
+ cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
350
+ cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
351
+ nk_f64_t optimal_rotation[9];
352
+ nk_f64_t trace_rotation_covariance;
353
+ if (covariance_offdiagonal_norm_squared < 1e-20 * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0 &&
354
+ cross_covariance[4] > 0.0 && cross_covariance[8] > 0.0) {
355
+ optimal_rotation[0] = 1, optimal_rotation[1] = 0, optimal_rotation[2] = 0, optimal_rotation[3] = 0,
356
+ optimal_rotation[4] = 1, optimal_rotation[5] = 0, optimal_rotation[6] = 0, optimal_rotation[7] = 0,
357
+ optimal_rotation[8] = 1;
358
+ trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
359
+ }
360
+ else {
361
+ nk_f64_t svd_left[9], svd_diagonal[9], svd_right[9];
362
+ nk_svd3x3_f64_(cross_covariance, svd_left, svd_diagonal, svd_right);
363
+
364
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
365
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
366
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
367
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
368
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
369
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
370
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
371
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
372
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
373
+
374
+ if (nk_det3x3_f64_(optimal_rotation) < 0) {
375
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
376
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
377
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
378
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
379
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
380
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
381
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
382
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
383
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
384
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
385
+ }
386
+
387
+ trace_rotation_covariance =
388
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
389
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
390
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
391
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
392
+ optimal_rotation[8] * cross_covariance[8];
608
393
  }
609
394
 
610
395
  if (rotation)
611
- for (int j = 0; j != 9; ++j) rotation[j] = (nk_f32_t)r[j];
396
+ for (int j = 0; j != 9; ++j) rotation[j] = (nk_f32_t)optimal_rotation[j];
612
397
  if (scale) *scale = 1.0f;
613
- *result = nk_f64_sqrt_neon(nk_transformed_ssd_f32_neon_(a, b, n, r, 1.0, centroid_a_x, centroid_a_y, centroid_a_z,
614
- centroid_b_x, centroid_b_y, centroid_b_z) /
615
- (nk_f64_t)n);
398
+
399
+ // Folded SSD via trace identity: SSD = ‖a-ā‖² + ‖b-b̄‖² − 2·trace(R · H_centered).
400
+ nk_f64_t centered_norm_squared_a = norm_squared_a -
401
+ (nk_f64_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
402
+ centroid_a_z * centroid_a_z);
403
+ nk_f64_t centered_norm_squared_b = norm_squared_b -
404
+ (nk_f64_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
405
+ centroid_b_z * centroid_b_z);
406
+ if (centered_norm_squared_a < 0.0) centered_norm_squared_a = 0.0;
407
+ if (centered_norm_squared_b < 0.0) centered_norm_squared_b = 0.0;
408
+ nk_f64_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0 * trace_rotation_covariance;
409
+ if (sum_squared < 0.0) sum_squared = 0.0;
410
+ *result = nk_f64_sqrt_neon(sum_squared / (nk_f64_t)n);
616
411
  }
617
412
 
618
413
  NK_PUBLIC void nk_kabsch_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
@@ -625,12 +420,20 @@ NK_PUBLIC void nk_kabsch_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_
625
420
  float64x2_t sum_a_x_b_f64x2 = zeros_f64x2, sum_a_y_b_f64x2 = zeros_f64x2, sum_a_z_b_f64x2 = zeros_f64x2;
626
421
  float64x2_t sum_b_x_b_f64x2 = zeros_f64x2, sum_b_y_b_f64x2 = zeros_f64x2, sum_b_z_b_f64x2 = zeros_f64x2;
627
422
 
628
- float64x2_t cov_xx_a_f64x2 = zeros_f64x2, cov_xy_a_f64x2 = zeros_f64x2, cov_xz_a_f64x2 = zeros_f64x2;
629
- float64x2_t cov_yx_a_f64x2 = zeros_f64x2, cov_yy_a_f64x2 = zeros_f64x2, cov_yz_a_f64x2 = zeros_f64x2;
630
- float64x2_t cov_zx_a_f64x2 = zeros_f64x2, cov_zy_a_f64x2 = zeros_f64x2, cov_zz_a_f64x2 = zeros_f64x2;
631
- float64x2_t cov_xx_b_f64x2 = zeros_f64x2, cov_xy_b_f64x2 = zeros_f64x2, cov_xz_b_f64x2 = zeros_f64x2;
632
- float64x2_t cov_yx_b_f64x2 = zeros_f64x2, cov_yy_b_f64x2 = zeros_f64x2, cov_yz_b_f64x2 = zeros_f64x2;
633
- float64x2_t cov_zx_b_f64x2 = zeros_f64x2, cov_zy_b_f64x2 = zeros_f64x2, cov_zz_b_f64x2 = zeros_f64x2;
423
+ float64x2_t covariance_xx_a_f64x2 = zeros_f64x2, covariance_xy_a_f64x2 = zeros_f64x2,
424
+ covariance_xz_a_f64x2 = zeros_f64x2;
425
+ float64x2_t covariance_yx_a_f64x2 = zeros_f64x2, covariance_yy_a_f64x2 = zeros_f64x2,
426
+ covariance_yz_a_f64x2 = zeros_f64x2;
427
+ float64x2_t covariance_zx_a_f64x2 = zeros_f64x2, covariance_zy_a_f64x2 = zeros_f64x2,
428
+ covariance_zz_a_f64x2 = zeros_f64x2;
429
+ float64x2_t covariance_xx_b_f64x2 = zeros_f64x2, covariance_xy_b_f64x2 = zeros_f64x2,
430
+ covariance_xz_b_f64x2 = zeros_f64x2;
431
+ float64x2_t covariance_yx_b_f64x2 = zeros_f64x2, covariance_yy_b_f64x2 = zeros_f64x2,
432
+ covariance_yz_b_f64x2 = zeros_f64x2;
433
+ float64x2_t covariance_zx_b_f64x2 = zeros_f64x2, covariance_zy_b_f64x2 = zeros_f64x2,
434
+ covariance_zz_b_f64x2 = zeros_f64x2;
435
+ float64x2_t norm_squared_a_a_f64x2 = zeros_f64x2, norm_squared_a_b_f64x2 = zeros_f64x2;
436
+ float64x2_t norm_squared_b_a_f64x2 = zeros_f64x2, norm_squared_b_b_f64x2 = zeros_f64x2;
634
437
 
635
438
  nk_size_t i = 0;
636
439
  float64x2_t a1_x_f64x2, a1_y_f64x2, a1_z_f64x2, b1_x_f64x2, b1_y_f64x2, b1_z_f64x2;
@@ -642,6 +445,18 @@ NK_PUBLIC void nk_kabsch_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_
642
445
  nk_deinterleave_f64x2_neon_(b + i * 3, &b1_x_f64x2, &b1_y_f64x2, &b1_z_f64x2);
643
446
  nk_deinterleave_f64x2_neon_(a + (i + 2) * 3, &a2_x_f64x2, &a2_y_f64x2, &a2_z_f64x2);
644
447
  nk_deinterleave_f64x2_neon_(b + (i + 2) * 3, &b2_x_f64x2, &b2_y_f64x2, &b2_z_f64x2);
448
+ norm_squared_a_a_f64x2 = vfmaq_f64(norm_squared_a_a_f64x2, a1_x_f64x2, a1_x_f64x2);
449
+ norm_squared_a_a_f64x2 = vfmaq_f64(norm_squared_a_a_f64x2, a1_y_f64x2, a1_y_f64x2);
450
+ norm_squared_a_a_f64x2 = vfmaq_f64(norm_squared_a_a_f64x2, a1_z_f64x2, a1_z_f64x2);
451
+ norm_squared_a_b_f64x2 = vfmaq_f64(norm_squared_a_b_f64x2, a2_x_f64x2, a2_x_f64x2);
452
+ norm_squared_a_b_f64x2 = vfmaq_f64(norm_squared_a_b_f64x2, a2_y_f64x2, a2_y_f64x2);
453
+ norm_squared_a_b_f64x2 = vfmaq_f64(norm_squared_a_b_f64x2, a2_z_f64x2, a2_z_f64x2);
454
+ norm_squared_b_a_f64x2 = vfmaq_f64(norm_squared_b_a_f64x2, b1_x_f64x2, b1_x_f64x2);
455
+ norm_squared_b_a_f64x2 = vfmaq_f64(norm_squared_b_a_f64x2, b1_y_f64x2, b1_y_f64x2);
456
+ norm_squared_b_a_f64x2 = vfmaq_f64(norm_squared_b_a_f64x2, b1_z_f64x2, b1_z_f64x2);
457
+ norm_squared_b_b_f64x2 = vfmaq_f64(norm_squared_b_b_f64x2, b2_x_f64x2, b2_x_f64x2);
458
+ norm_squared_b_b_f64x2 = vfmaq_f64(norm_squared_b_b_f64x2, b2_y_f64x2, b2_y_f64x2);
459
+ norm_squared_b_b_f64x2 = vfmaq_f64(norm_squared_b_b_f64x2, b2_z_f64x2, b2_z_f64x2);
645
460
 
646
461
  // Interleaved accumulation
647
462
  sum_a_x_a_f64x2 = vaddq_f64(sum_a_x_a_f64x2, a1_x_f64x2);
@@ -657,24 +472,24 @@ NK_PUBLIC void nk_kabsch_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_
657
472
  sum_b_z_a_f64x2 = vaddq_f64(sum_b_z_a_f64x2, b1_z_f64x2);
658
473
  sum_b_z_b_f64x2 = vaddq_f64(sum_b_z_b_f64x2, b2_z_f64x2);
659
474
 
660
- cov_xx_a_f64x2 = vfmaq_f64(cov_xx_a_f64x2, a1_x_f64x2, b1_x_f64x2);
661
- cov_xx_b_f64x2 = vfmaq_f64(cov_xx_b_f64x2, a2_x_f64x2, b2_x_f64x2);
662
- cov_xy_a_f64x2 = vfmaq_f64(cov_xy_a_f64x2, a1_x_f64x2, b1_y_f64x2);
663
- cov_xy_b_f64x2 = vfmaq_f64(cov_xy_b_f64x2, a2_x_f64x2, b2_y_f64x2);
664
- cov_xz_a_f64x2 = vfmaq_f64(cov_xz_a_f64x2, a1_x_f64x2, b1_z_f64x2);
665
- cov_xz_b_f64x2 = vfmaq_f64(cov_xz_b_f64x2, a2_x_f64x2, b2_z_f64x2);
666
- cov_yx_a_f64x2 = vfmaq_f64(cov_yx_a_f64x2, a1_y_f64x2, b1_x_f64x2);
667
- cov_yx_b_f64x2 = vfmaq_f64(cov_yx_b_f64x2, a2_y_f64x2, b2_x_f64x2);
668
- cov_yy_a_f64x2 = vfmaq_f64(cov_yy_a_f64x2, a1_y_f64x2, b1_y_f64x2);
669
- cov_yy_b_f64x2 = vfmaq_f64(cov_yy_b_f64x2, a2_y_f64x2, b2_y_f64x2);
670
- cov_yz_a_f64x2 = vfmaq_f64(cov_yz_a_f64x2, a1_y_f64x2, b1_z_f64x2);
671
- cov_yz_b_f64x2 = vfmaq_f64(cov_yz_b_f64x2, a2_y_f64x2, b2_z_f64x2);
672
- cov_zx_a_f64x2 = vfmaq_f64(cov_zx_a_f64x2, a1_z_f64x2, b1_x_f64x2);
673
- cov_zx_b_f64x2 = vfmaq_f64(cov_zx_b_f64x2, a2_z_f64x2, b2_x_f64x2);
674
- cov_zy_a_f64x2 = vfmaq_f64(cov_zy_a_f64x2, a1_z_f64x2, b1_y_f64x2);
675
- cov_zy_b_f64x2 = vfmaq_f64(cov_zy_b_f64x2, a2_z_f64x2, b2_y_f64x2);
676
- cov_zz_a_f64x2 = vfmaq_f64(cov_zz_a_f64x2, a1_z_f64x2, b1_z_f64x2);
677
- cov_zz_b_f64x2 = vfmaq_f64(cov_zz_b_f64x2, a2_z_f64x2, b2_z_f64x2);
475
+ covariance_xx_a_f64x2 = vfmaq_f64(covariance_xx_a_f64x2, a1_x_f64x2, b1_x_f64x2);
476
+ covariance_xx_b_f64x2 = vfmaq_f64(covariance_xx_b_f64x2, a2_x_f64x2, b2_x_f64x2);
477
+ covariance_xy_a_f64x2 = vfmaq_f64(covariance_xy_a_f64x2, a1_x_f64x2, b1_y_f64x2);
478
+ covariance_xy_b_f64x2 = vfmaq_f64(covariance_xy_b_f64x2, a2_x_f64x2, b2_y_f64x2);
479
+ covariance_xz_a_f64x2 = vfmaq_f64(covariance_xz_a_f64x2, a1_x_f64x2, b1_z_f64x2);
480
+ covariance_xz_b_f64x2 = vfmaq_f64(covariance_xz_b_f64x2, a2_x_f64x2, b2_z_f64x2);
481
+ covariance_yx_a_f64x2 = vfmaq_f64(covariance_yx_a_f64x2, a1_y_f64x2, b1_x_f64x2);
482
+ covariance_yx_b_f64x2 = vfmaq_f64(covariance_yx_b_f64x2, a2_y_f64x2, b2_x_f64x2);
483
+ covariance_yy_a_f64x2 = vfmaq_f64(covariance_yy_a_f64x2, a1_y_f64x2, b1_y_f64x2);
484
+ covariance_yy_b_f64x2 = vfmaq_f64(covariance_yy_b_f64x2, a2_y_f64x2, b2_y_f64x2);
485
+ covariance_yz_a_f64x2 = vfmaq_f64(covariance_yz_a_f64x2, a1_y_f64x2, b1_z_f64x2);
486
+ covariance_yz_b_f64x2 = vfmaq_f64(covariance_yz_b_f64x2, a2_y_f64x2, b2_z_f64x2);
487
+ covariance_zx_a_f64x2 = vfmaq_f64(covariance_zx_a_f64x2, a1_z_f64x2, b1_x_f64x2);
488
+ covariance_zx_b_f64x2 = vfmaq_f64(covariance_zx_b_f64x2, a2_z_f64x2, b2_x_f64x2);
489
+ covariance_zy_a_f64x2 = vfmaq_f64(covariance_zy_a_f64x2, a1_z_f64x2, b1_y_f64x2);
490
+ covariance_zy_b_f64x2 = vfmaq_f64(covariance_zy_b_f64x2, a2_z_f64x2, b2_y_f64x2);
491
+ covariance_zz_a_f64x2 = vfmaq_f64(covariance_zz_a_f64x2, a1_z_f64x2, b1_z_f64x2);
492
+ covariance_zz_b_f64x2 = vfmaq_f64(covariance_zz_b_f64x2, a2_z_f64x2, b2_z_f64x2);
678
493
  }
679
494
 
680
495
  // 2-point tail
@@ -687,33 +502,41 @@ NK_PUBLIC void nk_kabsch_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_
687
502
  sum_b_x_a_f64x2 = vaddq_f64(sum_b_x_a_f64x2, b1_x_f64x2);
688
503
  sum_b_y_a_f64x2 = vaddq_f64(sum_b_y_a_f64x2, b1_y_f64x2);
689
504
  sum_b_z_a_f64x2 = vaddq_f64(sum_b_z_a_f64x2, b1_z_f64x2);
690
- cov_xx_a_f64x2 = vfmaq_f64(cov_xx_a_f64x2, a1_x_f64x2, b1_x_f64x2);
691
- cov_xy_a_f64x2 = vfmaq_f64(cov_xy_a_f64x2, a1_x_f64x2, b1_y_f64x2);
692
- cov_xz_a_f64x2 = vfmaq_f64(cov_xz_a_f64x2, a1_x_f64x2, b1_z_f64x2);
693
- cov_yx_a_f64x2 = vfmaq_f64(cov_yx_a_f64x2, a1_y_f64x2, b1_x_f64x2);
694
- cov_yy_a_f64x2 = vfmaq_f64(cov_yy_a_f64x2, a1_y_f64x2, b1_y_f64x2);
695
- cov_yz_a_f64x2 = vfmaq_f64(cov_yz_a_f64x2, a1_y_f64x2, b1_z_f64x2);
696
- cov_zx_a_f64x2 = vfmaq_f64(cov_zx_a_f64x2, a1_z_f64x2, b1_x_f64x2);
697
- cov_zy_a_f64x2 = vfmaq_f64(cov_zy_a_f64x2, a1_z_f64x2, b1_y_f64x2);
698
- cov_zz_a_f64x2 = vfmaq_f64(cov_zz_a_f64x2, a1_z_f64x2, b1_z_f64x2);
505
+ covariance_xx_a_f64x2 = vfmaq_f64(covariance_xx_a_f64x2, a1_x_f64x2, b1_x_f64x2);
506
+ covariance_xy_a_f64x2 = vfmaq_f64(covariance_xy_a_f64x2, a1_x_f64x2, b1_y_f64x2);
507
+ covariance_xz_a_f64x2 = vfmaq_f64(covariance_xz_a_f64x2, a1_x_f64x2, b1_z_f64x2);
508
+ covariance_yx_a_f64x2 = vfmaq_f64(covariance_yx_a_f64x2, a1_y_f64x2, b1_x_f64x2);
509
+ covariance_yy_a_f64x2 = vfmaq_f64(covariance_yy_a_f64x2, a1_y_f64x2, b1_y_f64x2);
510
+ covariance_yz_a_f64x2 = vfmaq_f64(covariance_yz_a_f64x2, a1_y_f64x2, b1_z_f64x2);
511
+ covariance_zx_a_f64x2 = vfmaq_f64(covariance_zx_a_f64x2, a1_z_f64x2, b1_x_f64x2);
512
+ covariance_zy_a_f64x2 = vfmaq_f64(covariance_zy_a_f64x2, a1_z_f64x2, b1_y_f64x2);
513
+ covariance_zz_a_f64x2 = vfmaq_f64(covariance_zz_a_f64x2, a1_z_f64x2, b1_z_f64x2);
514
+ norm_squared_a_a_f64x2 = vfmaq_f64(norm_squared_a_a_f64x2, a1_x_f64x2, a1_x_f64x2);
515
+ norm_squared_a_a_f64x2 = vfmaq_f64(norm_squared_a_a_f64x2, a1_y_f64x2, a1_y_f64x2);
516
+ norm_squared_a_a_f64x2 = vfmaq_f64(norm_squared_a_a_f64x2, a1_z_f64x2, a1_z_f64x2);
517
+ norm_squared_b_a_f64x2 = vfmaq_f64(norm_squared_b_a_f64x2, b1_x_f64x2, b1_x_f64x2);
518
+ norm_squared_b_a_f64x2 = vfmaq_f64(norm_squared_b_a_f64x2, b1_y_f64x2, b1_y_f64x2);
519
+ norm_squared_b_a_f64x2 = vfmaq_f64(norm_squared_b_a_f64x2, b1_z_f64x2, b1_z_f64x2);
699
520
  }
700
521
 
701
522
  // Combine dual accumulators
523
+ float64x2_t norm_squared_a_f64x2 = vaddq_f64(norm_squared_a_a_f64x2, norm_squared_a_b_f64x2);
524
+ float64x2_t norm_squared_b_f64x2 = vaddq_f64(norm_squared_b_a_f64x2, norm_squared_b_b_f64x2);
702
525
  float64x2_t sum_a_x_f64x2 = vaddq_f64(sum_a_x_a_f64x2, sum_a_x_b_f64x2);
703
526
  float64x2_t sum_a_y_f64x2 = vaddq_f64(sum_a_y_a_f64x2, sum_a_y_b_f64x2);
704
527
  float64x2_t sum_a_z_f64x2 = vaddq_f64(sum_a_z_a_f64x2, sum_a_z_b_f64x2);
705
528
  float64x2_t sum_b_x_f64x2 = vaddq_f64(sum_b_x_a_f64x2, sum_b_x_b_f64x2);
706
529
  float64x2_t sum_b_y_f64x2 = vaddq_f64(sum_b_y_a_f64x2, sum_b_y_b_f64x2);
707
530
  float64x2_t sum_b_z_f64x2 = vaddq_f64(sum_b_z_a_f64x2, sum_b_z_b_f64x2);
708
- float64x2_t cov_xx_f64x2 = vaddq_f64(cov_xx_a_f64x2, cov_xx_b_f64x2);
709
- float64x2_t cov_xy_f64x2 = vaddq_f64(cov_xy_a_f64x2, cov_xy_b_f64x2);
710
- float64x2_t cov_xz_f64x2 = vaddq_f64(cov_xz_a_f64x2, cov_xz_b_f64x2);
711
- float64x2_t cov_yx_f64x2 = vaddq_f64(cov_yx_a_f64x2, cov_yx_b_f64x2);
712
- float64x2_t cov_yy_f64x2 = vaddq_f64(cov_yy_a_f64x2, cov_yy_b_f64x2);
713
- float64x2_t cov_yz_f64x2 = vaddq_f64(cov_yz_a_f64x2, cov_yz_b_f64x2);
714
- float64x2_t cov_zx_f64x2 = vaddq_f64(cov_zx_a_f64x2, cov_zx_b_f64x2);
715
- float64x2_t cov_zy_f64x2 = vaddq_f64(cov_zy_a_f64x2, cov_zy_b_f64x2);
716
- float64x2_t cov_zz_f64x2 = vaddq_f64(cov_zz_a_f64x2, cov_zz_b_f64x2);
531
+ float64x2_t covariance_xx_f64x2 = vaddq_f64(covariance_xx_a_f64x2, covariance_xx_b_f64x2);
532
+ float64x2_t covariance_xy_f64x2 = vaddq_f64(covariance_xy_a_f64x2, covariance_xy_b_f64x2);
533
+ float64x2_t covariance_xz_f64x2 = vaddq_f64(covariance_xz_a_f64x2, covariance_xz_b_f64x2);
534
+ float64x2_t covariance_yx_f64x2 = vaddq_f64(covariance_yx_a_f64x2, covariance_yx_b_f64x2);
535
+ float64x2_t covariance_yy_f64x2 = vaddq_f64(covariance_yy_a_f64x2, covariance_yy_b_f64x2);
536
+ float64x2_t covariance_yz_f64x2 = vaddq_f64(covariance_yz_a_f64x2, covariance_yz_b_f64x2);
537
+ float64x2_t covariance_zx_f64x2 = vaddq_f64(covariance_zx_a_f64x2, covariance_zx_b_f64x2);
538
+ float64x2_t covariance_zy_f64x2 = vaddq_f64(covariance_zy_a_f64x2, covariance_zy_b_f64x2);
539
+ float64x2_t covariance_zz_f64x2 = vaddq_f64(covariance_zz_a_f64x2, covariance_zz_b_f64x2);
717
540
 
718
541
  // Reduce vector accumulators.
719
542
  nk_f64_t sum_a_x = nk_reduce_stable_f64x2_neon_(sum_a_x_f64x2), sum_a_x_compensation = 0.0;
@@ -723,15 +546,17 @@ NK_PUBLIC void nk_kabsch_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_
723
546
  nk_f64_t sum_b_y = nk_reduce_stable_f64x2_neon_(sum_b_y_f64x2), sum_b_y_compensation = 0.0;
724
547
  nk_f64_t sum_b_z = nk_reduce_stable_f64x2_neon_(sum_b_z_f64x2), sum_b_z_compensation = 0.0;
725
548
 
726
- nk_f64_t covariance_x_x = nk_reduce_stable_f64x2_neon_(cov_xx_f64x2), covariance_x_x_compensation = 0.0;
727
- nk_f64_t covariance_x_y = nk_reduce_stable_f64x2_neon_(cov_xy_f64x2), covariance_x_y_compensation = 0.0;
728
- nk_f64_t covariance_x_z = nk_reduce_stable_f64x2_neon_(cov_xz_f64x2), covariance_x_z_compensation = 0.0;
729
- nk_f64_t covariance_y_x = nk_reduce_stable_f64x2_neon_(cov_yx_f64x2), covariance_y_x_compensation = 0.0;
730
- nk_f64_t covariance_y_y = nk_reduce_stable_f64x2_neon_(cov_yy_f64x2), covariance_y_y_compensation = 0.0;
731
- nk_f64_t covariance_y_z = nk_reduce_stable_f64x2_neon_(cov_yz_f64x2), covariance_y_z_compensation = 0.0;
732
- nk_f64_t covariance_z_x = nk_reduce_stable_f64x2_neon_(cov_zx_f64x2), covariance_z_x_compensation = 0.0;
733
- nk_f64_t covariance_z_y = nk_reduce_stable_f64x2_neon_(cov_zy_f64x2), covariance_z_y_compensation = 0.0;
734
- nk_f64_t covariance_z_z = nk_reduce_stable_f64x2_neon_(cov_zz_f64x2), covariance_z_z_compensation = 0.0;
549
+ nk_f64_t covariance_x_x = nk_reduce_stable_f64x2_neon_(covariance_xx_f64x2), covariance_x_x_compensation = 0.0;
550
+ nk_f64_t covariance_x_y = nk_reduce_stable_f64x2_neon_(covariance_xy_f64x2), covariance_x_y_compensation = 0.0;
551
+ nk_f64_t covariance_x_z = nk_reduce_stable_f64x2_neon_(covariance_xz_f64x2), covariance_x_z_compensation = 0.0;
552
+ nk_f64_t covariance_y_x = nk_reduce_stable_f64x2_neon_(covariance_yx_f64x2), covariance_y_x_compensation = 0.0;
553
+ nk_f64_t covariance_y_y = nk_reduce_stable_f64x2_neon_(covariance_yy_f64x2), covariance_y_y_compensation = 0.0;
554
+ nk_f64_t covariance_y_z = nk_reduce_stable_f64x2_neon_(covariance_yz_f64x2), covariance_y_z_compensation = 0.0;
555
+ nk_f64_t covariance_z_x = nk_reduce_stable_f64x2_neon_(covariance_zx_f64x2), covariance_z_x_compensation = 0.0;
556
+ nk_f64_t covariance_z_y = nk_reduce_stable_f64x2_neon_(covariance_zy_f64x2), covariance_z_y_compensation = 0.0;
557
+ nk_f64_t covariance_z_z = nk_reduce_stable_f64x2_neon_(covariance_zz_f64x2), covariance_z_z_compensation = 0.0;
558
+ nk_f64_t norm_squared_a_sum = nk_reduce_stable_f64x2_neon_(norm_squared_a_f64x2), norm_squared_a_compensation = 0.0;
559
+ nk_f64_t norm_squared_b_sum = nk_reduce_stable_f64x2_neon_(norm_squared_b_f64x2), norm_squared_b_compensation = 0.0;
735
560
 
736
561
  // Scalar tail
737
562
  for (; i < n; ++i) {
@@ -752,6 +577,12 @@ NK_PUBLIC void nk_kabsch_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_
752
577
  nk_accumulate_product_f64_(&covariance_z_x, &covariance_z_x_compensation, az, bx),
753
578
  nk_accumulate_product_f64_(&covariance_z_y, &covariance_z_y_compensation, az, by),
754
579
  nk_accumulate_product_f64_(&covariance_z_z, &covariance_z_z_compensation, az, bz);
580
+ nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, ax);
581
+ nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, ay);
582
+ nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, az);
583
+ nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, bx);
584
+ nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, by);
585
+ nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, bz);
755
586
  }
756
587
 
757
588
  sum_a_x += sum_a_x_compensation, sum_a_y += sum_a_y_compensation, sum_a_z += sum_a_z_compensation;
@@ -762,6 +593,8 @@ NK_PUBLIC void nk_kabsch_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_
762
593
  covariance_y_z += covariance_y_z_compensation;
763
594
  covariance_z_x += covariance_z_x_compensation, covariance_z_y += covariance_z_y_compensation,
764
595
  covariance_z_z += covariance_z_z_compensation;
596
+ norm_squared_a_sum += norm_squared_a_compensation;
597
+ norm_squared_b_sum += norm_squared_b_compensation;
765
598
 
766
599
  // Compute centroids
767
600
  nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
@@ -784,27 +617,60 @@ NK_PUBLIC void nk_kabsch_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_
784
617
  // Compute SVD and optimal rotation
785
618
  nk_f64_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
786
619
  covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
787
- nk_f64_t svd_u[9], svd_s[9], svd_v[9];
788
- nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
789
-
790
- nk_f64_t r[9];
791
- nk_rotation_from_svd_f64_neon_(svd_u, svd_v, r);
792
620
 
793
- // Handle reflection: if det(R) < 0, negate third column of V and recompute R
794
- if (nk_det3x3_f64_(r) < 0) {
795
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
796
- nk_rotation_from_svd_f64_neon_(svd_u, svd_v, r);
621
+ // Identity-dominant short-circuit: if H ≈ diag(positive entries), R = I and trace(R·H) = trace(H).
622
+ nk_f64_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
623
+ cross_covariance[4] * cross_covariance[4] +
624
+ cross_covariance[8] * cross_covariance[8];
625
+ nk_f64_t covariance_offdiagonal_norm_squared =
626
+ cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
627
+ cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
628
+ cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
629
+ nk_f64_t optimal_rotation[9];
630
+ nk_f64_t trace_rotation_covariance;
631
+ if (covariance_offdiagonal_norm_squared < 1e-20 * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0 &&
632
+ cross_covariance[4] > 0.0 && cross_covariance[8] > 0.0) {
633
+ optimal_rotation[0] = 1, optimal_rotation[1] = 0, optimal_rotation[2] = 0, optimal_rotation[3] = 0,
634
+ optimal_rotation[4] = 1, optimal_rotation[5] = 0, optimal_rotation[6] = 0, optimal_rotation[7] = 0,
635
+ optimal_rotation[8] = 1;
636
+ trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
637
+ }
638
+ else {
639
+ nk_f64_t svd_left[9], svd_diagonal[9], svd_right[9];
640
+ nk_svd3x3_f64_(cross_covariance, svd_left, svd_diagonal, svd_right);
641
+ nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
642
+
643
+ // Handle reflection: if det(R) < 0, negate third column of V and recompute R
644
+ if (nk_det3x3_f64_(optimal_rotation) < 0) {
645
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
646
+ nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
647
+ }
648
+
649
+ trace_rotation_covariance =
650
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
651
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
652
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
653
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
654
+ optimal_rotation[8] * cross_covariance[8];
797
655
  }
798
656
 
799
657
  // Output rotation matrix and scale=1.0.
800
658
  if (rotation)
801
- for (int j = 0; j < 9; ++j) rotation[j] = r[j];
659
+ for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
802
660
 
803
661
  if (scale) *scale = 1.0;
804
662
 
805
- // Compute RMSD after optimal rotation
806
- nk_f64_t sum_squared = nk_transformed_ssd_f64_neon_(a, b, n, r, 1.0, centroid_a_x, centroid_a_y, centroid_a_z,
807
- centroid_b_x, centroid_b_y, centroid_b_z);
663
+ // Folded SSD via trace identity: SSD = ‖a-ā‖² + ‖b-b̄‖² − 2·trace(R · H_centered).
664
+ nk_f64_t centered_norm_squared_a = norm_squared_a_sum -
665
+ (nk_f64_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
666
+ centroid_a_z * centroid_a_z);
667
+ nk_f64_t centered_norm_squared_b = norm_squared_b_sum -
668
+ (nk_f64_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
669
+ centroid_b_z * centroid_b_z);
670
+ if (centered_norm_squared_a < 0.0) centered_norm_squared_a = 0.0;
671
+ if (centered_norm_squared_b < 0.0) centered_norm_squared_b = 0.0;
672
+ nk_f64_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0 * trace_rotation_covariance;
673
+ if (sum_squared < 0.0) sum_squared = 0.0;
808
674
  *result = nk_f64_sqrt_neon(sum_squared * inv_n);
809
675
  }
810
676
 
@@ -821,18 +687,19 @@ NK_PUBLIC void nk_umeyama_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size
821
687
  float64x2_t sum_b_z_low_f64x2 = zero_f64x2, sum_b_z_high_f64x2 = zero_f64x2;
822
688
 
823
689
  // Covariance accumulators (f64, lower/upper halves)
824
- float64x2_t cov_xx_low_f64x2 = zero_f64x2, cov_xx_high_f64x2 = zero_f64x2;
825
- float64x2_t cov_xy_low_f64x2 = zero_f64x2, cov_xy_high_f64x2 = zero_f64x2;
826
- float64x2_t cov_xz_low_f64x2 = zero_f64x2, cov_xz_high_f64x2 = zero_f64x2;
827
- float64x2_t cov_yx_low_f64x2 = zero_f64x2, cov_yx_high_f64x2 = zero_f64x2;
828
- float64x2_t cov_yy_low_f64x2 = zero_f64x2, cov_yy_high_f64x2 = zero_f64x2;
829
- float64x2_t cov_yz_low_f64x2 = zero_f64x2, cov_yz_high_f64x2 = zero_f64x2;
830
- float64x2_t cov_zx_low_f64x2 = zero_f64x2, cov_zx_high_f64x2 = zero_f64x2;
831
- float64x2_t cov_zy_low_f64x2 = zero_f64x2, cov_zy_high_f64x2 = zero_f64x2;
832
- float64x2_t cov_zz_low_f64x2 = zero_f64x2, cov_zz_high_f64x2 = zero_f64x2;
833
-
834
- // Variance of A accumulator
835
- float64x2_t variance_low_f64x2 = zero_f64x2, variance_high_f64x2 = zero_f64x2;
690
+ float64x2_t covariance_xx_low_f64x2 = zero_f64x2, covariance_xx_high_f64x2 = zero_f64x2;
691
+ float64x2_t covariance_xy_low_f64x2 = zero_f64x2, covariance_xy_high_f64x2 = zero_f64x2;
692
+ float64x2_t covariance_xz_low_f64x2 = zero_f64x2, covariance_xz_high_f64x2 = zero_f64x2;
693
+ float64x2_t covariance_yx_low_f64x2 = zero_f64x2, covariance_yx_high_f64x2 = zero_f64x2;
694
+ float64x2_t covariance_yy_low_f64x2 = zero_f64x2, covariance_yy_high_f64x2 = zero_f64x2;
695
+ float64x2_t covariance_yz_low_f64x2 = zero_f64x2, covariance_yz_high_f64x2 = zero_f64x2;
696
+ float64x2_t covariance_zx_low_f64x2 = zero_f64x2, covariance_zx_high_f64x2 = zero_f64x2;
697
+ float64x2_t covariance_zy_low_f64x2 = zero_f64x2, covariance_zy_high_f64x2 = zero_f64x2;
698
+ float64x2_t covariance_zz_low_f64x2 = zero_f64x2, covariance_zz_high_f64x2 = zero_f64x2;
699
+
700
+ // Norm-squared accumulators for both point sets (used for Umeyama scale and folded SSD).
701
+ float64x2_t norm_squared_a_low_f64x2 = zero_f64x2, norm_squared_a_high_f64x2 = zero_f64x2;
702
+ float64x2_t norm_squared_b_low_f64x2 = zero_f64x2, norm_squared_b_high_f64x2 = zero_f64x2;
836
703
 
837
704
  nk_size_t index = 0;
838
705
  for (; index + 4 <= n; index += 4) {
@@ -868,32 +735,38 @@ NK_PUBLIC void nk_umeyama_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size
868
735
  sum_b_z_high_f64x2 = vaddq_f64(sum_b_z_high_f64x2, b_z_high_f64x2);
869
736
 
870
737
  // Accumulate raw outer products (uncentered)
871
- cov_xx_low_f64x2 = vfmaq_f64(cov_xx_low_f64x2, a_x_low_f64x2, b_x_low_f64x2),
872
- cov_xx_high_f64x2 = vfmaq_f64(cov_xx_high_f64x2, a_x_high_f64x2, b_x_high_f64x2);
873
- cov_xy_low_f64x2 = vfmaq_f64(cov_xy_low_f64x2, a_x_low_f64x2, b_y_low_f64x2),
874
- cov_xy_high_f64x2 = vfmaq_f64(cov_xy_high_f64x2, a_x_high_f64x2, b_y_high_f64x2);
875
- cov_xz_low_f64x2 = vfmaq_f64(cov_xz_low_f64x2, a_x_low_f64x2, b_z_low_f64x2),
876
- cov_xz_high_f64x2 = vfmaq_f64(cov_xz_high_f64x2, a_x_high_f64x2, b_z_high_f64x2);
877
- cov_yx_low_f64x2 = vfmaq_f64(cov_yx_low_f64x2, a_y_low_f64x2, b_x_low_f64x2),
878
- cov_yx_high_f64x2 = vfmaq_f64(cov_yx_high_f64x2, a_y_high_f64x2, b_x_high_f64x2);
879
- cov_yy_low_f64x2 = vfmaq_f64(cov_yy_low_f64x2, a_y_low_f64x2, b_y_low_f64x2),
880
- cov_yy_high_f64x2 = vfmaq_f64(cov_yy_high_f64x2, a_y_high_f64x2, b_y_high_f64x2);
881
- cov_yz_low_f64x2 = vfmaq_f64(cov_yz_low_f64x2, a_y_low_f64x2, b_z_low_f64x2),
882
- cov_yz_high_f64x2 = vfmaq_f64(cov_yz_high_f64x2, a_y_high_f64x2, b_z_high_f64x2);
883
- cov_zx_low_f64x2 = vfmaq_f64(cov_zx_low_f64x2, a_z_low_f64x2, b_x_low_f64x2),
884
- cov_zx_high_f64x2 = vfmaq_f64(cov_zx_high_f64x2, a_z_high_f64x2, b_x_high_f64x2);
885
- cov_zy_low_f64x2 = vfmaq_f64(cov_zy_low_f64x2, a_z_low_f64x2, b_y_low_f64x2),
886
- cov_zy_high_f64x2 = vfmaq_f64(cov_zy_high_f64x2, a_z_high_f64x2, b_y_high_f64x2);
887
- cov_zz_low_f64x2 = vfmaq_f64(cov_zz_low_f64x2, a_z_low_f64x2, b_z_low_f64x2),
888
- cov_zz_high_f64x2 = vfmaq_f64(cov_zz_high_f64x2, a_z_high_f64x2, b_z_high_f64x2);
889
-
890
- // Accumulate variance of A (sum of squared coordinates)
891
- variance_low_f64x2 = vfmaq_f64(variance_low_f64x2, a_x_low_f64x2, a_x_low_f64x2),
892
- variance_high_f64x2 = vfmaq_f64(variance_high_f64x2, a_x_high_f64x2, a_x_high_f64x2);
893
- variance_low_f64x2 = vfmaq_f64(variance_low_f64x2, a_y_low_f64x2, a_y_low_f64x2),
894
- variance_high_f64x2 = vfmaq_f64(variance_high_f64x2, a_y_high_f64x2, a_y_high_f64x2);
895
- variance_low_f64x2 = vfmaq_f64(variance_low_f64x2, a_z_low_f64x2, a_z_low_f64x2),
896
- variance_high_f64x2 = vfmaq_f64(variance_high_f64x2, a_z_high_f64x2, a_z_high_f64x2);
738
+ covariance_xx_low_f64x2 = vfmaq_f64(covariance_xx_low_f64x2, a_x_low_f64x2, b_x_low_f64x2),
739
+ covariance_xx_high_f64x2 = vfmaq_f64(covariance_xx_high_f64x2, a_x_high_f64x2, b_x_high_f64x2);
740
+ covariance_xy_low_f64x2 = vfmaq_f64(covariance_xy_low_f64x2, a_x_low_f64x2, b_y_low_f64x2),
741
+ covariance_xy_high_f64x2 = vfmaq_f64(covariance_xy_high_f64x2, a_x_high_f64x2, b_y_high_f64x2);
742
+ covariance_xz_low_f64x2 = vfmaq_f64(covariance_xz_low_f64x2, a_x_low_f64x2, b_z_low_f64x2),
743
+ covariance_xz_high_f64x2 = vfmaq_f64(covariance_xz_high_f64x2, a_x_high_f64x2, b_z_high_f64x2);
744
+ covariance_yx_low_f64x2 = vfmaq_f64(covariance_yx_low_f64x2, a_y_low_f64x2, b_x_low_f64x2),
745
+ covariance_yx_high_f64x2 = vfmaq_f64(covariance_yx_high_f64x2, a_y_high_f64x2, b_x_high_f64x2);
746
+ covariance_yy_low_f64x2 = vfmaq_f64(covariance_yy_low_f64x2, a_y_low_f64x2, b_y_low_f64x2),
747
+ covariance_yy_high_f64x2 = vfmaq_f64(covariance_yy_high_f64x2, a_y_high_f64x2, b_y_high_f64x2);
748
+ covariance_yz_low_f64x2 = vfmaq_f64(covariance_yz_low_f64x2, a_y_low_f64x2, b_z_low_f64x2),
749
+ covariance_yz_high_f64x2 = vfmaq_f64(covariance_yz_high_f64x2, a_y_high_f64x2, b_z_high_f64x2);
750
+ covariance_zx_low_f64x2 = vfmaq_f64(covariance_zx_low_f64x2, a_z_low_f64x2, b_x_low_f64x2),
751
+ covariance_zx_high_f64x2 = vfmaq_f64(covariance_zx_high_f64x2, a_z_high_f64x2, b_x_high_f64x2);
752
+ covariance_zy_low_f64x2 = vfmaq_f64(covariance_zy_low_f64x2, a_z_low_f64x2, b_y_low_f64x2),
753
+ covariance_zy_high_f64x2 = vfmaq_f64(covariance_zy_high_f64x2, a_z_high_f64x2, b_y_high_f64x2);
754
+ covariance_zz_low_f64x2 = vfmaq_f64(covariance_zz_low_f64x2, a_z_low_f64x2, b_z_low_f64x2),
755
+ covariance_zz_high_f64x2 = vfmaq_f64(covariance_zz_high_f64x2, a_z_high_f64x2, b_z_high_f64x2);
756
+
757
+ // Accumulate norm-squared of A and B (sum of squared coordinates per point set).
758
+ norm_squared_a_low_f64x2 = vfmaq_f64(norm_squared_a_low_f64x2, a_x_low_f64x2, a_x_low_f64x2),
759
+ norm_squared_a_high_f64x2 = vfmaq_f64(norm_squared_a_high_f64x2, a_x_high_f64x2, a_x_high_f64x2);
760
+ norm_squared_a_low_f64x2 = vfmaq_f64(norm_squared_a_low_f64x2, a_y_low_f64x2, a_y_low_f64x2),
761
+ norm_squared_a_high_f64x2 = vfmaq_f64(norm_squared_a_high_f64x2, a_y_high_f64x2, a_y_high_f64x2);
762
+ norm_squared_a_low_f64x2 = vfmaq_f64(norm_squared_a_low_f64x2, a_z_low_f64x2, a_z_low_f64x2),
763
+ norm_squared_a_high_f64x2 = vfmaq_f64(norm_squared_a_high_f64x2, a_z_high_f64x2, a_z_high_f64x2);
764
+ norm_squared_b_low_f64x2 = vfmaq_f64(norm_squared_b_low_f64x2, b_x_low_f64x2, b_x_low_f64x2),
765
+ norm_squared_b_high_f64x2 = vfmaq_f64(norm_squared_b_high_f64x2, b_x_high_f64x2, b_x_high_f64x2);
766
+ norm_squared_b_low_f64x2 = vfmaq_f64(norm_squared_b_low_f64x2, b_y_low_f64x2, b_y_low_f64x2),
767
+ norm_squared_b_high_f64x2 = vfmaq_f64(norm_squared_b_high_f64x2, b_y_high_f64x2, b_y_high_f64x2);
768
+ norm_squared_b_low_f64x2 = vfmaq_f64(norm_squared_b_low_f64x2, b_z_low_f64x2, b_z_low_f64x2),
769
+ norm_squared_b_high_f64x2 = vfmaq_f64(norm_squared_b_high_f64x2, b_z_high_f64x2, b_z_high_f64x2);
897
770
  }
898
771
 
899
772
  // Reduce centroid accumulators
@@ -905,16 +778,17 @@ NK_PUBLIC void nk_umeyama_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size
905
778
  nk_f64_t sum_b_z = vaddvq_f64(vaddq_f64(sum_b_z_low_f64x2, sum_b_z_high_f64x2));
906
779
 
907
780
  // Reduce covariance accumulators
908
- nk_f64_t covariance_x_x = vaddvq_f64(vaddq_f64(cov_xx_low_f64x2, cov_xx_high_f64x2));
909
- nk_f64_t covariance_x_y = vaddvq_f64(vaddq_f64(cov_xy_low_f64x2, cov_xy_high_f64x2));
910
- nk_f64_t covariance_x_z = vaddvq_f64(vaddq_f64(cov_xz_low_f64x2, cov_xz_high_f64x2));
911
- nk_f64_t covariance_y_x = vaddvq_f64(vaddq_f64(cov_yx_low_f64x2, cov_yx_high_f64x2));
912
- nk_f64_t covariance_y_y = vaddvq_f64(vaddq_f64(cov_yy_low_f64x2, cov_yy_high_f64x2));
913
- nk_f64_t covariance_y_z = vaddvq_f64(vaddq_f64(cov_yz_low_f64x2, cov_yz_high_f64x2));
914
- nk_f64_t covariance_z_x = vaddvq_f64(vaddq_f64(cov_zx_low_f64x2, cov_zx_high_f64x2));
915
- nk_f64_t covariance_z_y = vaddvq_f64(vaddq_f64(cov_zy_low_f64x2, cov_zy_high_f64x2));
916
- nk_f64_t covariance_z_z = vaddvq_f64(vaddq_f64(cov_zz_low_f64x2, cov_zz_high_f64x2));
917
- nk_f64_t sum_sq_a = vaddvq_f64(vaddq_f64(variance_low_f64x2, variance_high_f64x2));
781
+ nk_f64_t covariance_x_x = vaddvq_f64(vaddq_f64(covariance_xx_low_f64x2, covariance_xx_high_f64x2));
782
+ nk_f64_t covariance_x_y = vaddvq_f64(vaddq_f64(covariance_xy_low_f64x2, covariance_xy_high_f64x2));
783
+ nk_f64_t covariance_x_z = vaddvq_f64(vaddq_f64(covariance_xz_low_f64x2, covariance_xz_high_f64x2));
784
+ nk_f64_t covariance_y_x = vaddvq_f64(vaddq_f64(covariance_yx_low_f64x2, covariance_yx_high_f64x2));
785
+ nk_f64_t covariance_y_y = vaddvq_f64(vaddq_f64(covariance_yy_low_f64x2, covariance_yy_high_f64x2));
786
+ nk_f64_t covariance_y_z = vaddvq_f64(vaddq_f64(covariance_yz_low_f64x2, covariance_yz_high_f64x2));
787
+ nk_f64_t covariance_z_x = vaddvq_f64(vaddq_f64(covariance_zx_low_f64x2, covariance_zx_high_f64x2));
788
+ nk_f64_t covariance_z_y = vaddvq_f64(vaddq_f64(covariance_zy_low_f64x2, covariance_zy_high_f64x2));
789
+ nk_f64_t covariance_z_z = vaddvq_f64(vaddq_f64(covariance_zz_low_f64x2, covariance_zz_high_f64x2));
790
+ nk_f64_t norm_squared_a_sum = vaddvq_f64(vaddq_f64(norm_squared_a_low_f64x2, norm_squared_a_high_f64x2));
791
+ nk_f64_t norm_squared_b_sum = vaddvq_f64(vaddq_f64(norm_squared_b_low_f64x2, norm_squared_b_high_f64x2));
918
792
 
919
793
  // Scalar tail
920
794
  for (; index < n; ++index) {
@@ -925,7 +799,8 @@ NK_PUBLIC void nk_umeyama_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size
925
799
  covariance_x_x += ax * bx, covariance_x_y += ax * by, covariance_x_z += ax * bz;
926
800
  covariance_y_x += ay * bx, covariance_y_y += ay * by, covariance_y_z += ay * bz;
927
801
  covariance_z_x += az * bx, covariance_z_y += az * by, covariance_z_z += az * bz;
928
- sum_sq_a += ax * ax + ay * ay + az * az;
802
+ norm_squared_a_sum += ax * ax + ay * ay + az * az;
803
+ norm_squared_b_sum += bx * bx + by * by + bz * bz;
929
804
  }
930
805
 
931
806
  // Compute centroids
@@ -939,57 +814,94 @@ NK_PUBLIC void nk_umeyama_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size
939
814
  b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
940
815
  b_centroid[2] = (nk_f32_t)centroid_b_z;
941
816
 
942
- // Compute variance of A (centered): var = sum(a^2)/n - centroid^2
943
- nk_f64_t variance_a = sum_sq_a * inv_n -
944
- (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
817
+ // Centered norm-squared via parallel-axis identity; clamp at zero for numeric safety.
818
+ nk_f64_t centered_norm_squared_a = norm_squared_a_sum -
819
+ (nk_f64_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
820
+ centroid_a_z * centroid_a_z);
821
+ nk_f64_t centered_norm_squared_b = norm_squared_b_sum -
822
+ (nk_f64_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
823
+ centroid_b_z * centroid_b_z);
824
+ if (centered_norm_squared_a < 0.0) centered_norm_squared_a = 0.0;
825
+ if (centered_norm_squared_b < 0.0) centered_norm_squared_b = 0.0;
945
826
 
946
827
  // Apply centering correction: H_centered = sum(a * bᵀ) - n * centroid_a * centroid_bᵀ
947
- nk_f64_t h[9];
948
- h[0] = covariance_x_x - (nk_f64_t)n * centroid_a_x * centroid_b_x;
949
- h[1] = covariance_x_y - (nk_f64_t)n * centroid_a_x * centroid_b_y;
950
- h[2] = covariance_x_z - (nk_f64_t)n * centroid_a_x * centroid_b_z;
951
- h[3] = covariance_y_x - (nk_f64_t)n * centroid_a_y * centroid_b_x;
952
- h[4] = covariance_y_y - (nk_f64_t)n * centroid_a_y * centroid_b_y;
953
- h[5] = covariance_y_z - (nk_f64_t)n * centroid_a_y * centroid_b_z;
954
- h[6] = covariance_z_x - (nk_f64_t)n * centroid_a_z * centroid_b_x;
955
- h[7] = covariance_z_y - (nk_f64_t)n * centroid_a_z * centroid_b_y;
956
- h[8] = covariance_z_z - (nk_f64_t)n * centroid_a_z * centroid_b_z;
957
-
958
- nk_f64_t svd_u[9], svd_s[9], svd_v[9];
959
- nk_svd3x3_f64_(h, svd_u, svd_s, svd_v);
960
-
961
- nk_f64_t r[9];
962
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
963
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
964
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
965
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
966
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
967
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
968
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
969
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
970
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
971
-
972
- nk_f64_t det = nk_det3x3_f64_(r), sign_correction = det < 0 ? -1.0 : 1.0;
973
- if (det < 0) {
974
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
975
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
976
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
977
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
978
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
979
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
980
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
981
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
982
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
983
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
828
+ nk_f64_t cross_covariance[9];
829
+ cross_covariance[0] = covariance_x_x - (nk_f64_t)n * centroid_a_x * centroid_b_x;
830
+ cross_covariance[1] = covariance_x_y - (nk_f64_t)n * centroid_a_x * centroid_b_y;
831
+ cross_covariance[2] = covariance_x_z - (nk_f64_t)n * centroid_a_x * centroid_b_z;
832
+ cross_covariance[3] = covariance_y_x - (nk_f64_t)n * centroid_a_y * centroid_b_x;
833
+ cross_covariance[4] = covariance_y_y - (nk_f64_t)n * centroid_a_y * centroid_b_y;
834
+ cross_covariance[5] = covariance_y_z - (nk_f64_t)n * centroid_a_y * centroid_b_z;
835
+ cross_covariance[6] = covariance_z_x - (nk_f64_t)n * centroid_a_z * centroid_b_x;
836
+ cross_covariance[7] = covariance_z_y - (nk_f64_t)n * centroid_a_z * centroid_b_y;
837
+ cross_covariance[8] = covariance_z_z - (nk_f64_t)n * centroid_a_z * centroid_b_z;
838
+
839
+ // Identity-dominant short-circuit: if H ≈ diag(positive entries), R = I and trace(R·H) = trace(H).
840
+ nk_f64_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
841
+ cross_covariance[4] * cross_covariance[4] +
842
+ cross_covariance[8] * cross_covariance[8];
843
+ nk_f64_t covariance_offdiagonal_norm_squared =
844
+ cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
845
+ cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
846
+ cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
847
+ nk_f64_t optimal_rotation[9];
848
+ nk_f64_t trace_rotation_covariance;
849
+ nk_f64_t applied_scale;
850
+ if (covariance_offdiagonal_norm_squared < 1e-20 * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0 &&
851
+ cross_covariance[4] > 0.0 && cross_covariance[8] > 0.0) {
852
+ optimal_rotation[0] = 1, optimal_rotation[1] = 0, optimal_rotation[2] = 0, optimal_rotation[3] = 0,
853
+ optimal_rotation[4] = 1, optimal_rotation[5] = 0, optimal_rotation[6] = 0, optimal_rotation[7] = 0,
854
+ optimal_rotation[8] = 1;
855
+ trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
856
+ applied_scale = centered_norm_squared_a > 0.0 ? trace_rotation_covariance / centered_norm_squared_a : 0.0;
857
+ }
858
+ else {
859
+ nk_f64_t svd_left[9], svd_diagonal[9], svd_right[9];
860
+ nk_svd3x3_f64_(cross_covariance, svd_left, svd_diagonal, svd_right);
861
+
862
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
863
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
864
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
865
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
866
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
867
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
868
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
869
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
870
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
871
+
872
+ nk_f64_t det = nk_det3x3_f64_(optimal_rotation), sign_correction = det < 0 ? -1.0 : 1.0;
873
+ if (det < 0) {
874
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
875
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
876
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
877
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
878
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
879
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
880
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
881
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
882
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
883
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
884
+ }
885
+
886
+ nk_f64_t trace_ds = svd_diagonal[0] + svd_diagonal[4] + sign_correction * svd_diagonal[8];
887
+ applied_scale = centered_norm_squared_a > 0.0 ? trace_ds / centered_norm_squared_a : 0.0;
888
+
889
+ trace_rotation_covariance =
890
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
891
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
892
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
893
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
894
+ optimal_rotation[8] * cross_covariance[8];
984
895
  }
985
-
986
- nk_f64_t applied_scale = (svd_s[0] + svd_s[4] + sign_correction * svd_s[8]) / ((nk_f64_t)n * variance_a);
987
896
  if (rotation)
988
- for (int j = 0; j != 9; ++j) rotation[j] = (nk_f32_t)r[j];
897
+ for (int j = 0; j != 9; ++j) rotation[j] = (nk_f32_t)optimal_rotation[j];
989
898
  if (scale) *scale = (nk_f32_t)applied_scale;
990
- *result = nk_f64_sqrt_neon(nk_transformed_ssd_f32_neon_(a, b, n, r, applied_scale, centroid_a_x, centroid_a_y,
991
- centroid_a_z, centroid_b_x, centroid_b_y, centroid_b_z) /
992
- (nk_f64_t)n);
899
+
900
+ // Folded SSD with scale: c²·‖a-ā‖² + ‖b-b̄‖² − 2c·trace(R · H_centered).
901
+ nk_f64_t sum_squared = applied_scale * applied_scale * centered_norm_squared_a + centered_norm_squared_b -
902
+ 2.0 * applied_scale * trace_rotation_covariance;
903
+ if (sum_squared < 0.0) sum_squared = 0.0;
904
+ *result = nk_f64_sqrt_neon(sum_squared / (nk_f64_t)n);
993
905
  }
994
906
 
995
907
  NK_PUBLIC void nk_umeyama_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
@@ -1002,13 +914,20 @@ NK_PUBLIC void nk_umeyama_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size
1002
914
  float64x2_t sum_a_x_b_f64x2 = zeros_f64x2, sum_a_y_b_f64x2 = zeros_f64x2, sum_a_z_b_f64x2 = zeros_f64x2;
1003
915
  float64x2_t sum_b_x_b_f64x2 = zeros_f64x2, sum_b_y_b_f64x2 = zeros_f64x2, sum_b_z_b_f64x2 = zeros_f64x2;
1004
916
 
1005
- float64x2_t cov_xx_a_f64x2 = zeros_f64x2, cov_xy_a_f64x2 = zeros_f64x2, cov_xz_a_f64x2 = zeros_f64x2;
1006
- float64x2_t cov_yx_a_f64x2 = zeros_f64x2, cov_yy_a_f64x2 = zeros_f64x2, cov_yz_a_f64x2 = zeros_f64x2;
1007
- float64x2_t cov_zx_a_f64x2 = zeros_f64x2, cov_zy_a_f64x2 = zeros_f64x2, cov_zz_a_f64x2 = zeros_f64x2;
1008
- float64x2_t cov_xx_b_f64x2 = zeros_f64x2, cov_xy_b_f64x2 = zeros_f64x2, cov_xz_b_f64x2 = zeros_f64x2;
1009
- float64x2_t cov_yx_b_f64x2 = zeros_f64x2, cov_yy_b_f64x2 = zeros_f64x2, cov_yz_b_f64x2 = zeros_f64x2;
1010
- float64x2_t cov_zx_b_f64x2 = zeros_f64x2, cov_zy_b_f64x2 = zeros_f64x2, cov_zz_b_f64x2 = zeros_f64x2;
1011
- float64x2_t variance_a_a_f64x2 = zeros_f64x2, variance_a_b_f64x2 = zeros_f64x2;
917
+ float64x2_t covariance_xx_a_f64x2 = zeros_f64x2, covariance_xy_a_f64x2 = zeros_f64x2,
918
+ covariance_xz_a_f64x2 = zeros_f64x2;
919
+ float64x2_t covariance_yx_a_f64x2 = zeros_f64x2, covariance_yy_a_f64x2 = zeros_f64x2,
920
+ covariance_yz_a_f64x2 = zeros_f64x2;
921
+ float64x2_t covariance_zx_a_f64x2 = zeros_f64x2, covariance_zy_a_f64x2 = zeros_f64x2,
922
+ covariance_zz_a_f64x2 = zeros_f64x2;
923
+ float64x2_t covariance_xx_b_f64x2 = zeros_f64x2, covariance_xy_b_f64x2 = zeros_f64x2,
924
+ covariance_xz_b_f64x2 = zeros_f64x2;
925
+ float64x2_t covariance_yx_b_f64x2 = zeros_f64x2, covariance_yy_b_f64x2 = zeros_f64x2,
926
+ covariance_yz_b_f64x2 = zeros_f64x2;
927
+ float64x2_t covariance_zx_b_f64x2 = zeros_f64x2, covariance_zy_b_f64x2 = zeros_f64x2,
928
+ covariance_zz_b_f64x2 = zeros_f64x2;
929
+ float64x2_t norm_squared_a_a_f64x2 = zeros_f64x2, norm_squared_a_b_f64x2 = zeros_f64x2;
930
+ float64x2_t norm_squared_b_a_f64x2 = zeros_f64x2, norm_squared_b_b_f64x2 = zeros_f64x2;
1012
931
 
1013
932
  nk_size_t i = 0;
1014
933
  float64x2_t a1_x_f64x2, a1_y_f64x2, a1_z_f64x2, b1_x_f64x2, b1_y_f64x2, b1_z_f64x2;
@@ -1035,31 +954,37 @@ NK_PUBLIC void nk_umeyama_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size
1035
954
  sum_b_z_a_f64x2 = vaddq_f64(sum_b_z_a_f64x2, b1_z_f64x2);
1036
955
  sum_b_z_b_f64x2 = vaddq_f64(sum_b_z_b_f64x2, b2_z_f64x2);
1037
956
 
1038
- cov_xx_a_f64x2 = vfmaq_f64(cov_xx_a_f64x2, a1_x_f64x2, b1_x_f64x2);
1039
- cov_xx_b_f64x2 = vfmaq_f64(cov_xx_b_f64x2, a2_x_f64x2, b2_x_f64x2);
1040
- cov_xy_a_f64x2 = vfmaq_f64(cov_xy_a_f64x2, a1_x_f64x2, b1_y_f64x2);
1041
- cov_xy_b_f64x2 = vfmaq_f64(cov_xy_b_f64x2, a2_x_f64x2, b2_y_f64x2);
1042
- cov_xz_a_f64x2 = vfmaq_f64(cov_xz_a_f64x2, a1_x_f64x2, b1_z_f64x2);
1043
- cov_xz_b_f64x2 = vfmaq_f64(cov_xz_b_f64x2, a2_x_f64x2, b2_z_f64x2);
1044
- cov_yx_a_f64x2 = vfmaq_f64(cov_yx_a_f64x2, a1_y_f64x2, b1_x_f64x2);
1045
- cov_yx_b_f64x2 = vfmaq_f64(cov_yx_b_f64x2, a2_y_f64x2, b2_x_f64x2);
1046
- cov_yy_a_f64x2 = vfmaq_f64(cov_yy_a_f64x2, a1_y_f64x2, b1_y_f64x2);
1047
- cov_yy_b_f64x2 = vfmaq_f64(cov_yy_b_f64x2, a2_y_f64x2, b2_y_f64x2);
1048
- cov_yz_a_f64x2 = vfmaq_f64(cov_yz_a_f64x2, a1_y_f64x2, b1_z_f64x2);
1049
- cov_yz_b_f64x2 = vfmaq_f64(cov_yz_b_f64x2, a2_y_f64x2, b2_z_f64x2);
1050
- cov_zx_a_f64x2 = vfmaq_f64(cov_zx_a_f64x2, a1_z_f64x2, b1_x_f64x2);
1051
- cov_zx_b_f64x2 = vfmaq_f64(cov_zx_b_f64x2, a2_z_f64x2, b2_x_f64x2);
1052
- cov_zy_a_f64x2 = vfmaq_f64(cov_zy_a_f64x2, a1_z_f64x2, b1_y_f64x2);
1053
- cov_zy_b_f64x2 = vfmaq_f64(cov_zy_b_f64x2, a2_z_f64x2, b2_y_f64x2);
1054
- cov_zz_a_f64x2 = vfmaq_f64(cov_zz_a_f64x2, a1_z_f64x2, b1_z_f64x2);
1055
- cov_zz_b_f64x2 = vfmaq_f64(cov_zz_b_f64x2, a2_z_f64x2, b2_z_f64x2);
1056
-
1057
- variance_a_a_f64x2 = vfmaq_f64(variance_a_a_f64x2, a1_x_f64x2, a1_x_f64x2);
1058
- variance_a_b_f64x2 = vfmaq_f64(variance_a_b_f64x2, a2_x_f64x2, a2_x_f64x2);
1059
- variance_a_a_f64x2 = vfmaq_f64(variance_a_a_f64x2, a1_y_f64x2, a1_y_f64x2);
1060
- variance_a_b_f64x2 = vfmaq_f64(variance_a_b_f64x2, a2_y_f64x2, a2_y_f64x2);
1061
- variance_a_a_f64x2 = vfmaq_f64(variance_a_a_f64x2, a1_z_f64x2, a1_z_f64x2);
1062
- variance_a_b_f64x2 = vfmaq_f64(variance_a_b_f64x2, a2_z_f64x2, a2_z_f64x2);
957
+ covariance_xx_a_f64x2 = vfmaq_f64(covariance_xx_a_f64x2, a1_x_f64x2, b1_x_f64x2);
958
+ covariance_xx_b_f64x2 = vfmaq_f64(covariance_xx_b_f64x2, a2_x_f64x2, b2_x_f64x2);
959
+ covariance_xy_a_f64x2 = vfmaq_f64(covariance_xy_a_f64x2, a1_x_f64x2, b1_y_f64x2);
960
+ covariance_xy_b_f64x2 = vfmaq_f64(covariance_xy_b_f64x2, a2_x_f64x2, b2_y_f64x2);
961
+ covariance_xz_a_f64x2 = vfmaq_f64(covariance_xz_a_f64x2, a1_x_f64x2, b1_z_f64x2);
962
+ covariance_xz_b_f64x2 = vfmaq_f64(covariance_xz_b_f64x2, a2_x_f64x2, b2_z_f64x2);
963
+ covariance_yx_a_f64x2 = vfmaq_f64(covariance_yx_a_f64x2, a1_y_f64x2, b1_x_f64x2);
964
+ covariance_yx_b_f64x2 = vfmaq_f64(covariance_yx_b_f64x2, a2_y_f64x2, b2_x_f64x2);
965
+ covariance_yy_a_f64x2 = vfmaq_f64(covariance_yy_a_f64x2, a1_y_f64x2, b1_y_f64x2);
966
+ covariance_yy_b_f64x2 = vfmaq_f64(covariance_yy_b_f64x2, a2_y_f64x2, b2_y_f64x2);
967
+ covariance_yz_a_f64x2 = vfmaq_f64(covariance_yz_a_f64x2, a1_y_f64x2, b1_z_f64x2);
968
+ covariance_yz_b_f64x2 = vfmaq_f64(covariance_yz_b_f64x2, a2_y_f64x2, b2_z_f64x2);
969
+ covariance_zx_a_f64x2 = vfmaq_f64(covariance_zx_a_f64x2, a1_z_f64x2, b1_x_f64x2);
970
+ covariance_zx_b_f64x2 = vfmaq_f64(covariance_zx_b_f64x2, a2_z_f64x2, b2_x_f64x2);
971
+ covariance_zy_a_f64x2 = vfmaq_f64(covariance_zy_a_f64x2, a1_z_f64x2, b1_y_f64x2);
972
+ covariance_zy_b_f64x2 = vfmaq_f64(covariance_zy_b_f64x2, a2_z_f64x2, b2_y_f64x2);
973
+ covariance_zz_a_f64x2 = vfmaq_f64(covariance_zz_a_f64x2, a1_z_f64x2, b1_z_f64x2);
974
+ covariance_zz_b_f64x2 = vfmaq_f64(covariance_zz_b_f64x2, a2_z_f64x2, b2_z_f64x2);
975
+
976
+ norm_squared_a_a_f64x2 = vfmaq_f64(norm_squared_a_a_f64x2, a1_x_f64x2, a1_x_f64x2);
977
+ norm_squared_a_b_f64x2 = vfmaq_f64(norm_squared_a_b_f64x2, a2_x_f64x2, a2_x_f64x2);
978
+ norm_squared_a_a_f64x2 = vfmaq_f64(norm_squared_a_a_f64x2, a1_y_f64x2, a1_y_f64x2);
979
+ norm_squared_a_b_f64x2 = vfmaq_f64(norm_squared_a_b_f64x2, a2_y_f64x2, a2_y_f64x2);
980
+ norm_squared_a_a_f64x2 = vfmaq_f64(norm_squared_a_a_f64x2, a1_z_f64x2, a1_z_f64x2);
981
+ norm_squared_a_b_f64x2 = vfmaq_f64(norm_squared_a_b_f64x2, a2_z_f64x2, a2_z_f64x2);
982
+ norm_squared_b_a_f64x2 = vfmaq_f64(norm_squared_b_a_f64x2, b1_x_f64x2, b1_x_f64x2);
983
+ norm_squared_b_b_f64x2 = vfmaq_f64(norm_squared_b_b_f64x2, b2_x_f64x2, b2_x_f64x2);
984
+ norm_squared_b_a_f64x2 = vfmaq_f64(norm_squared_b_a_f64x2, b1_y_f64x2, b1_y_f64x2);
985
+ norm_squared_b_b_f64x2 = vfmaq_f64(norm_squared_b_b_f64x2, b2_y_f64x2, b2_y_f64x2);
986
+ norm_squared_b_a_f64x2 = vfmaq_f64(norm_squared_b_a_f64x2, b1_z_f64x2, b1_z_f64x2);
987
+ norm_squared_b_b_f64x2 = vfmaq_f64(norm_squared_b_b_f64x2, b2_z_f64x2, b2_z_f64x2);
1063
988
  }
1064
989
 
1065
990
  // 2-point tail
@@ -1072,18 +997,21 @@ NK_PUBLIC void nk_umeyama_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size
1072
997
  sum_b_x_a_f64x2 = vaddq_f64(sum_b_x_a_f64x2, b1_x_f64x2);
1073
998
  sum_b_y_a_f64x2 = vaddq_f64(sum_b_y_a_f64x2, b1_y_f64x2);
1074
999
  sum_b_z_a_f64x2 = vaddq_f64(sum_b_z_a_f64x2, b1_z_f64x2);
1075
- cov_xx_a_f64x2 = vfmaq_f64(cov_xx_a_f64x2, a1_x_f64x2, b1_x_f64x2);
1076
- cov_xy_a_f64x2 = vfmaq_f64(cov_xy_a_f64x2, a1_x_f64x2, b1_y_f64x2);
1077
- cov_xz_a_f64x2 = vfmaq_f64(cov_xz_a_f64x2, a1_x_f64x2, b1_z_f64x2);
1078
- cov_yx_a_f64x2 = vfmaq_f64(cov_yx_a_f64x2, a1_y_f64x2, b1_x_f64x2);
1079
- cov_yy_a_f64x2 = vfmaq_f64(cov_yy_a_f64x2, a1_y_f64x2, b1_y_f64x2);
1080
- cov_yz_a_f64x2 = vfmaq_f64(cov_yz_a_f64x2, a1_y_f64x2, b1_z_f64x2);
1081
- cov_zx_a_f64x2 = vfmaq_f64(cov_zx_a_f64x2, a1_z_f64x2, b1_x_f64x2);
1082
- cov_zy_a_f64x2 = vfmaq_f64(cov_zy_a_f64x2, a1_z_f64x2, b1_y_f64x2);
1083
- cov_zz_a_f64x2 = vfmaq_f64(cov_zz_a_f64x2, a1_z_f64x2, b1_z_f64x2);
1084
- variance_a_a_f64x2 = vfmaq_f64(variance_a_a_f64x2, a1_x_f64x2, a1_x_f64x2);
1085
- variance_a_a_f64x2 = vfmaq_f64(variance_a_a_f64x2, a1_y_f64x2, a1_y_f64x2);
1086
- variance_a_a_f64x2 = vfmaq_f64(variance_a_a_f64x2, a1_z_f64x2, a1_z_f64x2);
1000
+ covariance_xx_a_f64x2 = vfmaq_f64(covariance_xx_a_f64x2, a1_x_f64x2, b1_x_f64x2);
1001
+ covariance_xy_a_f64x2 = vfmaq_f64(covariance_xy_a_f64x2, a1_x_f64x2, b1_y_f64x2);
1002
+ covariance_xz_a_f64x2 = vfmaq_f64(covariance_xz_a_f64x2, a1_x_f64x2, b1_z_f64x2);
1003
+ covariance_yx_a_f64x2 = vfmaq_f64(covariance_yx_a_f64x2, a1_y_f64x2, b1_x_f64x2);
1004
+ covariance_yy_a_f64x2 = vfmaq_f64(covariance_yy_a_f64x2, a1_y_f64x2, b1_y_f64x2);
1005
+ covariance_yz_a_f64x2 = vfmaq_f64(covariance_yz_a_f64x2, a1_y_f64x2, b1_z_f64x2);
1006
+ covariance_zx_a_f64x2 = vfmaq_f64(covariance_zx_a_f64x2, a1_z_f64x2, b1_x_f64x2);
1007
+ covariance_zy_a_f64x2 = vfmaq_f64(covariance_zy_a_f64x2, a1_z_f64x2, b1_y_f64x2);
1008
+ covariance_zz_a_f64x2 = vfmaq_f64(covariance_zz_a_f64x2, a1_z_f64x2, b1_z_f64x2);
1009
+ norm_squared_a_a_f64x2 = vfmaq_f64(norm_squared_a_a_f64x2, a1_x_f64x2, a1_x_f64x2);
1010
+ norm_squared_a_a_f64x2 = vfmaq_f64(norm_squared_a_a_f64x2, a1_y_f64x2, a1_y_f64x2);
1011
+ norm_squared_a_a_f64x2 = vfmaq_f64(norm_squared_a_a_f64x2, a1_z_f64x2, a1_z_f64x2);
1012
+ norm_squared_b_a_f64x2 = vfmaq_f64(norm_squared_b_a_f64x2, b1_x_f64x2, b1_x_f64x2);
1013
+ norm_squared_b_a_f64x2 = vfmaq_f64(norm_squared_b_a_f64x2, b1_y_f64x2, b1_y_f64x2);
1014
+ norm_squared_b_a_f64x2 = vfmaq_f64(norm_squared_b_a_f64x2, b1_z_f64x2, b1_z_f64x2);
1087
1015
  }
1088
1016
 
1089
1017
  // Combine dual accumulators
@@ -1093,16 +1021,17 @@ NK_PUBLIC void nk_umeyama_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size
1093
1021
  float64x2_t sum_b_x_f64x2 = vaddq_f64(sum_b_x_a_f64x2, sum_b_x_b_f64x2);
1094
1022
  float64x2_t sum_b_y_f64x2 = vaddq_f64(sum_b_y_a_f64x2, sum_b_y_b_f64x2);
1095
1023
  float64x2_t sum_b_z_f64x2 = vaddq_f64(sum_b_z_a_f64x2, sum_b_z_b_f64x2);
1096
- float64x2_t cov_xx_f64x2 = vaddq_f64(cov_xx_a_f64x2, cov_xx_b_f64x2);
1097
- float64x2_t cov_xy_f64x2 = vaddq_f64(cov_xy_a_f64x2, cov_xy_b_f64x2);
1098
- float64x2_t cov_xz_f64x2 = vaddq_f64(cov_xz_a_f64x2, cov_xz_b_f64x2);
1099
- float64x2_t cov_yx_f64x2 = vaddq_f64(cov_yx_a_f64x2, cov_yx_b_f64x2);
1100
- float64x2_t cov_yy_f64x2 = vaddq_f64(cov_yy_a_f64x2, cov_yy_b_f64x2);
1101
- float64x2_t cov_yz_f64x2 = vaddq_f64(cov_yz_a_f64x2, cov_yz_b_f64x2);
1102
- float64x2_t cov_zx_f64x2 = vaddq_f64(cov_zx_a_f64x2, cov_zx_b_f64x2);
1103
- float64x2_t cov_zy_f64x2 = vaddq_f64(cov_zy_a_f64x2, cov_zy_b_f64x2);
1104
- float64x2_t cov_zz_f64x2 = vaddq_f64(cov_zz_a_f64x2, cov_zz_b_f64x2);
1105
- float64x2_t variance_a_f64x2 = vaddq_f64(variance_a_a_f64x2, variance_a_b_f64x2);
1024
+ float64x2_t covariance_xx_f64x2 = vaddq_f64(covariance_xx_a_f64x2, covariance_xx_b_f64x2);
1025
+ float64x2_t covariance_xy_f64x2 = vaddq_f64(covariance_xy_a_f64x2, covariance_xy_b_f64x2);
1026
+ float64x2_t covariance_xz_f64x2 = vaddq_f64(covariance_xz_a_f64x2, covariance_xz_b_f64x2);
1027
+ float64x2_t covariance_yx_f64x2 = vaddq_f64(covariance_yx_a_f64x2, covariance_yx_b_f64x2);
1028
+ float64x2_t covariance_yy_f64x2 = vaddq_f64(covariance_yy_a_f64x2, covariance_yy_b_f64x2);
1029
+ float64x2_t covariance_yz_f64x2 = vaddq_f64(covariance_yz_a_f64x2, covariance_yz_b_f64x2);
1030
+ float64x2_t covariance_zx_f64x2 = vaddq_f64(covariance_zx_a_f64x2, covariance_zx_b_f64x2);
1031
+ float64x2_t covariance_zy_f64x2 = vaddq_f64(covariance_zy_a_f64x2, covariance_zy_b_f64x2);
1032
+ float64x2_t covariance_zz_f64x2 = vaddq_f64(covariance_zz_a_f64x2, covariance_zz_b_f64x2);
1033
+ float64x2_t norm_squared_a_f64x2 = vaddq_f64(norm_squared_a_a_f64x2, norm_squared_a_b_f64x2);
1034
+ float64x2_t norm_squared_b_f64x2 = vaddq_f64(norm_squared_b_a_f64x2, norm_squared_b_b_f64x2);
1106
1035
 
1107
1036
  // Reduce vector accumulators.
1108
1037
  nk_f64_t sum_a_x = nk_reduce_stable_f64x2_neon_(sum_a_x_f64x2), sum_a_x_compensation = 0.0;
@@ -1111,16 +1040,17 @@ NK_PUBLIC void nk_umeyama_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size
1111
1040
  nk_f64_t sum_b_x = nk_reduce_stable_f64x2_neon_(sum_b_x_f64x2), sum_b_x_compensation = 0.0;
1112
1041
  nk_f64_t sum_b_y = nk_reduce_stable_f64x2_neon_(sum_b_y_f64x2), sum_b_y_compensation = 0.0;
1113
1042
  nk_f64_t sum_b_z = nk_reduce_stable_f64x2_neon_(sum_b_z_f64x2), sum_b_z_compensation = 0.0;
1114
- nk_f64_t covariance_x_x = nk_reduce_stable_f64x2_neon_(cov_xx_f64x2), covariance_x_x_compensation = 0.0;
1115
- nk_f64_t covariance_x_y = nk_reduce_stable_f64x2_neon_(cov_xy_f64x2), covariance_x_y_compensation = 0.0;
1116
- nk_f64_t covariance_x_z = nk_reduce_stable_f64x2_neon_(cov_xz_f64x2), covariance_x_z_compensation = 0.0;
1117
- nk_f64_t covariance_y_x = nk_reduce_stable_f64x2_neon_(cov_yx_f64x2), covariance_y_x_compensation = 0.0;
1118
- nk_f64_t covariance_y_y = nk_reduce_stable_f64x2_neon_(cov_yy_f64x2), covariance_y_y_compensation = 0.0;
1119
- nk_f64_t covariance_y_z = nk_reduce_stable_f64x2_neon_(cov_yz_f64x2), covariance_y_z_compensation = 0.0;
1120
- nk_f64_t covariance_z_x = nk_reduce_stable_f64x2_neon_(cov_zx_f64x2), covariance_z_x_compensation = 0.0;
1121
- nk_f64_t covariance_z_y = nk_reduce_stable_f64x2_neon_(cov_zy_f64x2), covariance_z_y_compensation = 0.0;
1122
- nk_f64_t covariance_z_z = nk_reduce_stable_f64x2_neon_(cov_zz_f64x2), covariance_z_z_compensation = 0.0;
1123
- nk_f64_t sum_sq_a = nk_reduce_stable_f64x2_neon_(variance_a_f64x2), sum_sq_a_compensation = 0.0;
1043
+ nk_f64_t covariance_x_x = nk_reduce_stable_f64x2_neon_(covariance_xx_f64x2), covariance_x_x_compensation = 0.0;
1044
+ nk_f64_t covariance_x_y = nk_reduce_stable_f64x2_neon_(covariance_xy_f64x2), covariance_x_y_compensation = 0.0;
1045
+ nk_f64_t covariance_x_z = nk_reduce_stable_f64x2_neon_(covariance_xz_f64x2), covariance_x_z_compensation = 0.0;
1046
+ nk_f64_t covariance_y_x = nk_reduce_stable_f64x2_neon_(covariance_yx_f64x2), covariance_y_x_compensation = 0.0;
1047
+ nk_f64_t covariance_y_y = nk_reduce_stable_f64x2_neon_(covariance_yy_f64x2), covariance_y_y_compensation = 0.0;
1048
+ nk_f64_t covariance_y_z = nk_reduce_stable_f64x2_neon_(covariance_yz_f64x2), covariance_y_z_compensation = 0.0;
1049
+ nk_f64_t covariance_z_x = nk_reduce_stable_f64x2_neon_(covariance_zx_f64x2), covariance_z_x_compensation = 0.0;
1050
+ nk_f64_t covariance_z_y = nk_reduce_stable_f64x2_neon_(covariance_zy_f64x2), covariance_z_y_compensation = 0.0;
1051
+ nk_f64_t covariance_z_z = nk_reduce_stable_f64x2_neon_(covariance_zz_f64x2), covariance_z_z_compensation = 0.0;
1052
+ nk_f64_t norm_squared_a_sum = nk_reduce_stable_f64x2_neon_(norm_squared_a_f64x2), norm_squared_a_compensation = 0.0;
1053
+ nk_f64_t norm_squared_b_sum = nk_reduce_stable_f64x2_neon_(norm_squared_b_f64x2), norm_squared_b_compensation = 0.0;
1124
1054
 
1125
1055
  // Scalar tail
1126
1056
  for (; i < n; ++i) {
@@ -1141,9 +1071,12 @@ NK_PUBLIC void nk_umeyama_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size
1141
1071
  nk_accumulate_product_f64_(&covariance_z_x, &covariance_z_x_compensation, az, bx),
1142
1072
  nk_accumulate_product_f64_(&covariance_z_y, &covariance_z_y_compensation, az, by),
1143
1073
  nk_accumulate_product_f64_(&covariance_z_z, &covariance_z_z_compensation, az, bz);
1144
- nk_accumulate_square_f64_(&sum_sq_a, &sum_sq_a_compensation, ax),
1145
- nk_accumulate_square_f64_(&sum_sq_a, &sum_sq_a_compensation, ay),
1146
- nk_accumulate_square_f64_(&sum_sq_a, &sum_sq_a_compensation, az);
1074
+ nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, ax),
1075
+ nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, ay),
1076
+ nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, az);
1077
+ nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, bx),
1078
+ nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, by),
1079
+ nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, bz);
1147
1080
  }
1148
1081
 
1149
1082
  sum_a_x += sum_a_x_compensation, sum_a_y += sum_a_y_compensation, sum_a_z += sum_a_z_compensation;
@@ -1154,7 +1087,8 @@ NK_PUBLIC void nk_umeyama_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size
1154
1087
  covariance_y_z += covariance_y_z_compensation;
1155
1088
  covariance_z_x += covariance_z_x_compensation, covariance_z_y += covariance_z_y_compensation,
1156
1089
  covariance_z_z += covariance_z_z_compensation;
1157
- sum_sq_a += sum_sq_a_compensation;
1090
+ norm_squared_a_sum += norm_squared_a_compensation;
1091
+ norm_squared_b_sum += norm_squared_b_compensation;
1158
1092
 
1159
1093
  // Compute centroids
1160
1094
  nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
@@ -1163,9 +1097,15 @@ NK_PUBLIC void nk_umeyama_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size
1163
1097
  if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
1164
1098
  if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
1165
1099
 
1166
- // Compute variance of A (centered)
1167
- nk_f64_t centroid_sq = centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z;
1168
- nk_f64_t var_a = sum_sq_a * inv_n - centroid_sq;
1100
+ // Centered norm-squared via parallel-axis identity; clamp at zero for numeric safety.
1101
+ nk_f64_t centered_norm_squared_a = norm_squared_a_sum -
1102
+ (nk_f64_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
1103
+ centroid_a_z * centroid_a_z);
1104
+ nk_f64_t centered_norm_squared_b = norm_squared_b_sum -
1105
+ (nk_f64_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
1106
+ centroid_b_z * centroid_b_z);
1107
+ if (centered_norm_squared_a < 0.0) centered_norm_squared_a = 0.0;
1108
+ if (centered_norm_squared_b < 0.0) centered_norm_squared_b = 0.0;
1169
1109
 
1170
1110
  // Apply centering correction: H_centered = H - n * centroid_a * centroid_bᵀ
1171
1111
  covariance_x_x -= (nk_f64_t)n * centroid_a_x * centroid_b_x;
@@ -1181,29 +1121,57 @@ NK_PUBLIC void nk_umeyama_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size
1181
1121
  // Compute SVD
1182
1122
  nk_f64_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
1183
1123
  covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
1184
- nk_f64_t svd_u[9], svd_s[9], svd_v[9];
1185
- nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
1186
-
1187
- nk_f64_t r[9];
1188
- nk_rotation_from_svd_f64_neon_(svd_u, svd_v, r);
1189
1124
 
1190
- // Handle reflection and compute scale
1191
- nk_f64_t det = nk_det3x3_f64_(r);
1192
- nk_f64_t trace_d_s = svd_s[0] + svd_s[4] + (det < 0 ? -svd_s[8] : svd_s[8]);
1193
- nk_f64_t computed_scale = trace_d_s / ((nk_f64_t)n * var_a);
1194
-
1195
- if (det < 0) {
1196
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
1197
- nk_rotation_from_svd_f64_neon_(svd_u, svd_v, r);
1125
+ // Identity-dominant short-circuit: if H ≈ diag(positive entries), R = I and trace(R·H) = trace(H).
1126
+ nk_f64_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
1127
+ cross_covariance[4] * cross_covariance[4] +
1128
+ cross_covariance[8] * cross_covariance[8];
1129
+ nk_f64_t covariance_offdiagonal_norm_squared =
1130
+ cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
1131
+ cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
1132
+ cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
1133
+ nk_f64_t optimal_rotation[9];
1134
+ nk_f64_t trace_rotation_covariance;
1135
+ nk_f64_t computed_scale;
1136
+ if (covariance_offdiagonal_norm_squared < 1e-20 * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0 &&
1137
+ cross_covariance[4] > 0.0 && cross_covariance[8] > 0.0) {
1138
+ optimal_rotation[0] = 1, optimal_rotation[1] = 0, optimal_rotation[2] = 0, optimal_rotation[3] = 0,
1139
+ optimal_rotation[4] = 1, optimal_rotation[5] = 0, optimal_rotation[6] = 0, optimal_rotation[7] = 0,
1140
+ optimal_rotation[8] = 1;
1141
+ trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
1142
+ computed_scale = centered_norm_squared_a > 0.0 ? trace_rotation_covariance / centered_norm_squared_a : 0.0;
1143
+ }
1144
+ else {
1145
+ nk_f64_t svd_left[9], svd_diagonal[9], svd_right[9];
1146
+ nk_svd3x3_f64_(cross_covariance, svd_left, svd_diagonal, svd_right);
1147
+ nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
1148
+
1149
+ // Handle reflection and compute scale
1150
+ nk_f64_t det = nk_det3x3_f64_(optimal_rotation);
1151
+ nk_f64_t trace_d_s = svd_diagonal[0] + svd_diagonal[4] + (det < 0 ? -svd_diagonal[8] : svd_diagonal[8]);
1152
+ computed_scale = centered_norm_squared_a > 0.0 ? trace_d_s / centered_norm_squared_a : 0.0;
1153
+
1154
+ if (det < 0) {
1155
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
1156
+ nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
1157
+ }
1158
+
1159
+ trace_rotation_covariance =
1160
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
1161
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
1162
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
1163
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
1164
+ optimal_rotation[8] * cross_covariance[8];
1198
1165
  }
1199
1166
 
1200
1167
  if (rotation)
1201
- for (int j = 0; j < 9; ++j) rotation[j] = r[j];
1168
+ for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
1202
1169
  if (scale) *scale = computed_scale;
1203
1170
 
1204
- // Compute RMSD after transformation
1205
- nk_f64_t sum_squared = nk_transformed_ssd_f64_neon_(a, b, n, r, computed_scale, centroid_a_x, centroid_a_y,
1206
- centroid_a_z, centroid_b_x, centroid_b_y, centroid_b_z);
1171
+ // Folded SSD with scale: c²·‖a-ā‖² + ‖b-b̄‖² − 2c·trace(R · H_centered).
1172
+ nk_f64_t sum_squared = computed_scale * computed_scale * centered_norm_squared_a + centered_norm_squared_b -
1173
+ 2.0 * computed_scale * trace_rotation_covariance;
1174
+ if (sum_squared < 0.0) sum_squared = 0.0;
1207
1175
  *result = nk_f64_sqrt_neon(sum_squared * inv_n);
1208
1176
  }
1209
1177
 
@@ -1240,162 +1208,21 @@ NK_INTERNAL void nk_partial_deinterleave_f16_to_f32x4x2_neon_(nk_f16_t const *pt
1240
1208
  z_low_out, z_high_out);
1241
1209
  }
1242
1210
 
1243
- NK_INTERNAL nk_f32_t nk_transformed_ssd_f16_neon_(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t const *r,
1244
- nk_f32_t scale, nk_f32_t centroid_a_x, nk_f32_t centroid_a_y,
1245
- nk_f32_t centroid_a_z, nk_f32_t centroid_b_x, nk_f32_t centroid_b_y,
1246
- nk_f32_t centroid_b_z) {
1247
- // Compute sum of squared differences after rigid transformation.
1248
- // Used by Kabsch algorithm for RMSD computation after rotation is applied.
1249
- float32x4_t const centroid_a_x_f32x4 = vdupq_n_f32(centroid_a_x);
1250
- float32x4_t const centroid_a_y_f32x4 = vdupq_n_f32(centroid_a_y);
1251
- float32x4_t const centroid_a_z_f32x4 = vdupq_n_f32(centroid_a_z);
1252
- float32x4_t const centroid_b_x_f32x4 = vdupq_n_f32(centroid_b_x);
1253
- float32x4_t const centroid_b_y_f32x4 = vdupq_n_f32(centroid_b_y);
1254
- float32x4_t const centroid_b_z_f32x4 = vdupq_n_f32(centroid_b_z);
1255
- float32x4_t const scale_f32x4 = vdupq_n_f32(scale);
1256
-
1257
- // Load rotation matrix elements
1258
- float32x4_t const r00_f32x4 = vdupq_n_f32(r[0]), r01_f32x4 = vdupq_n_f32(r[1]), r02_f32x4 = vdupq_n_f32(r[2]);
1259
- float32x4_t const r10_f32x4 = vdupq_n_f32(r[3]), r11_f32x4 = vdupq_n_f32(r[4]), r12_f32x4 = vdupq_n_f32(r[5]);
1260
- float32x4_t const r20_f32x4 = vdupq_n_f32(r[6]), r21_f32x4 = vdupq_n_f32(r[7]), r22_f32x4 = vdupq_n_f32(r[8]);
1261
-
1262
- float32x4_t sum_squared_f32x4 = vdupq_n_f32(0);
1263
- float32x4_t a_x_low_f32x4, a_x_high_f32x4, a_y_low_f32x4, a_y_high_f32x4, a_z_low_f32x4, a_z_high_f32x4;
1264
- float32x4_t b_x_low_f32x4, b_x_high_f32x4, b_y_low_f32x4, b_y_high_f32x4, b_z_low_f32x4, b_z_high_f32x4;
1265
-
1266
- nk_size_t j = 0;
1267
- for (; j + 8 <= n; j += 8) {
1268
- nk_deinterleave_f16x8_to_f32x4x2_neon_(a + j * 3, &a_x_low_f32x4, &a_x_high_f32x4, &a_y_low_f32x4,
1269
- &a_y_high_f32x4, &a_z_low_f32x4, &a_z_high_f32x4);
1270
- nk_deinterleave_f16x8_to_f32x4x2_neon_(b + j * 3, &b_x_low_f32x4, &b_x_high_f32x4, &b_y_low_f32x4,
1271
- &b_y_high_f32x4, &b_z_low_f32x4, &b_z_high_f32x4);
1272
-
1273
- // Center points → low half
1274
- float32x4_t pa_x_f32x4 = vsubq_f32(a_x_low_f32x4, centroid_a_x_f32x4);
1275
- float32x4_t pa_y_f32x4 = vsubq_f32(a_y_low_f32x4, centroid_a_y_f32x4);
1276
- float32x4_t pa_z_f32x4 = vsubq_f32(a_z_low_f32x4, centroid_a_z_f32x4);
1277
- float32x4_t pb_x_f32x4 = vsubq_f32(b_x_low_f32x4, centroid_b_x_f32x4);
1278
- float32x4_t pb_y_f32x4 = vsubq_f32(b_y_low_f32x4, centroid_b_y_f32x4);
1279
- float32x4_t pb_z_f32x4 = vsubq_f32(b_z_low_f32x4, centroid_b_z_f32x4);
1280
- float32x4_t ra_x_f32x4 = vmulq_f32(
1281
- scale_f32x4,
1282
- vfmaq_f32(vfmaq_f32(vmulq_f32(r00_f32x4, pa_x_f32x4), r01_f32x4, pa_y_f32x4), r02_f32x4, pa_z_f32x4));
1283
- float32x4_t ra_y_f32x4 = vmulq_f32(
1284
- scale_f32x4,
1285
- vfmaq_f32(vfmaq_f32(vmulq_f32(r10_f32x4, pa_x_f32x4), r11_f32x4, pa_y_f32x4), r12_f32x4, pa_z_f32x4));
1286
- float32x4_t ra_z_f32x4 = vmulq_f32(
1287
- scale_f32x4,
1288
- vfmaq_f32(vfmaq_f32(vmulq_f32(r20_f32x4, pa_x_f32x4), r21_f32x4, pa_y_f32x4), r22_f32x4, pa_z_f32x4));
1289
- float32x4_t delta_x_f32x4 = vsubq_f32(ra_x_f32x4, pb_x_f32x4);
1290
- float32x4_t delta_y_f32x4 = vsubq_f32(ra_y_f32x4, pb_y_f32x4);
1291
- float32x4_t delta_z_f32x4 = vsubq_f32(ra_z_f32x4, pb_z_f32x4);
1292
- sum_squared_f32x4 = vfmaq_f32(sum_squared_f32x4, delta_x_f32x4, delta_x_f32x4);
1293
- sum_squared_f32x4 = vfmaq_f32(sum_squared_f32x4, delta_y_f32x4, delta_y_f32x4);
1294
- sum_squared_f32x4 = vfmaq_f32(sum_squared_f32x4, delta_z_f32x4, delta_z_f32x4);
1295
-
1296
- // Center points → high half
1297
- pa_x_f32x4 = vsubq_f32(a_x_high_f32x4, centroid_a_x_f32x4);
1298
- pa_y_f32x4 = vsubq_f32(a_y_high_f32x4, centroid_a_y_f32x4);
1299
- pa_z_f32x4 = vsubq_f32(a_z_high_f32x4, centroid_a_z_f32x4);
1300
- pb_x_f32x4 = vsubq_f32(b_x_high_f32x4, centroid_b_x_f32x4);
1301
- pb_y_f32x4 = vsubq_f32(b_y_high_f32x4, centroid_b_y_f32x4);
1302
- pb_z_f32x4 = vsubq_f32(b_z_high_f32x4, centroid_b_z_f32x4);
1303
- ra_x_f32x4 = vmulq_f32(
1304
- scale_f32x4,
1305
- vfmaq_f32(vfmaq_f32(vmulq_f32(r00_f32x4, pa_x_f32x4), r01_f32x4, pa_y_f32x4), r02_f32x4, pa_z_f32x4));
1306
- ra_y_f32x4 = vmulq_f32(
1307
- scale_f32x4,
1308
- vfmaq_f32(vfmaq_f32(vmulq_f32(r10_f32x4, pa_x_f32x4), r11_f32x4, pa_y_f32x4), r12_f32x4, pa_z_f32x4));
1309
- ra_z_f32x4 = vmulq_f32(
1310
- scale_f32x4,
1311
- vfmaq_f32(vfmaq_f32(vmulq_f32(r20_f32x4, pa_x_f32x4), r21_f32x4, pa_y_f32x4), r22_f32x4, pa_z_f32x4));
1312
- delta_x_f32x4 = vsubq_f32(ra_x_f32x4, pb_x_f32x4);
1313
- delta_y_f32x4 = vsubq_f32(ra_y_f32x4, pb_y_f32x4);
1314
- delta_z_f32x4 = vsubq_f32(ra_z_f32x4, pb_z_f32x4);
1315
- sum_squared_f32x4 = vfmaq_f32(sum_squared_f32x4, delta_x_f32x4, delta_x_f32x4);
1316
- sum_squared_f32x4 = vfmaq_f32(sum_squared_f32x4, delta_y_f32x4, delta_y_f32x4);
1317
- sum_squared_f32x4 = vfmaq_f32(sum_squared_f32x4, delta_z_f32x4, delta_z_f32x4);
1318
- }
1319
-
1320
- // Reduce to scalar
1321
- nk_f32_t sum_squared = vaddvq_f32(sum_squared_f32x4);
1322
-
1323
- if (j < n) {
1324
- nk_partial_deinterleave_f16_to_f32x4x2_neon_(a + j * 3, n - j, &a_x_low_f32x4, &a_x_high_f32x4, &a_y_low_f32x4,
1325
- &a_y_high_f32x4, &a_z_low_f32x4, &a_z_high_f32x4);
1326
- nk_partial_deinterleave_f16_to_f32x4x2_neon_(b + j * 3, n - j, &b_x_low_f32x4, &b_x_high_f32x4, &b_y_low_f32x4,
1327
- &b_y_high_f32x4, &b_z_low_f32x4, &b_z_high_f32x4);
1328
-
1329
- // Low half
1330
- float32x4_t pa_x_f32x4 = vsubq_f32(a_x_low_f32x4, centroid_a_x_f32x4);
1331
- float32x4_t pa_y_f32x4 = vsubq_f32(a_y_low_f32x4, centroid_a_y_f32x4);
1332
- float32x4_t pa_z_f32x4 = vsubq_f32(a_z_low_f32x4, centroid_a_z_f32x4);
1333
- float32x4_t pb_x_f32x4 = vsubq_f32(b_x_low_f32x4, centroid_b_x_f32x4);
1334
- float32x4_t pb_y_f32x4 = vsubq_f32(b_y_low_f32x4, centroid_b_y_f32x4);
1335
- float32x4_t pb_z_f32x4 = vsubq_f32(b_z_low_f32x4, centroid_b_z_f32x4);
1336
- float32x4_t ra_x_f32x4 = vmulq_f32(
1337
- scale_f32x4,
1338
- vfmaq_f32(vfmaq_f32(vmulq_f32(r00_f32x4, pa_x_f32x4), r01_f32x4, pa_y_f32x4), r02_f32x4, pa_z_f32x4));
1339
- float32x4_t ra_y_f32x4 = vmulq_f32(
1340
- scale_f32x4,
1341
- vfmaq_f32(vfmaq_f32(vmulq_f32(r10_f32x4, pa_x_f32x4), r11_f32x4, pa_y_f32x4), r12_f32x4, pa_z_f32x4));
1342
- float32x4_t ra_z_f32x4 = vmulq_f32(
1343
- scale_f32x4,
1344
- vfmaq_f32(vfmaq_f32(vmulq_f32(r20_f32x4, pa_x_f32x4), r21_f32x4, pa_y_f32x4), r22_f32x4, pa_z_f32x4));
1345
- float32x4_t delta_x_f32x4 = vsubq_f32(ra_x_f32x4, pb_x_f32x4);
1346
- float32x4_t delta_y_f32x4 = vsubq_f32(ra_y_f32x4, pb_y_f32x4);
1347
- float32x4_t delta_z_f32x4 = vsubq_f32(ra_z_f32x4, pb_z_f32x4);
1348
- float32x4_t tail_sum_f32x4 = vmulq_f32(delta_x_f32x4, delta_x_f32x4);
1349
- tail_sum_f32x4 = vfmaq_f32(tail_sum_f32x4, delta_y_f32x4, delta_y_f32x4);
1350
- tail_sum_f32x4 = vfmaq_f32(tail_sum_f32x4, delta_z_f32x4, delta_z_f32x4);
1351
-
1352
- // High half
1353
- pa_x_f32x4 = vsubq_f32(a_x_high_f32x4, centroid_a_x_f32x4);
1354
- pa_y_f32x4 = vsubq_f32(a_y_high_f32x4, centroid_a_y_f32x4);
1355
- pa_z_f32x4 = vsubq_f32(a_z_high_f32x4, centroid_a_z_f32x4);
1356
- pb_x_f32x4 = vsubq_f32(b_x_high_f32x4, centroid_b_x_f32x4);
1357
- pb_y_f32x4 = vsubq_f32(b_y_high_f32x4, centroid_b_y_f32x4);
1358
- pb_z_f32x4 = vsubq_f32(b_z_high_f32x4, centroid_b_z_f32x4);
1359
- ra_x_f32x4 = vmulq_f32(
1360
- scale_f32x4,
1361
- vfmaq_f32(vfmaq_f32(vmulq_f32(r00_f32x4, pa_x_f32x4), r01_f32x4, pa_y_f32x4), r02_f32x4, pa_z_f32x4));
1362
- ra_y_f32x4 = vmulq_f32(
1363
- scale_f32x4,
1364
- vfmaq_f32(vfmaq_f32(vmulq_f32(r10_f32x4, pa_x_f32x4), r11_f32x4, pa_y_f32x4), r12_f32x4, pa_z_f32x4));
1365
- ra_z_f32x4 = vmulq_f32(
1366
- scale_f32x4,
1367
- vfmaq_f32(vfmaq_f32(vmulq_f32(r20_f32x4, pa_x_f32x4), r21_f32x4, pa_y_f32x4), r22_f32x4, pa_z_f32x4));
1368
- delta_x_f32x4 = vsubq_f32(ra_x_f32x4, pb_x_f32x4);
1369
- delta_y_f32x4 = vsubq_f32(ra_y_f32x4, pb_y_f32x4);
1370
- delta_z_f32x4 = vsubq_f32(ra_z_f32x4, pb_z_f32x4);
1371
- tail_sum_f32x4 = vfmaq_f32(tail_sum_f32x4, delta_x_f32x4, delta_x_f32x4);
1372
- tail_sum_f32x4 = vfmaq_f32(tail_sum_f32x4, delta_y_f32x4, delta_y_f32x4);
1373
- tail_sum_f32x4 = vfmaq_f32(tail_sum_f32x4, delta_z_f32x4, delta_z_f32x4);
1374
- sum_squared += vaddvq_f32(tail_sum_f32x4);
1375
- }
1376
-
1377
- return sum_squared;
1378
- }
1379
-
1380
1211
  /**
1381
1212
  * @brief RMSD (Root Mean Square Deviation) computation using NEON FP16 with widening to FP32.
1382
- * Computes the RMS of distances between corresponding points after centroid alignment.
1213
+ * Matches the serial-RMSD contract: zero centroids, identity rotation, raw √(Σ‖a-b‖² / n).
1383
1214
  */
1384
1215
  NK_PUBLIC void nk_rmsd_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
1385
1216
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
1386
- // RMSD uses identity rotation and scale=1.0
1387
1217
  if (rotation)
1388
1218
  rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
1389
1219
  rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
1390
1220
  if (scale) *scale = 1.0f;
1221
+ if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
1222
+ if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
1391
1223
 
1392
1224
  float32x4_t const zeros_f32x4 = vdupq_n_f32(0);
1393
-
1394
- // Accumulators for centroids and squared differences (all in f32)
1395
- float32x4_t sum_a_x_f32x4 = zeros_f32x4, sum_a_y_f32x4 = zeros_f32x4, sum_a_z_f32x4 = zeros_f32x4;
1396
- float32x4_t sum_b_x_f32x4 = zeros_f32x4, sum_b_y_f32x4 = zeros_f32x4, sum_b_z_f32x4 = zeros_f32x4;
1397
1225
  float32x4_t sum_squared_x_f32x4 = zeros_f32x4, sum_squared_y_f32x4 = zeros_f32x4, sum_squared_z_f32x4 = zeros_f32x4;
1398
-
1399
1226
  float32x4_t a_x_low_f32x4, a_x_high_f32x4, a_y_low_f32x4, a_y_high_f32x4, a_z_low_f32x4, a_z_high_f32x4;
1400
1227
  float32x4_t b_x_low_f32x4, b_x_high_f32x4, b_y_low_f32x4, b_y_high_f32x4, b_z_low_f32x4, b_z_high_f32x4;
1401
1228
  nk_size_t i = 0;
@@ -1406,13 +1233,6 @@ NK_PUBLIC void nk_rmsd_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t
1406
1233
  nk_deinterleave_f16x8_to_f32x4x2_neon_(b + i * 3, &b_x_low_f32x4, &b_x_high_f32x4, &b_y_low_f32x4,
1407
1234
  &b_y_high_f32x4, &b_z_low_f32x4, &b_z_high_f32x4);
1408
1235
 
1409
- sum_a_x_f32x4 = vaddq_f32(vaddq_f32(sum_a_x_f32x4, a_x_low_f32x4), a_x_high_f32x4);
1410
- sum_a_y_f32x4 = vaddq_f32(vaddq_f32(sum_a_y_f32x4, a_y_low_f32x4), a_y_high_f32x4);
1411
- sum_a_z_f32x4 = vaddq_f32(vaddq_f32(sum_a_z_f32x4, a_z_low_f32x4), a_z_high_f32x4);
1412
- sum_b_x_f32x4 = vaddq_f32(vaddq_f32(sum_b_x_f32x4, b_x_low_f32x4), b_x_high_f32x4);
1413
- sum_b_y_f32x4 = vaddq_f32(vaddq_f32(sum_b_y_f32x4, b_y_low_f32x4), b_y_high_f32x4);
1414
- sum_b_z_f32x4 = vaddq_f32(vaddq_f32(sum_b_z_f32x4, b_z_low_f32x4), b_z_high_f32x4);
1415
-
1416
1236
  float32x4_t delta_x_f32x4 = vsubq_f32(a_x_low_f32x4, b_x_low_f32x4);
1417
1237
  float32x4_t delta_y_f32x4 = vsubq_f32(a_y_low_f32x4, b_y_low_f32x4);
1418
1238
  float32x4_t delta_z_f32x4 = vsubq_f32(a_z_low_f32x4, b_z_low_f32x4);
@@ -1434,13 +1254,6 @@ NK_PUBLIC void nk_rmsd_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t
1434
1254
  nk_partial_deinterleave_f16_to_f32x4x2_neon_(b + i * 3, n - i, &b_x_low_f32x4, &b_x_high_f32x4, &b_y_low_f32x4,
1435
1255
  &b_y_high_f32x4, &b_z_low_f32x4, &b_z_high_f32x4);
1436
1256
 
1437
- sum_a_x_f32x4 = vaddq_f32(vaddq_f32(sum_a_x_f32x4, a_x_low_f32x4), a_x_high_f32x4);
1438
- sum_a_y_f32x4 = vaddq_f32(vaddq_f32(sum_a_y_f32x4, a_y_low_f32x4), a_y_high_f32x4);
1439
- sum_a_z_f32x4 = vaddq_f32(vaddq_f32(sum_a_z_f32x4, a_z_low_f32x4), a_z_high_f32x4);
1440
- sum_b_x_f32x4 = vaddq_f32(vaddq_f32(sum_b_x_f32x4, b_x_low_f32x4), b_x_high_f32x4);
1441
- sum_b_y_f32x4 = vaddq_f32(vaddq_f32(sum_b_y_f32x4, b_y_low_f32x4), b_y_high_f32x4);
1442
- sum_b_z_f32x4 = vaddq_f32(vaddq_f32(sum_b_z_f32x4, b_z_low_f32x4), b_z_high_f32x4);
1443
-
1444
1257
  float32x4_t delta_x_f32x4 = vsubq_f32(a_x_low_f32x4, b_x_low_f32x4);
1445
1258
  float32x4_t delta_y_f32x4 = vsubq_f32(a_y_low_f32x4, b_y_low_f32x4);
1446
1259
  float32x4_t delta_z_f32x4 = vsubq_f32(a_z_low_f32x4, b_z_low_f32x4);
@@ -1456,37 +1269,9 @@ NK_PUBLIC void nk_rmsd_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t
1456
1269
  sum_squared_z_f32x4 = vfmaq_f32(sum_squared_z_f32x4, delta_z_f32x4, delta_z_f32x4);
1457
1270
  }
1458
1271
 
1459
- // Reduce vectors to scalars
1460
- nk_f32_t total_ax = vaddvq_f32(sum_a_x_f32x4);
1461
- nk_f32_t total_ay = vaddvq_f32(sum_a_y_f32x4);
1462
- nk_f32_t total_az = vaddvq_f32(sum_a_z_f32x4);
1463
- nk_f32_t total_bx = vaddvq_f32(sum_b_x_f32x4);
1464
- nk_f32_t total_by = vaddvq_f32(sum_b_y_f32x4);
1465
- nk_f32_t total_bz = vaddvq_f32(sum_b_z_f32x4);
1466
- nk_f32_t total_sq_x = vaddvq_f32(sum_squared_x_f32x4);
1467
- nk_f32_t total_sq_y = vaddvq_f32(sum_squared_y_f32x4);
1468
- nk_f32_t total_sq_z = vaddvq_f32(sum_squared_z_f32x4);
1469
-
1470
- // Compute centroids
1471
- nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
1472
- nk_f32_t centroid_a_x = total_ax * inv_n;
1473
- nk_f32_t centroid_a_y = total_ay * inv_n;
1474
- nk_f32_t centroid_a_z = total_az * inv_n;
1475
- nk_f32_t centroid_b_x = total_bx * inv_n;
1476
- nk_f32_t centroid_b_y = total_by * inv_n;
1477
- nk_f32_t centroid_b_z = total_bz * inv_n;
1478
-
1479
- if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
1480
- if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
1481
-
1482
- // Compute RMSD
1483
- nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
1484
- nk_f32_t mean_diff_y = centroid_a_y - centroid_b_y;
1485
- nk_f32_t mean_diff_z = centroid_a_z - centroid_b_z;
1486
- nk_f32_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
1487
- nk_f32_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
1488
-
1489
- *result = nk_f32_sqrt_neon(sum_squared * inv_n - mean_diff_sq);
1272
+ nk_f32_t sum_squared = vaddvq_f32(sum_squared_x_f32x4) + vaddvq_f32(sum_squared_y_f32x4) +
1273
+ vaddvq_f32(sum_squared_z_f32x4);
1274
+ *result = nk_f32_sqrt_neon(sum_squared / (nk_f32_t)n);
1490
1275
  }
1491
1276
 
1492
1277
  /**
@@ -1503,9 +1288,10 @@ NK_PUBLIC void nk_kabsch_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_
1503
1288
  float32x4_t sum_b_x_f32x4 = zeros_f32x4, sum_b_y_f32x4 = zeros_f32x4, sum_b_z_f32x4 = zeros_f32x4;
1504
1289
 
1505
1290
  // Accumulators for covariance matrix (sum of outer products)
1506
- float32x4_t cov_xx_f32x4 = zeros_f32x4, cov_xy_f32x4 = zeros_f32x4, cov_xz_f32x4 = zeros_f32x4;
1507
- float32x4_t cov_yx_f32x4 = zeros_f32x4, cov_yy_f32x4 = zeros_f32x4, cov_yz_f32x4 = zeros_f32x4;
1508
- float32x4_t cov_zx_f32x4 = zeros_f32x4, cov_zy_f32x4 = zeros_f32x4, cov_zz_f32x4 = zeros_f32x4;
1291
+ float32x4_t covariance_xx_f32x4 = zeros_f32x4, covariance_xy_f32x4 = zeros_f32x4, covariance_xz_f32x4 = zeros_f32x4;
1292
+ float32x4_t covariance_yx_f32x4 = zeros_f32x4, covariance_yy_f32x4 = zeros_f32x4, covariance_yz_f32x4 = zeros_f32x4;
1293
+ float32x4_t covariance_zx_f32x4 = zeros_f32x4, covariance_zy_f32x4 = zeros_f32x4, covariance_zz_f32x4 = zeros_f32x4;
1294
+ float32x4_t norm_squared_a_f32x4 = zeros_f32x4, norm_squared_b_f32x4 = zeros_f32x4;
1509
1295
 
1510
1296
  nk_size_t i = 0;
1511
1297
  float32x4_t a_x_low_f32x4, a_x_high_f32x4, a_y_low_f32x4, a_y_high_f32x4, a_z_low_f32x4, a_z_high_f32x4;
@@ -1524,15 +1310,36 @@ NK_PUBLIC void nk_kabsch_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_
1524
1310
  sum_b_y_f32x4 = vaddq_f32(vaddq_f32(sum_b_y_f32x4, b_y_low_f32x4), b_y_high_f32x4);
1525
1311
  sum_b_z_f32x4 = vaddq_f32(vaddq_f32(sum_b_z_f32x4, b_z_low_f32x4), b_z_high_f32x4);
1526
1312
 
1527
- cov_xx_f32x4 = vfmaq_f32(vfmaq_f32(cov_xx_f32x4, a_x_low_f32x4, b_x_low_f32x4), a_x_high_f32x4, b_x_high_f32x4);
1528
- cov_xy_f32x4 = vfmaq_f32(vfmaq_f32(cov_xy_f32x4, a_x_low_f32x4, b_y_low_f32x4), a_x_high_f32x4, b_y_high_f32x4);
1529
- cov_xz_f32x4 = vfmaq_f32(vfmaq_f32(cov_xz_f32x4, a_x_low_f32x4, b_z_low_f32x4), a_x_high_f32x4, b_z_high_f32x4);
1530
- cov_yx_f32x4 = vfmaq_f32(vfmaq_f32(cov_yx_f32x4, a_y_low_f32x4, b_x_low_f32x4), a_y_high_f32x4, b_x_high_f32x4);
1531
- cov_yy_f32x4 = vfmaq_f32(vfmaq_f32(cov_yy_f32x4, a_y_low_f32x4, b_y_low_f32x4), a_y_high_f32x4, b_y_high_f32x4);
1532
- cov_yz_f32x4 = vfmaq_f32(vfmaq_f32(cov_yz_f32x4, a_y_low_f32x4, b_z_low_f32x4), a_y_high_f32x4, b_z_high_f32x4);
1533
- cov_zx_f32x4 = vfmaq_f32(vfmaq_f32(cov_zx_f32x4, a_z_low_f32x4, b_x_low_f32x4), a_z_high_f32x4, b_x_high_f32x4);
1534
- cov_zy_f32x4 = vfmaq_f32(vfmaq_f32(cov_zy_f32x4, a_z_low_f32x4, b_y_low_f32x4), a_z_high_f32x4, b_y_high_f32x4);
1535
- cov_zz_f32x4 = vfmaq_f32(vfmaq_f32(cov_zz_f32x4, a_z_low_f32x4, b_z_low_f32x4), a_z_high_f32x4, b_z_high_f32x4);
1313
+ covariance_xx_f32x4 = vfmaq_f32(vfmaq_f32(covariance_xx_f32x4, a_x_low_f32x4, b_x_low_f32x4), a_x_high_f32x4,
1314
+ b_x_high_f32x4);
1315
+ covariance_xy_f32x4 = vfmaq_f32(vfmaq_f32(covariance_xy_f32x4, a_x_low_f32x4, b_y_low_f32x4), a_x_high_f32x4,
1316
+ b_y_high_f32x4);
1317
+ covariance_xz_f32x4 = vfmaq_f32(vfmaq_f32(covariance_xz_f32x4, a_x_low_f32x4, b_z_low_f32x4), a_x_high_f32x4,
1318
+ b_z_high_f32x4);
1319
+ covariance_yx_f32x4 = vfmaq_f32(vfmaq_f32(covariance_yx_f32x4, a_y_low_f32x4, b_x_low_f32x4), a_y_high_f32x4,
1320
+ b_x_high_f32x4);
1321
+ covariance_yy_f32x4 = vfmaq_f32(vfmaq_f32(covariance_yy_f32x4, a_y_low_f32x4, b_y_low_f32x4), a_y_high_f32x4,
1322
+ b_y_high_f32x4);
1323
+ covariance_yz_f32x4 = vfmaq_f32(vfmaq_f32(covariance_yz_f32x4, a_y_low_f32x4, b_z_low_f32x4), a_y_high_f32x4,
1324
+ b_z_high_f32x4);
1325
+ covariance_zx_f32x4 = vfmaq_f32(vfmaq_f32(covariance_zx_f32x4, a_z_low_f32x4, b_x_low_f32x4), a_z_high_f32x4,
1326
+ b_x_high_f32x4);
1327
+ covariance_zy_f32x4 = vfmaq_f32(vfmaq_f32(covariance_zy_f32x4, a_z_low_f32x4, b_y_low_f32x4), a_z_high_f32x4,
1328
+ b_y_high_f32x4);
1329
+ covariance_zz_f32x4 = vfmaq_f32(vfmaq_f32(covariance_zz_f32x4, a_z_low_f32x4, b_z_low_f32x4), a_z_high_f32x4,
1330
+ b_z_high_f32x4);
1331
+ norm_squared_a_f32x4 = vfmaq_f32(vfmaq_f32(norm_squared_a_f32x4, a_x_low_f32x4, a_x_low_f32x4), a_x_high_f32x4,
1332
+ a_x_high_f32x4);
1333
+ norm_squared_a_f32x4 = vfmaq_f32(vfmaq_f32(norm_squared_a_f32x4, a_y_low_f32x4, a_y_low_f32x4), a_y_high_f32x4,
1334
+ a_y_high_f32x4);
1335
+ norm_squared_a_f32x4 = vfmaq_f32(vfmaq_f32(norm_squared_a_f32x4, a_z_low_f32x4, a_z_low_f32x4), a_z_high_f32x4,
1336
+ a_z_high_f32x4);
1337
+ norm_squared_b_f32x4 = vfmaq_f32(vfmaq_f32(norm_squared_b_f32x4, b_x_low_f32x4, b_x_low_f32x4), b_x_high_f32x4,
1338
+ b_x_high_f32x4);
1339
+ norm_squared_b_f32x4 = vfmaq_f32(vfmaq_f32(norm_squared_b_f32x4, b_y_low_f32x4, b_y_low_f32x4), b_y_high_f32x4,
1340
+ b_y_high_f32x4);
1341
+ norm_squared_b_f32x4 = vfmaq_f32(vfmaq_f32(norm_squared_b_f32x4, b_z_low_f32x4, b_z_low_f32x4), b_z_high_f32x4,
1342
+ b_z_high_f32x4);
1536
1343
  }
1537
1344
 
1538
1345
  if (i < n) {
@@ -1548,15 +1355,36 @@ NK_PUBLIC void nk_kabsch_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_
1548
1355
  sum_b_y_f32x4 = vaddq_f32(vaddq_f32(sum_b_y_f32x4, b_y_low_f32x4), b_y_high_f32x4);
1549
1356
  sum_b_z_f32x4 = vaddq_f32(vaddq_f32(sum_b_z_f32x4, b_z_low_f32x4), b_z_high_f32x4);
1550
1357
 
1551
- cov_xx_f32x4 = vfmaq_f32(vfmaq_f32(cov_xx_f32x4, a_x_low_f32x4, b_x_low_f32x4), a_x_high_f32x4, b_x_high_f32x4);
1552
- cov_xy_f32x4 = vfmaq_f32(vfmaq_f32(cov_xy_f32x4, a_x_low_f32x4, b_y_low_f32x4), a_x_high_f32x4, b_y_high_f32x4);
1553
- cov_xz_f32x4 = vfmaq_f32(vfmaq_f32(cov_xz_f32x4, a_x_low_f32x4, b_z_low_f32x4), a_x_high_f32x4, b_z_high_f32x4);
1554
- cov_yx_f32x4 = vfmaq_f32(vfmaq_f32(cov_yx_f32x4, a_y_low_f32x4, b_x_low_f32x4), a_y_high_f32x4, b_x_high_f32x4);
1555
- cov_yy_f32x4 = vfmaq_f32(vfmaq_f32(cov_yy_f32x4, a_y_low_f32x4, b_y_low_f32x4), a_y_high_f32x4, b_y_high_f32x4);
1556
- cov_yz_f32x4 = vfmaq_f32(vfmaq_f32(cov_yz_f32x4, a_y_low_f32x4, b_z_low_f32x4), a_y_high_f32x4, b_z_high_f32x4);
1557
- cov_zx_f32x4 = vfmaq_f32(vfmaq_f32(cov_zx_f32x4, a_z_low_f32x4, b_x_low_f32x4), a_z_high_f32x4, b_x_high_f32x4);
1558
- cov_zy_f32x4 = vfmaq_f32(vfmaq_f32(cov_zy_f32x4, a_z_low_f32x4, b_y_low_f32x4), a_z_high_f32x4, b_y_high_f32x4);
1559
- cov_zz_f32x4 = vfmaq_f32(vfmaq_f32(cov_zz_f32x4, a_z_low_f32x4, b_z_low_f32x4), a_z_high_f32x4, b_z_high_f32x4);
1358
+ covariance_xx_f32x4 = vfmaq_f32(vfmaq_f32(covariance_xx_f32x4, a_x_low_f32x4, b_x_low_f32x4), a_x_high_f32x4,
1359
+ b_x_high_f32x4);
1360
+ covariance_xy_f32x4 = vfmaq_f32(vfmaq_f32(covariance_xy_f32x4, a_x_low_f32x4, b_y_low_f32x4), a_x_high_f32x4,
1361
+ b_y_high_f32x4);
1362
+ covariance_xz_f32x4 = vfmaq_f32(vfmaq_f32(covariance_xz_f32x4, a_x_low_f32x4, b_z_low_f32x4), a_x_high_f32x4,
1363
+ b_z_high_f32x4);
1364
+ covariance_yx_f32x4 = vfmaq_f32(vfmaq_f32(covariance_yx_f32x4, a_y_low_f32x4, b_x_low_f32x4), a_y_high_f32x4,
1365
+ b_x_high_f32x4);
1366
+ covariance_yy_f32x4 = vfmaq_f32(vfmaq_f32(covariance_yy_f32x4, a_y_low_f32x4, b_y_low_f32x4), a_y_high_f32x4,
1367
+ b_y_high_f32x4);
1368
+ covariance_yz_f32x4 = vfmaq_f32(vfmaq_f32(covariance_yz_f32x4, a_y_low_f32x4, b_z_low_f32x4), a_y_high_f32x4,
1369
+ b_z_high_f32x4);
1370
+ covariance_zx_f32x4 = vfmaq_f32(vfmaq_f32(covariance_zx_f32x4, a_z_low_f32x4, b_x_low_f32x4), a_z_high_f32x4,
1371
+ b_x_high_f32x4);
1372
+ covariance_zy_f32x4 = vfmaq_f32(vfmaq_f32(covariance_zy_f32x4, a_z_low_f32x4, b_y_low_f32x4), a_z_high_f32x4,
1373
+ b_y_high_f32x4);
1374
+ covariance_zz_f32x4 = vfmaq_f32(vfmaq_f32(covariance_zz_f32x4, a_z_low_f32x4, b_z_low_f32x4), a_z_high_f32x4,
1375
+ b_z_high_f32x4);
1376
+ norm_squared_a_f32x4 = vfmaq_f32(vfmaq_f32(norm_squared_a_f32x4, a_x_low_f32x4, a_x_low_f32x4), a_x_high_f32x4,
1377
+ a_x_high_f32x4);
1378
+ norm_squared_a_f32x4 = vfmaq_f32(vfmaq_f32(norm_squared_a_f32x4, a_y_low_f32x4, a_y_low_f32x4), a_y_high_f32x4,
1379
+ a_y_high_f32x4);
1380
+ norm_squared_a_f32x4 = vfmaq_f32(vfmaq_f32(norm_squared_a_f32x4, a_z_low_f32x4, a_z_low_f32x4), a_z_high_f32x4,
1381
+ a_z_high_f32x4);
1382
+ norm_squared_b_f32x4 = vfmaq_f32(vfmaq_f32(norm_squared_b_f32x4, b_x_low_f32x4, b_x_low_f32x4), b_x_high_f32x4,
1383
+ b_x_high_f32x4);
1384
+ norm_squared_b_f32x4 = vfmaq_f32(vfmaq_f32(norm_squared_b_f32x4, b_y_low_f32x4, b_y_low_f32x4), b_y_high_f32x4,
1385
+ b_y_high_f32x4);
1386
+ norm_squared_b_f32x4 = vfmaq_f32(vfmaq_f32(norm_squared_b_f32x4, b_z_low_f32x4, b_z_low_f32x4), b_z_high_f32x4,
1387
+ b_z_high_f32x4);
1560
1388
  }
1561
1389
 
1562
1390
  // Reduce vector accumulators
@@ -1567,15 +1395,17 @@ NK_PUBLIC void nk_kabsch_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_
1567
1395
  nk_f32_t sum_b_y = vaddvq_f32(sum_b_y_f32x4);
1568
1396
  nk_f32_t sum_b_z = vaddvq_f32(sum_b_z_f32x4);
1569
1397
 
1570
- nk_f32_t covariance_x_x = vaddvq_f32(cov_xx_f32x4);
1571
- nk_f32_t covariance_x_y = vaddvq_f32(cov_xy_f32x4);
1572
- nk_f32_t covariance_x_z = vaddvq_f32(cov_xz_f32x4);
1573
- nk_f32_t covariance_y_x = vaddvq_f32(cov_yx_f32x4);
1574
- nk_f32_t covariance_y_y = vaddvq_f32(cov_yy_f32x4);
1575
- nk_f32_t covariance_y_z = vaddvq_f32(cov_yz_f32x4);
1576
- nk_f32_t covariance_z_x = vaddvq_f32(cov_zx_f32x4);
1577
- nk_f32_t covariance_z_y = vaddvq_f32(cov_zy_f32x4);
1578
- nk_f32_t covariance_z_z = vaddvq_f32(cov_zz_f32x4);
1398
+ nk_f32_t covariance_x_x = vaddvq_f32(covariance_xx_f32x4);
1399
+ nk_f32_t covariance_x_y = vaddvq_f32(covariance_xy_f32x4);
1400
+ nk_f32_t covariance_x_z = vaddvq_f32(covariance_xz_f32x4);
1401
+ nk_f32_t covariance_y_x = vaddvq_f32(covariance_yx_f32x4);
1402
+ nk_f32_t covariance_y_y = vaddvq_f32(covariance_yy_f32x4);
1403
+ nk_f32_t covariance_y_z = vaddvq_f32(covariance_yz_f32x4);
1404
+ nk_f32_t covariance_z_x = vaddvq_f32(covariance_zx_f32x4);
1405
+ nk_f32_t covariance_z_y = vaddvq_f32(covariance_zy_f32x4);
1406
+ nk_f32_t covariance_z_z = vaddvq_f32(covariance_zz_f32x4);
1407
+ nk_f32_t norm_squared_a = vaddvq_f32(norm_squared_a_f32x4);
1408
+ nk_f32_t norm_squared_b = vaddvq_f32(norm_squared_b_f32x4);
1579
1409
 
1580
1410
  // Compute centroids
1581
1411
  nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
@@ -1591,55 +1421,88 @@ NK_PUBLIC void nk_kabsch_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_
1591
1421
 
1592
1422
  // Compute centered covariance: H = (A - centroid_A)ᵀ * (B - centroid_B)
1593
1423
  // H = sum(a * bᵀ) - n * centroid_a * centroid_bᵀ
1594
- nk_f32_t h[9];
1595
- h[0] = covariance_x_x - (nk_f32_t)n * centroid_a_x * centroid_b_x;
1596
- h[1] = covariance_x_y - (nk_f32_t)n * centroid_a_x * centroid_b_y;
1597
- h[2] = covariance_x_z - (nk_f32_t)n * centroid_a_x * centroid_b_z;
1598
- h[3] = covariance_y_x - (nk_f32_t)n * centroid_a_y * centroid_b_x;
1599
- h[4] = covariance_y_y - (nk_f32_t)n * centroid_a_y * centroid_b_y;
1600
- h[5] = covariance_y_z - (nk_f32_t)n * centroid_a_y * centroid_b_z;
1601
- h[6] = covariance_z_x - (nk_f32_t)n * centroid_a_z * centroid_b_x;
1602
- h[7] = covariance_z_y - (nk_f32_t)n * centroid_a_z * centroid_b_y;
1603
- h[8] = covariance_z_z - (nk_f32_t)n * centroid_a_z * centroid_b_z;
1604
-
1605
- // SVD of H = U * S * Vᵀ
1606
- nk_f32_t svd_u[9], svd_s[9], svd_v[9];
1607
- nk_svd3x3_f32_(h, svd_u, svd_s, svd_v);
1608
-
1609
- // R = V * Uᵀ
1610
- nk_f32_t r[9];
1611
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
1612
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
1613
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
1614
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
1615
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
1616
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
1617
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
1618
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
1619
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1620
-
1621
- // Handle reflection: if det(R) < 0, negate third column of V and recompute
1622
- nk_f32_t det_r = nk_det3x3_f32_(r);
1623
- if (det_r < 0) {
1624
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
1625
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
1626
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
1627
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
1628
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
1629
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
1630
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
1631
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
1632
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
1633
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1424
+ nk_f32_t cross_covariance[9];
1425
+ cross_covariance[0] = covariance_x_x - (nk_f32_t)n * centroid_a_x * centroid_b_x;
1426
+ cross_covariance[1] = covariance_x_y - (nk_f32_t)n * centroid_a_x * centroid_b_y;
1427
+ cross_covariance[2] = covariance_x_z - (nk_f32_t)n * centroid_a_x * centroid_b_z;
1428
+ cross_covariance[3] = covariance_y_x - (nk_f32_t)n * centroid_a_y * centroid_b_x;
1429
+ cross_covariance[4] = covariance_y_y - (nk_f32_t)n * centroid_a_y * centroid_b_y;
1430
+ cross_covariance[5] = covariance_y_z - (nk_f32_t)n * centroid_a_y * centroid_b_z;
1431
+ cross_covariance[6] = covariance_z_x - (nk_f32_t)n * centroid_a_z * centroid_b_x;
1432
+ cross_covariance[7] = covariance_z_y - (nk_f32_t)n * centroid_a_z * centroid_b_y;
1433
+ cross_covariance[8] = covariance_z_z - (nk_f32_t)n * centroid_a_z * centroid_b_z;
1434
+
1435
+ // Identity-dominant short-circuit: if H ≈ diag(positive entries), R = I and trace(R·H) = trace(H).
1436
+ nk_f32_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
1437
+ cross_covariance[4] * cross_covariance[4] +
1438
+ cross_covariance[8] * cross_covariance[8];
1439
+ nk_f32_t covariance_offdiagonal_norm_squared =
1440
+ cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
1441
+ cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
1442
+ cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
1443
+ nk_f32_t optimal_rotation[9];
1444
+ nk_f32_t trace_rotation_covariance;
1445
+ if (covariance_offdiagonal_norm_squared < 1e-12f * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0f &&
1446
+ cross_covariance[4] > 0.0f && cross_covariance[8] > 0.0f) {
1447
+ optimal_rotation[0] = 1, optimal_rotation[1] = 0, optimal_rotation[2] = 0, optimal_rotation[3] = 0,
1448
+ optimal_rotation[4] = 1, optimal_rotation[5] = 0, optimal_rotation[6] = 0, optimal_rotation[7] = 0,
1449
+ optimal_rotation[8] = 1;
1450
+ trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
1451
+ }
1452
+ else {
1453
+ // SVD of H = U * S * Vᵀ
1454
+ nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
1455
+ nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
1456
+
1457
+ // R = V * Uᵀ
1458
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
1459
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
1460
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
1461
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
1462
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
1463
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
1464
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
1465
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
1466
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
1467
+
1468
+ // Handle reflection: if det(R) < 0, negate third column of V and recompute
1469
+ nk_f32_t rotation_determinant = nk_det3x3_f32_(optimal_rotation);
1470
+ if (rotation_determinant < 0) {
1471
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
1472
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
1473
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
1474
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
1475
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
1476
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
1477
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
1478
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
1479
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
1480
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
1481
+ }
1482
+
1483
+ trace_rotation_covariance =
1484
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
1485
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
1486
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
1487
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
1488
+ optimal_rotation[8] * cross_covariance[8];
1634
1489
  }
1635
1490
 
1636
1491
  if (rotation)
1637
- for (int j = 0; j < 9; ++j) rotation[j] = r[j];
1492
+ for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
1638
1493
  if (scale) *scale = 1.0f;
1639
1494
 
1640
- // Compute RMSD after rotation
1641
- nk_f32_t sum_squared = nk_transformed_ssd_f16_neon_(a, b, n, r, 1.0f, centroid_a_x, centroid_a_y, centroid_a_z,
1642
- centroid_b_x, centroid_b_y, centroid_b_z);
1495
+ // Folded SSD via trace identity: SSD = ‖a-ā‖² + ‖b-b̄‖² − 2·trace(R · H_centered).
1496
+ nk_f32_t centered_norm_squared_a = norm_squared_a -
1497
+ (nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
1498
+ centroid_a_z * centroid_a_z);
1499
+ nk_f32_t centered_norm_squared_b = norm_squared_b -
1500
+ (nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
1501
+ centroid_b_z * centroid_b_z);
1502
+ if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
1503
+ if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
1504
+ nk_f32_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0f * trace_rotation_covariance;
1505
+ if (sum_squared < 0.0f) sum_squared = 0.0f;
1643
1506
  *result = nk_f32_sqrt_neon(sum_squared * inv_n);
1644
1507
  }
1645
1508
 
@@ -1650,10 +1513,10 @@ NK_PUBLIC void nk_umeyama_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size
1650
1513
 
1651
1514
  float32x4_t sum_a_x_f32x4 = zeros_f32x4, sum_a_y_f32x4 = zeros_f32x4, sum_a_z_f32x4 = zeros_f32x4;
1652
1515
  float32x4_t sum_b_x_f32x4 = zeros_f32x4, sum_b_y_f32x4 = zeros_f32x4, sum_b_z_f32x4 = zeros_f32x4;
1653
- float32x4_t cov_xx_f32x4 = zeros_f32x4, cov_xy_f32x4 = zeros_f32x4, cov_xz_f32x4 = zeros_f32x4;
1654
- float32x4_t cov_yx_f32x4 = zeros_f32x4, cov_yy_f32x4 = zeros_f32x4, cov_yz_f32x4 = zeros_f32x4;
1655
- float32x4_t cov_zx_f32x4 = zeros_f32x4, cov_zy_f32x4 = zeros_f32x4, cov_zz_f32x4 = zeros_f32x4;
1656
- float32x4_t variance_a_f32x4 = zeros_f32x4;
1516
+ float32x4_t covariance_xx_f32x4 = zeros_f32x4, covariance_xy_f32x4 = zeros_f32x4, covariance_xz_f32x4 = zeros_f32x4;
1517
+ float32x4_t covariance_yx_f32x4 = zeros_f32x4, covariance_yy_f32x4 = zeros_f32x4, covariance_yz_f32x4 = zeros_f32x4;
1518
+ float32x4_t covariance_zx_f32x4 = zeros_f32x4, covariance_zy_f32x4 = zeros_f32x4, covariance_zz_f32x4 = zeros_f32x4;
1519
+ float32x4_t norm_squared_a_f32x4 = zeros_f32x4, norm_squared_b_f32x4 = zeros_f32x4;
1657
1520
 
1658
1521
  nk_size_t i = 0;
1659
1522
  float32x4_t a_x_low_f32x4, a_x_high_f32x4, a_y_low_f32x4, a_y_high_f32x4, a_z_low_f32x4, a_z_high_f32x4;
@@ -1672,22 +1535,37 @@ NK_PUBLIC void nk_umeyama_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size
1672
1535
  sum_b_y_f32x4 = vaddq_f32(vaddq_f32(sum_b_y_f32x4, b_y_low_f32x4), b_y_high_f32x4);
1673
1536
  sum_b_z_f32x4 = vaddq_f32(vaddq_f32(sum_b_z_f32x4, b_z_low_f32x4), b_z_high_f32x4);
1674
1537
 
1675
- cov_xx_f32x4 = vfmaq_f32(vfmaq_f32(cov_xx_f32x4, a_x_low_f32x4, b_x_low_f32x4), a_x_high_f32x4, b_x_high_f32x4);
1676
- cov_xy_f32x4 = vfmaq_f32(vfmaq_f32(cov_xy_f32x4, a_x_low_f32x4, b_y_low_f32x4), a_x_high_f32x4, b_y_high_f32x4);
1677
- cov_xz_f32x4 = vfmaq_f32(vfmaq_f32(cov_xz_f32x4, a_x_low_f32x4, b_z_low_f32x4), a_x_high_f32x4, b_z_high_f32x4);
1678
- cov_yx_f32x4 = vfmaq_f32(vfmaq_f32(cov_yx_f32x4, a_y_low_f32x4, b_x_low_f32x4), a_y_high_f32x4, b_x_high_f32x4);
1679
- cov_yy_f32x4 = vfmaq_f32(vfmaq_f32(cov_yy_f32x4, a_y_low_f32x4, b_y_low_f32x4), a_y_high_f32x4, b_y_high_f32x4);
1680
- cov_yz_f32x4 = vfmaq_f32(vfmaq_f32(cov_yz_f32x4, a_y_low_f32x4, b_z_low_f32x4), a_y_high_f32x4, b_z_high_f32x4);
1681
- cov_zx_f32x4 = vfmaq_f32(vfmaq_f32(cov_zx_f32x4, a_z_low_f32x4, b_x_low_f32x4), a_z_high_f32x4, b_x_high_f32x4);
1682
- cov_zy_f32x4 = vfmaq_f32(vfmaq_f32(cov_zy_f32x4, a_z_low_f32x4, b_y_low_f32x4), a_z_high_f32x4, b_y_high_f32x4);
1683
- cov_zz_f32x4 = vfmaq_f32(vfmaq_f32(cov_zz_f32x4, a_z_low_f32x4, b_z_low_f32x4), a_z_high_f32x4, b_z_high_f32x4);
1684
-
1685
- variance_a_f32x4 = vfmaq_f32(variance_a_f32x4, a_x_low_f32x4, a_x_low_f32x4);
1686
- variance_a_f32x4 = vfmaq_f32(variance_a_f32x4, a_y_low_f32x4, a_y_low_f32x4);
1687
- variance_a_f32x4 = vfmaq_f32(variance_a_f32x4, a_z_low_f32x4, a_z_low_f32x4);
1688
- variance_a_f32x4 = vfmaq_f32(variance_a_f32x4, a_x_high_f32x4, a_x_high_f32x4);
1689
- variance_a_f32x4 = vfmaq_f32(variance_a_f32x4, a_y_high_f32x4, a_y_high_f32x4);
1690
- variance_a_f32x4 = vfmaq_f32(variance_a_f32x4, a_z_high_f32x4, a_z_high_f32x4);
1538
+ covariance_xx_f32x4 = vfmaq_f32(vfmaq_f32(covariance_xx_f32x4, a_x_low_f32x4, b_x_low_f32x4), a_x_high_f32x4,
1539
+ b_x_high_f32x4);
1540
+ covariance_xy_f32x4 = vfmaq_f32(vfmaq_f32(covariance_xy_f32x4, a_x_low_f32x4, b_y_low_f32x4), a_x_high_f32x4,
1541
+ b_y_high_f32x4);
1542
+ covariance_xz_f32x4 = vfmaq_f32(vfmaq_f32(covariance_xz_f32x4, a_x_low_f32x4, b_z_low_f32x4), a_x_high_f32x4,
1543
+ b_z_high_f32x4);
1544
+ covariance_yx_f32x4 = vfmaq_f32(vfmaq_f32(covariance_yx_f32x4, a_y_low_f32x4, b_x_low_f32x4), a_y_high_f32x4,
1545
+ b_x_high_f32x4);
1546
+ covariance_yy_f32x4 = vfmaq_f32(vfmaq_f32(covariance_yy_f32x4, a_y_low_f32x4, b_y_low_f32x4), a_y_high_f32x4,
1547
+ b_y_high_f32x4);
1548
+ covariance_yz_f32x4 = vfmaq_f32(vfmaq_f32(covariance_yz_f32x4, a_y_low_f32x4, b_z_low_f32x4), a_y_high_f32x4,
1549
+ b_z_high_f32x4);
1550
+ covariance_zx_f32x4 = vfmaq_f32(vfmaq_f32(covariance_zx_f32x4, a_z_low_f32x4, b_x_low_f32x4), a_z_high_f32x4,
1551
+ b_x_high_f32x4);
1552
+ covariance_zy_f32x4 = vfmaq_f32(vfmaq_f32(covariance_zy_f32x4, a_z_low_f32x4, b_y_low_f32x4), a_z_high_f32x4,
1553
+ b_y_high_f32x4);
1554
+ covariance_zz_f32x4 = vfmaq_f32(vfmaq_f32(covariance_zz_f32x4, a_z_low_f32x4, b_z_low_f32x4), a_z_high_f32x4,
1555
+ b_z_high_f32x4);
1556
+
1557
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_x_low_f32x4, a_x_low_f32x4);
1558
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_y_low_f32x4, a_y_low_f32x4);
1559
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_z_low_f32x4, a_z_low_f32x4);
1560
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_x_high_f32x4, a_x_high_f32x4);
1561
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_y_high_f32x4, a_y_high_f32x4);
1562
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_z_high_f32x4, a_z_high_f32x4);
1563
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_x_low_f32x4, b_x_low_f32x4);
1564
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_y_low_f32x4, b_y_low_f32x4);
1565
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_z_low_f32x4, b_z_low_f32x4);
1566
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_x_high_f32x4, b_x_high_f32x4);
1567
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_y_high_f32x4, b_y_high_f32x4);
1568
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_z_high_f32x4, b_z_high_f32x4);
1691
1569
  }
1692
1570
 
1693
1571
  if (i < n) {
@@ -1703,22 +1581,37 @@ NK_PUBLIC void nk_umeyama_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size
1703
1581
  sum_b_y_f32x4 = vaddq_f32(vaddq_f32(sum_b_y_f32x4, b_y_low_f32x4), b_y_high_f32x4);
1704
1582
  sum_b_z_f32x4 = vaddq_f32(vaddq_f32(sum_b_z_f32x4, b_z_low_f32x4), b_z_high_f32x4);
1705
1583
 
1706
- cov_xx_f32x4 = vfmaq_f32(vfmaq_f32(cov_xx_f32x4, a_x_low_f32x4, b_x_low_f32x4), a_x_high_f32x4, b_x_high_f32x4);
1707
- cov_xy_f32x4 = vfmaq_f32(vfmaq_f32(cov_xy_f32x4, a_x_low_f32x4, b_y_low_f32x4), a_x_high_f32x4, b_y_high_f32x4);
1708
- cov_xz_f32x4 = vfmaq_f32(vfmaq_f32(cov_xz_f32x4, a_x_low_f32x4, b_z_low_f32x4), a_x_high_f32x4, b_z_high_f32x4);
1709
- cov_yx_f32x4 = vfmaq_f32(vfmaq_f32(cov_yx_f32x4, a_y_low_f32x4, b_x_low_f32x4), a_y_high_f32x4, b_x_high_f32x4);
1710
- cov_yy_f32x4 = vfmaq_f32(vfmaq_f32(cov_yy_f32x4, a_y_low_f32x4, b_y_low_f32x4), a_y_high_f32x4, b_y_high_f32x4);
1711
- cov_yz_f32x4 = vfmaq_f32(vfmaq_f32(cov_yz_f32x4, a_y_low_f32x4, b_z_low_f32x4), a_y_high_f32x4, b_z_high_f32x4);
1712
- cov_zx_f32x4 = vfmaq_f32(vfmaq_f32(cov_zx_f32x4, a_z_low_f32x4, b_x_low_f32x4), a_z_high_f32x4, b_x_high_f32x4);
1713
- cov_zy_f32x4 = vfmaq_f32(vfmaq_f32(cov_zy_f32x4, a_z_low_f32x4, b_y_low_f32x4), a_z_high_f32x4, b_y_high_f32x4);
1714
- cov_zz_f32x4 = vfmaq_f32(vfmaq_f32(cov_zz_f32x4, a_z_low_f32x4, b_z_low_f32x4), a_z_high_f32x4, b_z_high_f32x4);
1715
-
1716
- variance_a_f32x4 = vfmaq_f32(variance_a_f32x4, a_x_low_f32x4, a_x_low_f32x4);
1717
- variance_a_f32x4 = vfmaq_f32(variance_a_f32x4, a_y_low_f32x4, a_y_low_f32x4);
1718
- variance_a_f32x4 = vfmaq_f32(variance_a_f32x4, a_z_low_f32x4, a_z_low_f32x4);
1719
- variance_a_f32x4 = vfmaq_f32(variance_a_f32x4, a_x_high_f32x4, a_x_high_f32x4);
1720
- variance_a_f32x4 = vfmaq_f32(variance_a_f32x4, a_y_high_f32x4, a_y_high_f32x4);
1721
- variance_a_f32x4 = vfmaq_f32(variance_a_f32x4, a_z_high_f32x4, a_z_high_f32x4);
1584
+ covariance_xx_f32x4 = vfmaq_f32(vfmaq_f32(covariance_xx_f32x4, a_x_low_f32x4, b_x_low_f32x4), a_x_high_f32x4,
1585
+ b_x_high_f32x4);
1586
+ covariance_xy_f32x4 = vfmaq_f32(vfmaq_f32(covariance_xy_f32x4, a_x_low_f32x4, b_y_low_f32x4), a_x_high_f32x4,
1587
+ b_y_high_f32x4);
1588
+ covariance_xz_f32x4 = vfmaq_f32(vfmaq_f32(covariance_xz_f32x4, a_x_low_f32x4, b_z_low_f32x4), a_x_high_f32x4,
1589
+ b_z_high_f32x4);
1590
+ covariance_yx_f32x4 = vfmaq_f32(vfmaq_f32(covariance_yx_f32x4, a_y_low_f32x4, b_x_low_f32x4), a_y_high_f32x4,
1591
+ b_x_high_f32x4);
1592
+ covariance_yy_f32x4 = vfmaq_f32(vfmaq_f32(covariance_yy_f32x4, a_y_low_f32x4, b_y_low_f32x4), a_y_high_f32x4,
1593
+ b_y_high_f32x4);
1594
+ covariance_yz_f32x4 = vfmaq_f32(vfmaq_f32(covariance_yz_f32x4, a_y_low_f32x4, b_z_low_f32x4), a_y_high_f32x4,
1595
+ b_z_high_f32x4);
1596
+ covariance_zx_f32x4 = vfmaq_f32(vfmaq_f32(covariance_zx_f32x4, a_z_low_f32x4, b_x_low_f32x4), a_z_high_f32x4,
1597
+ b_x_high_f32x4);
1598
+ covariance_zy_f32x4 = vfmaq_f32(vfmaq_f32(covariance_zy_f32x4, a_z_low_f32x4, b_y_low_f32x4), a_z_high_f32x4,
1599
+ b_y_high_f32x4);
1600
+ covariance_zz_f32x4 = vfmaq_f32(vfmaq_f32(covariance_zz_f32x4, a_z_low_f32x4, b_z_low_f32x4), a_z_high_f32x4,
1601
+ b_z_high_f32x4);
1602
+
1603
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_x_low_f32x4, a_x_low_f32x4);
1604
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_y_low_f32x4, a_y_low_f32x4);
1605
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_z_low_f32x4, a_z_low_f32x4);
1606
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_x_high_f32x4, a_x_high_f32x4);
1607
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_y_high_f32x4, a_y_high_f32x4);
1608
+ norm_squared_a_f32x4 = vfmaq_f32(norm_squared_a_f32x4, a_z_high_f32x4, a_z_high_f32x4);
1609
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_x_low_f32x4, b_x_low_f32x4);
1610
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_y_low_f32x4, b_y_low_f32x4);
1611
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_z_low_f32x4, b_z_low_f32x4);
1612
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_x_high_f32x4, b_x_high_f32x4);
1613
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_y_high_f32x4, b_y_high_f32x4);
1614
+ norm_squared_b_f32x4 = vfmaq_f32(norm_squared_b_f32x4, b_z_high_f32x4, b_z_high_f32x4);
1722
1615
  }
1723
1616
 
1724
1617
  // Reduce vector accumulators
@@ -1728,16 +1621,17 @@ NK_PUBLIC void nk_umeyama_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size
1728
1621
  nk_f32_t sum_b_x = vaddvq_f32(sum_b_x_f32x4);
1729
1622
  nk_f32_t sum_b_y = vaddvq_f32(sum_b_y_f32x4);
1730
1623
  nk_f32_t sum_b_z = vaddvq_f32(sum_b_z_f32x4);
1731
- nk_f32_t covariance_x_x = vaddvq_f32(cov_xx_f32x4);
1732
- nk_f32_t covariance_x_y = vaddvq_f32(cov_xy_f32x4);
1733
- nk_f32_t covariance_x_z = vaddvq_f32(cov_xz_f32x4);
1734
- nk_f32_t covariance_y_x = vaddvq_f32(cov_yx_f32x4);
1735
- nk_f32_t covariance_y_y = vaddvq_f32(cov_yy_f32x4);
1736
- nk_f32_t covariance_y_z = vaddvq_f32(cov_yz_f32x4);
1737
- nk_f32_t covariance_z_x = vaddvq_f32(cov_zx_f32x4);
1738
- nk_f32_t covariance_z_y = vaddvq_f32(cov_zy_f32x4);
1739
- nk_f32_t covariance_z_z = vaddvq_f32(cov_zz_f32x4);
1740
- nk_f32_t variance_a_sum = vaddvq_f32(variance_a_f32x4);
1624
+ nk_f32_t covariance_x_x = vaddvq_f32(covariance_xx_f32x4);
1625
+ nk_f32_t covariance_x_y = vaddvq_f32(covariance_xy_f32x4);
1626
+ nk_f32_t covariance_x_z = vaddvq_f32(covariance_xz_f32x4);
1627
+ nk_f32_t covariance_y_x = vaddvq_f32(covariance_yx_f32x4);
1628
+ nk_f32_t covariance_y_y = vaddvq_f32(covariance_yy_f32x4);
1629
+ nk_f32_t covariance_y_z = vaddvq_f32(covariance_yz_f32x4);
1630
+ nk_f32_t covariance_z_x = vaddvq_f32(covariance_zx_f32x4);
1631
+ nk_f32_t covariance_z_y = vaddvq_f32(covariance_zy_f32x4);
1632
+ nk_f32_t covariance_z_z = vaddvq_f32(covariance_zz_f32x4);
1633
+ nk_f32_t norm_squared_a = vaddvq_f32(norm_squared_a_f32x4);
1634
+ nk_f32_t norm_squared_b = vaddvq_f32(norm_squared_b_f32x4);
1741
1635
 
1742
1636
  // Compute centroids
1743
1637
  nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
@@ -1747,63 +1641,97 @@ NK_PUBLIC void nk_umeyama_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size
1747
1641
  if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
1748
1642
  if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
1749
1643
 
1750
- // Compute centered covariance and variance
1751
- nk_f32_t variance_a = variance_a_sum * inv_n -
1752
- (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
1753
-
1754
- nk_f32_t h[9];
1755
- h[0] = covariance_x_x - (nk_f32_t)n * centroid_a_x * centroid_b_x;
1756
- h[1] = covariance_x_y - (nk_f32_t)n * centroid_a_x * centroid_b_y;
1757
- h[2] = covariance_x_z - (nk_f32_t)n * centroid_a_x * centroid_b_z;
1758
- h[3] = covariance_y_x - (nk_f32_t)n * centroid_a_y * centroid_b_x;
1759
- h[4] = covariance_y_y - (nk_f32_t)n * centroid_a_y * centroid_b_y;
1760
- h[5] = covariance_y_z - (nk_f32_t)n * centroid_a_y * centroid_b_z;
1761
- h[6] = covariance_z_x - (nk_f32_t)n * centroid_a_z * centroid_b_x;
1762
- h[7] = covariance_z_y - (nk_f32_t)n * centroid_a_z * centroid_b_y;
1763
- h[8] = covariance_z_z - (nk_f32_t)n * centroid_a_z * centroid_b_z;
1764
-
1765
- // SVD of H = U * S * Vᵀ
1766
- nk_f32_t svd_u[9], svd_s[9], svd_v[9];
1767
- nk_svd3x3_f32_(h, svd_u, svd_s, svd_v);
1768
-
1769
- // R = V * Uᵀ
1770
- nk_f32_t r[9];
1771
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
1772
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
1773
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
1774
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
1775
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
1776
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
1777
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
1778
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
1779
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1780
-
1781
- // Handle reflection and compute scale: c = trace(D × S) / variance(a)
1782
- nk_f32_t det_r = nk_det3x3_f32_(r);
1783
- nk_f32_t sign_det = det_r < 0 ? -1.0f : 1.0f;
1784
- nk_f32_t trace_scaled_s = svd_s[0] + svd_s[4] + sign_det * svd_s[8];
1785
- nk_f32_t scale_factor = trace_scaled_s / ((nk_f32_t)n * variance_a);
1786
- if (scale) *scale = scale_factor;
1787
-
1788
- if (det_r < 0) {
1789
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
1790
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
1791
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
1792
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
1793
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
1794
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
1795
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
1796
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
1797
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
1798
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1644
+ // Centered norm-squared via parallel-axis identity; clamp at zero for numeric safety.
1645
+ nk_f32_t centered_norm_squared_a = norm_squared_a -
1646
+ (nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
1647
+ centroid_a_z * centroid_a_z);
1648
+ nk_f32_t centered_norm_squared_b = norm_squared_b -
1649
+ (nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
1650
+ centroid_b_z * centroid_b_z);
1651
+ if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
1652
+ if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
1653
+
1654
+ nk_f32_t cross_covariance[9];
1655
+ cross_covariance[0] = covariance_x_x - (nk_f32_t)n * centroid_a_x * centroid_b_x;
1656
+ cross_covariance[1] = covariance_x_y - (nk_f32_t)n * centroid_a_x * centroid_b_y;
1657
+ cross_covariance[2] = covariance_x_z - (nk_f32_t)n * centroid_a_x * centroid_b_z;
1658
+ cross_covariance[3] = covariance_y_x - (nk_f32_t)n * centroid_a_y * centroid_b_x;
1659
+ cross_covariance[4] = covariance_y_y - (nk_f32_t)n * centroid_a_y * centroid_b_y;
1660
+ cross_covariance[5] = covariance_y_z - (nk_f32_t)n * centroid_a_y * centroid_b_z;
1661
+ cross_covariance[6] = covariance_z_x - (nk_f32_t)n * centroid_a_z * centroid_b_x;
1662
+ cross_covariance[7] = covariance_z_y - (nk_f32_t)n * centroid_a_z * centroid_b_y;
1663
+ cross_covariance[8] = covariance_z_z - (nk_f32_t)n * centroid_a_z * centroid_b_z;
1664
+
1665
+ // Identity-dominant short-circuit: if H diag(positive entries), R = I and trace(R·H) = trace(H).
1666
+ nk_f32_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
1667
+ cross_covariance[4] * cross_covariance[4] +
1668
+ cross_covariance[8] * cross_covariance[8];
1669
+ nk_f32_t covariance_offdiagonal_norm_squared =
1670
+ cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
1671
+ cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
1672
+ cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
1673
+ nk_f32_t optimal_rotation[9];
1674
+ nk_f32_t trace_rotation_covariance;
1675
+ nk_f32_t scale_factor;
1676
+ if (covariance_offdiagonal_norm_squared < 1e-12f * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0f &&
1677
+ cross_covariance[4] > 0.0f && cross_covariance[8] > 0.0f) {
1678
+ optimal_rotation[0] = 1, optimal_rotation[1] = 0, optimal_rotation[2] = 0, optimal_rotation[3] = 0,
1679
+ optimal_rotation[4] = 1, optimal_rotation[5] = 0, optimal_rotation[6] = 0, optimal_rotation[7] = 0,
1680
+ optimal_rotation[8] = 1;
1681
+ trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
1682
+ scale_factor = centered_norm_squared_a > 0.0f ? trace_rotation_covariance / centered_norm_squared_a : 0.0f;
1799
1683
  }
1684
+ else {
1685
+ // SVD of H = U * S * Vᵀ
1686
+ nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
1687
+ nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
1688
+
1689
+ // R = V * Uᵀ
1690
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
1691
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
1692
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
1693
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
1694
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
1695
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
1696
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
1697
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
1698
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
1699
+
1700
+ // Handle reflection and compute scale: c = trace(D · S) / ‖a-ā‖²
1701
+ nk_f32_t rotation_determinant = nk_det3x3_f32_(optimal_rotation);
1702
+ nk_f32_t sign_det = rotation_determinant < 0 ? -1.0f : 1.0f;
1703
+ nk_f32_t trace_scaled_s = svd_diagonal[0] + svd_diagonal[4] + sign_det * svd_diagonal[8];
1704
+ scale_factor = centered_norm_squared_a > 0.0f ? trace_scaled_s / centered_norm_squared_a : 0.0f;
1705
+
1706
+ if (rotation_determinant < 0) {
1707
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
1708
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
1709
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
1710
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
1711
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
1712
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
1713
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
1714
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
1715
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
1716
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
1717
+ }
1718
+
1719
+ trace_rotation_covariance =
1720
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
1721
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
1722
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
1723
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
1724
+ optimal_rotation[8] * cross_covariance[8];
1725
+ }
1726
+ if (scale) *scale = scale_factor;
1800
1727
 
1801
1728
  if (rotation)
1802
- for (int j = 0; j < 9; ++j) rotation[j] = r[j];
1729
+ for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
1803
1730
 
1804
- // Compute RMSD after similarity transform
1805
- nk_f32_t sum_squared = nk_transformed_ssd_f16_neon_(a, b, n, r, scale_factor, centroid_a_x, centroid_a_y,
1806
- centroid_a_z, centroid_b_x, centroid_b_y, centroid_b_z);
1731
+ // Folded SSD with scale: c²·‖a-ā‖² + ‖b-b̄‖² − 2c·trace(R · H_centered).
1732
+ nk_f32_t sum_squared = scale_factor * scale_factor * centered_norm_squared_a + centered_norm_squared_b -
1733
+ 2.0f * scale_factor * trace_rotation_covariance;
1734
+ if (sum_squared < 0.0f) sum_squared = 0.0f;
1807
1735
  *result = nk_f32_sqrt_neon(sum_squared * inv_n);
1808
1736
  }
1809
1737
  #if defined(__clang__)