numkong 7.4.4 → 7.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. package/README.md +1 -0
  2. package/binding.gyp +81 -5
  3. package/c/dispatch_f16.c +23 -0
  4. package/c/numkong.c +0 -13
  5. package/include/numkong/attention/sme.h +34 -31
  6. package/include/numkong/capabilities.h +2 -15
  7. package/include/numkong/cast/neon.h +15 -0
  8. package/include/numkong/curved/smef64.h +82 -62
  9. package/include/numkong/dot/rvvbf16.h +1 -1
  10. package/include/numkong/dot/rvvhalf.h +1 -1
  11. package/include/numkong/dot/sve.h +6 -5
  12. package/include/numkong/dot/svebfdot.h +2 -1
  13. package/include/numkong/dot/svehalf.h +6 -5
  14. package/include/numkong/dot/svesdot.h +3 -2
  15. package/include/numkong/dots/graniteamx.h +733 -0
  16. package/include/numkong/dots/serial.h +11 -4
  17. package/include/numkong/dots/sme.h +172 -140
  18. package/include/numkong/dots/smebi32.h +14 -11
  19. package/include/numkong/dots/smef64.h +31 -26
  20. package/include/numkong/dots.h +29 -3
  21. package/include/numkong/each/serial.h +22 -0
  22. package/include/numkong/geospatial/haswell.h +1 -1
  23. package/include/numkong/geospatial/neon.h +1 -1
  24. package/include/numkong/geospatial/serial.h +1 -1
  25. package/include/numkong/geospatial/skylake.h +1 -1
  26. package/include/numkong/maxsim/sme.h +94 -55
  27. package/include/numkong/mesh/README.md +13 -27
  28. package/include/numkong/mesh/haswell.h +25 -122
  29. package/include/numkong/mesh/neon.h +21 -110
  30. package/include/numkong/mesh/neonbfdot.h +4 -43
  31. package/include/numkong/mesh/rvv.h +7 -82
  32. package/include/numkong/mesh/serial.h +48 -53
  33. package/include/numkong/mesh/skylake.h +7 -123
  34. package/include/numkong/mesh/v128relaxed.h +9 -93
  35. package/include/numkong/mesh.h +2 -2
  36. package/include/numkong/mesh.hpp +35 -96
  37. package/include/numkong/reduce/neon.h +29 -0
  38. package/include/numkong/reduce/neonbfdot.h +2 -2
  39. package/include/numkong/reduce/neonfhm.h +4 -4
  40. package/include/numkong/reduce/sve.h +52 -0
  41. package/include/numkong/reduce.h +4 -0
  42. package/include/numkong/set/sve.h +6 -5
  43. package/include/numkong/sets/smebi32.h +35 -30
  44. package/include/numkong/sparse/sve2.h +3 -2
  45. package/include/numkong/spatial/sve.h +7 -6
  46. package/include/numkong/spatial/svebfdot.h +7 -4
  47. package/include/numkong/spatial/svehalf.h +5 -4
  48. package/include/numkong/spatial/svesdot.h +9 -8
  49. package/include/numkong/spatials/graniteamx.h +173 -0
  50. package/include/numkong/spatials/serial.h +22 -0
  51. package/include/numkong/spatials/sme.h +391 -350
  52. package/include/numkong/spatials/smef64.h +79 -70
  53. package/include/numkong/spatials.h +37 -4
  54. package/include/numkong/types.h +59 -0
  55. package/javascript/dist/cjs/numkong.js +13 -0
  56. package/javascript/dist/esm/numkong.js +13 -0
  57. package/javascript/numkong.c +56 -12
  58. package/javascript/numkong.ts +13 -0
  59. package/package.json +7 -7
  60. package/probes/probe.js +2 -2
  61. package/wasm/numkong.wasm +0 -0
@@ -1,37 +1,23 @@
1
1
  # Point Cloud Alignment in NumKong
2
2
 
3
- NumKong implements RMSD, Kabsch, and Umeyama algorithms for rigid-body superposition of 3D point clouds.
4
- RMSD measures alignment quality, Kabsch finds the optimal rotation minimizing RMSD, and Umeyama extends Kabsch with uniform scaling.
5
- Used in structural biology (protein alignment), robotics (point cloud registration), and computer graphics (mesh registration).
3
+ NumKong implements three algorithms for 3D point cloud comparison and alignment, used in structural biology (protein alignment), robotics (point cloud registration), and computer graphics (mesh registration).
6
4
 
7
- Centroid:
5
+ RMSD measures raw point-pair deviation without centering or alignment:
8
6
 
9
7
  $$
10
- \bar{a} = \frac{1}{n}\sum a_i
8
+ \text{RMSD} = \sqrt{\frac{1}{n}\sum \|a_i - b_i\|^2}
11
9
  $$
12
10
 
13
- Cross-covariance matrix:
11
+ Kabsch finds the optimal rotation $R$ that minimizes RMSD after centering both clouds at their centroids $\bar{a}$, $\bar{b}$, recovering $R$ from the SVD of the cross-covariance matrix $H$:
14
12
 
15
13
  $$
16
- H = \sum (a_i - \bar{a})(b_i - \bar{b})^T
14
+ H = \sum (a_i - \bar{a})(b_i - \bar{b})^T = U \Sigma V^T, \quad R = V U^T
17
15
  $$
18
16
 
19
- SVD-based rotation:
17
+ Umeyama extends Kabsch with a uniform scale factor $s$ derived from the singular values and source variance $\sigma_a^2$:
20
18
 
21
19
  $$
22
- H = U \Sigma V^T, \quad R = V U^T
23
- $$
24
-
25
- Umeyama scale factor:
26
-
27
- $$
28
- s = \frac{\text{tr}(\Sigma)}{n \cdot \sigma_a^2}
29
- $$
30
-
31
- RMSD after alignment:
32
-
33
- $$
34
- \text{RMSD} = \sqrt{\frac{1}{n}\sum \|s \cdot R(a_i - \bar{a}) - (b_i - \bar{b})\|^2}
20
+ s = \frac{\text{tr}(\Sigma)}{n \cdot \sigma_a^2}, \quad \text{RMSD} = \sqrt{\frac{1}{n}\sum \|s \cdot R(a_i - \bar{a}) - (b_i - \bar{b})\|^2}
35
21
  $$
36
22
 
37
23
  Reformulating as Python pseudocode:
@@ -189,25 +175,25 @@ Measured with Wasmtime v42 (Cranelift backend).
189
175
  | Kernel | 256 | 1024 | 4096 |
190
176
  | :-------------------------- | -----------------------: | -----------------------: | -----------------------: |
191
177
  | __f64__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
