numkong 7.4.5 → 7.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. package/README.md +1 -0
  2. package/binding.gyp +99 -5
  3. package/c/dispatch_e5m2.c +23 -3
  4. package/c/dispatch_f16.c +23 -0
  5. package/c/numkong.c +0 -13
  6. package/include/numkong/attention/sme.h +34 -31
  7. package/include/numkong/capabilities.h +2 -15
  8. package/include/numkong/cast/README.md +3 -0
  9. package/include/numkong/cast/haswell.h +28 -64
  10. package/include/numkong/cast/neon.h +15 -0
  11. package/include/numkong/cast/serial.h +17 -0
  12. package/include/numkong/cast/skylake.h +67 -52
  13. package/include/numkong/cast.h +1 -0
  14. package/include/numkong/curved/smef64.h +82 -62
  15. package/include/numkong/dot/README.md +1 -0
  16. package/include/numkong/dot/haswell.h +92 -13
  17. package/include/numkong/dot/rvvbf16.h +1 -1
  18. package/include/numkong/dot/rvvhalf.h +1 -1
  19. package/include/numkong/dot/serial.h +15 -0
  20. package/include/numkong/dot/skylake.h +61 -14
  21. package/include/numkong/dot/sve.h +6 -5
  22. package/include/numkong/dot/svebfdot.h +2 -1
  23. package/include/numkong/dot/svehalf.h +6 -5
  24. package/include/numkong/dot/svesdot.h +3 -2
  25. package/include/numkong/dots/README.md +2 -0
  26. package/include/numkong/dots/graniteamx.h +1167 -0
  27. package/include/numkong/dots/haswell.h +28 -28
  28. package/include/numkong/dots/sapphireamx.h +1 -1
  29. package/include/numkong/dots/serial.h +33 -11
  30. package/include/numkong/dots/skylake.h +28 -23
  31. package/include/numkong/dots/sme.h +172 -140
  32. package/include/numkong/dots/smebi32.h +14 -11
  33. package/include/numkong/dots/smef64.h +31 -26
  34. package/include/numkong/dots.h +41 -3
  35. package/include/numkong/each/serial.h +39 -0
  36. package/include/numkong/geospatial/haswell.h +1 -1
  37. package/include/numkong/geospatial/neon.h +1 -1
  38. package/include/numkong/geospatial/serial.h +15 -4
  39. package/include/numkong/geospatial/skylake.h +1 -1
  40. package/include/numkong/maxsim/serial.h +15 -0
  41. package/include/numkong/maxsim/sme.h +34 -33
  42. package/include/numkong/mesh/README.md +50 -44
  43. package/include/numkong/mesh/genoa.h +462 -0
  44. package/include/numkong/mesh/haswell.h +806 -933
  45. package/include/numkong/mesh/neon.h +871 -943
  46. package/include/numkong/mesh/neonbfdot.h +382 -522
  47. package/include/numkong/mesh/neonfhm.h +676 -0
  48. package/include/numkong/mesh/rvv.h +404 -319
  49. package/include/numkong/mesh/serial.h +225 -161
  50. package/include/numkong/mesh/skylake.h +1029 -1585
  51. package/include/numkong/mesh/v128relaxed.h +403 -377
  52. package/include/numkong/mesh.h +38 -0
  53. package/include/numkong/reduce/neon.h +29 -0
  54. package/include/numkong/reduce/neonbfdot.h +2 -2
  55. package/include/numkong/reduce/neonfhm.h +4 -4
  56. package/include/numkong/reduce/serial.h +15 -1
  57. package/include/numkong/reduce/sve.h +52 -0
  58. package/include/numkong/reduce.h +4 -0
  59. package/include/numkong/set/sve.h +6 -5
  60. package/include/numkong/sets/smebi32.h +35 -30
  61. package/include/numkong/sparse/serial.h +17 -2
  62. package/include/numkong/sparse/sve2.h +3 -2
  63. package/include/numkong/spatial/genoa.h +0 -68
  64. package/include/numkong/spatial/haswell.h +98 -56
  65. package/include/numkong/spatial/serial.h +15 -0
  66. package/include/numkong/spatial/skylake.h +114 -54
  67. package/include/numkong/spatial/sve.h +7 -6
  68. package/include/numkong/spatial/svebfdot.h +7 -4
  69. package/include/numkong/spatial/svehalf.h +5 -4
  70. package/include/numkong/spatial/svesdot.h +9 -8
  71. package/include/numkong/spatial.h +0 -12
  72. package/include/numkong/spatials/graniteamx.h +301 -0
  73. package/include/numkong/spatials/serial.h +39 -0
  74. package/include/numkong/spatials/skylake.h +2 -2
  75. package/include/numkong/spatials/sme.h +391 -350
  76. package/include/numkong/spatials/smef64.h +79 -70
  77. package/include/numkong/spatials.h +54 -4
  78. package/include/numkong/tensor.hpp +107 -23
  79. package/include/numkong/types.h +59 -0
  80. package/javascript/dist/cjs/numkong.js +13 -0
  81. package/javascript/dist/esm/numkong.js +13 -0
  82. package/javascript/numkong.c +59 -14
  83. package/javascript/numkong.ts +13 -0
  84. package/package.json +7 -7
  85. package/probes/probe.js +2 -2
  86. package/wasm/numkong.wasm +0 -0
@@ -88,10 +88,6 @@ NK_INTERNAL nk_f64_t nk_reduce_stable_f64x4_haswell_(__m256d values_f64x4) {
88
88
  return sum + compensation;
89
89
  }
90
90
 
91
- NK_INTERNAL void nk_rotation_from_svd_f64_haswell_(nk_f64_t const *svd_u, nk_f64_t const *svd_v, nk_f64_t *rotation) {
92
- nk_rotation_from_svd_f64_serial_(svd_u, svd_v, rotation);
93
- }
94
-
95
91
  NK_INTERNAL void nk_accumulate_square_f64x4_haswell_(__m256d *sum_f64x4, __m256d *compensation_f64x4,
96
92
  __m256d values_f64x4) {
97
93
  __m256d product_f64x4 = _mm256_mul_pd(values_f64x4, values_f64x4);
@@ -105,208 +101,6 @@ NK_INTERNAL void nk_accumulate_square_f64x4_haswell_(__m256d *sum_f64x4, __m256d
105
101
  *compensation_f64x4 = _mm256_add_pd(*compensation_f64x4, _mm256_add_pd(sum_error_f64x4, product_error_f64x4));
106
102
  }
107
103
 
108
- /* Compute sum of squared distances after applying rotation (and optional scale).
109
- * Used by kabsch (scale=1.0) and umeyama (scale=computed_scale).
110
- * Returns sum_squared, caller computes sqrt(sum_squared / n).
111
- */
112
- NK_INTERNAL nk_f64_t nk_transformed_ssd_f32_haswell_(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n,
113
- nk_f64_t const *r, nk_f64_t scale, nk_f64_t centroid_a_x,
114
- nk_f64_t centroid_a_y, nk_f64_t centroid_a_z,
115
- nk_f64_t centroid_b_x, nk_f64_t centroid_b_y,
116
- nk_f64_t centroid_b_z) {
117
- __m256d scaled_rotation_x_x_f64x4 = _mm256_set1_pd(scale * r[0]);
118
- __m256d scaled_rotation_x_y_f64x4 = _mm256_set1_pd(scale * r[1]);
119
- __m256d scaled_rotation_x_z_f64x4 = _mm256_set1_pd(scale * r[2]);
120
- __m256d scaled_rotation_y_x_f64x4 = _mm256_set1_pd(scale * r[3]);
121
- __m256d scaled_rotation_y_y_f64x4 = _mm256_set1_pd(scale * r[4]);
122
- __m256d scaled_rotation_y_z_f64x4 = _mm256_set1_pd(scale * r[5]);
123
- __m256d scaled_rotation_z_x_f64x4 = _mm256_set1_pd(scale * r[6]);
124
- __m256d scaled_rotation_z_y_f64x4 = _mm256_set1_pd(scale * r[7]);
125
- __m256d scaled_rotation_z_z_f64x4 = _mm256_set1_pd(scale * r[8]);
126
- __m256d centroid_a_x_f64x4 = _mm256_set1_pd(centroid_a_x), centroid_a_y_f64x4 = _mm256_set1_pd(centroid_a_y);
127
- __m256d centroid_a_z_f64x4 = _mm256_set1_pd(centroid_a_z), centroid_b_x_f64x4 = _mm256_set1_pd(centroid_b_x);
128
- __m256d centroid_b_y_f64x4 = _mm256_set1_pd(centroid_b_y), centroid_b_z_f64x4 = _mm256_set1_pd(centroid_b_z);
129
- __m256d sum_squared_f64x4 = _mm256_setzero_pd();
130
- __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
131
- nk_size_t index = 0;
132
-
133
- for (; index + 8 <= n; index += 8) {
134
- nk_deinterleave_f32x8_haswell_(a + index * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8),
135
- nk_deinterleave_f32x8_haswell_(b + index * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
136
-
137
- __m256d a_x_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
138
- __m256d a_x_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
139
- __m256d a_y_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
140
- __m256d a_y_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
141
- __m256d a_z_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
142
- __m256d a_z_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
143
- __m256d b_x_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
144
- __m256d b_x_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
145
- __m256d b_y_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
146
- __m256d b_y_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
147
- __m256d b_z_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
148
- __m256d b_z_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
149
-
150
- __m256d centered_a_x_low_f64x4 = _mm256_sub_pd(a_x_low_f64x4, centroid_a_x_f64x4);
151
- __m256d centered_a_x_high_f64x4 = _mm256_sub_pd(a_x_high_f64x4, centroid_a_x_f64x4);
152
- __m256d centered_a_y_low_f64x4 = _mm256_sub_pd(a_y_low_f64x4, centroid_a_y_f64x4);
153
- __m256d centered_a_y_high_f64x4 = _mm256_sub_pd(a_y_high_f64x4, centroid_a_y_f64x4);
154
- __m256d centered_a_z_low_f64x4 = _mm256_sub_pd(a_z_low_f64x4, centroid_a_z_f64x4);
155
- __m256d centered_a_z_high_f64x4 = _mm256_sub_pd(a_z_high_f64x4, centroid_a_z_f64x4);
156
- __m256d centered_b_x_low_f64x4 = _mm256_sub_pd(b_x_low_f64x4, centroid_b_x_f64x4);
157
- __m256d centered_b_x_high_f64x4 = _mm256_sub_pd(b_x_high_f64x4, centroid_b_x_f64x4);
158
- __m256d centered_b_y_low_f64x4 = _mm256_sub_pd(b_y_low_f64x4, centroid_b_y_f64x4);
159
- __m256d centered_b_y_high_f64x4 = _mm256_sub_pd(b_y_high_f64x4, centroid_b_y_f64x4);
160
- __m256d centered_b_z_low_f64x4 = _mm256_sub_pd(b_z_low_f64x4, centroid_b_z_f64x4);
161
- __m256d centered_b_z_high_f64x4 = _mm256_sub_pd(b_z_high_f64x4, centroid_b_z_f64x4);
162
-
163
- __m256d rotated_a_x_low_f64x4 = _mm256_fmadd_pd(
164
- scaled_rotation_x_z_f64x4, centered_a_z_low_f64x4,
165
- _mm256_fmadd_pd(scaled_rotation_x_y_f64x4, centered_a_y_low_f64x4,
166
- _mm256_mul_pd(scaled_rotation_x_x_f64x4, centered_a_x_low_f64x4)));
167
- __m256d rotated_a_x_high_f64x4 = _mm256_fmadd_pd(
168
- scaled_rotation_x_z_f64x4, centered_a_z_high_f64x4,
169
- _mm256_fmadd_pd(scaled_rotation_x_y_f64x4, centered_a_y_high_f64x4,
170
- _mm256_mul_pd(scaled_rotation_x_x_f64x4, centered_a_x_high_f64x4)));
171
- __m256d rotated_a_y_low_f64x4 = _mm256_fmadd_pd(
172
- scaled_rotation_y_z_f64x4, centered_a_z_low_f64x4,
173
- _mm256_fmadd_pd(scaled_rotation_y_y_f64x4, centered_a_y_low_f64x4,
174
- _mm256_mul_pd(scaled_rotation_y_x_f64x4, centered_a_x_low_f64x4)));
175
- __m256d rotated_a_y_high_f64x4 = _mm256_fmadd_pd(
176
- scaled_rotation_y_z_f64x4, centered_a_z_high_f64x4,
177
- _mm256_fmadd_pd(scaled_rotation_y_y_f64x4, centered_a_y_high_f64x4,
178
- _mm256_mul_pd(scaled_rotation_y_x_f64x4, centered_a_x_high_f64x4)));
179
- __m256d rotated_a_z_low_f64x4 = _mm256_fmadd_pd(
180
- scaled_rotation_z_z_f64x4, centered_a_z_low_f64x4,
181
- _mm256_fmadd_pd(scaled_rotation_z_y_f64x4, centered_a_y_low_f64x4,
182
- _mm256_mul_pd(scaled_rotation_z_x_f64x4, centered_a_x_low_f64x4)));
183
- __m256d rotated_a_z_high_f64x4 = _mm256_fmadd_pd(
184
- scaled_rotation_z_z_f64x4, centered_a_z_high_f64x4,
185
- _mm256_fmadd_pd(scaled_rotation_z_y_f64x4, centered_a_y_high_f64x4,
186
- _mm256_mul_pd(scaled_rotation_z_x_f64x4, centered_a_x_high_f64x4)));
187
-
188
- __m256d delta_x_low_f64x4 = _mm256_sub_pd(rotated_a_x_low_f64x4, centered_b_x_low_f64x4);
189
- __m256d delta_x_high_f64x4 = _mm256_sub_pd(rotated_a_x_high_f64x4, centered_b_x_high_f64x4);
190
- __m256d delta_y_low_f64x4 = _mm256_sub_pd(rotated_a_y_low_f64x4, centered_b_y_low_f64x4);
191
- __m256d delta_y_high_f64x4 = _mm256_sub_pd(rotated_a_y_high_f64x4, centered_b_y_high_f64x4);
192
- __m256d delta_z_low_f64x4 = _mm256_sub_pd(rotated_a_z_low_f64x4, centered_b_z_low_f64x4);
193
- __m256d delta_z_high_f64x4 = _mm256_sub_pd(rotated_a_z_high_f64x4, centered_b_z_high_f64x4);
194
-
195
- __m256d batch_sum_squared_f64x4 = _mm256_add_pd(_mm256_mul_pd(delta_x_low_f64x4, delta_x_low_f64x4),
196
- _mm256_mul_pd(delta_x_high_f64x4, delta_x_high_f64x4));
197
- batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_low_f64x4, delta_y_low_f64x4, batch_sum_squared_f64x4);
198
- batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_high_f64x4, delta_y_high_f64x4, batch_sum_squared_f64x4);
199
- batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_low_f64x4, delta_z_low_f64x4, batch_sum_squared_f64x4);
200
- batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_high_f64x4, delta_z_high_f64x4, batch_sum_squared_f64x4);
201
- sum_squared_f64x4 = _mm256_add_pd(sum_squared_f64x4, batch_sum_squared_f64x4);
202
- }
203
-
204
- nk_f64_t sum_squared = nk_reduce_add_f64x4_haswell_(sum_squared_f64x4);
205
- for (; index < n; ++index) {
206
- nk_f64_t centered_a_x = (nk_f64_t)a[index * 3 + 0] - centroid_a_x,
207
- centered_a_y = (nk_f64_t)a[index * 3 + 1] - centroid_a_y,
208
- centered_a_z = (nk_f64_t)a[index * 3 + 2] - centroid_a_z;
209
- nk_f64_t centered_b_x = (nk_f64_t)b[index * 3 + 0] - centroid_b_x,
210
- centered_b_y = (nk_f64_t)b[index * 3 + 1] - centroid_b_y,
211
- centered_b_z = (nk_f64_t)b[index * 3 + 2] - centroid_b_z;
212
- nk_f64_t rotated_a_x = scale * (r[0] * centered_a_x + r[1] * centered_a_y + r[2] * centered_a_z),
213
- rotated_a_y = scale * (r[3] * centered_a_x + r[4] * centered_a_y + r[5] * centered_a_z),
214
- rotated_a_z = scale * (r[6] * centered_a_x + r[7] * centered_a_y + r[8] * centered_a_z);
215
- nk_f64_t delta_x = rotated_a_x - centered_b_x, delta_y = rotated_a_y - centered_b_y,
216
- delta_z = rotated_a_z - centered_b_z;
217
- sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
218
- }
219
-
220
- return sum_squared;
221
- }
222
-
223
- /* Compute sum of squared distances for f64 after applying rotation (and optional scale).
224
- * Rotation matrix, scale and data are all f64 for full precision.
225
- */
226
- NK_INTERNAL nk_f64_t nk_transformed_ssd_f64_haswell_(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n,
227
- nk_f64_t const *r, nk_f64_t scale, nk_f64_t centroid_a_x,
228
- nk_f64_t centroid_a_y, nk_f64_t centroid_a_z,
229
- nk_f64_t centroid_b_x, nk_f64_t centroid_b_y,
230
- nk_f64_t centroid_b_z) {
231
- // Broadcast scaled rotation matrix elements
232
- __m256d scaled_rotation_x_x_f64x4 = _mm256_set1_pd(scale * r[0]);
233
- __m256d scaled_rotation_x_y_f64x4 = _mm256_set1_pd(scale * r[1]);
234
- __m256d scaled_rotation_x_z_f64x4 = _mm256_set1_pd(scale * r[2]);
235
- __m256d scaled_rotation_y_x_f64x4 = _mm256_set1_pd(scale * r[3]);
236
- __m256d scaled_rotation_y_y_f64x4 = _mm256_set1_pd(scale * r[4]);
237
- __m256d scaled_rotation_y_z_f64x4 = _mm256_set1_pd(scale * r[5]);
238
- __m256d scaled_rotation_z_x_f64x4 = _mm256_set1_pd(scale * r[6]);
239
- __m256d scaled_rotation_z_y_f64x4 = _mm256_set1_pd(scale * r[7]);
240
- __m256d scaled_rotation_z_z_f64x4 = _mm256_set1_pd(scale * r[8]);
241
-
242
- // Broadcast centroids
243
- __m256d centroid_a_x_f64x4 = _mm256_set1_pd(centroid_a_x);
244
- __m256d centroid_a_y_f64x4 = _mm256_set1_pd(centroid_a_y);
245
- __m256d centroid_a_z_f64x4 = _mm256_set1_pd(centroid_a_z);
246
- __m256d centroid_b_x_f64x4 = _mm256_set1_pd(centroid_b_x);
247
- __m256d centroid_b_y_f64x4 = _mm256_set1_pd(centroid_b_y);
248
- __m256d centroid_b_z_f64x4 = _mm256_set1_pd(centroid_b_z);
249
-
250
- __m256d sum_squared_f64x4 = _mm256_setzero_pd();
251
- __m256d sum_squared_compensation_f64x4 = _mm256_setzero_pd();
252
- __m256d a_x_f64x4, a_y_f64x4, a_z_f64x4, b_x_f64x4, b_y_f64x4, b_z_f64x4;
253
- nk_size_t j = 0;
254
-
255
- for (; j + 4 <= n; j += 4) {
256
- nk_deinterleave_f64x4_haswell_(a + j * 3, &a_x_f64x4, &a_y_f64x4, &a_z_f64x4);
257
- nk_deinterleave_f64x4_haswell_(b + j * 3, &b_x_f64x4, &b_y_f64x4, &b_z_f64x4);
258
-
259
- // Center points
260
- __m256d pa_x_f64x4 = _mm256_sub_pd(a_x_f64x4, centroid_a_x_f64x4);
261
- __m256d pa_y_f64x4 = _mm256_sub_pd(a_y_f64x4, centroid_a_y_f64x4);
262
- __m256d pa_z_f64x4 = _mm256_sub_pd(a_z_f64x4, centroid_a_z_f64x4);
263
- __m256d pb_x_f64x4 = _mm256_sub_pd(b_x_f64x4, centroid_b_x_f64x4);
264
- __m256d pb_y_f64x4 = _mm256_sub_pd(b_y_f64x4, centroid_b_y_f64x4);
265
- __m256d pb_z_f64x4 = _mm256_sub_pd(b_z_f64x4, centroid_b_z_f64x4);
266
-
267
- // Rotate and scale: ra = scale * R * pa
268
- __m256d ra_x_f64x4 = _mm256_fmadd_pd(scaled_rotation_x_z_f64x4, pa_z_f64x4,
269
- _mm256_fmadd_pd(scaled_rotation_x_y_f64x4, pa_y_f64x4,
270
- _mm256_mul_pd(scaled_rotation_x_x_f64x4, pa_x_f64x4)));
271
- __m256d ra_y_f64x4 = _mm256_fmadd_pd(scaled_rotation_y_z_f64x4, pa_z_f64x4,
272
- _mm256_fmadd_pd(scaled_rotation_y_y_f64x4, pa_y_f64x4,
273
- _mm256_mul_pd(scaled_rotation_y_x_f64x4, pa_x_f64x4)));
274
- __m256d ra_z_f64x4 = _mm256_fmadd_pd(scaled_rotation_z_z_f64x4, pa_z_f64x4,
275
- _mm256_fmadd_pd(scaled_rotation_z_y_f64x4, pa_y_f64x4,
276
- _mm256_mul_pd(scaled_rotation_z_x_f64x4, pa_x_f64x4)));
277
-
278
- // Delta and accumulate
279
- __m256d delta_x_f64x4 = _mm256_sub_pd(ra_x_f64x4, pb_x_f64x4);
280
- __m256d delta_y_f64x4 = _mm256_sub_pd(ra_y_f64x4, pb_y_f64x4);
281
- __m256d delta_z_f64x4 = _mm256_sub_pd(ra_z_f64x4, pb_z_f64x4);
282
-
283
- nk_accumulate_square_f64x4_haswell_(&sum_squared_f64x4, &sum_squared_compensation_f64x4, delta_x_f64x4);
284
- nk_accumulate_square_f64x4_haswell_(&sum_squared_f64x4, &sum_squared_compensation_f64x4, delta_y_f64x4);
285
- nk_accumulate_square_f64x4_haswell_(&sum_squared_f64x4, &sum_squared_compensation_f64x4, delta_z_f64x4);
286
- }
287
-
288
- nk_f64_t sum_squared = nk_dot_stable_sum_f64x4_haswell_(sum_squared_f64x4, sum_squared_compensation_f64x4);
289
- nk_f64_t sum_squared_compensation = 0.0;
290
-
291
- // Scalar tail
292
- for (; j < n; ++j) {
293
- nk_f64_t pa_x = a[j * 3 + 0] - centroid_a_x, pa_y = a[j * 3 + 1] - centroid_a_y,
294
- pa_z = a[j * 3 + 2] - centroid_a_z;
295
- nk_f64_t pb_x = b[j * 3 + 0] - centroid_b_x, pb_y = b[j * 3 + 1] - centroid_b_y,
296
- pb_z = b[j * 3 + 2] - centroid_b_z;
297
- nk_f64_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z),
298
- ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z),
299
- ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
300
-
301
- nk_f64_t delta_x = ra_x - pb_x, delta_y = ra_y - pb_y, delta_z = ra_z - pb_z;
302
- nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_x);
303
- nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_y);
304
- nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_z);
305
- }
306
-
307
- return sum_squared + sum_squared_compensation;
308
- }
309
-
310
104
  NK_PUBLIC void nk_rmsd_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
311
105
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
312
106
  if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
@@ -441,6 +235,7 @@ NK_PUBLIC void nk_kabsch_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_si
441
235
  __m256d covariance_11_f64x4 = _mm256_setzero_pd(), covariance_12_f64x4 = _mm256_setzero_pd();
442
236
  __m256d covariance_20_f64x4 = _mm256_setzero_pd(), covariance_21_f64x4 = _mm256_setzero_pd();
443
237
  __m256d covariance_22_f64x4 = _mm256_setzero_pd();
238
+ __m256d norm_squared_a_f64x4 = _mm256_setzero_pd(), norm_squared_b_f64x4 = _mm256_setzero_pd();
444
239
  __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
445
240
  nk_size_t index = 0;
446
241
 
@@ -494,6 +289,24 @@ NK_PUBLIC void nk_kabsch_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_si
494
289
  covariance_22_f64x4 = _mm256_add_pd(
495
290
  covariance_22_f64x4,
496
291
  _mm256_add_pd(_mm256_mul_pd(a_z_low_f64x4, b_z_low_f64x4), _mm256_mul_pd(a_z_high_f64x4, b_z_high_f64x4)));
292
+ norm_squared_a_f64x4 = _mm256_add_pd(
293
+ norm_squared_a_f64x4,
294
+ _mm256_add_pd(_mm256_mul_pd(a_x_low_f64x4, a_x_low_f64x4), _mm256_mul_pd(a_x_high_f64x4, a_x_high_f64x4)));
295
+ norm_squared_a_f64x4 = _mm256_add_pd(
296
+ norm_squared_a_f64x4,
297
+ _mm256_add_pd(_mm256_mul_pd(a_y_low_f64x4, a_y_low_f64x4), _mm256_mul_pd(a_y_high_f64x4, a_y_high_f64x4)));
298
+ norm_squared_a_f64x4 = _mm256_add_pd(
299
+ norm_squared_a_f64x4,
300
+ _mm256_add_pd(_mm256_mul_pd(a_z_low_f64x4, a_z_low_f64x4), _mm256_mul_pd(a_z_high_f64x4, a_z_high_f64x4)));
301
+ norm_squared_b_f64x4 = _mm256_add_pd(
302
+ norm_squared_b_f64x4,
303
+ _mm256_add_pd(_mm256_mul_pd(b_x_low_f64x4, b_x_low_f64x4), _mm256_mul_pd(b_x_high_f64x4, b_x_high_f64x4)));
304
+ norm_squared_b_f64x4 = _mm256_add_pd(
305
+ norm_squared_b_f64x4,
306
+ _mm256_add_pd(_mm256_mul_pd(b_y_low_f64x4, b_y_low_f64x4), _mm256_mul_pd(b_y_high_f64x4, b_y_high_f64x4)));
307
+ norm_squared_b_f64x4 = _mm256_add_pd(
308
+ norm_squared_b_f64x4,
309
+ _mm256_add_pd(_mm256_mul_pd(b_z_low_f64x4, b_z_low_f64x4), _mm256_mul_pd(b_z_high_f64x4, b_z_high_f64x4)));
497
310
  }
498
311
 
499
312
  nk_f64_t sum_a_x = nk_reduce_add_f64x4_haswell_(sum_a_x_f64x4);
@@ -502,21 +315,25 @@ NK_PUBLIC void nk_kabsch_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_si
502
315
  nk_f64_t sum_b_x = nk_reduce_add_f64x4_haswell_(sum_b_x_f64x4);
503
316
  nk_f64_t sum_b_y = nk_reduce_add_f64x4_haswell_(sum_b_y_f64x4);
504
317
  nk_f64_t sum_b_z = nk_reduce_add_f64x4_haswell_(sum_b_z_f64x4);
505
- nk_f64_t h[9] = {
318
+ nk_f64_t cross_covariance[9] = {
506
319
  nk_reduce_add_f64x4_haswell_(covariance_00_f64x4), nk_reduce_add_f64x4_haswell_(covariance_01_f64x4),
507
320
  nk_reduce_add_f64x4_haswell_(covariance_02_f64x4), nk_reduce_add_f64x4_haswell_(covariance_10_f64x4),
508
321
  nk_reduce_add_f64x4_haswell_(covariance_11_f64x4), nk_reduce_add_f64x4_haswell_(covariance_12_f64x4),
509
322
  nk_reduce_add_f64x4_haswell_(covariance_20_f64x4), nk_reduce_add_f64x4_haswell_(covariance_21_f64x4),
510
323
  nk_reduce_add_f64x4_haswell_(covariance_22_f64x4)};
324
+ nk_f64_t norm_squared_a = nk_reduce_add_f64x4_haswell_(norm_squared_a_f64x4);
325
+ nk_f64_t norm_squared_b = nk_reduce_add_f64x4_haswell_(norm_squared_b_f64x4);
511
326
 
512
327
  for (; index < n; ++index) {
513
328
  nk_f64_t a_x = a[index * 3 + 0], a_y = a[index * 3 + 1], a_z = a[index * 3 + 2];
514
329
  nk_f64_t b_x = b[index * 3 + 0], b_y = b[index * 3 + 1], b_z = b[index * 3 + 2];
515
330
  sum_a_x += a_x, sum_a_y += a_y, sum_a_z += a_z;
516
331
  sum_b_x += b_x, sum_b_y += b_y, sum_b_z += b_z;
517
- h[0] += a_x * b_x, h[1] += a_x * b_y, h[2] += a_x * b_z;
518
- h[3] += a_y * b_x, h[4] += a_y * b_y, h[5] += a_y * b_z;
519
- h[6] += a_z * b_x, h[7] += a_z * b_y, h[8] += a_z * b_z;
332
+ cross_covariance[0] += a_x * b_x, cross_covariance[1] += a_x * b_y, cross_covariance[2] += a_x * b_z;
333
+ cross_covariance[3] += a_y * b_x, cross_covariance[4] += a_y * b_y, cross_covariance[5] += a_y * b_z;
334
+ cross_covariance[6] += a_z * b_x, cross_covariance[7] += a_z * b_y, cross_covariance[8] += a_z * b_z;
335
+ norm_squared_a += a_x * a_x + a_y * a_y + a_z * a_z;
336
+ norm_squared_b += b_x * b_x + b_y * b_y + b_z * b_z;
520
337
  }
521
338
 
522
339
  nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
@@ -529,41 +346,81 @@ NK_PUBLIC void nk_kabsch_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_si
529
346
  b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
530
347
  b_centroid[2] = (nk_f32_t)centroid_b_z;
531
348
 
532
- h[0] -= (nk_f64_t)n * centroid_a_x * centroid_b_x, h[1] -= (nk_f64_t)n * centroid_a_x * centroid_b_y,
533
- h[2] -= (nk_f64_t)n * centroid_a_x * centroid_b_z, h[3] -= (nk_f64_t)n * centroid_a_y * centroid_b_x,
534
- h[4] -= (nk_f64_t)n * centroid_a_y * centroid_b_y, h[5] -= (nk_f64_t)n * centroid_a_y * centroid_b_z,
535
- h[6] -= (nk_f64_t)n * centroid_a_z * centroid_b_x, h[7] -= (nk_f64_t)n * centroid_a_z * centroid_b_y,
536
- h[8] -= (nk_f64_t)n * centroid_a_z * centroid_b_z;
537
-
538
- nk_f64_t cross_covariance[9] = {h[0], h[1], h[2], h[3], h[4], h[5], h[6], h[7], h[8]};
539
- nk_f64_t svd_u[9], svd_s[9], svd_v[9], r[9];
540
- nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
541
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
542
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
543
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
544
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
545
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
546
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
547
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
548
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
549
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
550
- if (nk_det3x3_f64_(r) < 0) {
551
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
552
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
553
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
554
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
555
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
556
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
557
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
558
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
559
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
560
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
349
+ cross_covariance[0] -= (nk_f64_t)n * centroid_a_x * centroid_b_x,
350
+ cross_covariance[1] -= (nk_f64_t)n * centroid_a_x * centroid_b_y,
351
+ cross_covariance[2] -= (nk_f64_t)n * centroid_a_x * centroid_b_z,
352
+ cross_covariance[3] -= (nk_f64_t)n * centroid_a_y * centroid_b_x,
353
+ cross_covariance[4] -= (nk_f64_t)n * centroid_a_y * centroid_b_y,
354
+ cross_covariance[5] -= (nk_f64_t)n * centroid_a_y * centroid_b_z,
355
+ cross_covariance[6] -= (nk_f64_t)n * centroid_a_z * centroid_b_x,
356
+ cross_covariance[7] -= (nk_f64_t)n * centroid_a_z * centroid_b_y,
357
+ cross_covariance[8] -= (nk_f64_t)n * centroid_a_z * centroid_b_z;
358
+
359
+ nk_f64_t centered_norm_squared_a = norm_squared_a -
360
+ (nk_f64_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
361
+ centroid_a_z * centroid_a_z);
362
+ nk_f64_t centered_norm_squared_b = norm_squared_b -
363
+ (nk_f64_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
364
+ centroid_b_z * centroid_b_z);
365
+ if (centered_norm_squared_a < 0.0) centered_norm_squared_a = 0.0;
366
+ if (centered_norm_squared_b < 0.0) centered_norm_squared_b = 0.0;
367
+
368
+ // Identity-dominant short-circuit: R = I, trace(R * H) = H[0]+H[4]+H[8]. Skips SVD + two
369
+ // rotation_from_svd reconstructions when the inputs are already aligned.
370
+ nk_f64_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
371
+ cross_covariance[4] * cross_covariance[4] +
372
+ cross_covariance[8] * cross_covariance[8];
373
+ nk_f64_t covariance_offdiagonal_norm_squared =
374
+ cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
375
+ cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
376
+ cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
377
+ nk_f64_t optimal_rotation[9];
378
+ nk_f64_t trace_rotation_covariance;
379
+ if (covariance_offdiagonal_norm_squared < 1e-20 * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0 &&
380
+ cross_covariance[4] > 0.0 && cross_covariance[8] > 0.0) {
381
+ optimal_rotation[0] = 1, optimal_rotation[1] = 0, optimal_rotation[2] = 0, optimal_rotation[3] = 0,
382
+ optimal_rotation[4] = 1, optimal_rotation[5] = 0, optimal_rotation[6] = 0, optimal_rotation[7] = 0,
383
+ optimal_rotation[8] = 1;
384
+ trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
385
+ }
386
+ else {
387
+ nk_f64_t svd_left[9], svd_diagonal[9], svd_right[9];
388
+ nk_svd3x3_f64_(cross_covariance, svd_left, svd_diagonal, svd_right);
389
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
390
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
391
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
392
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
393
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
394
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
395
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
396
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
397
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
398
+ if (nk_det3x3_f64_(optimal_rotation) < 0) {
399
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
400
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
401
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
402
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
403
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
404
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
405
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
406
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
407
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
408
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
409
+ }
410
+ trace_rotation_covariance =
411
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
412
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
413
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
414
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
415
+ optimal_rotation[8] * cross_covariance[8];
561
416
  }
562
417
 
563
418
  if (rotation)
564
- for (int j = 0; j != 9; ++j) rotation[j] = (nk_f32_t)r[j];
565
- nk_f64_t sum_squared = nk_transformed_ssd_f32_haswell_(a, b, n, r, 1.0, centroid_a_x, centroid_a_y, centroid_a_z,
566
- centroid_b_x, centroid_b_y, centroid_b_z);
419
+ for (int j = 0; j != 9; ++j) rotation[j] = (nk_f32_t)optimal_rotation[j];
420
+
421
+ // Folded SSD via trace identity: SSD = ‖a-ā‖² + ‖b-b̄‖² − 2·trace(R · H_centered).
422
+ nk_f64_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0 * trace_rotation_covariance;
423
+ if (sum_squared < 0.0) sum_squared = 0.0;
567
424
  *result = nk_f64_sqrt_haswell(sum_squared / (nk_f64_t)n);
568
425
  }
569
426
 
@@ -576,14 +433,15 @@ NK_PUBLIC void nk_kabsch_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_si
576
433
  __m256d sum_b_x_f64x4 = zeros_f64x4, sum_b_y_f64x4 = zeros_f64x4, sum_b_z_f64x4 = zeros_f64x4;
577
434
 
578
435
  // Accumulators for covariance matrix (sum of outer products)
579
- __m256d cov_xx_f64x4 = zeros_f64x4, cov_xy_f64x4 = zeros_f64x4, cov_xz_f64x4 = zeros_f64x4;
580
- __m256d cov_yx_f64x4 = zeros_f64x4, cov_yy_f64x4 = zeros_f64x4, cov_yz_f64x4 = zeros_f64x4;
581
- __m256d cov_zx_f64x4 = zeros_f64x4, cov_zy_f64x4 = zeros_f64x4, cov_zz_f64x4 = zeros_f64x4;
436
+ __m256d covariance_xx_f64x4 = zeros_f64x4, covariance_xy_f64x4 = zeros_f64x4, covariance_xz_f64x4 = zeros_f64x4;
437
+ __m256d covariance_yx_f64x4 = zeros_f64x4, covariance_yy_f64x4 = zeros_f64x4, covariance_yz_f64x4 = zeros_f64x4;
438
+ __m256d covariance_zx_f64x4 = zeros_f64x4, covariance_zy_f64x4 = zeros_f64x4, covariance_zz_f64x4 = zeros_f64x4;
439
+ __m256d norm_squared_a_f64x4 = zeros_f64x4, norm_squared_b_f64x4 = zeros_f64x4;
582
440
 
583
441
  nk_size_t i = 0;
584
442
  __m256d a_x_f64x4, a_y_f64x4, a_z_f64x4, b_x_f64x4, b_y_f64x4, b_z_f64x4;
585
443
 
586
- // Fused single-pass
444
+ // Fused single-pass (centroids + covariance + norm-squared for folded SSD)
587
445
  for (; i + 4 <= n; i += 4) {
588
446
  nk_deinterleave_f64x4_haswell_(a + i * 3, &a_x_f64x4, &a_y_f64x4, &a_z_f64x4);
589
447
  nk_deinterleave_f64x4_haswell_(b + i * 3, &b_x_f64x4, &b_y_f64x4, &b_z_f64x4);
@@ -595,15 +453,21 @@ NK_PUBLIC void nk_kabsch_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_si
595
453
  sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, b_y_f64x4);
596
454
  sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, b_z_f64x4);
597
455
 