192
- | `nk_rmsd_f64_serial` | 120 mp/s, 1.4 ulp | 118 mp/s, 2.6 ulp | 121 mp/s, 5.3 ulp |
178
+ | `nk_rmsd_f64_serial` | 279 mp/s, 0.5 ulp | 267 mp/s, 0.5 ulp | 279 mp/s, 0.5 ulp |
193
179
  | `nk_kabsch_f64_serial` | 40.4 mp/s, 1.4 ulp | 47.3 mp/s, 2.6 ulp | 50.2 mp/s, 5.4 ulp |
194
180
  | `nk_umeyama_f64_serial` | 34.5 mp/s, 1.0 ulp | 39.2 mp/s, 1.9 ulp | 41.6 mp/s, 3.7 ulp |
195
- | `nk_rmsd_f64_neon` | 1,418 mp/s, 0.4 ulp | 1,338 mp/s, 0.7 ulp | 1,419 mp/s, 1.3 ulp |
181
+ | `nk_rmsd_f64_neon` | 1,776 mp/s, 0.4 ulp | 1,536 mp/s, 0.7 ulp | 2,037 mp/s, 1.3 ulp |
196
182
  | `nk_kabsch_f64_neon` | 119 mp/s, 0.8 ulp | 222 mp/s, 1.3 ulp | 304 mp/s, 2.2 ulp |
197
183
  | `nk_umeyama_f64_neon` | 115 mp/s, 0.4 ulp | 220 mp/s, 0.8 ulp | 296 mp/s, 1.6 ulp |
198
184
  | __f32__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
199
- | `nk_rmsd_f32_serial` | 122 mp/s, 1.4 ulp | 123 mp/s, 2.6 ulp | 125 mp/s, 5.2 ulp |
185
+ | `nk_rmsd_f32_serial` | 264 mp/s, 0.5 ulp | 264 mp/s, 0.5 ulp | 261 mp/s, 0.5 ulp |
200
186
  | `nk_kabsch_f32_serial` | 39.4 mp/s, 1.4 ulp | 46.0 mp/s, 2.7 ulp | 49.9 mp/s, 5.0 ulp |
201
187
  | `nk_umeyama_f32_serial` | 33.6 mp/s, 0.9 ulp | 38.8 mp/s, 1.8 ulp | 41.4 mp/s, 3.5 ulp |
202
- | `nk_rmsd_f32_neon` | 1,337 mp/s, 0.3 ulp | 1,377 mp/s, 0.4 ulp | 1,261 mp/s, 0.8 ulp |
188
+ | `nk_rmsd_f32_neon` | 1,912 mp/s, 1.5 ulp | 2,239 mp/s, 1.3 ulp | 1,966 mp/s, 4.8 ulp |
203
189
  | `nk_kabsch_f32_neon` | 135 mp/s, 0.7 ulp | 288 mp/s, 0.9 ulp | 385 mp/s, 1.4 ulp |
204
190
  | `nk_umeyama_f32_neon` | 130 mp/s, 0.3 ulp | 272 mp/s, 0.4 ulp | 367 mp/s, 0.8 ulp |
205
191
  | __bf16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
206
- | `nk_rmsd_bf16_neonbfdot` | 2,342 mp/s, 0.5 ulp | 2,378 mp/s, 6.0 ulp | 2,416 mp/s, 10.0 ulp |
192
+ | `nk_rmsd_bf16_neonbfdot` | 3,728 mp/s, 0.4 ulp | 3,756 mp/s, 6.0 ulp | 3,769 mp/s, 10.0 ulp |
207
193
  | `nk_kabsch_bf16_neonbfdot` | 180 mp/s, 0.7 ulp | 448 mp/s, 0.9 ulp | 726 mp/s, 1.3 ulp |
208
194
  | `nk_umeyama_bf16_neonbfdot` | 176 mp/s, 0.2 ulp | 433 mp/s, 0.4 ulp | 705 mp/s, 0.8 ulp |
209
195
  | __f16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
210
- | `nk_rmsd_f16_neonhalf` | 2,315 mp/s, 0.4 ulp | 2,372 mp/s, 1.7 ulp | 2,423 mp/s, 4.6 ulp |
196
+ | `nk_rmsd_f16_neonhalf` | 2,998 mp/s, 0.4 ulp | 3,215 mp/s, 1.7 ulp | 3,216 mp/s, 4.6 ulp |
211
197
  | `nk_kabsch_f16_neonhalf` | 178 mp/s, 0.9 ulp | 443 mp/s, 1.3 ulp | 711 mp/s, 2.4 ulp |
212
198
  | `nk_umeyama_f16_neonhalf` | 175 mp/s, 0.4 ulp | 408 mp/s, 0.8 ulp | 620 mp/s, 1.5 ulp |
213
199
 
@@ -309,14 +309,13 @@ NK_INTERNAL nk_f64_t nk_transformed_ssd_f64_haswell_(nk_f64_t const *a, nk_f64_t
309
309
 
310
310
  NK_PUBLIC void nk_rmsd_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
311
311
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
312
+ if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
313
+ if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
312
314
  if (rotation)
313
315
  rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
314
316
  rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
315
317
  if (scale) *scale = 1.0f;
316
318
 
317
- __m256d sum_a_x_f64x4 = _mm256_setzero_pd(), sum_a_y_f64x4 = _mm256_setzero_pd();
318
- __m256d sum_a_z_f64x4 = _mm256_setzero_pd(), sum_b_x_f64x4 = _mm256_setzero_pd();
319
- __m256d sum_b_y_f64x4 = _mm256_setzero_pd(), sum_b_z_f64x4 = _mm256_setzero_pd();
320
319
  __m256d sum_squared_f64x4 = _mm256_setzero_pd();
321
320
  __m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
322
321
  nk_size_t index = 0;
@@ -325,32 +324,19 @@ NK_PUBLIC void nk_rmsd_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size
325
324
  nk_deinterleave_f32x8_haswell_(a + index * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8),
326
325
  nk_deinterleave_f32x8_haswell_(b + index * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
327
326
 
328
- __m256d a_x_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
329
- __m256d a_x_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
330
- __m256d a_y_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
331
- __m256d a_y_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
332
- __m256d a_z_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
333
- __m256d a_z_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
334
- __m256d b_x_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
335
- __m256d b_x_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
336
- __m256d b_y_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
337
- __m256d b_y_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
338
- __m256d b_z_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
339
- __m256d b_z_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
340
-
341
- sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, _mm256_add_pd(a_x_low_f64x4, a_x_high_f64x4));
342
- sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, _mm256_add_pd(a_y_low_f64x4, a_y_high_f64x4));
343
- sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, _mm256_add_pd(a_z_low_f64x4, a_z_high_f64x4));
344
- sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, _mm256_add_pd(b_x_low_f64x4, b_x_high_f64x4));
345
- sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, _mm256_add_pd(b_y_low_f64x4, b_y_high_f64x4));
346
- sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, _mm256_add_pd(b_z_low_f64x4, b_z_high_f64x4));
327
+ __m256d delta_x_low_f64x4 = _mm256_sub_pd(_mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8)),
328
+ _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8)));
329
+ __m256d delta_x_high_f64x4 = _mm256_sub_pd(_mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1)),
330
+ _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1)));
331
+ __m256d delta_y_low_f64x4 = _mm256_sub_pd(_mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8)),
332
+ _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8)));
333
+ __m256d delta_y_high_f64x4 = _mm256_sub_pd(_mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1)),
334
+ _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1)));
335
+ __m256d delta_z_low_f64x4 = _mm256_sub_pd(_mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8)),
336
+ _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8)));
337
+ __m256d delta_z_high_f64x4 = _mm256_sub_pd(_mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1)),
338
+ _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1)));
347
339
 
348
- __m256d delta_x_low_f64x4 = _mm256_sub_pd(a_x_low_f64x4, b_x_low_f64x4);
349
- __m256d delta_x_high_f64x4 = _mm256_sub_pd(a_x_high_f64x4, b_x_high_f64x4);
350
- __m256d delta_y_low_f64x4 = _mm256_sub_pd(a_y_low_f64x4, b_y_low_f64x4);
351
- __m256d delta_y_high_f64x4 = _mm256_sub_pd(a_y_high_f64x4, b_y_high_f64x4);
352
- __m256d delta_z_low_f64x4 = _mm256_sub_pd(a_z_low_f64x4, b_z_low_f64x4);
353
- __m256d delta_z_high_f64x4 = _mm256_sub_pd(a_z_high_f64x4, b_z_high_f64x4);
354
340
  __m256d batch_sum_squared_f64x4 = _mm256_add_pd(_mm256_mul_pd(delta_x_low_f64x4, delta_x_low_f64x4),
355
341
  _mm256_mul_pd(delta_x_high_f64x4, delta_x_high_f64x4));
356
342
  batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_low_f64x4, delta_y_low_f64x4, batch_sum_squared_f64x4);
@@ -360,70 +346,38 @@ NK_PUBLIC void nk_rmsd_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size
360
346
  sum_squared_f64x4 = _mm256_add_pd(sum_squared_f64x4, batch_sum_squared_f64x4);
361
347
  }
362
348
 
363
- nk_f64_t total_a_x = nk_reduce_add_f64x4_haswell_(sum_a_x_f64x4);
364
- nk_f64_t total_a_y = nk_reduce_add_f64x4_haswell_(sum_a_y_f64x4);
365
- nk_f64_t total_a_z = nk_reduce_add_f64x4_haswell_(sum_a_z_f64x4);
366
- nk_f64_t total_b_x = nk_reduce_add_f64x4_haswell_(sum_b_x_f64x4);
367
- nk_f64_t total_b_y = nk_reduce_add_f64x4_haswell_(sum_b_y_f64x4);
368
- nk_f64_t total_b_z = nk_reduce_add_f64x4_haswell_(sum_b_z_f64x4);
369
349
  nk_f64_t sum_squared = nk_reduce_add_f64x4_haswell_(sum_squared_f64x4);
370
350
 
371
351
  for (; index < n; ++index) {
372
- nk_f64_t a_x = a[index * 3 + 0], a_y = a[index * 3 + 1], a_z = a[index * 3 + 2];
373
- nk_f64_t b_x = b[index * 3 + 0], b_y = b[index * 3 + 1], b_z = b[index * 3 + 2];
374
- total_a_x += a_x, total_a_y += a_y, total_a_z += a_z;
375
- total_b_x += b_x, total_b_y += b_y, total_b_z += b_z;
376
- nk_f64_t delta_x = a_x - b_x, delta_y = a_y - b_y, delta_z = a_z - b_z;
352
+ nk_f64_t delta_x = (nk_f64_t)a[index * 3 + 0] - (nk_f64_t)b[index * 3 + 0];
353
+ nk_f64_t delta_y = (nk_f64_t)a[index * 3 + 1] - (nk_f64_t)b[index * 3 + 1];
354
+ nk_f64_t delta_z = (nk_f64_t)a[index * 3 + 2] - (nk_f64_t)b[index * 3 + 2];
377
355
  sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
378
356
  }
379
357
 
380
- nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
381
- nk_f64_t centroid_a_x = total_a_x * inv_n, centroid_a_y = total_a_y * inv_n, centroid_a_z = total_a_z * inv_n;
382
- nk_f64_t centroid_b_x = total_b_x * inv_n, centroid_b_y = total_b_y * inv_n, centroid_b_z = total_b_z * inv_n;
383
- if (a_centroid)
384
- a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
385
- a_centroid[2] = (nk_f32_t)centroid_a_z;
386
- if (b_centroid)
387
- b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
388
- b_centroid[2] = (nk_f32_t)centroid_b_z;
389
-
390
- nk_f64_t mean_delta_x = centroid_a_x - centroid_b_x, mean_delta_y = centroid_a_y - centroid_b_y,
391
- mean_delta_z = centroid_a_z - centroid_b_z;
392
- nk_f64_t mean_delta_squared = mean_delta_x * mean_delta_x + mean_delta_y * mean_delta_y +
393
- mean_delta_z * mean_delta_z;
394
- *result = nk_f64_sqrt_haswell(sum_squared * inv_n - mean_delta_squared);
358
+ *result = nk_f64_sqrt_haswell(sum_squared / (nk_f64_t)n);
395
359
  }
396
360
 
397
361
  NK_PUBLIC void nk_rmsd_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
398
362
  nk_f64_t *b_centroid, nk_f64_t *rotation, nk_f64_t *scale, nk_f64_t *result) {
399
- // RMSD uses identity rotation and scale=1.0
363
+ if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
364
+ if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
400
365
  if (rotation)
401
366
  rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
402
367
  rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
403
368
  if (scale) *scale = 1.0;
404
369
  __m256d const zeros_f64x4 = _mm256_setzero_pd();
405
370
 
406
- // Accumulators for centroids and squared differences
407
- __m256d sum_a_x_f64x4 = zeros_f64x4, sum_a_y_f64x4 = zeros_f64x4, sum_a_z_f64x4 = zeros_f64x4;
408
- __m256d sum_b_x_f64x4 = zeros_f64x4, sum_b_y_f64x4 = zeros_f64x4, sum_b_z_f64x4 = zeros_f64x4;
409
371
  __m256d sum_squared_x_f64x4 = zeros_f64x4, sum_squared_y_f64x4 = zeros_f64x4, sum_squared_z_f64x4 = zeros_f64x4;
410
372
 
411
373
  __m256d a_x_f64x4, a_y_f64x4, a_z_f64x4, b_x_f64x4, b_y_f64x4, b_z_f64x4;
412
374
  nk_size_t i = 0;
413
375
 
414
- // Main loop with 2x unrolling
376
+ // Main loop with unrolling
415
377
  for (; i + 8 <= n; i += 8) {
416
- // Iteration 0
417
378
  nk_deinterleave_f64x4_haswell_(a + i * 3, &a_x_f64x4, &a_y_f64x4, &a_z_f64x4);
418
379
  nk_deinterleave_f64x4_haswell_(b + i * 3, &b_x_f64x4, &b_y_f64x4, &b_z_f64x4);
419
380
 
420
- sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, a_x_f64x4);
421
- sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, a_y_f64x4);
422
- sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, a_z_f64x4);
423
- sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, b_x_f64x4);
424
- sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, b_y_f64x4);
425
- sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, b_z_f64x4);
426
-
427
381
  __m256d delta_x_f64x4 = _mm256_sub_pd(a_x_f64x4, b_x_f64x4);