598
- cov_xx_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_x_f64x4, cov_xx_f64x4);
599
- cov_xy_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_y_f64x4, cov_xy_f64x4);
600
- cov_xz_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_z_f64x4, cov_xz_f64x4);
601
- cov_yx_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_x_f64x4, cov_yx_f64x4);
602
- cov_yy_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_y_f64x4, cov_yy_f64x4);
603
- cov_yz_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_z_f64x4, cov_yz_f64x4);
604
- cov_zx_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_x_f64x4, cov_zx_f64x4);
605
- cov_zy_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_y_f64x4, cov_zy_f64x4);
606
- cov_zz_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_z_f64x4, cov_zz_f64x4);
456
+ covariance_xx_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_x_f64x4, covariance_xx_f64x4);
457
+ covariance_xy_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_y_f64x4, covariance_xy_f64x4);
458
+ covariance_xz_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_z_f64x4, covariance_xz_f64x4);
459
+ covariance_yx_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_x_f64x4, covariance_yx_f64x4);
460
+ covariance_yy_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_y_f64x4, covariance_yy_f64x4);
461
+ covariance_yz_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_z_f64x4, covariance_yz_f64x4);
462
+ covariance_zx_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_x_f64x4, covariance_zx_f64x4);
463
+ covariance_zy_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_y_f64x4, covariance_zy_f64x4);
464
+ covariance_zz_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_z_f64x4, covariance_zz_f64x4);
465
+ norm_squared_a_f64x4 = _mm256_fmadd_pd(a_x_f64x4, a_x_f64x4, norm_squared_a_f64x4);
466
+ norm_squared_a_f64x4 = _mm256_fmadd_pd(a_y_f64x4, a_y_f64x4, norm_squared_a_f64x4);
467
+ norm_squared_a_f64x4 = _mm256_fmadd_pd(a_z_f64x4, a_z_f64x4, norm_squared_a_f64x4);
468
+ norm_squared_b_f64x4 = _mm256_fmadd_pd(b_x_f64x4, b_x_f64x4, norm_squared_b_f64x4);
469
+ norm_squared_b_f64x4 = _mm256_fmadd_pd(b_y_f64x4, b_y_f64x4, norm_squared_b_f64x4);
470
+ norm_squared_b_f64x4 = _mm256_fmadd_pd(b_z_f64x4, b_z_f64x4, norm_squared_b_f64x4);
607
471
  }
608
472
 
609
473
  // Reduce vector accumulators
@@ -614,15 +478,19 @@ NK_PUBLIC void nk_kabsch_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_si
614
478
  nk_f64_t sum_b_y = nk_reduce_stable_f64x4_haswell_(sum_b_y_f64x4), sum_b_y_compensation = 0.0;
615
479
  nk_f64_t sum_b_z = nk_reduce_stable_f64x4_haswell_(sum_b_z_f64x4), sum_b_z_compensation = 0.0;
616
480
 
617
- nk_f64_t covariance_x_x = nk_reduce_stable_f64x4_haswell_(cov_xx_f64x4), covariance_x_x_compensation = 0.0;
618
- nk_f64_t covariance_x_y = nk_reduce_stable_f64x4_haswell_(cov_xy_f64x4), covariance_x_y_compensation = 0.0;
619
- nk_f64_t covariance_x_z = nk_reduce_stable_f64x4_haswell_(cov_xz_f64x4), covariance_x_z_compensation = 0.0;
620
- nk_f64_t covariance_y_x = nk_reduce_stable_f64x4_haswell_(cov_yx_f64x4), covariance_y_x_compensation = 0.0;
621
- nk_f64_t covariance_y_y = nk_reduce_stable_f64x4_haswell_(cov_yy_f64x4), covariance_y_y_compensation = 0.0;
622
- nk_f64_t covariance_y_z = nk_reduce_stable_f64x4_haswell_(cov_yz_f64x4), covariance_y_z_compensation = 0.0;
623
- nk_f64_t covariance_z_x = nk_reduce_stable_f64x4_haswell_(cov_zx_f64x4), covariance_z_x_compensation = 0.0;
624
- nk_f64_t covariance_z_y = nk_reduce_stable_f64x4_haswell_(cov_zy_f64x4), covariance_z_y_compensation = 0.0;
625
- nk_f64_t covariance_z_z = nk_reduce_stable_f64x4_haswell_(cov_zz_f64x4), covariance_z_z_compensation = 0.0;
481
+ nk_f64_t covariance_x_x = nk_reduce_stable_f64x4_haswell_(covariance_xx_f64x4), covariance_x_x_compensation = 0.0;
482
+ nk_f64_t covariance_x_y = nk_reduce_stable_f64x4_haswell_(covariance_xy_f64x4), covariance_x_y_compensation = 0.0;
483
+ nk_f64_t covariance_x_z = nk_reduce_stable_f64x4_haswell_(covariance_xz_f64x4), covariance_x_z_compensation = 0.0;
484
+ nk_f64_t covariance_y_x = nk_reduce_stable_f64x4_haswell_(covariance_yx_f64x4), covariance_y_x_compensation = 0.0;
485
+ nk_f64_t covariance_y_y = nk_reduce_stable_f64x4_haswell_(covariance_yy_f64x4), covariance_y_y_compensation = 0.0;
486
+ nk_f64_t covariance_y_z = nk_reduce_stable_f64x4_haswell_(covariance_yz_f64x4), covariance_y_z_compensation = 0.0;
487
+ nk_f64_t covariance_z_x = nk_reduce_stable_f64x4_haswell_(covariance_zx_f64x4), covariance_z_x_compensation = 0.0;
488
+ nk_f64_t covariance_z_y = nk_reduce_stable_f64x4_haswell_(covariance_zy_f64x4), covariance_z_y_compensation = 0.0;
489
+ nk_f64_t covariance_z_z = nk_reduce_stable_f64x4_haswell_(covariance_zz_f64x4), covariance_z_z_compensation = 0.0;
490
+ nk_f64_t norm_squared_a_sum = nk_reduce_stable_f64x4_haswell_(norm_squared_a_f64x4),
491
+ norm_squared_a_compensation = 0.0;
492
+ nk_f64_t norm_squared_b_sum = nk_reduce_stable_f64x4_haswell_(norm_squared_b_f64x4),
493
+ norm_squared_b_compensation = 0.0;
626
494
 
627
495
  // Scalar tail
628
496
  for (; i < n; ++i) {
@@ -643,6 +511,12 @@ NK_PUBLIC void nk_kabsch_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_si
643
511
  nk_accumulate_product_f64_(&covariance_z_x, &covariance_z_x_compensation, az, bx);
644
512
  nk_accumulate_product_f64_(&covariance_z_y, &covariance_z_y_compensation, az, by);
645
513
  nk_accumulate_product_f64_(&covariance_z_z, &covariance_z_z_compensation, az, bz);
514
+ nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, ax);
515
+ nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, ay);
516
+ nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, az);
517
+ nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, bx);
518
+ nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, by);
519
+ nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, bz);
646
520
  }
647
521
 
648
522
  sum_a_x += sum_a_x_compensation, sum_a_y += sum_a_y_compensation, sum_a_z += sum_a_z_compensation;
@@ -653,6 +527,8 @@ NK_PUBLIC void nk_kabsch_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_si
653
527
  covariance_y_z += covariance_y_z_compensation;
654
528
  covariance_z_x += covariance_z_x_compensation, covariance_z_y += covariance_z_y_compensation,
655
529
  covariance_z_z += covariance_z_z_compensation;
530
+ norm_squared_a_sum += norm_squared_a_compensation;
531
+ norm_squared_b_sum += norm_squared_b_compensation;
656
532
 
657
533
  // Compute centroids
658
534
  nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
@@ -677,29 +553,59 @@ NK_PUBLIC void nk_kabsch_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_si
677
553
  covariance_z_y -= (nk_f64_t)n * centroid_a_z * centroid_b_y;
678
554
  covariance_z_z -= (nk_f64_t)n * centroid_a_z * centroid_b_z;
679
555
 
680
- // Compute SVD and optimal rotation using f64 precision (svd_s is 9-element diagonal matrix)
681
556
  nk_f64_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
682
557
  covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
683
- nk_f64_t svd_u[9], svd_s[9], svd_v[9];
684
- nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
685
558
 
686
- nk_f64_t r[9];
687
- nk_rotation_from_svd_f64_haswell_(svd_u, svd_v, r);
688
-
689
- // Handle reflection: if det(R) < 0, negate third column of V and recompute R
690
- if (nk_det3x3_f64_(r) < 0) {
691
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
692
- nk_rotation_from_svd_f64_haswell_(svd_u, svd_v, r);
559
+ // Folded SSD via trace identity: SSD = ‖a-ā‖² + ‖b-b̄‖² − 2·trace(R · H_centered).
560
+ nk_f64_t centered_norm_squared_a = norm_squared_a_sum -
561
+ (nk_f64_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
562
+ centroid_a_z * centroid_a_z);
563
+ nk_f64_t centered_norm_squared_b = norm_squared_b_sum -
564
+ (nk_f64_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
565
+ centroid_b_z * centroid_b_z);
566
+ if (centered_norm_squared_a < 0.0) centered_norm_squared_a = 0.0;
567
+ if (centered_norm_squared_b < 0.0) centered_norm_squared_b = 0.0;
568
+
569
+ // Identity-dominant short-circuit: if H is essentially diagonal with positive diagonals, R = I.
570
+ nk_f64_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
571
+ cross_covariance[4] * cross_covariance[4] +
572
+ cross_covariance[8] * cross_covariance[8];
573
+ nk_f64_t covariance_offdiagonal_norm_squared =
574
+ cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
575
+ cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
576
+ cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
577
+ nk_f64_t optimal_rotation[9];
578
+ nk_f64_t trace_rotation_covariance;
579
+ if (covariance_offdiagonal_norm_squared < 1e-20 * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0 &&
580
+ cross_covariance[4] > 0.0 && cross_covariance[8] > 0.0) {
581
+ optimal_rotation[0] = 1.0, optimal_rotation[1] = 0.0, optimal_rotation[2] = 0.0;
582
+ optimal_rotation[3] = 0.0, optimal_rotation[4] = 1.0, optimal_rotation[5] = 0.0;
583
+ optimal_rotation[6] = 0.0, optimal_rotation[7] = 0.0, optimal_rotation[8] = 1.0;
584
+ trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
585
+ }
586
+ else {
587
+ nk_f64_t svd_left[9], svd_diagonal[9], svd_right[9];
588
+ nk_svd3x3_f64_(cross_covariance, svd_left, svd_diagonal, svd_right);
589
+ nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
590
+ if (nk_det3x3_f64_(optimal_rotation) < 0) {
591
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
592
+ nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
593
+ }
594
+ trace_rotation_covariance =
595
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
596
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
597
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
598
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
599
+ optimal_rotation[8] * cross_covariance[8];
693
600
  }
694
601
 
695
602
  // Output rotation matrix and scale=1.0
696
603
  if (rotation)
697
- for (int j = 0; j < 9; ++j) rotation[j] = r[j];
604
+ for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
698
605
  if (scale) *scale = 1.0;
699
606
 
700
- // Compute RMSD after optimal rotation
701
- nk_f64_t sum_squared = nk_transformed_ssd_f64_haswell_(a, b, n, r, 1.0, centroid_a_x, centroid_a_y, centroid_a_z,
702
- centroid_b_x, centroid_b_y, centroid_b_z);
607
+ nk_f64_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0 * trace_rotation_covariance;
608
+ if (sum_squared < 0.0) sum_squared = 0.0;
703
609
  *result = nk_f64_sqrt_haswell(sum_squared * inv_n);
704
610
  }
705
611
 
@@ -712,7 +618,8 @@ NK_PUBLIC void nk_umeyama_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_s
712
618
  __m256d covariance_02_f64x4 = _mm256_setzero_pd(), covariance_10_f64x4 = _mm256_setzero_pd();
713
619
  __m256d covariance_11_f64x4 = _mm256_setzero_pd(), covariance_12_f64x4 = _mm256_setzero_pd();
714
620
  __m256d covariance_20_f64x4 = _mm256_setzero_pd(), covariance_21_f64x4 = _mm256_setzero_pd();
715
- __m256d covariance_22_f64x4 = _mm256_setzero_pd(), variance_a_f64x4 = _mm256_setzero_pd();
621
+ __m256d covariance_22_f64x4 = _mm256_setzero_pd();
622
+ __m256d norm_squared_a_f64x4 = _mm256_setzero_pd(), norm_squared_b_f64x4 = _mm256_setzero_pd();
716
623
  __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
717
624
  nk_size_t index = 0;
718
625
 
@@ -765,14 +672,22 @@ NK_PUBLIC void nk_umeyama_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_s
765
672
  covariance_22_f64x4 = _mm256_add_pd(
766
673
  covariance_22_f64x4,
767
674
  _mm256_add_pd(_mm256_mul_pd(a_z_low_f64x4, b_z_low_f64x4), _mm256_mul_pd(a_z_high_f64x4, b_z_high_f64x4)));
768
- variance_a_f64x4 = _mm256_add_pd(
769
- variance_a_f64x4,
675
+ norm_squared_a_f64x4 = _mm256_add_pd(
676
+ norm_squared_a_f64x4,
770
677
  _mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a_x_low_f64x4, a_x_low_f64x4),
771
678
  _mm256_mul_pd(a_x_high_f64x4, a_x_high_f64x4)),
772
679
  _mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a_y_low_f64x4, a_y_low_f64x4),
773
680
  _mm256_mul_pd(a_y_high_f64x4, a_y_high_f64x4)),
774
681
  _mm256_add_pd(_mm256_mul_pd(a_z_low_f64x4, a_z_low_f64x4),
775
682
  _mm256_mul_pd(a_z_high_f64x4, a_z_high_f64x4)))));
683
+ norm_squared_b_f64x4 = _mm256_add_pd(
684
+ norm_squared_b_f64x4,
685
+ _mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(b_x_low_f64x4, b_x_low_f64x4),
686
+ _mm256_mul_pd(b_x_high_f64x4, b_x_high_f64x4)),
687
+ _mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(b_y_low_f64x4, b_y_low_f64x4),
688
+ _mm256_mul_pd(b_y_high_f64x4, b_y_high_f64x4)),
689
+ _mm256_add_pd(_mm256_mul_pd(b_z_low_f64x4, b_z_low_f64x4),
690
+ _mm256_mul_pd(b_z_high_f64x4, b_z_high_f64x4)))));
776
691
  }
777
692
 
778
693
  nk_f64_t sum_a_x = nk_reduce_add_f64x4_haswell_(sum_a_x_f64x4);
@@ -781,23 +696,25 @@ NK_PUBLIC void nk_umeyama_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_s
781
696
  nk_f64_t sum_b_x = nk_reduce_add_f64x4_haswell_(sum_b_x_f64x4);
782
697
  nk_f64_t sum_b_y = nk_reduce_add_f64x4_haswell_(sum_b_y_f64x4);
783
698
  nk_f64_t sum_b_z = nk_reduce_add_f64x4_haswell_(sum_b_z_f64x4);