428
382
  __m256d delta_y_f64x4 = _mm256_sub_pd(a_y_f64x4, b_y_f64x4);
429
383
  __m256d delta_z_f64x4 = _mm256_sub_pd(a_z_f64x4, b_z_f64x4);
@@ -432,18 +386,10 @@ NK_PUBLIC void nk_rmsd_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size
432
386
  sum_squared_y_f64x4 = _mm256_fmadd_pd(delta_y_f64x4, delta_y_f64x4, sum_squared_y_f64x4);
433
387
  sum_squared_z_f64x4 = _mm256_fmadd_pd(delta_z_f64x4, delta_z_f64x4, sum_squared_z_f64x4);
434
388
 
435
- // Iteration 1
436
389
  __m256d a_x1_f64x4, a_y1_f64x4, a_z1_f64x4, b_x1_f64x4, b_y1_f64x4, b_z1_f64x4;
437
390
  nk_deinterleave_f64x4_haswell_(a + (i + 4) * 3, &a_x1_f64x4, &a_y1_f64x4, &a_z1_f64x4);
438
391
  nk_deinterleave_f64x4_haswell_(b + (i + 4) * 3, &b_x1_f64x4, &b_y1_f64x4, &b_z1_f64x4);
439
392
 
440
- sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, a_x1_f64x4);
441
- sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, a_y1_f64x4);
442
- sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, a_z1_f64x4);
443
- sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, b_x1_f64x4);
444
- sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, b_y1_f64x4);
445
- sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, b_z1_f64x4);
446
-
447
393
  __m256d delta_x1_f64x4 = _mm256_sub_pd(a_x1_f64x4, b_x1_f64x4);
448
394
  __m256d delta_y1_f64x4 = _mm256_sub_pd(a_y1_f64x4, b_y1_f64x4);
449
395
  __m256d delta_z1_f64x4 = _mm256_sub_pd(a_z1_f64x4, b_z1_f64x4);
@@ -453,18 +399,10 @@ NK_PUBLIC void nk_rmsd_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size
453
399
  sum_squared_z_f64x4 = _mm256_fmadd_pd(delta_z1_f64x4, delta_z1_f64x4, sum_squared_z_f64x4);
454
400
  }
455
401
 
456
- // Handle 4-point remainder
457
402
  for (; i + 4 <= n; i += 4) {
458
403
  nk_deinterleave_f64x4_haswell_(a + i * 3, &a_x_f64x4, &a_y_f64x4, &a_z_f64x4);
459
404
  nk_deinterleave_f64x4_haswell_(b + i * 3, &b_x_f64x4, &b_y_f64x4, &b_z_f64x4);
460
405
 
461
- sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, a_x_f64x4);
462
- sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, a_y_f64x4);
463
- sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, a_z_f64x4);
464
- sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, b_x_f64x4);
465
- sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, b_y_f64x4);
466
- sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, b_z_f64x4);
467
-
468
406
  __m256d delta_x_f64x4 = _mm256_sub_pd(a_x_f64x4, b_x_f64x4);
469
407
  __m256d delta_y_f64x4 = _mm256_sub_pd(a_y_f64x4, b_y_f64x4);
470
408
  __m256d delta_z_f64x4 = _mm256_sub_pd(a_z_f64x4, b_z_f64x4);
@@ -474,57 +412,22 @@ NK_PUBLIC void nk_rmsd_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size
474
412
  sum_squared_z_f64x4 = _mm256_fmadd_pd(delta_z_f64x4, delta_z_f64x4, sum_squared_z_f64x4);
475
413
  }
476
414
 
477
- // Reduce vectors to scalars
478
- nk_f64_t total_ax = nk_reduce_stable_f64x4_haswell_(sum_a_x_f64x4), total_ax_compensation = 0.0;
479
- nk_f64_t total_ay = nk_reduce_stable_f64x4_haswell_(sum_a_y_f64x4), total_ay_compensation = 0.0;
480
- nk_f64_t total_az = nk_reduce_stable_f64x4_haswell_(sum_a_z_f64x4), total_az_compensation = 0.0;
481
- nk_f64_t total_bx = nk_reduce_stable_f64x4_haswell_(sum_b_x_f64x4), total_bx_compensation = 0.0;
482
- nk_f64_t total_by = nk_reduce_stable_f64x4_haswell_(sum_b_y_f64x4), total_by_compensation = 0.0;
483
- nk_f64_t total_bz = nk_reduce_stable_f64x4_haswell_(sum_b_z_f64x4), total_bz_compensation = 0.0;
484
415
  nk_f64_t total_sq_x = nk_reduce_stable_f64x4_haswell_(sum_squared_x_f64x4), total_sq_x_compensation = 0.0;
485
416
  nk_f64_t total_sq_y = nk_reduce_stable_f64x4_haswell_(sum_squared_y_f64x4), total_sq_y_compensation = 0.0;
486
417
  nk_f64_t total_sq_z = nk_reduce_stable_f64x4_haswell_(sum_squared_z_f64x4), total_sq_z_compensation = 0.0;
487
418
 
488
- // Scalar tail
489
419
  for (; i < n; ++i) {
490
- nk_f64_t ax = a[i * 3 + 0], ay = a[i * 3 + 1], az = a[i * 3 + 2];
491
- nk_f64_t bx = b[i * 3 + 0], by = b[i * 3 + 1], bz = b[i * 3 + 2];
492
- nk_accumulate_sum_f64_(&total_ax, &total_ax_compensation, ax);
493
- nk_accumulate_sum_f64_(&total_ay, &total_ay_compensation, ay);
494
- nk_accumulate_sum_f64_(&total_az, &total_az_compensation, az);
495
- nk_accumulate_sum_f64_(&total_bx, &total_bx_compensation, bx);
496
- nk_accumulate_sum_f64_(&total_by, &total_by_compensation, by);
497
- nk_accumulate_sum_f64_(&total_bz, &total_bz_compensation, bz);
498
- nk_f64_t delta_x = ax - bx, delta_y = ay - by, delta_z = az - bz;
420
+ nk_f64_t delta_x = a[i * 3 + 0] - b[i * 3 + 0];
421
+ nk_f64_t delta_y = a[i * 3 + 1] - b[i * 3 + 1];
422
+ nk_f64_t delta_z = a[i * 3 + 2] - b[i * 3 + 2];
499
423
  nk_accumulate_square_f64_(&total_sq_x, &total_sq_x_compensation, delta_x);
500
424
  nk_accumulate_square_f64_(&total_sq_y, &total_sq_y_compensation, delta_y);
501
425
  nk_accumulate_square_f64_(&total_sq_z, &total_sq_z_compensation, delta_z);
502
426
  }
503
427
 
504
- total_ax += total_ax_compensation, total_ay += total_ay_compensation, total_az += total_az_compensation;
505
- total_bx += total_bx_compensation, total_by += total_by_compensation, total_bz += total_bz_compensation;
506
428
  total_sq_x += total_sq_x_compensation, total_sq_y += total_sq_y_compensation, total_sq_z += total_sq_z_compensation;
507
429
 
508
- // Compute centroids
509
- nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
510
- nk_f64_t centroid_a_x = total_ax * inv_n;
511
- nk_f64_t centroid_a_y = total_ay * inv_n;
512
- nk_f64_t centroid_a_z = total_az * inv_n;
513
- nk_f64_t centroid_b_x = total_bx * inv_n;
514
- nk_f64_t centroid_b_y = total_by * inv_n;
515
- nk_f64_t centroid_b_z = total_bz * inv_n;
516
-
517
- if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
518
- if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
519
-
520
- // Compute RMSD
521
- nk_f64_t mean_diff_x = centroid_a_x - centroid_b_x;
522
- nk_f64_t mean_diff_y = centroid_a_y - centroid_b_y;
523
- nk_f64_t mean_diff_z = centroid_a_z - centroid_b_z;
524
- nk_f64_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
525
- nk_f64_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
526
-
527
- *result = nk_f64_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
430
+ *result = nk_f64_sqrt_haswell((total_sq_x + total_sq_y + total_sq_z) / (nk_f64_t)n);
528
431
  }
529
432
 
530
433
  NK_PUBLIC void nk_kabsch_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