784
- nk_f64_t h[9] = {
699
+ nk_f64_t cross_covariance[9] = {
785
700
  nk_reduce_add_f64x4_haswell_(covariance_00_f64x4), nk_reduce_add_f64x4_haswell_(covariance_01_f64x4),
786
701
  nk_reduce_add_f64x4_haswell_(covariance_02_f64x4), nk_reduce_add_f64x4_haswell_(covariance_10_f64x4),
787
702
  nk_reduce_add_f64x4_haswell_(covariance_11_f64x4), nk_reduce_add_f64x4_haswell_(covariance_12_f64x4),
788
703
  nk_reduce_add_f64x4_haswell_(covariance_20_f64x4), nk_reduce_add_f64x4_haswell_(covariance_21_f64x4),
789
704
  nk_reduce_add_f64x4_haswell_(covariance_22_f64x4)};
790
- nk_f64_t variance_a = nk_reduce_add_f64x4_haswell_(variance_a_f64x4);
705
+ nk_f64_t norm_squared_a_sum = nk_reduce_add_f64x4_haswell_(norm_squared_a_f64x4);
706
+ nk_f64_t norm_squared_b_sum = nk_reduce_add_f64x4_haswell_(norm_squared_b_f64x4);
791
707
 
792
708
  for (; index < n; ++index) {
793
709
  nk_f64_t a_x = a[index * 3 + 0], a_y = a[index * 3 + 1], a_z = a[index * 3 + 2];
794
710
  nk_f64_t b_x = b[index * 3 + 0], b_y = b[index * 3 + 1], b_z = b[index * 3 + 2];
795
711
  sum_a_x += a_x, sum_a_y += a_y, sum_a_z += a_z;
796
712
  sum_b_x += b_x, sum_b_y += b_y, sum_b_z += b_z;
797
- h[0] += a_x * b_x, h[1] += a_x * b_y, h[2] += a_x * b_z;
798
- h[3] += a_y * b_x, h[4] += a_y * b_y, h[5] += a_y * b_z;
799
- h[6] += a_z * b_x, h[7] += a_z * b_y, h[8] += a_z * b_z;
800
- variance_a += a_x * a_x + a_y * a_y + a_z * a_z;
713
+ cross_covariance[0] += a_x * b_x, cross_covariance[1] += a_x * b_y, cross_covariance[2] += a_x * b_z;
714
+ cross_covariance[3] += a_y * b_x, cross_covariance[4] += a_y * b_y, cross_covariance[5] += a_y * b_z;
715
+ cross_covariance[6] += a_z * b_x, cross_covariance[7] += a_z * b_y, cross_covariance[8] += a_z * b_z;
716
+ norm_squared_a_sum += a_x * a_x + a_y * a_y + a_z * a_z;
717
+ norm_squared_b_sum += b_x * b_x + b_y * b_y + b_z * b_z;
801
718
  }
802
719
 
803
720
  nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
@@ -810,49 +727,89 @@ NK_PUBLIC void nk_umeyama_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_s
810
727
  b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
811
728
  b_centroid[2] = (nk_f32_t)centroid_b_z;
812
729
 
813
- variance_a = variance_a * inv_n -
814
- (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
815
- h[0] -= (nk_f64_t)n * centroid_a_x * centroid_b_x, h[1] -= (nk_f64_t)n * centroid_a_x * centroid_b_y,
816
- h[2] -= (nk_f64_t)n * centroid_a_x * centroid_b_z, h[3] -= (nk_f64_t)n * centroid_a_y * centroid_b_x,
817
- h[4] -= (nk_f64_t)n * centroid_a_y * centroid_b_y, h[5] -= (nk_f64_t)n * centroid_a_y * centroid_b_z,
818
- h[6] -= (nk_f64_t)n * centroid_a_z * centroid_b_x, h[7] -= (nk_f64_t)n * centroid_a_z * centroid_b_y,
819
- h[8] -= (nk_f64_t)n * centroid_a_z * centroid_b_z;
820
-
821
- nk_f64_t cross_covariance[9] = {h[0], h[1], h[2], h[3], h[4], h[5], h[6], h[7], h[8]};
822
- nk_f64_t svd_u[9], svd_s[9], svd_v[9], r[9];
823
- nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
824
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
825
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
826
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
827
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
828
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
829
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
830
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
831
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
832
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
833
-
834
- nk_f64_t det = nk_det3x3_f64_(r), sign_correction = det < 0 ? -1.0 : 1.0;
835
- if (det < 0) {
836
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
837
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
838
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
839
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
840
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
841
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
842
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
843
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
844
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
845
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
730
+ nk_f64_t centered_norm_squared_a = norm_squared_a_sum -
731
+ (nk_f64_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
732
+ centroid_a_z * centroid_a_z);
733
+ nk_f64_t centered_norm_squared_b = norm_squared_b_sum -
734
+ (nk_f64_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
735
+ centroid_b_z * centroid_b_z);
736
+ if (centered_norm_squared_a < 0.0) centered_norm_squared_a = 0.0;
737
+ if (centered_norm_squared_b < 0.0) centered_norm_squared_b = 0.0;
738
+ cross_covariance[0] -= (nk_f64_t)n * centroid_a_x * centroid_b_x,
739
+ cross_covariance[1] -= (nk_f64_t)n * centroid_a_x * centroid_b_y,
740
+ cross_covariance[2] -= (nk_f64_t)n * centroid_a_x * centroid_b_z,
741
+ cross_covariance[3] -= (nk_f64_t)n * centroid_a_y * centroid_b_x,
742
+ cross_covariance[4] -= (nk_f64_t)n * centroid_a_y * centroid_b_y,
743
+ cross_covariance[5] -= (nk_f64_t)n * centroid_a_y * centroid_b_z,
744
+ cross_covariance[6] -= (nk_f64_t)n * centroid_a_z * centroid_b_x,
745
+ cross_covariance[7] -= (nk_f64_t)n * centroid_a_z * centroid_b_y,
746
+ cross_covariance[8] -= (nk_f64_t)n * centroid_a_z * centroid_b_z;
747
+
748
+ // Identity-dominant short-circuit: if H is essentially diagonal with positive diagonals,
749
+ // R = I and trace(DS) reduces to trace(H) directly.
750
+ nk_f64_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
751
+ cross_covariance[4] * cross_covariance[4] +
752
+ cross_covariance[8] * cross_covariance[8];
753
+ nk_f64_t covariance_offdiagonal_norm_squared =
754
+ cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
755
+ cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
756
+ cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
757
+ nk_f64_t optimal_rotation[9];
758
+ nk_f64_t applied_scale;
759
+ nk_f64_t trace_rotation_covariance;
760
+ if (covariance_offdiagonal_norm_squared < 1e-20 * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0 &&
761
+ cross_covariance[4] > 0.0 && cross_covariance[8] > 0.0) {
762
+ optimal_rotation[0] = 1.0, optimal_rotation[1] = 0.0, optimal_rotation[2] = 0.0;
763
+ optimal_rotation[3] = 0.0, optimal_rotation[4] = 1.0, optimal_rotation[5] = 0.0;
764
+ optimal_rotation[6] = 0.0, optimal_rotation[7] = 0.0, optimal_rotation[8] = 1.0;
765
+ trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
766
+ applied_scale = centered_norm_squared_a > 0.0 ? trace_rotation_covariance / centered_norm_squared_a : 0.0;
767
+ }
768
+ else {
769
+ nk_f64_t svd_left[9], svd_diagonal[9], svd_right[9];
770
+ nk_svd3x3_f64_(cross_covariance, svd_left, svd_diagonal, svd_right);
771
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
772
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
773
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
774
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
775
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
776
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
777
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
778
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
779
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
780
+
781
+ nk_f64_t det = nk_det3x3_f64_(optimal_rotation), sign_correction = det < 0 ? -1.0 : 1.0;
782
+ if (det < 0) {
783
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
784
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
785
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
786
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
787
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
788
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
789
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
790
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
791
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
792
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
793
+ }
794
+ nk_f64_t trace_ds = svd_diagonal[0] + svd_diagonal[4] + sign_correction * svd_diagonal[8];
795
+ applied_scale = centered_norm_squared_a > 0.0 ? trace_ds / centered_norm_squared_a : 0.0;
796
+ trace_rotation_covariance =
797
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
798
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
799
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
800
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
801
+ optimal_rotation[8] * cross_covariance[8];
846
802
  }
847
803
 
848
- nk_f64_t applied_scale = (svd_s[0] + svd_s[4] + sign_correction * svd_s[8]) / ((nk_f64_t)n * variance_a);
849
804
  if (rotation)
850
- for (int j = 0; j != 9; ++j) rotation[j] = (nk_f32_t)r[j];
805
+ for (int j = 0; j != 9; ++j) rotation[j] = (nk_f32_t)optimal_rotation[j];
851
806
  if (scale) *scale = (nk_f32_t)applied_scale;
852
- *result = nk_f64_sqrt_haswell(nk_transformed_ssd_f32_haswell_(a, b, n, r, applied_scale, centroid_a_x, centroid_a_y,
853
- centroid_a_z, centroid_b_x, centroid_b_y,
854
- centroid_b_z) /
855
- (nk_f64_t)n);
807
+
808
+ // Folded SSD with scale: c²·‖a-ā‖² + ‖b-b̄‖² − 2c·trace(R · H_centered).
809
+ nk_f64_t sum_squared = applied_scale * applied_scale * centered_norm_squared_a + centered_norm_squared_b -
810
+ 2.0 * applied_scale * trace_rotation_covariance;
811
+ if (sum_squared < 0.0) sum_squared = 0.0;
812
+ *result = nk_f64_sqrt_haswell(sum_squared / (nk_f64_t)n);
856
813
  }
857
814
 
858
815
  NK_PUBLIC void nk_umeyama_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
@@ -862,10 +819,10 @@ NK_PUBLIC void nk_umeyama_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_s
862
819
 
863
820
  __m256d sum_a_x_f64x4 = zeros_f64x4, sum_a_y_f64x4 = zeros_f64x4, sum_a_z_f64x4 = zeros_f64x4;
864
821
  __m256d sum_b_x_f64x4 = zeros_f64x4, sum_b_y_f64x4 = zeros_f64x4, sum_b_z_f64x4 = zeros_f64x4;
865
- __m256d cov_xx_f64x4 = zeros_f64x4, cov_xy_f64x4 = zeros_f64x4, cov_xz_f64x4 = zeros_f64x4;
866
- __m256d cov_yx_f64x4 = zeros_f64x4, cov_yy_f64x4 = zeros_f64x4, cov_yz_f64x4 = zeros_f64x4;
867
- __m256d cov_zx_f64x4 = zeros_f64x4, cov_zy_f64x4 = zeros_f64x4, cov_zz_f64x4 = zeros_f64x4;
868
- __m256d variance_a_f64x4 = zeros_f64x4;
822
+ __m256d covariance_xx_f64x4 = zeros_f64x4, covariance_xy_f64x4 = zeros_f64x4, covariance_xz_f64x4 = zeros_f64x4;
823
+ __m256d covariance_yx_f64x4 = zeros_f64x4, covariance_yy_f64x4 = zeros_f64x4, covariance_yz_f64x4 = zeros_f64x4;
824
+ __m256d covariance_zx_f64x4 = zeros_f64x4, covariance_zy_f64x4 = zeros_f64x4, covariance_zz_f64x4 = zeros_f64x4;
825
+ __m256d norm_squared_a_f64x4 = zeros_f64x4, norm_squared_b_f64x4 = zeros_f64x4;
869
826
 
870
827
  nk_size_t i = 0;
871
828
  __m256d a_x_f64x4, a_y_f64x4, a_z_f64x4, b_x_f64x4, b_y_f64x4, b_z_f64x4;
@@ -881,18 +838,21 @@ NK_PUBLIC void nk_umeyama_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_s
881
838
  sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, b_y_f64x4);
882
839
  sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, b_z_f64x4);
883
840
 
884
- cov_xx_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_x_f64x4, cov_xx_f64x4),
885
- cov_xy_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_y_f64x4, cov_xy_f64x4);
886
- cov_xz_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_z_f64x4, cov_xz_f64x4);
887
- cov_yx_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_x_f64x4, cov_yx_f64x4),
888
- cov_yy_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_y_f64x4, cov_yy_f64x4);
889
- cov_yz_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_z_f64x4, cov_yz_f64x4);
890
- cov_zx_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_x_f64x4, cov_zx_f64x4),
891
- cov_zy_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_y_f64x4, cov_zy_f64x4);
892
- cov_zz_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_z_f64x4, cov_zz_f64x4);
893
- variance_a_f64x4 = _mm256_fmadd_pd(a_x_f64x4, a_x_f64x4, variance_a_f64x4);
894
- variance_a_f64x4 = _mm256_fmadd_pd(a_y_f64x4, a_y_f64x4, variance_a_f64x4);
895
- variance_a_f64x4 = _mm256_fmadd_pd(a_z_f64x4, a_z_f64x4, variance_a_f64x4);
841
+ covariance_xx_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_x_f64x4, covariance_xx_f64x4),
842
+ covariance_xy_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_y_f64x4, covariance_xy_f64x4);
843
+ covariance_xz_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_z_f64x4, covariance_xz_f64x4);
844
+ covariance_yx_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_x_f64x4, covariance_yx_f64x4),
845
+ covariance_yy_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_y_f64x4, covariance_yy_f64x4);
846
+ covariance_yz_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_z_f64x4, covariance_yz_f64x4);
847
+ covariance_zx_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_x_f64x4, covariance_zx_f64x4),
848
+ covariance_zy_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_y_f64x4, covariance_zy_f64x4);
849
+ covariance_zz_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_z_f64x4, covariance_zz_f64x4);
850
+ norm_squared_a_f64x4 = _mm256_fmadd_pd(a_x_f64x4, a_x_f64x4, norm_squared_a_f64x4);
851
+ norm_squared_a_f64x4 = _mm256_fmadd_pd(a_y_f64x4, a_y_f64x4, norm_squared_a_f64x4);
852
+ norm_squared_a_f64x4 = _mm256_fmadd_pd(a_z_f64x4, a_z_f64x4, norm_squared_a_f64x4);
853
+ norm_squared_b_f64x4 = _mm256_fmadd_pd(b_x_f64x4, b_x_f64x4, norm_squared_b_f64x4);
854
+ norm_squared_b_f64x4 = _mm256_fmadd_pd(b_y_f64x4, b_y_f64x4, norm_squared_b_f64x4);
855
+ norm_squared_b_f64x4 = _mm256_fmadd_pd(b_z_f64x4, b_z_f64x4, norm_squared_b_f64x4);
896
856
  }
897
857
 
898
858
  // Reduce vector accumulators
@@ -902,16 +862,19 @@ NK_PUBLIC void nk_umeyama_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_s
902
862
  nk_f64_t sum_b_x = nk_reduce_stable_f64x4_haswell_(sum_b_x_f64x4), sum_b_x_compensation = 0.0;
903
863
  nk_f64_t sum_b_y = nk_reduce_stable_f64x4_haswell_(sum_b_y_f64x4), sum_b_y_compensation = 0.0;
904
864
  nk_f64_t sum_b_z = nk_reduce_stable_f64x4_haswell_(sum_b_z_f64x4), sum_b_z_compensation = 0.0;
905
- nk_f64_t covariance_x_x = nk_reduce_stable_f64x4_haswell_(cov_xx_f64x4), covariance_x_x_compensation = 0.0;
906
- nk_f64_t covariance_x_y = nk_reduce_stable_f64x4_haswell_(cov_xy_f64x4), covariance_x_y_compensation = 0.0;
907
- nk_f64_t covariance_x_z = nk_reduce_stable_f64x4_haswell_(cov_xz_f64x4), covariance_x_z_compensation = 0.0;
908
- nk_f64_t covariance_y_x = nk_reduce_stable_f64x4_haswell_(cov_yx_f64x4), covariance_y_x_compensation = 0.0;
909
- nk_f64_t covariance_y_y = nk_reduce_stable_f64x4_haswell_(cov_yy_f64x4), covariance_y_y_compensation = 0.0;
910
- nk_f64_t covariance_y_z = nk_reduce_stable_f64x4_haswell_(cov_yz_f64x4), covariance_y_z_compensation = 0.0;
911
- nk_f64_t covariance_z_x = nk_reduce_stable_f64x4_haswell_(cov_zx_f64x4), covariance_z_x_compensation = 0.0;
912
- nk_f64_t covariance_z_y = nk_reduce_stable_f64x4_haswell_(cov_zy_f64x4), covariance_z_y_compensation = 0.0;
913
- nk_f64_t covariance_z_z = nk_reduce_stable_f64x4_haswell_(cov_zz_f64x4), covariance_z_z_compensation = 0.0;
914
- nk_f64_t variance_a_sum = nk_reduce_stable_f64x4_haswell_(variance_a_f64x4), variance_a_compensation = 0.0;
865
+ nk_f64_t covariance_x_x = nk_reduce_stable_f64x4_haswell_(covariance_xx_f64x4), covariance_x_x_compensation = 0.0;
866
+ nk_f64_t covariance_x_y = nk_reduce_stable_f64x4_haswell_(covariance_xy_f64x4), covariance_x_y_compensation = 0.0;
867
+ nk_f64_t covariance_x_z = nk_reduce_stable_f64x4_haswell_(covariance_xz_f64x4), covariance_x_z_compensation = 0.0;
868
+ nk_f64_t covariance_y_x = nk_reduce_stable_f64x4_haswell_(covariance_yx_f64x4), covariance_y_x_compensation = 0.0;
869
+ nk_f64_t covariance_y_y = nk_reduce_stable_f64x4_haswell_(covariance_yy_f64x4), covariance_y_y_compensation = 0.0;
870
+ nk_f64_t covariance_y_z = nk_reduce_stable_f64x4_haswell_(covariance_yz_f64x4), covariance_y_z_compensation = 0.0;
871
+ nk_f64_t covariance_z_x = nk_reduce_stable_f64x4_haswell_(covariance_zx_f64x4), covariance_z_x_compensation = 0.0;
872
+ nk_f64_t covariance_z_y = nk_reduce_stable_f64x4_haswell_(covariance_zy_f64x4), covariance_z_y_compensation = 0.0;
873
+ nk_f64_t covariance_z_z = nk_reduce_stable_f64x4_haswell_(covariance_zz_f64x4), covariance_z_z_compensation = 0.0;
874
+ nk_f64_t norm_squared_a_sum = nk_reduce_stable_f64x4_haswell_(norm_squared_a_f64x4),
875
+ norm_squared_a_compensation = 0.0;
876
+ nk_f64_t norm_squared_b_sum = nk_reduce_stable_f64x4_haswell_(norm_squared_b_f64x4),
877
+ norm_squared_b_compensation = 0.0;
915
878
 
916
879
  // Scalar tail loop for remaining points
917
880
  for (; i < n; i++) {
@@ -932,9 +895,12 @@ NK_PUBLIC void nk_umeyama_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_s
932
895
  nk_accumulate_product_f64_(&covariance_z_x, &covariance_z_x_compensation, az, bx);
933
896
  nk_accumulate_product_f64_(&covariance_z_y, &covariance_z_y_compensation, az, by);
934
897
  nk_accumulate_product_f64_(&covariance_z_z, &covariance_z_z_compensation, az, bz);
935
- nk_accumulate_square_f64_(&variance_a_sum, &variance_a_compensation, ax);
936
- nk_accumulate_square_f64_(&variance_a_sum, &variance_a_compensation, ay);
937
- nk_accumulate_square_f64_(&variance_a_sum, &variance_a_compensation, az);
898
+ nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, ax);
899
+ nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, ay);
900
+ nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, az);
901
+ nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, bx);
902
+ nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, by);
903
+ nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, bz);
938
904
  }
939
905
 
940
906
  sum_a_x += sum_a_x_compensation, sum_a_y += sum_a_y_compensation, sum_a_z += sum_a_z_compensation;
@@ -945,7 +911,8 @@ NK_PUBLIC void nk_umeyama_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_s
945
911
  covariance_y_z += covariance_y_z_compensation;
946
912
  covariance_z_x += covariance_z_x_compensation, covariance_z_y += covariance_z_y_compensation,
947
913
  covariance_z_z += covariance_z_z_compensation;
948
- variance_a_sum += variance_a_compensation;
914
+ norm_squared_a_sum += norm_squared_a_compensation;
915
+ norm_squared_b_sum += norm_squared_b_compensation;
949
916
 
950
917
  // Compute centroids
951
918
  nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
@@ -956,9 +923,15 @@ NK_PUBLIC void nk_umeyama_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_s
956
923
  if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
957
924
  if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
958
925
 
959
- // Compute centered covariance and variance
960
- nk_f64_t variance_a = variance_a_sum * inv_n -
961
- (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
926
+ // Centered norm-squared via parallel-axis identity; clamped at zero for numeric safety.
927
+ nk_f64_t centered_norm_squared_a = norm_squared_a_sum -
928
+ (nk_f64_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
929
+ centroid_a_z * centroid_a_z);
930
+ nk_f64_t centered_norm_squared_b = norm_squared_b_sum -
931
+ (nk_f64_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
932
+ centroid_b_z * centroid_b_z);
933
+ if (centered_norm_squared_a < 0.0) centered_norm_squared_a = 0.0;
934
+ if (centered_norm_squared_b < 0.0) centered_norm_squared_b = 0.0;
962
935
 
963
936
  nk_f64_t cross_covariance[9];
964
937
  cross_covariance[0] = covariance_x_x - sum_a_x * sum_b_x * inv_n;
@@ -971,34 +944,56 @@ NK_PUBLIC void nk_umeyama_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_s
971
944
  cross_covariance[7] = covariance_z_y - sum_a_z * sum_b_y * inv_n;
972
945
  cross_covariance[8] = covariance_z_z - sum_a_z * sum_b_z * inv_n;
973
946
 
974
- // SVD using f64 for full precision (svd_s is 9-element diagonal matrix)
975
- nk_f64_t svd_u[9], svd_s[9], svd_v[9];
976
- nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
977
-
978
- nk_f64_t r[9];
979
- nk_rotation_from_svd_f64_haswell_(svd_u, svd_v, r);
980
-
981
- // Scale factor: c = trace(D × S) / (n × variance(a))
982
- // svd_s diagonal: [0], [4], [8]
983
- nk_f64_t det = nk_det3x3_f64_(r);
984
- nk_f64_t d3 = det < 0 ? -1.0 : 1.0;
985
- nk_f64_t trace_ds = nk_sum_three_products_f64_(svd_s[0], 1.0, svd_s[4], 1.0, svd_s[8], d3);
986
- nk_f64_t c = trace_ds / ((nk_f64_t)n * variance_a);
987
- if (scale) *scale = c;
988
-
989
- // Handle reflection
990
- if (det < 0) {
991
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
992
- nk_rotation_from_svd_f64_haswell_(svd_u, svd_v, r);
947
+ // Identity-dominant short-circuit: if H is essentially diagonal with positive diagonals,
948
+ // R = I and trace(DS) reduces to trace(H) directly.
949
+ nk_f64_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
950
+ cross_covariance[4] * cross_covariance[4] +
951
+ cross_covariance[8] * cross_covariance[8];
952
+ nk_f64_t covariance_offdiagonal_norm_squared =
953
+ cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
954
+ cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
955
+ cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
956
+ nk_f64_t optimal_rotation[9];
957
+ nk_f64_t c;
958
+ nk_f64_t trace_rotation_covariance;
959
+ if (covariance_offdiagonal_norm_squared < 1e-20 * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0 &&
960
+ cross_covariance[4] > 0.0 && cross_covariance[8] > 0.0) {
961
+ optimal_rotation[0] = 1.0, optimal_rotation[1] = 0.0, optimal_rotation[2] = 0.0;
962
+ optimal_rotation[3] = 0.0, optimal_rotation[4] = 1.0, optimal_rotation[5] = 0.0;
963
+ optimal_rotation[6] = 0.0, optimal_rotation[7] = 0.0, optimal_rotation[8] = 1.0;
964
+ trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
965
+ c = centered_norm_squared_a > 0.0 ? trace_rotation_covariance / centered_norm_squared_a : 0.0;
966
+ }
967
+ else {
968
+ nk_f64_t svd_left[9], svd_diagonal[9], svd_right[9];
969
+ nk_svd3x3_f64_(cross_covariance, svd_left, svd_diagonal, svd_right);
970
+ nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
971
+
972
+ nk_f64_t det = nk_det3x3_f64_(optimal_rotation);
973
+ nk_f64_t d3 = det < 0 ? -1.0 : 1.0;
974
+ nk_f64_t trace_ds = nk_sum_three_products_f64_(svd_diagonal[0], 1.0, svd_diagonal[4], 1.0, svd_diagonal[8], d3);
975
+ c = centered_norm_squared_a > 0.0 ? trace_ds / centered_norm_squared_a : 0.0;
976
+
977
+ if (det < 0) {
978
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
979
+ nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
980
+ }
981
+ trace_rotation_covariance =
982
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
983
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
984
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
985
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
986
+ optimal_rotation[8] * cross_covariance[8];
993
987
  }
994
988
 
995
- // Output rotation matrix
989
+ if (scale) *scale = c;
996
990
  if (rotation)
997
- for (int j = 0; j < 9; ++j) rotation[j] = r[j];
991
+ for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
998
992
 
999
- // Compute RMSD with scaling
1000
- nk_f64_t sum_squared = nk_transformed_ssd_f64_haswell_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
1001
- centroid_b_x, centroid_b_y, centroid_b_z);
993
+ // Folded SSD with scale: c²·‖a-ā‖² + ‖b-b̄‖² − 2c·trace(R · H_centered).
994
+ nk_f64_t sum_squared = c * c * centered_norm_squared_a + centered_norm_squared_b -
995
+ 2.0 * c * trace_rotation_covariance;
996
+ if (sum_squared < 0.0) sum_squared = 0.0;
1002
997
  *result = nk_f64_sqrt_haswell(sum_squared * inv_n);
1003
998
  }