@@ -353,18 +353,14 @@ NK_INTERNAL nk_f64_t nk_transformed_ssd_f64_neon_(nk_f64_t const *a, nk_f64_t co
353
353
 
354
354
  NK_PUBLIC void nk_rmsd_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
355
355
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
356
+ if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
357
+ if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
356
358
  if (rotation)
357
359
  rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
358
360
  rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
359
361
  if (scale) *scale = 1.0f;
360
362
 
361
363
  float64x2_t zero_f64x2 = vdupq_n_f64(0.0);
362
- float64x2_t sum_a_x_low_f64x2 = zero_f64x2, sum_a_x_high_f64x2 = zero_f64x2;
363
- float64x2_t sum_a_y_low_f64x2 = zero_f64x2, sum_a_y_high_f64x2 = zero_f64x2;
364
- float64x2_t sum_a_z_low_f64x2 = zero_f64x2, sum_a_z_high_f64x2 = zero_f64x2;
365
- float64x2_t sum_b_x_low_f64x2 = zero_f64x2, sum_b_x_high_f64x2 = zero_f64x2;
366
- float64x2_t sum_b_y_low_f64x2 = zero_f64x2, sum_b_y_high_f64x2 = zero_f64x2;
367
- float64x2_t sum_b_z_low_f64x2 = zero_f64x2, sum_b_z_high_f64x2 = zero_f64x2;
368
364
  float64x2_t sum_squared_x_low_f64x2 = zero_f64x2, sum_squared_x_high_f64x2 = zero_f64x2;
369
365
  float64x2_t sum_squared_y_low_f64x2 = zero_f64x2, sum_squared_y_high_f64x2 = zero_f64x2;
370
366
  float64x2_t sum_squared_z_low_f64x2 = zero_f64x2, sum_squared_z_high_f64x2 = zero_f64x2;
@@ -375,38 +371,15 @@ NK_PUBLIC void nk_rmsd_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t
375
371
  nk_deinterleave_f32x4_neon_(a + index * 3, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4),
376
372
  nk_deinterleave_f32x4_neon_(b + index * 3, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
377
373
 
378
- float64x2_t a_x_low_f64x2 = vcvt_f64_f32(vget_low_f32(a_x_f32x4));
379
- float64x2_t a_x_high_f64x2 = vcvt_high_f64_f32(a_x_f32x4);
380
- float64x2_t a_y_low_f64x2 = vcvt_f64_f32(vget_low_f32(a_y_f32x4));
381
- float64x2_t a_y_high_f64x2 = vcvt_high_f64_f32(a_y_f32x4);
382
- float64x2_t a_z_low_f64x2 = vcvt_f64_f32(vget_low_f32(a_z_f32x4));
383
- float64x2_t a_z_high_f64x2 = vcvt_high_f64_f32(a_z_f32x4);
384
- float64x2_t b_x_low_f64x2 = vcvt_f64_f32(vget_low_f32(b_x_f32x4));
385
- float64x2_t b_x_high_f64x2 = vcvt_high_f64_f32(b_x_f32x4);
386
- float64x2_t b_y_low_f64x2 = vcvt_f64_f32(vget_low_f32(b_y_f32x4));
387
- float64x2_t b_y_high_f64x2 = vcvt_high_f64_f32(b_y_f32x4);
388
- float64x2_t b_z_low_f64x2 = vcvt_f64_f32(vget_low_f32(b_z_f32x4));
389
- float64x2_t b_z_high_f64x2 = vcvt_high_f64_f32(b_z_f32x4);
390
-
391
- sum_a_x_low_f64x2 = vaddq_f64(sum_a_x_low_f64x2, a_x_low_f64x2),
392
- sum_a_x_high_f64x2 = vaddq_f64(sum_a_x_high_f64x2, a_x_high_f64x2);
393
- sum_a_y_low_f64x2 = vaddq_f64(sum_a_y_low_f64x2, a_y_low_f64x2),
394
- sum_a_y_high_f64x2 = vaddq_f64(sum_a_y_high_f64x2, a_y_high_f64x2);
395
- sum_a_z_low_f64x2 = vaddq_f64(sum_a_z_low_f64x2, a_z_low_f64x2),
396
- sum_a_z_high_f64x2 = vaddq_f64(sum_a_z_high_f64x2, a_z_high_f64x2);
397
- sum_b_x_low_f64x2 = vaddq_f64(sum_b_x_low_f64x2, b_x_low_f64x2),
398
- sum_b_x_high_f64x2 = vaddq_f64(sum_b_x_high_f64x2, b_x_high_f64x2);
399
- sum_b_y_low_f64x2 = vaddq_f64(sum_b_y_low_f64x2, b_y_low_f64x2),
400
- sum_b_y_high_f64x2 = vaddq_f64(sum_b_y_high_f64x2, b_y_high_f64x2);
401
- sum_b_z_low_f64x2 = vaddq_f64(sum_b_z_low_f64x2, b_z_low_f64x2),
402
- sum_b_z_high_f64x2 = vaddq_f64(sum_b_z_high_f64x2, b_z_high_f64x2);
403
-
404
- float64x2_t delta_x_low_f64x2 = vsubq_f64(a_x_low_f64x2, b_x_low_f64x2);
405
- float64x2_t delta_x_high_f64x2 = vsubq_f64(a_x_high_f64x2, b_x_high_f64x2);
406
- float64x2_t delta_y_low_f64x2 = vsubq_f64(a_y_low_f64x2, b_y_low_f64x2);
407
- float64x2_t delta_y_high_f64x2 = vsubq_f64(a_y_high_f64x2, b_y_high_f64x2);
408
- float64x2_t delta_z_low_f64x2 = vsubq_f64(a_z_low_f64x2, b_z_low_f64x2);
409
- float64x2_t delta_z_high_f64x2 = vsubq_f64(a_z_high_f64x2, b_z_high_f64x2);
374
+ float64x2_t delta_x_low_f64x2 = vsubq_f64(vcvt_f64_f32(vget_low_f32(a_x_f32x4)),
375
+ vcvt_f64_f32(vget_low_f32(b_x_f32x4)));
376
+ float64x2_t delta_x_high_f64x2 = vsubq_f64(vcvt_high_f64_f32(a_x_f32x4), vcvt_high_f64_f32(b_x_f32x4));
377
+ float64x2_t delta_y_low_f64x2 = vsubq_f64(vcvt_f64_f32(vget_low_f32(a_y_f32x4)),
378
+ vcvt_f64_f32(vget_low_f32(b_y_f32x4)));
379
+ float64x2_t delta_y_high_f64x2 = vsubq_f64(vcvt_high_f64_f32(a_y_f32x4), vcvt_high_f64_f32(b_y_f32x4));
380
+ float64x2_t delta_z_low_f64x2 = vsubq_f64(vcvt_f64_f32(vget_low_f32(a_z_f32x4)),
381
+ vcvt_f64_f32(vget_low_f32(b_z_f32x4)));
382
+ float64x2_t delta_z_high_f64x2 = vsubq_f64(vcvt_high_f64_f32(a_z_f32x4), vcvt_high_f64_f32(b_z_f32x4));
410
383
 
411
384
  sum_squared_x_low_f64x2 = vfmaq_f64(sum_squared_x_low_f64x2, delta_x_low_f64x2, delta_x_low_f64x2),
412
385
  sum_squared_x_high_f64x2 = vfmaq_f64(sum_squared_x_high_f64x2, delta_x_high_f64x2, delta_x_high_f64x2);
@@ -416,71 +389,39 @@ NK_PUBLIC void nk_rmsd_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t
416
389
  sum_squared_z_high_f64x2 = vfmaq_f64(sum_squared_z_high_f64x2, delta_z_high_f64x2, delta_z_high_f64x2);
417
390
  }
418
391
 
419
- nk_f64_t sum_a_x = vaddvq_f64(vaddq_f64(sum_a_x_low_f64x2, sum_a_x_high_f64x2));
420
- nk_f64_t sum_a_y = vaddvq_f64(vaddq_f64(sum_a_y_low_f64x2, sum_a_y_high_f64x2));
421
- nk_f64_t sum_a_z = vaddvq_f64(vaddq_f64(sum_a_z_low_f64x2, sum_a_z_high_f64x2));
422
- nk_f64_t sum_b_x = vaddvq_f64(vaddq_f64(sum_b_x_low_f64x2, sum_b_x_high_f64x2));
423
- nk_f64_t sum_b_y = vaddvq_f64(vaddq_f64(sum_b_y_low_f64x2, sum_b_y_high_f64x2));
424
- nk_f64_t sum_b_z = vaddvq_f64(vaddq_f64(sum_b_z_low_f64x2, sum_b_z_high_f64x2));
425
392
  nk_f64_t sum_squared_x = vaddvq_f64(vaddq_f64(sum_squared_x_low_f64x2, sum_squared_x_high_f64x2));
426
393
  nk_f64_t sum_squared_y = vaddvq_f64(vaddq_f64(sum_squared_y_low_f64x2, sum_squared_y_high_f64x2));
427
394
  nk_f64_t sum_squared_z = vaddvq_f64(vaddq_f64(sum_squared_z_low_f64x2, sum_squared_z_high_f64x2));
428
395
 
429
396
  for (; index < n; ++index) {
430
- nk_f64_t a_x = a[index * 3 + 0], a_y = a[index * 3 + 1], a_z = a[index * 3 + 2];
431
- nk_f64_t b_x = b[index * 3 + 0], b_y = b[index * 3 + 1], b_z = b[index * 3 + 2];
432
- sum_a_x += a_x, sum_a_y += a_y, sum_a_z += a_z;
433
- sum_b_x += b_x, sum_b_y += b_y, sum_b_z += b_z;
434
- nk_f64_t delta_x = a_x - b_x, delta_y = a_y - b_y, delta_z = a_z - b_z;
397
+ nk_f64_t delta_x = (nk_f64_t)a[index * 3 + 0] - (nk_f64_t)b[index * 3 + 0];
398
+ nk_f64_t delta_y = (nk_f64_t)a[index * 3 + 1] - (nk_f64_t)b[index * 3 + 1];
399
+ nk_f64_t delta_z = (nk_f64_t)a[index * 3 + 2] - (nk_f64_t)b[index * 3 + 2];
435
400
  sum_squared_x += delta_x * delta_x, sum_squared_y += delta_y * delta_y, sum_squared_z += delta_z * delta_z;
436
401
  }
437
402
 
438
- nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
439
- nk_f64_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
440
- nk_f64_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
441
- if (a_centroid)
442
- a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
443
- a_centroid[2] = (nk_f32_t)centroid_a_z;
444
- if (b_centroid)
445
- b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
446
- b_centroid[2] = (nk_f32_t)centroid_b_z;
447
-
448
- nk_f64_t mean_diff_x = centroid_a_x - centroid_b_x, mean_diff_y = centroid_a_y - centroid_b_y,
449
- mean_diff_z = centroid_a_z - centroid_b_z;
450
- nk_f64_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
451
- *result = nk_f64_sqrt_neon((sum_squared_x + sum_squared_y + sum_squared_z) * inv_n - mean_diff_sq);
403
+ *result = nk_f64_sqrt_neon((sum_squared_x + sum_squared_y + sum_squared_z) / (nk_f64_t)n);
452
404
  }
453
405
 
454
406
  NK_PUBLIC void nk_rmsd_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