1004
999
 
@@ -1046,237 +1041,34 @@ NK_INTERNAL void nk_deinterleave_bf16x8_to_f32x8_haswell_(nk_bf16_t const *ptr,
1046
1041
  *z_out = nk_bf16x8_to_f32x8_haswell_(z_vec.xmms[0]);
1047
1042
  }
1048
1043
 
1049
- /* Compute sum of squared distances for f16 data after applying rotation (and optional scale).
1050
- * Loads f16 data, converts to f32 during processing.
1051
- * Note: rotation matrix r is f32 (from SVD), scale and computation done in f32.
1052
- */
1053
- NK_INTERNAL nk_f32_t nk_transformed_ssd_f16_haswell_(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
1054
- nk_f32_t const *r, nk_f32_t scale, nk_f32_t centroid_a_x,
1055
- nk_f32_t centroid_a_y, nk_f32_t centroid_a_z,
1056
- nk_f32_t centroid_b_x, nk_f32_t centroid_b_y,
1057
- nk_f32_t centroid_b_z) {
1058
- // Broadcast scaled rotation matrix elements
1059
- __m256 scaled_rotation_x_x_f32x8 = _mm256_set1_ps(scale * r[0]);
1060
- __m256 scaled_rotation_x_y_f32x8 = _mm256_set1_ps(scale * r[1]);
1061
- __m256 scaled_rotation_x_z_f32x8 = _mm256_set1_ps(scale * r[2]);
1062
- __m256 scaled_rotation_y_x_f32x8 = _mm256_set1_ps(scale * r[3]);
1063
- __m256 scaled_rotation_y_y_f32x8 = _mm256_set1_ps(scale * r[4]);
1064
- __m256 scaled_rotation_y_z_f32x8 = _mm256_set1_ps(scale * r[5]);
1065
- __m256 scaled_rotation_z_x_f32x8 = _mm256_set1_ps(scale * r[6]);
1066
- __m256 scaled_rotation_z_y_f32x8 = _mm256_set1_ps(scale * r[7]);
1067
- __m256 scaled_rotation_z_z_f32x8 = _mm256_set1_ps(scale * r[8]);
1068
-
1069
- // Broadcast centroids
1070
- __m256 centroid_a_x_f32x8 = _mm256_set1_ps(centroid_a_x);
1071
- __m256 centroid_a_y_f32x8 = _mm256_set1_ps(centroid_a_y);
1072
- __m256 centroid_a_z_f32x8 = _mm256_set1_ps(centroid_a_z);
1073
- __m256 centroid_b_x_f32x8 = _mm256_set1_ps(centroid_b_x);
1074
- __m256 centroid_b_y_f32x8 = _mm256_set1_ps(centroid_b_y);
1075
- __m256 centroid_b_z_f32x8 = _mm256_set1_ps(centroid_b_z);
1076
-
1077
- __m256 sum_squared_f32x8 = _mm256_setzero_ps();
1078
- __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
1079
- nk_size_t j = 0;
1080
-
1081
- for (; j + 8 <= n; j += 8) {
1082
- nk_deinterleave_f16x8_to_f32x8_haswell_(a + j * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
1083
- nk_deinterleave_f16x8_to_f32x8_haswell_(b + j * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
1084
-
1085
- // Center points
1086
- __m256 pa_x_f32x8 = _mm256_sub_ps(a_x_f32x8, centroid_a_x_f32x8);
1087
- __m256 pa_y_f32x8 = _mm256_sub_ps(a_y_f32x8, centroid_a_y_f32x8);
1088
- __m256 pa_z_f32x8 = _mm256_sub_ps(a_z_f32x8, centroid_a_z_f32x8);
1089
- __m256 pb_x_f32x8 = _mm256_sub_ps(b_x_f32x8, centroid_b_x_f32x8);
1090
- __m256 pb_y_f32x8 = _mm256_sub_ps(b_y_f32x8, centroid_b_y_f32x8);
1091
- __m256 pb_z_f32x8 = _mm256_sub_ps(b_z_f32x8, centroid_b_z_f32x8);
1092
-
1093
- // Rotate and scale: ra = scale * R * pa
1094
- __m256 ra_x_f32x8 = _mm256_fmadd_ps(scaled_rotation_x_z_f32x8, pa_z_f32x8,
1095
- _mm256_fmadd_ps(scaled_rotation_x_y_f32x8, pa_y_f32x8,
1096
- _mm256_mul_ps(scaled_rotation_x_x_f32x8, pa_x_f32x8)));
1097
- __m256 ra_y_f32x8 = _mm256_fmadd_ps(scaled_rotation_y_z_f32x8, pa_z_f32x8,
1098
- _mm256_fmadd_ps(scaled_rotation_y_y_f32x8, pa_y_f32x8,
1099
- _mm256_mul_ps(scaled_rotation_y_x_f32x8, pa_x_f32x8)));
1100
- __m256 ra_z_f32x8 = _mm256_fmadd_ps(scaled_rotation_z_z_f32x8, pa_z_f32x8,
1101
- _mm256_fmadd_ps(scaled_rotation_z_y_f32x8, pa_y_f32x8,
1102
- _mm256_mul_ps(scaled_rotation_z_x_f32x8, pa_x_f32x8)));
1103
-
1104
- // Delta and accumulate
1105
- __m256 delta_x_f32x8 = _mm256_sub_ps(ra_x_f32x8, pb_x_f32x8);
1106
- __m256 delta_y_f32x8 = _mm256_sub_ps(ra_y_f32x8, pb_y_f32x8);
1107
- __m256 delta_z_f32x8 = _mm256_sub_ps(ra_z_f32x8, pb_z_f32x8);
1108
-
1109
- sum_squared_f32x8 = _mm256_fmadd_ps(delta_x_f32x8, delta_x_f32x8, sum_squared_f32x8);
1110
- sum_squared_f32x8 = _mm256_fmadd_ps(delta_y_f32x8, delta_y_f32x8, sum_squared_f32x8);
1111
- sum_squared_f32x8 = _mm256_fmadd_ps(delta_z_f32x8, delta_z_f32x8, sum_squared_f32x8);
1112
- }
1113
-
1114
- nk_f32_t sum_squared = nk_reduce_add_f32x8_haswell_(sum_squared_f32x8);
1115
-
1116
- // Scalar tail
1117
- for (; j < n; ++j) {
1118
- nk_f32_t a_x_f32, a_y_f32, a_z_f32, b_x_f32, b_y_f32, b_z_f32;
1119
- nk_f16_to_f32_haswell(&a[j * 3 + 0], &a_x_f32);
1120
- nk_f16_to_f32_haswell(&a[j * 3 + 1], &a_y_f32);
1121
- nk_f16_to_f32_haswell(&a[j * 3 + 2], &a_z_f32);
1122
- nk_f16_to_f32_haswell(&b[j * 3 + 0], &b_x_f32);
1123
- nk_f16_to_f32_haswell(&b[j * 3 + 1], &b_y_f32);
1124
- nk_f16_to_f32_haswell(&b[j * 3 + 2], &b_z_f32);
1125
-
1126
- nk_f32_t pa_x = a_x_f32 - centroid_a_x, pa_y = a_y_f32 - centroid_a_y, pa_z = a_z_f32 - centroid_a_z;
1127
- nk_f32_t pb_x = b_x_f32 - centroid_b_x, pb_y = b_y_f32 - centroid_b_y, pb_z = b_z_f32 - centroid_b_z;
1128
- nk_f32_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z),
1129
- ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z),
1130
- ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
1131
-
1132
- nk_f32_t delta_x = ra_x - pb_x, delta_y = ra_y - pb_y, delta_z = ra_z - pb_z;
1133
- sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
1134
- }
1135
-
1136
- return sum_squared;
1137
- }
1138
-
1139
- /* Compute sum of squared distances for bf16 data after applying rotation (and optional scale).
1140
- * Loads bf16 data, converts to f32 during processing.
1141
- * Note: rotation matrix r is f32 (from SVD), scale and computation done in f32.
1142
- */
1143
- NK_INTERNAL nk_f32_t nk_transformed_ssd_bf16_haswell_(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n,
1144
- nk_f32_t const *r, nk_f32_t scale, nk_f32_t centroid_a_x,
1145
- nk_f32_t centroid_a_y, nk_f32_t centroid_a_z,
1146
- nk_f32_t centroid_b_x, nk_f32_t centroid_b_y,
1147
- nk_f32_t centroid_b_z) {
1148
- // Broadcast scaled rotation matrix elements
1149
- __m256 scaled_rotation_x_x_f32x8 = _mm256_set1_ps(scale * r[0]);
1150
- __m256 scaled_rotation_x_y_f32x8 = _mm256_set1_ps(scale * r[1]);
1151
- __m256 scaled_rotation_x_z_f32x8 = _mm256_set1_ps(scale * r[2]);
1152
- __m256 scaled_rotation_y_x_f32x8 = _mm256_set1_ps(scale * r[3]);
1153
- __m256 scaled_rotation_y_y_f32x8 = _mm256_set1_ps(scale * r[4]);
1154
- __m256 scaled_rotation_y_z_f32x8 = _mm256_set1_ps(scale * r[5]);
1155
- __m256 scaled_rotation_z_x_f32x8 = _mm256_set1_ps(scale * r[6]);
1156
- __m256 scaled_rotation_z_y_f32x8 = _mm256_set1_ps(scale * r[7]);
1157
- __m256 scaled_rotation_z_z_f32x8 = _mm256_set1_ps(scale * r[8]);
1158
-
1159
- // Broadcast centroids
1160
- __m256 centroid_a_x_f32x8 = _mm256_set1_ps(centroid_a_x);
1161
- __m256 centroid_a_y_f32x8 = _mm256_set1_ps(centroid_a_y);
1162
- __m256 centroid_a_z_f32x8 = _mm256_set1_ps(centroid_a_z);
1163
- __m256 centroid_b_x_f32x8 = _mm256_set1_ps(centroid_b_x);
1164
- __m256 centroid_b_y_f32x8 = _mm256_set1_ps(centroid_b_y);
1165
- __m256 centroid_b_z_f32x8 = _mm256_set1_ps(centroid_b_z);
1166
-
1167
- __m256 sum_squared_f32x8 = _mm256_setzero_ps();
1168
- __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
1169
- nk_size_t j = 0;
1170
-
1171
- for (; j + 8 <= n; j += 8) {
1172
- nk_deinterleave_bf16x8_to_f32x8_haswell_(a + j * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
1173
- nk_deinterleave_bf16x8_to_f32x8_haswell_(b + j * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
1174
-
1175
- // Center points
1176
- __m256 pa_x_f32x8 = _mm256_sub_ps(a_x_f32x8, centroid_a_x_f32x8);
1177
- __m256 pa_y_f32x8 = _mm256_sub_ps(a_y_f32x8, centroid_a_y_f32x8);
1178
- __m256 pa_z_f32x8 = _mm256_sub_ps(a_z_f32x8, centroid_a_z_f32x8);
1179
- __m256 pb_x_f32x8 = _mm256_sub_ps(b_x_f32x8, centroid_b_x_f32x8);
1180
- __m256 pb_y_f32x8 = _mm256_sub_ps(b_y_f32x8, centroid_b_y_f32x8);
1181
- __m256 pb_z_f32x8 = _mm256_sub_ps(b_z_f32x8, centroid_b_z_f32x8);
1182
-
1183
- // Rotate and scale: ra = scale * R * pa
1184
- __m256 ra_x_f32x8 = _mm256_fmadd_ps(scaled_rotation_x_z_f32x8, pa_z_f32x8,
1185
- _mm256_fmadd_ps(scaled_rotation_x_y_f32x8, pa_y_f32x8,
1186
- _mm256_mul_ps(scaled_rotation_x_x_f32x8, pa_x_f32x8)));
1187
- __m256 ra_y_f32x8 = _mm256_fmadd_ps(scaled_rotation_y_z_f32x8, pa_z_f32x8,
1188
- _mm256_fmadd_ps(scaled_rotation_y_y_f32x8, pa_y_f32x8,
1189
- _mm256_mul_ps(scaled_rotation_y_x_f32x8, pa_x_f32x8)));
1190
- __m256 ra_z_f32x8 = _mm256_fmadd_ps(scaled_rotation_z_z_f32x8, pa_z_f32x8,
1191
- _mm256_fmadd_ps(scaled_rotation_z_y_f32x8, pa_y_f32x8,
1192
- _mm256_mul_ps(scaled_rotation_z_x_f32x8, pa_x_f32x8)));
1193
-
1194
- // Delta and accumulate
1195
- __m256 delta_x_f32x8 = _mm256_sub_ps(ra_x_f32x8, pb_x_f32x8);
1196
- __m256 delta_y_f32x8 = _mm256_sub_ps(ra_y_f32x8, pb_y_f32x8);
1197
- __m256 delta_z_f32x8 = _mm256_sub_ps(ra_z_f32x8, pb_z_f32x8);
1198
-
1199
- sum_squared_f32x8 = _mm256_fmadd_ps(delta_x_f32x8, delta_x_f32x8, sum_squared_f32x8);
1200
- sum_squared_f32x8 = _mm256_fmadd_ps(delta_y_f32x8, delta_y_f32x8, sum_squared_f32x8);
1201
- sum_squared_f32x8 = _mm256_fmadd_ps(delta_z_f32x8, delta_z_f32x8, sum_squared_f32x8);
1202
- }
1203
-
1204
- nk_f32_t sum_squared = nk_reduce_add_f32x8_haswell_(sum_squared_f32x8);
1205
-
1206
- // Scalar tail
1207
- for (; j < n; ++j) {
1208
- nk_f32_t a_x_f32, a_y_f32, a_z_f32, b_x_f32, b_y_f32, b_z_f32;
1209
- nk_bf16_to_f32_serial(&a[j * 3 + 0], &a_x_f32);
1210
- nk_bf16_to_f32_serial(&a[j * 3 + 1], &a_y_f32);
1211
- nk_bf16_to_f32_serial(&a[j * 3 + 2], &a_z_f32);
1212
- nk_bf16_to_f32_serial(&b[j * 3 + 0], &b_x_f32);
1213
- nk_bf16_to_f32_serial(&b[j * 3 + 1], &b_y_f32);
1214
- nk_bf16_to_f32_serial(&b[j * 3 + 2], &b_z_f32);
1215
-
1216
- nk_f32_t pa_x = a_x_f32 - centroid_a_x, pa_y = a_y_f32 - centroid_a_y, pa_z = a_z_f32 - centroid_a_z;
1217
- nk_f32_t pb_x = b_x_f32 - centroid_b_x, pb_y = b_y_f32 - centroid_b_y, pb_z = b_z_f32 - centroid_b_z;
1218
- nk_f32_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z),
1219
- ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z),
1220
- ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
1221
-
1222
- nk_f32_t delta_x = ra_x - pb_x, delta_y = ra_y - pb_y, delta_z = ra_z - pb_z;
1223
- sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
1224
- }
1225
-
1226
- return sum_squared;
1227
- }
1228
-
1229
1044
  NK_PUBLIC void nk_rmsd_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
1230
1045
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
1231
- // RMSD uses identity rotation and scale=1.0
1232
1046
  if (rotation)
1233
1047
  rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
1234
1048
  rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
1235
1049
  if (scale) *scale = 1.0f;
1050
+ if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
1051
+ if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
1236
1052
 
1237
1053
  __m256 const zeros_f32x8 = _mm256_setzero_ps();
1238
-
1239
- // Accumulators for centroids and squared differences (all in f32)
1240
- __m256 sum_a_x_f32x8 = zeros_f32x8, sum_a_y_f32x8 = zeros_f32x8, sum_a_z_f32x8 = zeros_f32x8;
1241
- __m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
1242
1054
  __m256 sum_squared_x_f32x8 = zeros_f32x8, sum_squared_y_f32x8 = zeros_f32x8, sum_squared_z_f32x8 = zeros_f32x8;
1243
-
1244
1055
  __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
1245
1056
  nk_size_t i = 0;
1246
1057
 
1247
- // Main loop processing 8 points at a time
1248
1058
  for (; i + 8 <= n; i += 8) {
1249
1059
  nk_deinterleave_f16x8_to_f32x8_haswell_(a + i * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
1250
1060
  nk_deinterleave_f16x8_to_f32x8_haswell_(b + i * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
1251
-
1252
- sum_a_x_f32x8 = _mm256_add_ps(sum_a_x_f32x8, a_x_f32x8);
1253
- sum_a_y_f32x8 = _mm256_add_ps(sum_a_y_f32x8, a_y_f32x8);
1254
- sum_a_z_f32x8 = _mm256_add_ps(sum_a_z_f32x8, a_z_f32x8);
1255
- sum_b_x_f32x8 = _mm256_add_ps(sum_b_x_f32x8, b_x_f32x8);
1256
- sum_b_y_f32x8 = _mm256_add_ps(sum_b_y_f32x8, b_y_f32x8);
1257
- sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
1258
-
1259
1061
  __m256 delta_x_f32x8 = _mm256_sub_ps(a_x_f32x8, b_x_f32x8);
1260
1062
  __m256 delta_y_f32x8 = _mm256_sub_ps(a_y_f32x8, b_y_f32x8);
1261
1063
  __m256 delta_z_f32x8 = _mm256_sub_ps(a_z_f32x8, b_z_f32x8);
1262
-
1263
1064
  sum_squared_x_f32x8 = _mm256_fmadd_ps(delta_x_f32x8, delta_x_f32x8, sum_squared_x_f32x8);
1264
1065
  sum_squared_y_f32x8 = _mm256_fmadd_ps(delta_y_f32x8, delta_y_f32x8, sum_squared_y_f32x8);
1265
1066
  sum_squared_z_f32x8 = _mm256_fmadd_ps(delta_z_f32x8, delta_z_f32x8, sum_squared_z_f32x8);
1266
1067
  }
1267
1068
 
1268
- // Reduce vectors to scalars
1269
- nk_f32_t total_ax = nk_reduce_add_f32x8_haswell_(sum_a_x_f32x8);
1270
- nk_f32_t total_ay = nk_reduce_add_f32x8_haswell_(sum_a_y_f32x8);
1271
- nk_f32_t total_az = nk_reduce_add_f32x8_haswell_(sum_a_z_f32x8);
1272
- nk_f32_t total_bx = nk_reduce_add_f32x8_haswell_(sum_b_x_f32x8);
1273
- nk_f32_t total_by = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
1274
- nk_f32_t total_bz = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
1275
- nk_f32_t total_sq_x = nk_reduce_add_f32x8_haswell_(sum_squared_x_f32x8);
1276
- nk_f32_t total_sq_y = nk_reduce_add_f32x8_haswell_(sum_squared_y_f32x8);
1277
- nk_f32_t total_sq_z = nk_reduce_add_f32x8_haswell_(sum_squared_z_f32x8);
1278
-
1279
- // Scalar tail
1069
+ nk_f32_t sum_squared = nk_reduce_add_f32x8_haswell_(sum_squared_x_f32x8) +
1070
+ nk_reduce_add_f32x8_haswell_(sum_squared_y_f32x8) +
1071
+ nk_reduce_add_f32x8_haswell_(sum_squared_z_f32x8);
1280
1072
  for (; i < n; ++i) {
1281
1073
  nk_f32_t ax, ay, az, bx, by, bz;
1282
1074
  nk_f16_to_f32_haswell(&a[i * 3 + 0], &ax);
@@ -1285,91 +1077,41 @@ NK_PUBLIC void nk_rmsd_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size
1285
1077
  nk_f16_to_f32_haswell(&b[i * 3 + 0], &bx);
1286
1078
  nk_f16_to_f32_haswell(&b[i * 3 + 1], &by);
1287
1079
  nk_f16_to_f32_haswell(&b[i * 3 + 2], &bz);
1288
- total_ax += ax;
1289
- total_ay += ay;
1290
- total_az += az;
1291
- total_bx += bx;
1292
- total_by += by;
1293
- total_bz += bz;
1294
1080
  nk_f32_t delta_x = ax - bx, delta_y = ay - by, delta_z = az - bz;
1295
- total_sq_x += delta_x * delta_x;
1296
- total_sq_y += delta_y * delta_y;
1297
- total_sq_z += delta_z * delta_z;
1081
+ sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
1298
1082
  }
1299
1083
 
1300
- // Compute centroids
1301
- nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
1302
- nk_f32_t centroid_a_x = total_ax * inv_n;
1303
- nk_f32_t centroid_a_y = total_ay * inv_n;
1304
- nk_f32_t centroid_a_z = total_az * inv_n;
1305
- nk_f32_t centroid_b_x = total_bx * inv_n;
1306
- nk_f32_t centroid_b_y = total_by * inv_n;
1307
- nk_f32_t centroid_b_z = total_bz * inv_n;
1308
-
1309
- if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
1310
- if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
1311
-
1312
- // Compute RMSD
1313
- nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
1314
- nk_f32_t mean_diff_y = centroid_a_y - centroid_b_y;
1315
- nk_f32_t mean_diff_z = centroid_a_z - centroid_b_z;
1316
- nk_f32_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
1317
- nk_f32_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
1318
-
1319
- *result = nk_f32_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
1084
+ *result = nk_f32_sqrt_haswell(sum_squared / (nk_f32_t)n);
1320
1085
  }
1321
1086
 
1322
1087
  NK_PUBLIC void nk_rmsd_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
1323
1088
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
1324
- // RMSD uses identity rotation and scale=1.0
1325
1089
  if (rotation)
1326
1090
  rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
1327
1091
  rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
1328
1092
  if (scale) *scale = 1.0f;
1093
+ if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
1094
+ if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
1329
1095
 
1330
1096
  __m256 const zeros_f32x8 = _mm256_setzero_ps();
1331
-
1332
- // Accumulators for centroids and squared differences (all in f32)
1333
- __m256 sum_a_x_f32x8 = zeros_f32x8, sum_a_y_f32x8 = zeros_f32x8, sum_a_z_f32x8 = zeros_f32x8;
1334
- __m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
1335
1097
  __m256 sum_squared_x_f32x8 = zeros_f32x8, sum_squared_y_f32x8 = zeros_f32x8, sum_squared_z_f32x8 = zeros_f32x8;
1336
-
1337
1098
  __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
1338
1099
  nk_size_t i = 0;
1339
1100
 
1340
- // Main loop processing 8 points at a time
1341
1101
  for (; i + 8 <= n; i += 8) {
1342
1102
  nk_deinterleave_bf16x8_to_f32x8_haswell_(a + i * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
1343
1103
  nk_deinterleave_bf16x8_to_f32x8_haswell_(b + i * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
1344
-
1345
- sum_a_x_f32x8 = _mm256_add_ps(sum_a_x_f32x8, a_x_f32x8);
1346
- sum_a_y_f32x8 = _mm256_add_ps(sum_a_y_f32x8, a_y_f32x8);
1347
- sum_a_z_f32x8 = _mm256_add_ps(sum_a_z_f32x8, a_z_f32x8);
1348
- sum_b_x_f32x8 = _mm256_add_ps(sum_b_x_f32x8, b_x_f32x8);
1349
- sum_b_y_f32x8 = _mm256_add_ps(sum_b_y_f32x8, b_y_f32x8);
1350
- sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
1351
-
1352
1104
  __m256 delta_x_f32x8 = _mm256_sub_ps(a_x_f32x8, b_x_f32x8);
1353
1105
  __m256 delta_y_f32x8 = _mm256_sub_ps(a_y_f32x8, b_y_f32x8);
1354
1106
  __m256 delta_z_f32x8 = _mm256_sub_ps(a_z_f32x8, b_z_f32x8);
1355
-
1356
1107
  sum_squared_x_f32x8 = _mm256_fmadd_ps(delta_x_f32x8, delta_x_f32x8, sum_squared_x_f32x8);
1357
1108
  sum_squared_y_f32x8 = _mm256_fmadd_ps(delta_y_f32x8, delta_y_f32x8, sum_squared_y_f32x8);
1358
1109
  sum_squared_z_f32x8 = _mm256_fmadd_ps(delta_z_f32x8, delta_z_f32x8, sum_squared_z_f32x8);
1359
1110
  }
1360
1111
 
1361
- // Reduce vectors to scalars
1362
- nk_f32_t total_ax = nk_reduce_add_f32x8_haswell_(sum_a_x_f32x8);
1363
- nk_f32_t total_ay = nk_reduce_add_f32x8_haswell_(sum_a_y_f32x8);
1364
- nk_f32_t total_az = nk_reduce_add_f32x8_haswell_(sum_a_z_f32x8);
1365
- nk_f32_t total_bx = nk_reduce_add_f32x8_haswell_(sum_b_x_f32x8);
1366
- nk_f32_t total_by = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
1367
- nk_f32_t total_bz = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
1368
- nk_f32_t total_sq_x = nk_reduce_add_f32x8_haswell_(sum_squared_x_f32x8);
1369
- nk_f32_t total_sq_y = nk_reduce_add_f32x8_haswell_(sum_squared_y_f32x8);
1370
- nk_f32_t total_sq_z = nk_reduce_add_f32x8_haswell_(sum_squared_z_f32x8);
1371
-
1372
- // Scalar tail
1112
+ nk_f32_t sum_squared = nk_reduce_add_f32x8_haswell_(sum_squared_x_f32x8) +
1113
+ nk_reduce_add_f32x8_haswell_(sum_squared_y_f32x8) +
1114
+ nk_reduce_add_f32x8_haswell_(sum_squared_z_f32x8);
1373
1115
  for (; i < n; ++i) {
1374
1116
  nk_f32_t ax, ay, az, bx, by, bz;
1375
1117
  nk_bf16_to_f32_serial(&a[i * 3 + 0], &ax);
@@ -1378,43 +1120,16 @@ NK_PUBLIC void nk_rmsd_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_s
1378
1120
  nk_bf16_to_f32_serial(&b[i * 3 + 0], &bx);
1379
1121
  nk_bf16_to_f32_serial(&b[i * 3 + 1], &by);
1380
1122
  nk_bf16_to_f32_serial(&b[i * 3 + 2], &bz);
1381
- total_ax += ax;
1382
- total_ay += ay;
1383
- total_az += az;
1384
- total_bx += bx;
1385
- total_by += by;
1386
- total_bz += bz;
1387
1123
  nk_f32_t delta_x = ax - bx, delta_y = ay - by, delta_z = az - bz;
1388
- total_sq_x += delta_x * delta_x;
1389
- total_sq_y += delta_y * delta_y;
1390
- total_sq_z += delta_z * delta_z;
1124
+ sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
1391
1125
  }
1392
1126
 
1393
- // Compute centroids
1394
- nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
1395
- nk_f32_t centroid_a_x = total_ax * inv_n;
1396
- nk_f32_t centroid_a_y = total_ay * inv_n;
1397
- nk_f32_t centroid_a_z = total_az * inv_n;
1398
- nk_f32_t centroid_b_x = total_bx * inv_n;
1399
- nk_f32_t centroid_b_y = total_by * inv_n;
1400
- nk_f32_t centroid_b_z = total_bz * inv_n;
1401
-
1402
- if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
1403
- if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
1404
-
1405
- // Compute RMSD
1406
- nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
1407
- nk_f32_t mean_diff_y = centroid_a_y - centroid_b_y;
1408
- nk_f32_t mean_diff_z = centroid_a_z - centroid_b_z;
1409
- nk_f32_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
1410
- nk_f32_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
1411
-
1412
- *result = nk_f32_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
1127
+ *result = nk_f32_sqrt_haswell(sum_squared / (nk_f32_t)n);
1413
1128
  }
1414
1129
 
1415
1130
  NK_PUBLIC void nk_kabsch_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
1416
1131
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
1417
- // Fused single-pass: load f16, convert to f32, compute centroids and covariance
1132
+ // Fused single-pass: load f16, convert to f32, compute centroids, covariance, and ‖a‖²/‖b‖²
1418
1133
  __m256 const zeros_f32x8 = _mm256_setzero_ps();
1419
1134
 
1420
1135
  // Accumulators for centroids (f32)
@@ -1422,9 +1137,10 @@ NK_PUBLIC void nk_kabsch_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_si
1422
1137
  __m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
1423
1138
 
1424
1139
  // Accumulators for covariance matrix (sum of outer products)
1425
- __m256 cov_xx_f32x8 = zeros_f32x8, cov_xy_f32x8 = zeros_f32x8, cov_xz_f32x8 = zeros_f32x8;
1426
- __m256 cov_yx_f32x8 = zeros_f32x8, cov_yy_f32x8 = zeros_f32x8, cov_yz_f32x8 = zeros_f32x8;
1427
- __m256 cov_zx_f32x8 = zeros_f32x8, cov_zy_f32x8 = zeros_f32x8, cov_zz_f32x8 = zeros_f32x8;
1140
+ __m256 covariance_xx_f32x8 = zeros_f32x8, covariance_xy_f32x8 = zeros_f32x8, covariance_xz_f32x8 = zeros_f32x8;
1141
+ __m256 covariance_yx_f32x8 = zeros_f32x8, covariance_yy_f32x8 = zeros_f32x8, covariance_yz_f32x8 = zeros_f32x8;
1142
+ __m256 covariance_zx_f32x8 = zeros_f32x8, covariance_zy_f32x8 = zeros_f32x8, covariance_zz_f32x8 = zeros_f32x8;
1143
+ __m256 norm_squared_a_f32x8 = zeros_f32x8, norm_squared_b_f32x8 = zeros_f32x8;
1428
1144
 
1429
1145
  nk_size_t i = 0;
1430
1146
  __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
@@ -1442,15 +1158,23 @@ NK_PUBLIC void nk_kabsch_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_si
1442
1158
  sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
1443
1159
 
1444
1160
  // Accumulate outer products
1445
- cov_xx_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_x_f32x8, cov_xx_f32x8);
1446
- cov_xy_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_y_f32x8, cov_xy_f32x8);
1447
- cov_xz_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_z_f32x8, cov_xz_f32x8);
1448
- cov_yx_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_x_f32x8, cov_yx_f32x8);
1449
- cov_yy_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_y_f32x8, cov_yy_f32x8);
1450
- cov_yz_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_z_f32x8, cov_yz_f32x8);
1451
- cov_zx_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_x_f32x8, cov_zx_f32x8);
1452
- cov_zy_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_y_f32x8, cov_zy_f32x8);
1453
- cov_zz_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_z_f32x8, cov_zz_f32x8);
1161
+ covariance_xx_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_x_f32x8, covariance_xx_f32x8);
1162
+ covariance_xy_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_y_f32x8, covariance_xy_f32x8);
1163
+ covariance_xz_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_z_f32x8, covariance_xz_f32x8);
1164
+ covariance_yx_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_x_f32x8, covariance_yx_f32x8);
1165
+ covariance_yy_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_y_f32x8, covariance_yy_f32x8);
1166
+ covariance_yz_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_z_f32x8, covariance_yz_f32x8);
1167
+ covariance_zx_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_x_f32x8, covariance_zx_f32x8);
1168
+ covariance_zy_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_y_f32x8, covariance_zy_f32x8);
1169
+ covariance_zz_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_z_f32x8, covariance_zz_f32x8);
1170
+
1171
+ // Accumulate ‖a‖² and ‖b‖² for folded SSD
1172
+ norm_squared_a_f32x8 = _mm256_fmadd_ps(a_x_f32x8, a_x_f32x8, norm_squared_a_f32x8);
1173
+ norm_squared_a_f32x8 = _mm256_fmadd_ps(a_y_f32x8, a_y_f32x8, norm_squared_a_f32x8);
1174
+ norm_squared_a_f32x8 = _mm256_fmadd_ps(a_z_f32x8, a_z_f32x8, norm_squared_a_f32x8);
1175
+ norm_squared_b_f32x8 = _mm256_fmadd_ps(b_x_f32x8, b_x_f32x8, norm_squared_b_f32x8);
1176
+ norm_squared_b_f32x8 = _mm256_fmadd_ps(b_y_f32x8, b_y_f32x8, norm_squared_b_f32x8);
1177
+ norm_squared_b_f32x8 = _mm256_fmadd_ps(b_z_f32x8, b_z_f32x8, norm_squared_b_f32x8);
1454
1178
  }
1455
1179
 
1456
1180
  // Reduce vector accumulators
@@ -1461,15 +1185,17 @@ NK_PUBLIC void nk_kabsch_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_si
1461
1185
  nk_f32_t sum_b_y = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
1462
1186
  nk_f32_t sum_b_z = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
1463
1187
 
1464
- nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(cov_xx_f32x8);
1465
- nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(cov_xy_f32x8);
1466
- nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(cov_xz_f32x8);
1467
- nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(cov_yx_f32x8);
1468
- nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(cov_yy_f32x8);
1469
- nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(cov_yz_f32x8);
1470
- nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(cov_zx_f32x8);
1471
- nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(cov_zy_f32x8);
1472
- nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(cov_zz_f32x8);
1188
+ nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(covariance_xx_f32x8);
1189
+ nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(covariance_xy_f32x8);
1190
+ nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(covariance_xz_f32x8);
1191
+ nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(covariance_yx_f32x8);
1192
+ nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(covariance_yy_f32x8);
1193
+ nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(covariance_yz_f32x8);
1194
+ nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(covariance_zx_f32x8);
1195
+ nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(covariance_zy_f32x8);
1196
+ nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(covariance_zz_f32x8);
1197
+ nk_f32_t norm_squared_a_sum = nk_reduce_add_f32x8_haswell_(norm_squared_a_f32x8);
1198
+ nk_f32_t norm_squared_b_sum = nk_reduce_add_f32x8_haswell_(norm_squared_b_f32x8);
1473
1199
 
1474
1200
  // Scalar tail
1475
1201
  for (; i < n; ++i) {
@@ -1485,6 +1211,8 @@ NK_PUBLIC void nk_kabsch_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_si
1485
1211
  covariance_x_x += ax * bx, covariance_x_y += ax * by, covariance_x_z += ax * bz;
1486
1212
  covariance_y_x += ay * bx, covariance_y_y += ay * by, covariance_y_z += ay * bz;
1487
1213
  covariance_z_x += az * bx, covariance_z_y += az * by, covariance_z_z += az * bz;
1214
+ norm_squared_a_sum += ax * ax + ay * ay + az * az;
1215
+ norm_squared_b_sum += bx * bx + by * by + bz * bz;
1488
1216
  }
1489
1217
 
1490
1218
  // Compute centroids
@@ -1510,52 +1238,84 @@ NK_PUBLIC void nk_kabsch_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_si
1510
1238
  covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
1511
1239
  covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
1512
1240
 
1513
- // Compute SVD and optimal rotation
1514
1241
  nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
1515
1242
  covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
1516
- nk_f32_t svd_u[9], svd_s[9], svd_v[9];
1517
- nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
1518
-
1519
- // R = V * Uᵀ
1520
- nk_f32_t r[9];
1521
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
1522
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
1523
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
1524
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
1525
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
1526
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
1527
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
1528
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
1529
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1530
-
1531
- // Handle reflection: if det(R) < 0, negate third column of V and recompute R
1532
- if (nk_det3x3_f32_(r) < 0) {
1533
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
1534
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
1535
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
1536
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
1537
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
1538
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
1539
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
1540
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
1541
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
1542
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1243
+
1244
+ // Centered ‖a-ā‖², ‖b-b̄‖² via the parallel-axis identity
1245
+ nk_f32_t centered_norm_squared_a = norm_squared_a_sum -
1246
+ (nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
1247
+ centroid_a_z * centroid_a_z);
1248
+ nk_f32_t centered_norm_squared_b = norm_squared_b_sum -
1249
+ (nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
1250
+ centroid_b_z * centroid_b_z);
1251
+ if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
1252
+ if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
1253
+
1254
+ // Identity-dominant short-circuit: R = I, trace(R · H) = H[0]+H[4]+H[8]. Skips SVD + two
1255
+ // rotation reconstructions when the inputs are already aligned.
1256
+ nk_f32_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
1257
+ cross_covariance[4] * cross_covariance[4] +
1258
+ cross_covariance[8] * cross_covariance[8];
1259
+ nk_f32_t covariance_offdiagonal_norm_squared =
1260
+ cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
1261
+ cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
1262
+ cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
1263
+ nk_f32_t optimal_rotation[9];
1264
+ nk_f32_t trace_rotation_covariance;
1265
+ if (covariance_offdiagonal_norm_squared < 1e-12f * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0f &&
1266
+ cross_covariance[4] > 0.0f && cross_covariance[8] > 0.0f) {
1267
+ optimal_rotation[0] = 1.0f, optimal_rotation[1] = 0.0f, optimal_rotation[2] = 0.0f;
1268
+ optimal_rotation[3] = 0.0f, optimal_rotation[4] = 1.0f, optimal_rotation[5] = 0.0f;
1269
+ optimal_rotation[6] = 0.0f, optimal_rotation[7] = 0.0f, optimal_rotation[8] = 1.0f;
1270
+ trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
1271
+ }
1272
+ else {
1273
+ nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
1274
+ nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
1275
+ // R = V * Uᵀ
1276
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
1277
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
1278
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
1279
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
1280
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
1281
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
1282
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
1283
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
1284
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
1285
+ if (nk_det3x3_f32_(optimal_rotation) < 0) {
1286
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
1287
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
1288
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
1289
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
1290
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
1291
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
1292
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
1293
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
1294
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
1295
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
1296
+ }
1297
+ trace_rotation_covariance =
1298
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
1299
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
1300
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
1301
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
1302
+ optimal_rotation[8] * cross_covariance[8];
1543
1303
  }
1544
1304
 
1545
1305
  // Output rotation matrix and scale=1.0
1546
1306
  if (rotation)
1547
- for (int j = 0; j < 9; ++j) rotation[j] = r[j];
1307
+ for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
1548
1308
  if (scale) *scale = 1.0f;
1549
1309
 
1550
- // Compute RMSD after optimal rotation
1551
- nk_f32_t sum_squared = nk_transformed_ssd_f16_haswell_(a, b, n, r, 1.0f, centroid_a_x, centroid_a_y, centroid_a_z,
1552
- centroid_b_x, centroid_b_y, centroid_b_z);
1310
+ // Folded SSD via trace identity: SSD = ‖a-ā‖² + ‖b-b̄‖² − 2·trace(R · H_centered).
1311
+ nk_f32_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0f * trace_rotation_covariance;
1312
+ if (sum_squared < 0.0f) sum_squared = 0.0f;
1553
1313
  *result = nk_f32_sqrt_haswell(sum_squared * inv_n);
1554
1314
  }