455
407
  nk_f64_t *b_centroid, nk_f64_t *rotation, nk_f64_t *scale, nk_f64_t *result) {
456
- // RMSD uses identity rotation and scale=1.0.
408
+ if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
409
+ if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
457
410
  if (rotation)
458
411
  rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
459
412
  rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
460
413
  if (scale) *scale = 1.0;
461
414
 
462
415
  float64x2_t const zeros_f64x2 = vdupq_n_f64(0);
463
-
464
- // Accumulators for centroids and squared differences
465
- float64x2_t sum_a_x_f64x2 = zeros_f64x2, sum_a_y_f64x2 = zeros_f64x2, sum_a_z_f64x2 = zeros_f64x2;
466
- float64x2_t sum_b_x_f64x2 = zeros_f64x2, sum_b_y_f64x2 = zeros_f64x2, sum_b_z_f64x2 = zeros_f64x2;
467
416
  float64x2_t sum_squared_x_f64x2 = zeros_f64x2, sum_squared_y_f64x2 = zeros_f64x2, sum_squared_z_f64x2 = zeros_f64x2;
468
417
 
469
418
  float64x2_t a_x_f64x2, a_y_f64x2, a_z_f64x2, b_x_f64x2, b_y_f64x2, b_z_f64x2;
470
419
  nk_size_t i = 0;
471
420
 
472
- // Main loop processing 2 points at a time
473
421
  for (; i + 2 <= n; i += 2) {
474
422
  nk_deinterleave_f64x2_neon_(a + i * 3, &a_x_f64x2, &a_y_f64x2, &a_z_f64x2);
475
423
  nk_deinterleave_f64x2_neon_(b + i * 3, &b_x_f64x2, &b_y_f64x2, &b_z_f64x2);
476
424
 
477
- sum_a_x_f64x2 = vaddq_f64(sum_a_x_f64x2, a_x_f64x2);
478
- sum_a_y_f64x2 = vaddq_f64(sum_a_y_f64x2, a_y_f64x2);
479
- sum_a_z_f64x2 = vaddq_f64(sum_a_z_f64x2, a_z_f64x2);
480
- sum_b_x_f64x2 = vaddq_f64(sum_b_x_f64x2, b_x_f64x2);
481
- sum_b_y_f64x2 = vaddq_f64(sum_b_y_f64x2, b_y_f64x2);
482
- sum_b_z_f64x2 = vaddq_f64(sum_b_z_f64x2, b_z_f64x2);
483
-
484
425
  float64x2_t delta_x_f64x2 = vsubq_f64(a_x_f64x2, b_x_f64x2);
485
426
  float64x2_t delta_y_f64x2 = vsubq_f64(a_y_f64x2, b_y_f64x2);
486
427
  float64x2_t delta_z_f64x2 = vsubq_f64(a_z_f64x2, b_z_f64x2);
@@ -490,53 +431,23 @@ NK_PUBLIC void nk_rmsd_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_t
490
431
  sum_squared_z_f64x2 = vfmaq_f64(sum_squared_z_f64x2, delta_z_f64x2, delta_z_f64x2);
491
432
  }
492
433
 
493
- // Reduce vectors to scalars.
494
- nk_f64_t total_ax = nk_reduce_stable_f64x2_neon_(sum_a_x_f64x2), total_ax_compensation = 0.0;
495
- nk_f64_t total_ay = nk_reduce_stable_f64x2_neon_(sum_a_y_f64x2), total_ay_compensation = 0.0;
496
- nk_f64_t total_az = nk_reduce_stable_f64x2_neon_(sum_a_z_f64x2), total_az_compensation = 0.0;
497
- nk_f64_t total_bx = nk_reduce_stable_f64x2_neon_(sum_b_x_f64x2), total_bx_compensation = 0.0;
498
- nk_f64_t total_by = nk_reduce_stable_f64x2_neon_(sum_b_y_f64x2), total_by_compensation = 0.0;
499
- nk_f64_t total_bz = nk_reduce_stable_f64x2_neon_(sum_b_z_f64x2), total_bz_compensation = 0.0;
500
434
  nk_f64_t total_squared_x = nk_reduce_stable_f64x2_neon_(sum_squared_x_f64x2), total_squared_x_compensation = 0.0;
501
435
  nk_f64_t total_squared_y = nk_reduce_stable_f64x2_neon_(sum_squared_y_f64x2), total_squared_y_compensation = 0.0;
502
436
  nk_f64_t total_squared_z = nk_reduce_stable_f64x2_neon_(sum_squared_z_f64x2), total_squared_z_compensation = 0.0;
503
437
 
504
- // Scalar tail
505
438
  for (; i < n; ++i) {
506
- nk_f64_t ax = a[i * 3 + 0], ay = a[i * 3 + 1], az = a[i * 3 + 2];
507
- nk_f64_t bx = b[i * 3 + 0], by = b[i * 3 + 1], bz = b[i * 3 + 2];
508
- nk_accumulate_sum_f64_(&total_ax, &total_ax_compensation, ax);
509
- nk_accumulate_sum_f64_(&total_ay, &total_ay_compensation, ay);
510
- nk_accumulate_sum_f64_(&total_az, &total_az_compensation, az);
511
- nk_accumulate_sum_f64_(&total_bx, &total_bx_compensation, bx);
512
- nk_accumulate_sum_f64_(&total_by, &total_by_compensation, by);
513
- nk_accumulate_sum_f64_(&total_bz, &total_bz_compensation, bz);
514
- nk_f64_t delta_x = ax - bx, delta_y = ay - by, delta_z = az - bz;
439
+ nk_f64_t delta_x = a[i * 3 + 0] - b[i * 3 + 0];
440
+ nk_f64_t delta_y = a[i * 3 + 1] - b[i * 3 + 1];
441
+ nk_f64_t delta_z = a[i * 3 + 2] - b[i * 3 + 2];
515
442
  nk_accumulate_square_f64_(&total_squared_x, &total_squared_x_compensation, delta_x);
516
443
  nk_accumulate_square_f64_(&total_squared_y, &total_squared_y_compensation, delta_y);
517
444
  nk_accumulate_square_f64_(&total_squared_z, &total_squared_z_compensation, delta_z);
518
445
  }
519
446
 
520
- total_ax += total_ax_compensation, total_ay += total_ay_compensation, total_az += total_az_compensation;
521
- total_bx += total_bx_compensation, total_by += total_by_compensation, total_bz += total_bz_compensation;
522
447
  total_squared_x += total_squared_x_compensation, total_squared_y += total_squared_y_compensation,
523
448
  total_squared_z += total_squared_z_compensation;
524
449
 
525
- // Compute centroids
526
- nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
527
- nk_f64_t centroid_a_x = total_ax * inv_n, centroid_a_y = total_ay * inv_n, centroid_a_z = total_az * inv_n;
528
- nk_f64_t centroid_b_x = total_bx * inv_n, centroid_b_y = total_by * inv_n, centroid_b_z = total_bz * inv_n;
529
- if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
530
- if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
531
-
532
- // Compute RMSD
533
- nk_f64_t mean_diff_x = centroid_a_x - centroid_b_x;
534
- nk_f64_t mean_diff_y = centroid_a_y - centroid_b_y;
535
- nk_f64_t mean_diff_z = centroid_a_z - centroid_b_z;
536
- nk_f64_t sum_squared = total_squared_x + total_squared_y + total_squared_z;
537
- nk_f64_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
538
-
539
- *result = nk_f64_sqrt_neon(sum_squared * inv_n - mean_diff_sq);
450
+ *result = nk_f64_sqrt_neon((total_squared_x + total_squared_y + total_squared_z) / (nk_f64_t)n);
540
451
  }