1555
1315
 
1556
1316
  NK_PUBLIC void nk_kabsch_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
1557
1317
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
1558
- // Fused single-pass: load bf16, convert to f32, compute centroids and covariance
1318
+ // Fused single-pass: load bf16, convert to f32, compute centroids, covariance, and ‖a‖²/‖b‖²
1559
1319
  __m256 const zeros_f32x8 = _mm256_setzero_ps();
1560
1320
 
1561
1321
  // Accumulators for centroids (f32)
@@ -1563,9 +1323,10 @@ NK_PUBLIC void nk_kabsch_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk
1563
1323
  __m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
1564
1324
 
1565
1325
  // Accumulators for covariance matrix (sum of outer products)
1566
- __m256 cov_xx_f32x8 = zeros_f32x8, cov_xy_f32x8 = zeros_f32x8, cov_xz_f32x8 = zeros_f32x8;
1567
- __m256 cov_yx_f32x8 = zeros_f32x8, cov_yy_f32x8 = zeros_f32x8, cov_yz_f32x8 = zeros_f32x8;
1568
- __m256 cov_zx_f32x8 = zeros_f32x8, cov_zy_f32x8 = zeros_f32x8, cov_zz_f32x8 = zeros_f32x8;
1326
+ __m256 covariance_xx_f32x8 = zeros_f32x8, covariance_xy_f32x8 = zeros_f32x8, covariance_xz_f32x8 = zeros_f32x8;
1327
+ __m256 covariance_yx_f32x8 = zeros_f32x8, covariance_yy_f32x8 = zeros_f32x8, covariance_yz_f32x8 = zeros_f32x8;
1328
+ __m256 covariance_zx_f32x8 = zeros_f32x8, covariance_zy_f32x8 = zeros_f32x8, covariance_zz_f32x8 = zeros_f32x8;
1329
+ __m256 norm_squared_a_f32x8 = zeros_f32x8, norm_squared_b_f32x8 = zeros_f32x8;
1569
1330
 
1570
1331
  nk_size_t i = 0;
1571
1332
  __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
@@ -1583,15 +1344,23 @@ NK_PUBLIC void nk_kabsch_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk
1583
1344
  sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
1584
1345
 
1585
1346
  // Accumulate outer products
1586
- cov_xx_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_x_f32x8, cov_xx_f32x8);
1587
- cov_xy_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_y_f32x8, cov_xy_f32x8);
1588
- cov_xz_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_z_f32x8, cov_xz_f32x8);
1589
- cov_yx_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_x_f32x8, cov_yx_f32x8);
1590
- cov_yy_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_y_f32x8, cov_yy_f32x8);
1591
- cov_yz_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_z_f32x8, cov_yz_f32x8);
1592
- cov_zx_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_x_f32x8, cov_zx_f32x8);
1593
- cov_zy_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_y_f32x8, cov_zy_f32x8);
1594
- cov_zz_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_z_f32x8, cov_zz_f32x8);
1347
+ covariance_xx_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_x_f32x8, covariance_xx_f32x8);
1348
+ covariance_xy_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_y_f32x8, covariance_xy_f32x8);
1349
+ covariance_xz_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_z_f32x8, covariance_xz_f32x8);
1350
+ covariance_yx_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_x_f32x8, covariance_yx_f32x8);
1351
+ covariance_yy_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_y_f32x8, covariance_yy_f32x8);
1352
+ covariance_yz_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_z_f32x8, covariance_yz_f32x8);
1353
+ covariance_zx_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_x_f32x8, covariance_zx_f32x8);
1354
+ covariance_zy_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_y_f32x8, covariance_zy_f32x8);
1355
+ covariance_zz_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_z_f32x8, covariance_zz_f32x8);
1356
+
1357
+ // Accumulate ‖a‖² and ‖b‖² for folded SSD
1358
+ norm_squared_a_f32x8 = _mm256_fmadd_ps(a_x_f32x8, a_x_f32x8, norm_squared_a_f32x8);
1359
+ norm_squared_a_f32x8 = _mm256_fmadd_ps(a_y_f32x8, a_y_f32x8, norm_squared_a_f32x8);
1360
+ norm_squared_a_f32x8 = _mm256_fmadd_ps(a_z_f32x8, a_z_f32x8, norm_squared_a_f32x8);
1361
+ norm_squared_b_f32x8 = _mm256_fmadd_ps(b_x_f32x8, b_x_f32x8, norm_squared_b_f32x8);
1362
+ norm_squared_b_f32x8 = _mm256_fmadd_ps(b_y_f32x8, b_y_f32x8, norm_squared_b_f32x8);
1363
+ norm_squared_b_f32x8 = _mm256_fmadd_ps(b_z_f32x8, b_z_f32x8, norm_squared_b_f32x8);
1595
1364
  }
1596
1365
 
1597
1366
  // Reduce vector accumulators
@@ -1602,15 +1371,17 @@ NK_PUBLIC void nk_kabsch_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk
1602
1371
  nk_f32_t sum_b_y = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
1603
1372
  nk_f32_t sum_b_z = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
1604
1373
 
1605
- nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(cov_xx_f32x8);
1606
- nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(cov_xy_f32x8);
1607
- nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(cov_xz_f32x8);
1608
- nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(cov_yx_f32x8);
1609
- nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(cov_yy_f32x8);
1610
- nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(cov_yz_f32x8);
1611
- nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(cov_zx_f32x8);
1612
- nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(cov_zy_f32x8);
1613
- nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(cov_zz_f32x8);
1374
+ nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(covariance_xx_f32x8);
1375
+ nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(covariance_xy_f32x8);
1376
+ nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(covariance_xz_f32x8);
1377
+ nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(covariance_yx_f32x8);
1378
+ nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(covariance_yy_f32x8);
1379
+ nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(covariance_yz_f32x8);
1380
+ nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(covariance_zx_f32x8);
1381
+ nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(covariance_zy_f32x8);
1382
+ nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(covariance_zz_f32x8);
1383
+ nk_f32_t norm_squared_a_sum = nk_reduce_add_f32x8_haswell_(norm_squared_a_f32x8);
1384
+ nk_f32_t norm_squared_b_sum = nk_reduce_add_f32x8_haswell_(norm_squared_b_f32x8);
1614
1385
 
1615
1386
  // Scalar tail
1616
1387
  for (; i < n; ++i) {
@@ -1626,6 +1397,8 @@ NK_PUBLIC void nk_kabsch_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk
1626
1397
  covariance_x_x += ax * bx, covariance_x_y += ax * by, covariance_x_z += ax * bz;
1627
1398
  covariance_y_x += ay * bx, covariance_y_y += ay * by, covariance_y_z += ay * bz;
1628
1399
  covariance_z_x += az * bx, covariance_z_y += az * by, covariance_z_z += az * bz;
1400
+ norm_squared_a_sum += ax * ax + ay * ay + az * az;
1401
+ norm_squared_b_sum += bx * bx + by * by + bz * bz;
1629
1402
  }
1630
1403
 
1631
1404
  // Compute centroids
@@ -1651,60 +1424,92 @@ NK_PUBLIC void nk_kabsch_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk
1651
1424
  covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
1652
1425
  covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
1653
1426
 
1654
- // Compute SVD and optimal rotation
1655
1427
  nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
1656
1428
  covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
1657
- nk_f32_t svd_u[9], svd_s[9], svd_v[9];
1658
- nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
1659
-
1660
- // R = V * Uᵀ
1661
- nk_f32_t r[9];
1662
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
1663
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
1664
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
1665
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
1666
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
1667
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
1668
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
1669
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
1670
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1671
-
1672
- // Handle reflection: if det(R) < 0, negate third column of V and recompute R
1673
- if (nk_det3x3_f32_(r) < 0) {
1674
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
1675
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
1676
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
1677
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
1678
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
1679
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
1680
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
1681
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
1682
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
1683
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1429
+
1430
+ // Centered ‖a-ā‖², ‖b-b̄‖² via the parallel-axis identity
1431
+ nk_f32_t centered_norm_squared_a = norm_squared_a_sum -
1432
+ (nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
1433
+ centroid_a_z * centroid_a_z);
1434
+ nk_f32_t centered_norm_squared_b = norm_squared_b_sum -
1435
+ (nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
1436
+ centroid_b_z * centroid_b_z);
1437
+ if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
1438
+ if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
1439
+
1440
+ // Identity-dominant short-circuit: R = I, trace(R · H) = H[0]+H[4]+H[8]. Skips SVD + two
1441
+ // rotation reconstructions when the inputs are already aligned.
1442
+ nk_f32_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
1443
+ cross_covariance[4] * cross_covariance[4] +
1444
+ cross_covariance[8] * cross_covariance[8];
1445
+ nk_f32_t covariance_offdiagonal_norm_squared =
1446
+ cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
1447
+ cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
1448
+ cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
1449
+ nk_f32_t optimal_rotation[9];
1450
+ nk_f32_t trace_rotation_covariance;
1451
+ if (covariance_offdiagonal_norm_squared < 1e-12f * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0f &&
1452
+ cross_covariance[4] > 0.0f && cross_covariance[8] > 0.0f) {
1453
+ optimal_rotation[0] = 1.0f, optimal_rotation[1] = 0.0f, optimal_rotation[2] = 0.0f;
1454
+ optimal_rotation[3] = 0.0f, optimal_rotation[4] = 1.0f, optimal_rotation[5] = 0.0f;
1455
+ optimal_rotation[6] = 0.0f, optimal_rotation[7] = 0.0f, optimal_rotation[8] = 1.0f;
1456
+ trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
1457
+ }
1458
+ else {
1459
+ nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
1460
+ nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
1461
+ // R = V * Uᵀ
1462
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
1463
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
1464
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
1465
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
1466
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
1467
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
1468
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
1469
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
1470
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
1471
+ if (nk_det3x3_f32_(optimal_rotation) < 0) {
1472
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
1473
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
1474
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
1475
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
1476
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
1477
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
1478
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
1479
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
1480
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
1481
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
1482
+ }
1483
+ trace_rotation_covariance =
1484
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
1485
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
1486
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
1487
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
1488
+ optimal_rotation[8] * cross_covariance[8];
1684
1489
  }
1685
1490
 
1686
1491
  // Output rotation matrix and scale=1.0
1687
1492
  if (rotation)
1688
- for (int j = 0; j < 9; ++j) rotation[j] = r[j];
1493
+ for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
1689
1494
  if (scale) *scale = 1.0f;
1690
1495
 
1691
- // Compute RMSD after optimal rotation
1692
- nk_f32_t sum_squared = nk_transformed_ssd_bf16_haswell_(a, b, n, r, 1.0f, centroid_a_x, centroid_a_y, centroid_a_z,
1693
- centroid_b_x, centroid_b_y, centroid_b_z);
1496
+ // Folded SSD via trace identity: SSD = ‖a-ā‖² + ‖b-b̄‖² − 2·trace(R · H_centered).
1497
+ nk_f32_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0f * trace_rotation_covariance;
1498
+ if (sum_squared < 0.0f) sum_squared = 0.0f;
1694
1499
  *result = nk_f32_sqrt_haswell(sum_squared * inv_n);
1695
1500
  }
1696
1501
 
1697
1502
  NK_PUBLIC void nk_umeyama_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
1698
1503
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
1699
- // Fused single-pass: load f16, convert to f32, compute centroids, covariance, and variance
1504
+ // Fused single-pass: load f16, convert to f32, compute centroids, covariance, and ‖a‖²/‖b‖²
1700
1505
  __m256 const zeros_f32x8 = _mm256_setzero_ps();
1701
1506
 
1702
1507
  __m256 sum_a_x_f32x8 = zeros_f32x8, sum_a_y_f32x8 = zeros_f32x8, sum_a_z_f32x8 = zeros_f32x8;
1703
1508
  __m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
1704
- __m256 cov_xx_f32x8 = zeros_f32x8, cov_xy_f32x8 = zeros_f32x8, cov_xz_f32x8 = zeros_f32x8;
1705
- __m256 cov_yx_f32x8 = zeros_f32x8, cov_yy_f32x8 = zeros_f32x8, cov_yz_f32x8 = zeros_f32x8;
1706
- __m256 cov_zx_f32x8 = zeros_f32x8, cov_zy_f32x8 = zeros_f32x8, cov_zz_f32x8 = zeros_f32x8;
1707
- __m256 variance_a_f32x8 = zeros_f32x8;
1509
+ __m256 covariance_xx_f32x8 = zeros_f32x8, covariance_xy_f32x8 = zeros_f32x8, covariance_xz_f32x8 = zeros_f32x8;
1510
+ __m256 covariance_yx_f32x8 = zeros_f32x8, covariance_yy_f32x8 = zeros_f32x8, covariance_yz_f32x8 = zeros_f32x8;
1511
+ __m256 covariance_zx_f32x8 = zeros_f32x8, covariance_zy_f32x8 = zeros_f32x8, covariance_zz_f32x8 = zeros_f32x8;
1512
+ __m256 norm_squared_a_f32x8 = zeros_f32x8, norm_squared_b_f32x8 = zeros_f32x8;
1708
1513
 
1709
1514
  nk_size_t i = 0;
1710
1515
  __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
@@ -1722,20 +1527,23 @@ NK_PUBLIC void nk_umeyama_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_s
1722
1527
  sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
1723
1528
 
1724
1529
  // Accumulate outer products
1725
- cov_xx_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_x_f32x8, cov_xx_f32x8);
1726
- cov_xy_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_y_f32x8, cov_xy_f32x8);
1727
- cov_xz_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_z_f32x8, cov_xz_f32x8);
1728
- cov_yx_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_x_f32x8, cov_yx_f32x8);
1729
- cov_yy_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_y_f32x8, cov_yy_f32x8);
1730
- cov_yz_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_z_f32x8, cov_yz_f32x8);
1731
- cov_zx_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_x_f32x8, cov_zx_f32x8);
1732
- cov_zy_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_y_f32x8, cov_zy_f32x8);
1733
- cov_zz_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_z_f32x8, cov_zz_f32x8);
1734
-
1735
- // Accumulate variance of A
1736
- variance_a_f32x8 = _mm256_fmadd_ps(a_x_f32x8, a_x_f32x8, variance_a_f32x8);
1737
- variance_a_f32x8 = _mm256_fmadd_ps(a_y_f32x8, a_y_f32x8, variance_a_f32x8);
1738
- variance_a_f32x8 = _mm256_fmadd_ps(a_z_f32x8, a_z_f32x8, variance_a_f32x8);
1530
+ covariance_xx_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_x_f32x8, covariance_xx_f32x8);
1531
+ covariance_xy_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_y_f32x8, covariance_xy_f32x8);
1532
+ covariance_xz_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_z_f32x8, covariance_xz_f32x8);
1533
+ covariance_yx_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_x_f32x8, covariance_yx_f32x8);
1534
+ covariance_yy_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_y_f32x8, covariance_yy_f32x8);
1535
+ covariance_yz_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_z_f32x8, covariance_yz_f32x8);
1536
+ covariance_zx_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_x_f32x8, covariance_zx_f32x8);
1537
+ covariance_zy_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_y_f32x8, covariance_zy_f32x8);
1538
+ covariance_zz_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_z_f32x8, covariance_zz_f32x8);
1539
+
1540
+ // Accumulate ‖a‖² and ‖b‖² for folded SSD
1541
+ norm_squared_a_f32x8 = _mm256_fmadd_ps(a_x_f32x8, a_x_f32x8, norm_squared_a_f32x8);
1542
+ norm_squared_a_f32x8 = _mm256_fmadd_ps(a_y_f32x8, a_y_f32x8, norm_squared_a_f32x8);
1543
+ norm_squared_a_f32x8 = _mm256_fmadd_ps(a_z_f32x8, a_z_f32x8, norm_squared_a_f32x8);
1544
+ norm_squared_b_f32x8 = _mm256_fmadd_ps(b_x_f32x8, b_x_f32x8, norm_squared_b_f32x8);
1545
+ norm_squared_b_f32x8 = _mm256_fmadd_ps(b_y_f32x8, b_y_f32x8, norm_squared_b_f32x8);
1546
+ norm_squared_b_f32x8 = _mm256_fmadd_ps(b_z_f32x8, b_z_f32x8, norm_squared_b_f32x8);
1739
1547
  }
1740
1548
 
1741
1549
  // Reduce vector accumulators
@@ -1745,16 +1553,17 @@ NK_PUBLIC void nk_umeyama_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_s
1745
1553
  nk_f32_t sum_b_x = nk_reduce_add_f32x8_haswell_(sum_b_x_f32x8);
1746
1554
  nk_f32_t sum_b_y = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
1747
1555
  nk_f32_t sum_b_z = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
1748
- nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(cov_xx_f32x8);
1749
- nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(cov_xy_f32x8);
1750
- nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(cov_xz_f32x8);
1751
- nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(cov_yx_f32x8);
1752
- nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(cov_yy_f32x8);
1753
- nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(cov_yz_f32x8);
1754
- nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(cov_zx_f32x8);
1755
- nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(cov_zy_f32x8);
1756
- nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(cov_zz_f32x8);
1757
- nk_f32_t variance_a_sum = nk_reduce_add_f32x8_haswell_(variance_a_f32x8);
1556
+ nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(covariance_xx_f32x8);
1557
+ nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(covariance_xy_f32x8);
1558
+ nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(covariance_xz_f32x8);
1559
+ nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(covariance_yx_f32x8);
1560
+ nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(covariance_yy_f32x8);
1561
+ nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(covariance_yz_f32x8);
1562
+ nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(covariance_zx_f32x8);
1563
+ nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(covariance_zy_f32x8);
1564
+ nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(covariance_zz_f32x8);
1565
+ nk_f32_t norm_squared_a_sum = nk_reduce_add_f32x8_haswell_(norm_squared_a_f32x8);
1566
+ nk_f32_t norm_squared_b_sum = nk_reduce_add_f32x8_haswell_(norm_squared_b_f32x8);
1758
1567
 
1759
1568
  // Scalar tail
1760
1569
  for (; i < n; ++i) {
@@ -1770,7 +1579,8 @@ NK_PUBLIC void nk_umeyama_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_s
1770
1579
  covariance_x_x += ax * bx, covariance_x_y += ax * by, covariance_x_z += ax * bz;
1771
1580
  covariance_y_x += ay * bx, covariance_y_y += ay * by, covariance_y_z += ay * bz;
1772
1581
  covariance_z_x += az * bx, covariance_z_y += az * by, covariance_z_z += az * bz;
1773
- variance_a_sum += ax * ax + ay * ay + az * az;
1582
+ norm_squared_a_sum += ax * ax + ay * ay + az * az;
1583
+ norm_squared_b_sum += bx * bx + by * by + bz * bz;
1774
1584
  }
1775
1585
 
1776
1586
  // Compute centroids
@@ -1781,10 +1591,6 @@ NK_PUBLIC void nk_umeyama_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_s
1781
1591
  if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
1782
1592
  if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
1783
1593
 
1784
- // Compute centered covariance and variance
1785
- nk_f32_t variance_a = variance_a_sum * inv_n -
1786
- (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
1787
-
1788
1594
  // Apply centering correction to covariance matrix
1789
1595
  covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
1790
1596
  covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
@@ -1799,64 +1605,97 @@ NK_PUBLIC void nk_umeyama_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_s
1799
1605
  nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
1800
1606
  covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
1801
1607
 
1802
- // SVD
1803
- nk_f32_t svd_u[9], svd_s[9], svd_v[9];
1804
- nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
1805
-
1806
- // R = V * Uᵀ
1807
- nk_f32_t r[9];
1808
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
1809
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
1810
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
1811
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
1812
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
1813
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
1814
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
1815
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
1816
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1817
-
1818
- // Scale factor: c = trace(D × S) / (n × variance(a))
1819
- nk_f32_t det = nk_det3x3_f32_(r);
1820
- nk_f32_t d3 = det < 0 ? -1.0f : 1.0f;
1821
- nk_f32_t trace_ds = svd_s[0] + svd_s[4] + d3 * svd_s[8];
1822
- nk_f32_t c = trace_ds / ((nk_f32_t)n * variance_a);
1823
- if (scale) *scale = c;
1824
-
1825
- // Handle reflection
1826
- if (det < 0) {
1827
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
1828
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
1829
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
1830
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
1831
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
1832
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
1833
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
1834
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
1835
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
1836
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1608
+ // Centered ‖a-ā‖², ‖b-b̄‖² via the parallel-axis identity
1609
+ nk_f32_t centered_norm_squared_a = norm_squared_a_sum -
1610
+ (nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
1611
+ centroid_a_z * centroid_a_z);
1612
+ nk_f32_t centered_norm_squared_b = norm_squared_b_sum -
1613
+ (nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
1614
+ centroid_b_z * centroid_b_z);
1615
+ if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
1616
+ if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
1617
+
1618
+ // Identity-dominant short-circuit: if H is essentially diagonal with positive diagonals,
1619
+ // R = I and trace(DS) = trace(H).
1620
+ nk_f32_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
1621
+ cross_covariance[4] * cross_covariance[4] +
1622
+ cross_covariance[8] * cross_covariance[8];
1623
+ nk_f32_t covariance_offdiagonal_norm_squared =
1624
+ cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
1625
+ cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
1626
+ cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
1627
+ nk_f32_t optimal_rotation[9];
1628
+ nk_f32_t applied_scale;
1629
+ nk_f32_t trace_rotation_covariance;
1630
+ if (covariance_offdiagonal_norm_squared < 1e-12f * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0f &&
1631
+ cross_covariance[4] > 0.0f && cross_covariance[8] > 0.0f) {
1632
+ optimal_rotation[0] = 1.0f, optimal_rotation[1] = 0.0f, optimal_rotation[2] = 0.0f;
1633
+ optimal_rotation[3] = 0.0f, optimal_rotation[4] = 1.0f, optimal_rotation[5] = 0.0f;
1634
+ optimal_rotation[6] = 0.0f, optimal_rotation[7] = 0.0f, optimal_rotation[8] = 1.0f;
1635
+ trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
1636
+ applied_scale = centered_norm_squared_a > 0.0f ? trace_rotation_covariance / centered_norm_squared_a : 0.0f;
1637
+ }
1638
+ else {
1639
+ nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
1640
+ nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
1641
+ // R = V * Uᵀ
1642
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
1643
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
1644
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
1645
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
1646
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
1647
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
1648
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
1649
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
1650
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
1651
+
1652
+ nk_f32_t det = nk_det3x3_f32_(optimal_rotation);
1653
+ nk_f32_t sign_correction = det < 0 ? -1.0f : 1.0f;
1654
+ if (det < 0) {
1655
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
1656
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
1657
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
1658
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
1659
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
1660
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
1661
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
1662
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
1663
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
1664
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
1665
+ }
1666
+ nk_f32_t trace_ds = svd_diagonal[0] + svd_diagonal[4] + sign_correction * svd_diagonal[8];
1667
+ applied_scale = centered_norm_squared_a > 0.0f ? trace_ds / centered_norm_squared_a : 0.0f;
1668
+ trace_rotation_covariance =
1669
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
1670
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
1671
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
1672
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
1673
+ optimal_rotation[8] * cross_covariance[8];
1837
1674
  }
1838
1675
 
1839
- // Output rotation matrix
1676
+ // Output rotation matrix and scale
1840
1677
  if (rotation)
1841
- for (int j = 0; j < 9; ++j) rotation[j] = r[j];
1678
+ for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
1679
+ if (scale) *scale = applied_scale;
1842
1680
 
1843
- // Compute RMSD with scaling
1844
- nk_f32_t sum_squared = nk_transformed_ssd_f16_haswell_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
1845
- centroid_b_x, centroid_b_y, centroid_b_z);
1681
+ // Folded SSD with scale: c²·‖a-ā‖² + ‖b-b̄‖² − 2c·trace(R · H_centered).
1682
+ nk_f32_t sum_squared = applied_scale * applied_scale * centered_norm_squared_a + centered_norm_squared_b -
1683
+ 2.0f * applied_scale * trace_rotation_covariance;
1684
+ if (sum_squared < 0.0f) sum_squared = 0.0f;
1846
1685
  *result = nk_f32_sqrt_haswell(sum_squared * inv_n);
1847
1686
  }
1848
1687
 
1849
1688
  NK_PUBLIC void nk_umeyama_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
1850
1689
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
1851
- // Fused single-pass: load bf16, convert to f32, compute centroids, covariance, and variance
1690
+ // Fused single-pass: load bf16, convert to f32, compute centroids, covariance, and ‖a‖²/‖b‖²
1852
1691
  __m256 const zeros_f32x8 = _mm256_setzero_ps();
1853
1692
 
1854
1693
  __m256 sum_a_x_f32x8 = zeros_f32x8, sum_a_y_f32x8 = zeros_f32x8, sum_a_z_f32x8 = zeros_f32x8;
1855
1694
  __m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
1856
- __m256 cov_xx_f32x8 = zeros_f32x8, cov_xy_f32x8 = zeros_f32x8, cov_xz_f32x8 = zeros_f32x8;
1857
- __m256 cov_yx_f32x8 = zeros_f32x8, cov_yy_f32x8 = zeros_f32x8, cov_yz_f32x8 = zeros_f32x8;
1858
- __m256 cov_zx_f32x8 = zeros_f32x8, cov_zy_f32x8 = zeros_f32x8, cov_zz_f32x8 = zeros_f32x8;
1859
- __m256 variance_a_f32x8 = zeros_f32x8;
1695
+ __m256 covariance_xx_f32x8 = zeros_f32x8, covariance_xy_f32x8 = zeros_f32x8, covariance_xz_f32x8 = zeros_f32x8;
1696
+ __m256 covariance_yx_f32x8 = zeros_f32x8, covariance_yy_f32x8 = zeros_f32x8, covariance_yz_f32x8 = zeros_f32x8;
1697
+ __m256 covariance_zx_f32x8 = zeros_f32x8, covariance_zy_f32x8 = zeros_f32x8, covariance_zz_f32x8 = zeros_f32x8;
1698
+ __m256 norm_squared_a_f32x8 = zeros_f32x8, norm_squared_b_f32x8 = zeros_f32x8;
1860
1699
 
1861
1700
  nk_size_t i = 0;
1862
1701
  __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
@@ -1874,20 +1713,23 @@ NK_PUBLIC void nk_umeyama_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, n
1874
1713
  sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
1875
1714
 
1876
1715
  // Accumulate outer products
1877
- cov_xx_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_x_f32x8, cov_xx_f32x8);
1878
- cov_xy_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_y_f32x8, cov_xy_f32x8);
1879
- cov_xz_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_z_f32x8, cov_xz_f32x8);
1880
- cov_yx_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_x_f32x8, cov_yx_f32x8);
1881
- cov_yy_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_y_f32x8, cov_yy_f32x8);
1882
- cov_yz_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_z_f32x8, cov_yz_f32x8);
1883
- cov_zx_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_x_f32x8, cov_zx_f32x8);
1884
- cov_zy_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_y_f32x8, cov_zy_f32x8);
1885
- cov_zz_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_z_f32x8, cov_zz_f32x8);
1886
-
1887
- // Accumulate variance of A
1888
- variance_a_f32x8 = _mm256_fmadd_ps(a_x_f32x8, a_x_f32x8, variance_a_f32x8);
1889
- variance_a_f32x8 = _mm256_fmadd_ps(a_y_f32x8, a_y_f32x8, variance_a_f32x8);
1890
- variance_a_f32x8 = _mm256_fmadd_ps(a_z_f32x8, a_z_f32x8, variance_a_f32x8);
1716
+ covariance_xx_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_x_f32x8, covariance_xx_f32x8);
1717
+ covariance_xy_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_y_f32x8, covariance_xy_f32x8);
1718
+ covariance_xz_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_z_f32x8, covariance_xz_f32x8);
1719
+ covariance_yx_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_x_f32x8, covariance_yx_f32x8);
1720
+ covariance_yy_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_y_f32x8, covariance_yy_f32x8);
1721
+ covariance_yz_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_z_f32x8, covariance_yz_f32x8);
1722
+ covariance_zx_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_x_f32x8, covariance_zx_f32x8);
1723
+ covariance_zy_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_y_f32x8, covariance_zy_f32x8);
1724
+ covariance_zz_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_z_f32x8, covariance_zz_f32x8);
1725
+
1726
+ // Accumulate ‖a‖² and ‖b‖² for folded SSD
1727
+ norm_squared_a_f32x8 = _mm256_fmadd_ps(a_x_f32x8, a_x_f32x8, norm_squared_a_f32x8);
1728
+ norm_squared_a_f32x8 = _mm256_fmadd_ps(a_y_f32x8, a_y_f32x8, norm_squared_a_f32x8);
1729
+ norm_squared_a_f32x8 = _mm256_fmadd_ps(a_z_f32x8, a_z_f32x8, norm_squared_a_f32x8);
1730
+ norm_squared_b_f32x8 = _mm256_fmadd_ps(b_x_f32x8, b_x_f32x8, norm_squared_b_f32x8);
1731
+ norm_squared_b_f32x8 = _mm256_fmadd_ps(b_y_f32x8, b_y_f32x8, norm_squared_b_f32x8);
1732
+ norm_squared_b_f32x8 = _mm256_fmadd_ps(b_z_f32x8, b_z_f32x8, norm_squared_b_f32x8);
1891
1733
  }
1892
1734
 
1893
1735
  // Reduce vector accumulators
@@ -1897,16 +1739,17 @@ NK_PUBLIC void nk_umeyama_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, n
1897
1739
  nk_f32_t sum_b_x = nk_reduce_add_f32x8_haswell_(sum_b_x_f32x8);
1898
1740
  nk_f32_t sum_b_y = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
1899
1741
  nk_f32_t sum_b_z = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
1900
- nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(cov_xx_f32x8);
1901
- nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(cov_xy_f32x8);
1902
- nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(cov_xz_f32x8);
1903
- nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(cov_yx_f32x8);
1904
- nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(cov_yy_f32x8);
1905
- nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(cov_yz_f32x8);
1906
- nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(cov_zx_f32x8);
1907
- nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(cov_zy_f32x8);
1908
- nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(cov_zz_f32x8);
1909
- nk_f32_t variance_a_sum = nk_reduce_add_f32x8_haswell_(variance_a_f32x8);
1742
+ nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(covariance_xx_f32x8);
1743
+ nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(covariance_xy_f32x8);
1744
+ nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(covariance_xz_f32x8);
1745
+ nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(covariance_yx_f32x8);
1746
+ nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(covariance_yy_f32x8);
1747
+ nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(covariance_yz_f32x8);
1748
+ nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(covariance_zx_f32x8);
1749
+ nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(covariance_zy_f32x8);
1750
+ nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(covariance_zz_f32x8);
1751
+ nk_f32_t norm_squared_a_sum = nk_reduce_add_f32x8_haswell_(norm_squared_a_f32x8);
1752
+ nk_f32_t norm_squared_b_sum = nk_reduce_add_f32x8_haswell_(norm_squared_b_f32x8);
1910
1753
 
1911
1754
  // Scalar tail
1912
1755
  for (; i < n; ++i) {
@@ -1922,7 +1765,8 @@ NK_PUBLIC void nk_umeyama_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, n
1922
1765
  covariance_x_x += ax * bx, covariance_x_y += ax * by, covariance_x_z += ax * bz;
1923
1766
  covariance_y_x += ay * bx, covariance_y_y += ay * by, covariance_y_z += ay * bz;
1924
1767
  covariance_z_x += az * bx, covariance_z_y += az * by, covariance_z_z += az * bz;
1925
- variance_a_sum += ax * ax + ay * ay + az * az;
1768
+ norm_squared_a_sum += ax * ax + ay * ay + az * az;
1769
+ norm_squared_b_sum += bx * bx + by * by + bz * bz;
1926
1770
  }
1927
1771
 
1928
1772
  // Compute centroids
@@ -1933,10 +1777,6 @@ NK_PUBLIC void nk_umeyama_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, n
1933
1777
  if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
1934
1778
  if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
1935
1779
 
1936
- // Compute centered covariance and variance
1937
- nk_f32_t variance_a = variance_a_sum * inv_n -
1938
- (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
1939
-
1940
1780
  // Apply centering correction to covariance matrix
1941
1781
  covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
1942
1782
  covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
@@ -1951,50 +1791,83 @@ NK_PUBLIC void nk_umeyama_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, n
1951
1791
  nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
1952
1792
  covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
1953
1793
 
1954
- // SVD
1955
- nk_f32_t svd_u[9], svd_s[9], svd_v[9];
1956
- nk_svd3x3_f32_(cross_covariance, svd_u, svd_s, svd_v);
1957
-
1958
- // R = V * Uᵀ
1959
- nk_f32_t r[9];
1960
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
1961
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
1962
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
1963
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
1964
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
1965
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
1966
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
1967
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
1968
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1969
-
1970
- // Scale factor: c = trace(D × S) / (n × variance(a))
1971
- nk_f32_t det = nk_det3x3_f32_(r);
1972
- nk_f32_t d3 = det < 0 ? -1.0f : 1.0f;
1973
- nk_f32_t trace_ds = svd_s[0] + svd_s[4] + d3 * svd_s[8];
1974
- nk_f32_t c = trace_ds / ((nk_f32_t)n * variance_a);
1975
- if (scale) *scale = c;
1976
-
1977
- // Handle reflection
1978
- if (det < 0) {
1979
- svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
1980
- r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
1981
- r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
1982
- r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
1983
- r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
1984
- r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
1985
- r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
1986
- r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
1987
- r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
1988
- r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
1794
+ // Centered ‖a-ā‖², ‖b-b̄‖² via the parallel-axis identity
1795
+ nk_f32_t centered_norm_squared_a = norm_squared_a_sum -
1796
+ (nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
1797
+ centroid_a_z * centroid_a_z);
1798
+ nk_f32_t centered_norm_squared_b = norm_squared_b_sum -
1799
+ (nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
1800
+ centroid_b_z * centroid_b_z);
1801
+ if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
1802
+ if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
1803
+
1804
+ // Identity-dominant short-circuit: if H is essentially diagonal with positive diagonals,
1805
+ // R = I and trace(DS) = trace(H).
1806
+ nk_f32_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
1807
+ cross_covariance[4] * cross_covariance[4] +
1808
+ cross_covariance[8] * cross_covariance[8];
1809
+ nk_f32_t covariance_offdiagonal_norm_squared =
1810
+ cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
1811
+ cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
1812
+ cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
1813
+ nk_f32_t optimal_rotation[9];
1814
+ nk_f32_t applied_scale;
1815
+ nk_f32_t trace_rotation_covariance;
1816
+ if (covariance_offdiagonal_norm_squared < 1e-12f * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0f &&
1817
+ cross_covariance[4] > 0.0f && cross_covariance[8] > 0.0f) {
1818
+ optimal_rotation[0] = 1.0f, optimal_rotation[1] = 0.0f, optimal_rotation[2] = 0.0f;
1819
+ optimal_rotation[3] = 0.0f, optimal_rotation[4] = 1.0f, optimal_rotation[5] = 0.0f;
1820
+ optimal_rotation[6] = 0.0f, optimal_rotation[7] = 0.0f, optimal_rotation[8] = 1.0f;
1821
+ trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
1822
+ applied_scale = centered_norm_squared_a > 0.0f ? trace_rotation_covariance / centered_norm_squared_a : 0.0f;
1823
+ }
1824
+ else {
1825
+ nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
1826
+ nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
1827
+ // R = V * Uᵀ
1828
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
1829
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
1830
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
1831
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
1832
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
1833
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
1834
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
1835
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
1836
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
1837
+
1838
+ nk_f32_t det = nk_det3x3_f32_(optimal_rotation);
1839
+ nk_f32_t sign_correction = det < 0 ? -1.0f : 1.0f;
1840
+ if (det < 0) {
1841
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
1842
+ optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
1843
+ optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
1844
+ optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
1845
+ optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
1846
+ optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
1847
+ optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
1848
+ optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
1849
+ optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
1850
+ optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
1851
+ }
1852
+ nk_f32_t trace_ds = svd_diagonal[0] + svd_diagonal[4] + sign_correction * svd_diagonal[8];
1853
+ applied_scale = centered_norm_squared_a > 0.0f ? trace_ds / centered_norm_squared_a : 0.0f;
1854
+ trace_rotation_covariance =
1855
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
1856
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
1857
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
1858
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
1859
+ optimal_rotation[8] * cross_covariance[8];
1989
1860
  }
1990
1861
 
1991
- // Output rotation matrix
1862
+ // Output rotation matrix and scale
1992
1863
  if (rotation)
1993
- for (int j = 0; j < 9; ++j) rotation[j] = r[j];
1864
+ for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
1865
+ if (scale) *scale = applied_scale;
1994
1866
 
1995
- // Compute RMSD with scaling
1996
- nk_f32_t sum_squared = nk_transformed_ssd_bf16_haswell_(a, b, n, r, c, centroid_a_x, centroid_a_y, centroid_a_z,
1997
- centroid_b_x, centroid_b_y, centroid_b_z);
1867
+ // Folded SSD with scale: c²·‖a-ā‖² + ‖b-b̄‖² − 2c·trace(R · H_centered).
1868
+ nk_f32_t sum_squared = applied_scale * applied_scale * centered_norm_squared_a + centered_norm_squared_b -
1869
+ 2.0f * applied_scale * trace_rotation_covariance;
1870
+ if (sum_squared < 0.0f) sum_squared = 0.0f;
1998
1871
  *result = nk_f32_sqrt_haswell(sum_squared * inv_n);
1999
1872
  }
2000
1873