541
452
 
542
453
  NK_PUBLIC void nk_kabsch_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
@@ -267,12 +267,12 @@ NK_PUBLIC void nk_rmsd_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk
267
267
  rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
268
268
  rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
269
269
  if (scale) *scale = 1.0f;
270
+ if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
271
+ if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
270
272
 
271
273
  float32x4_t const zeros_f32x4 = vdupq_n_f32(0);
272
274
 
273
- // Accumulators for centroids and squared differences
274
- float32x4_t sum_a_x_f32x4 = zeros_f32x4, sum_a_y_f32x4 = zeros_f32x4, sum_a_z_f32x4 = zeros_f32x4;
275
- float32x4_t sum_b_x_f32x4 = zeros_f32x4, sum_b_y_f32x4 = zeros_f32x4, sum_b_z_f32x4 = zeros_f32x4;
275
+ // Accumulators for squared differences
276
276
  float32x4_t sum_squared_x_f32x4 = zeros_f32x4, sum_squared_y_f32x4 = zeros_f32x4, sum_squared_z_f32x4 = zeros_f32x4;
277
277
 
278
278
  float32x4_t a_x_f32x4, a_y_f32x4, a_z_f32x4, b_x_f32x4, b_y_f32x4, b_z_f32x4;
@@ -283,13 +283,6 @@ NK_PUBLIC void nk_rmsd_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk
283
283
  nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + i * 3, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
284
284
  nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + i * 3, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
285
285
 
286
- sum_a_x_f32x4 = vaddq_f32(sum_a_x_f32x4, a_x_f32x4);
287
- sum_a_y_f32x4 = vaddq_f32(sum_a_y_f32x4, a_y_f32x4);
288
- sum_a_z_f32x4 = vaddq_f32(sum_a_z_f32x4, a_z_f32x4);
289
- sum_b_x_f32x4 = vaddq_f32(sum_b_x_f32x4, b_x_f32x4);
290
- sum_b_y_f32x4 = vaddq_f32(sum_b_y_f32x4, b_y_f32x4);
291
- sum_b_z_f32x4 = vaddq_f32(sum_b_z_f32x4, b_z_f32x4);
292
-
293
286
  float32x4_t delta_x_f32x4 = vsubq_f32(a_x_f32x4, b_x_f32x4);
294
287
  float32x4_t delta_y_f32x4 = vsubq_f32(a_y_f32x4, b_y_f32x4);
295
288
  float32x4_t delta_z_f32x4 = vsubq_f32(a_z_f32x4, b_z_f32x4);
@@ -305,13 +298,6 @@ NK_PUBLIC void nk_rmsd_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk
305
298
  nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(a + i * 3, n - i, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
306
299
  nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(b + i * 3, n - i, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
307
300
 
308
- sum_a_x_f32x4 = vaddq_f32(sum_a_x_f32x4, a_x_f32x4);
309
- sum_a_y_f32x4 = vaddq_f32(sum_a_y_f32x4, a_y_f32x4);
310
- sum_a_z_f32x4 = vaddq_f32(sum_a_z_f32x4, a_z_f32x4);
311
- sum_b_x_f32x4 = vaddq_f32(sum_b_x_f32x4, b_x_f32x4);
312
- sum_b_y_f32x4 = vaddq_f32(sum_b_y_f32x4, b_y_f32x4);
313
- sum_b_z_f32x4 = vaddq_f32(sum_b_z_f32x4, b_z_f32x4);
314
-
315
301
  float32x4_t delta_x_f32x4 = vsubq_f32(a_x_f32x4, b_x_f32x4);
316
302
  float32x4_t delta_y_f32x4 = vsubq_f32(a_y_f32x4, b_y_f32x4);
317
303
  float32x4_t delta_z_f32x4 = vsubq_f32(a_z_f32x4, b_z_f32x4);
@@ -322,36 +308,11 @@ NK_PUBLIC void nk_rmsd_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk
322
308
  }
323
309
 
324
310
  // Reduce vectors to scalars
325
- nk_f32_t total_ax = vaddvq_f32(sum_a_x_f32x4);
326
- nk_f32_t total_ay = vaddvq_f32(sum_a_y_f32x4);
327
- nk_f32_t total_az = vaddvq_f32(sum_a_z_f32x4);
328
- nk_f32_t total_bx = vaddvq_f32(sum_b_x_f32x4);
329
- nk_f32_t total_by = vaddvq_f32(sum_b_y_f32x4);
330
- nk_f32_t total_bz = vaddvq_f32(sum_b_z_f32x4);
331
311
  nk_f32_t total_squared_x = vaddvq_f32(sum_squared_x_f32x4);
332
312
  nk_f32_t total_squared_y = vaddvq_f32(sum_squared_y_f32x4);
333
313
  nk_f32_t total_squared_z = vaddvq_f32(sum_squared_z_f32x4);
334
314
 
335
- // Compute centroids
336
- nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
337
- nk_f32_t centroid_a_x = total_ax * inv_n;
338
- nk_f32_t centroid_a_y = total_ay * inv_n;
339
- nk_f32_t centroid_a_z = total_az * inv_n;
340
- nk_f32_t centroid_b_x = total_bx * inv_n;
341
- nk_f32_t centroid_b_y = total_by * inv_n;
342
- nk_f32_t centroid_b_z = total_bz * inv_n;
343
-
344
- if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
345
- if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
346
-
347
- // Compute RMSD
348
- nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
349
- nk_f32_t mean_diff_y = centroid_a_y - centroid_b_y;
350
- nk_f32_t mean_diff_z = centroid_a_z - centroid_b_z;
351
- nk_f32_t sum_squared = total_squared_x + total_squared_y + total_squared_z;
352
- nk_f32_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
353
-
354
- *result = nk_f32_sqrt_neon(sum_squared * inv_n - mean_diff_sq);
315
+ *result = nk_f32_sqrt_neon((total_squared_x + total_squared_y + total_squared_z) / (nk_f32_t)n);
355
316
  }
356
317
 
357
318
  NK_PUBLIC void nk_kabsch_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,