numkong 7.4.3 → 7.4.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +49 -49
- package/binding.gyp +3 -0
- package/include/numkong/capabilities.h +1 -1
- package/include/numkong/each/haswell.h +4 -4
- package/include/numkong/maxsim/sme.h +65 -27
- package/include/numkong/mesh/README.md +13 -27
- package/include/numkong/mesh/haswell.h +25 -122
- package/include/numkong/mesh/neon.h +21 -110
- package/include/numkong/mesh/neonbfdot.h +4 -43
- package/include/numkong/mesh/rvv.h +7 -82
- package/include/numkong/mesh/serial.h +26 -53
- package/include/numkong/mesh/skylake.h +7 -123
- package/include/numkong/mesh/v128relaxed.h +9 -93
- package/include/numkong/mesh.h +2 -2
- package/include/numkong/mesh.hpp +35 -96
- package/include/numkong/types.h +15 -9
- package/numkong.gypi +3 -0
- package/package.json +7 -7
- package/wasm/numkong.wasm +0 -0
package/README.md
CHANGED
|
@@ -391,24 +391,24 @@ Float16 prioritizes __precision over range__ (10 vs 7 mantissa bits), making it
|
|
|
391
391
|
On x86, older CPUs use __F16C extensions__ (Ivy Bridge+) for fast Float16 → Float32 conversion; Sapphire Rapids+ adds native __AVX-512-FP16__ with dedicated Float16 arithmetic.
|
|
392
392
|
On Arm, ARMv8.4-A adds __FMLAL/FMLAL2__ instructions for fused Float16 → Float32 widening multiply-accumulate, reducing the total latency from 7 cycles to 4 cycles and achieving 20–48% speedup over the separate convert-then-FMA path.
|
|
393
393
|
|
|
394
|
-
| Platform
|
|
395
|
-
|
|
|
396
|
-
| __x86__
|
|
397
|
-
| Diamond
|
|
398
|
-
| Sapphire
|
|
399
|
-
| Genoa
|
|
400
|
-
| Skylake
|
|
401
|
-
| Haswell
|
|
402
|
-
| __Arm__
|
|
403
|
-
| Apple M2
|
|
404
|
-
| Graviton 3
|
|
405
|
-
| Apple M1
|
|
406
|
-
| Graviton 2
|
|
407
|
-
| Graviton 1
|
|
408
|
-
| __RISC-V__
|
|
409
|
-
| RVV
|
|
410
|
-
| RVV
|
|
411
|
-
| RVV
|
|
394
|
+
| Platform | BFloat16 Path | Step | Float16 Path | Step |
|
|
395
|
+
| :--------------- | :------------------------- | ---: | :--------------------- | ---: |
|
|
396
|
+
| __x86__ | | | | |
|
|
397
|
+
| Diamond, '26 | ↓ Genoa | 32 | `VDPPHPS` widening dot | 32 |
|
|
398
|
+
| Sapphire, '23 | ↓ Genoa | 32 | ↓ Skylake | 16 |
|
|
399
|
+
| Genoa, '22 | `VDPBF16PS` widening dot | 32 | ↓ Skylake | 16 |
|
|
400
|
+
| Skylake, '15 | `SLLI` + `VFMADD` | 16 | `VCVTPH2PS` + `VFMADD` | 16 |
|
|
401
|
+
| Haswell, '13 | `SLLI` + `VFMADD` | 8 | `VCVTPH2PS` + `VFMADD` | 8 |
|
|
402
|
+
| __Arm__ | | | | |
|
|
403
|
+
| Apple M2+, '22 | `BFDOT` widening dot | 8 | ↓ FP16FML | 8 |
|
|
404
|
+
| Graviton 3+, '21 | `SVBFDOT` widening dot | 4–32 | `SVCVT` → `SVFMLA` | 4–32 |
|
|
405
|
+
| Apple M1, '20 | ↓ NEON | 8 | `FMLAL` widening FMA | 8 |
|
|
406
|
+
| Graviton 2, '19 | ↓ NEON | 8 | `FCVTL` + `FMLA` | 4 |
|
|
407
|
+
| Graviton 1, '18 | `SHLL` + `FMLA` | 8 | bit-manip → `FMLA` | 8 |
|
|
408
|
+
| __RISC-V__ | | | | |
|
|
409
|
+
| RVV+Zvfbfwma | `VFWMACCBF16` widening FMA | 4–32 | ↓ RVV | 4–32 |
|
|
410
|
+
| RVV+Zvfh | ↓ RVV | 4–32 | `VFWMACC` widening FMA | 4–32 |
|
|
411
|
+
| RVV | shift + `VFMACC` | 4–32 | convert + `VFMACC` | 4–32 |
|
|
412
412
|
|
|
413
413
|
> BFloat16 shares Float32's 8-bit exponent, so upcasting is a 16-bit left shift (`SLLI` on x86, `SHLL` on Arm) that zero-pads the truncated mantissa — essentially free.
|
|
414
414
|
> Float16 has a different exponent width (5 vs 8 bits), requiring a dedicated convert: `VCVTPH2PS` (x86 F16C) or `FCVTL` (Arm NEON).
|
|
@@ -444,22 +444,22 @@ E4M3FN (no infinities, NaN only) is preferred for __training__ where precision n
|
|
|
444
444
|
On x86 Genoa/Sapphire Rapids, E4M3/E5M2 values upcast to BFloat16 via lookup tables, then use native __DPBF16PS__ for 2-per-lane dot products accumulating to Float32.
|
|
445
445
|
On Arm Graviton 3+, the same BFloat16 upcast happens via NEON table lookups, then __BFDOT__ instructions complete the computation.
|
|
446
446
|
|
|
447
|
-
| Platform
|
|
448
|
-
|
|
|
449
|
-
| __x86__
|
|
450
|
-
| Diamond
|
|
451
|
-
| Genoa
|
|
452
|
-
| Ice Lake
|
|
453
|
-
| Skylake
|
|
454
|
-
| Haswell
|
|
455
|
-
| __Arm__
|
|
456
|
-
| NEON
|
|
457
|
-
| NEON
|
|
458
|
-
| NEON
|
|
459
|
-
| __RISC-V__
|
|
460
|
-
| RVV
|
|
461
|
-
| RVV
|
|
462
|
-
| RVV
|
|
447
|
+
| Platform | E5M2 Path | Step | E4M3 Path | Step |
|
|
448
|
+
| :---------------- | :----------------------------- | ---: | :----------------------------- | ---: |
|
|
449
|
+
| __x86__ | | | | |
|
|
450
|
+
| Diamond, '26 | `VCVTBF82PH` → F16 + `VDPPHPS` | 32 | `VCVTHF82PH` → F16 + `VDPPHPS` | 32 |
|
|
451
|
+
| Genoa, '22 | → BF16 + `VDPBF16PS` | 32 | ↓ Ice Lake | 64 |
|
|
452
|
+
| Ice Lake, '19 | ↓ Skylake | 16 | octave LUT + `VPDPBUSD` | 64 |
|
|
453
|
+
| Skylake, '15 | rebias → F32 FMA | 16 | rebias → F32 FMA | 16 |
|
|
454
|
+
| Haswell, '13 | rebias → F32 FMA | 8 | rebias → F32 FMA | 8 |
|
|
455
|
+
| __Arm__ | | | | |
|
|
456
|
+
| NEON+FP8DOT, '26 | native `FDOT` | 16 | native `FDOT` | 16 |
|
|
457
|
+
| NEON+FP16FML, '20 | SHL → F16 + `FMLAL` | 16 | LUT → F16 + `FMLAL` | 16 |
|
|
458
|
+
| NEON, '18 | SHL + `FCVTL` + FMA | 8 | → F16 + `FCVTL` + FMA | 8 |
|
|
459
|
+
| __RISC-V__ | | | | |
|
|
460
|
+
| RVV+Zvfbfwma | rebias → BF16 + `VFWMACCBF16` | 4–32 | LUT → BF16 + `VFWMACCBF16` | 4–32 |
|
|
461
|
+
| RVV+Zvfh | SHL → F16 + `VFWMACC` | 4–32 | LUT → F16 + `VFWMACC` | 4–32 |
|
|
462
|
+
| RVV | rebias → F32 + `VFMACC` | 4–32 | LUT → F32 + `VFMACC` | 4–32 |
|
|
463
463
|
|
|
464
464
|
> E5M2 shares Float16's exponent bias (15), so E5M2 → Float16 conversion is a single left-shift by 8 bits (`SHL 8`).
|
|
465
465
|
> E4M3 on Ice Lake uses "octave decomposition": the 4-bit exponent splits into 2 octave + 2 remainder bits, yielding 7 integer accumulators post-scaled by powers of 2.
|
|
@@ -469,23 +469,23 @@ Their smaller range allows scaling to exact integers that fit in `i8`/`i16`, ena
|
|
|
469
469
|
Float16 can also serve as an accumulator, accurately representing ~50 products of E3M2FN pairs or ~20 products of E2M3FN pairs before overflow.
|
|
470
470
|
On Arm, NEON FHM extensions bring widening `FMLAL` dot-products for Float16 — both faster and more widely available than `BFDOT` for BFloat16.
|
|
471
471
|
|
|
472
|
-
| Platform
|
|
473
|
-
|
|
|
474
|
-
| __x86__
|
|
475
|
-
| Sierra Forest
|
|
476
|
-
| Alder Lake
|
|
477
|
-
| Ice Lake
|
|
478
|
-
| Skylake
|
|
479
|
-
| Haswell
|
|
480
|
-
| __Arm__
|
|
481
|
-
| NEON
|
|
482
|
-
| NEON
|
|
483
|
-
| NEON
|
|
484
|
-
| __RISC-V__
|
|
485
|
-
| RVV
|
|
472
|
+
| Platform | E3M2 Path | Step | E2M3 Path | Step |
|
|
473
|
+
| :----------------- | :------------------------ | ---: | :----------------------- | ---: |
|
|
474
|
+
| __x86__ | | | | |
|
|
475
|
+
| Sierra Forest, '24 | ↓ Haswell | 32 | `VPSHUFB` + `VPDPBSSD` | 32 |
|
|
476
|
+
| Alder Lake, '21 | ↓ Haswell | 32 | `VPSHUFB` + `VPDPBUSD` | 32 |
|
|
477
|
+
| Ice Lake, '19 | `VPERMW` + `VPMADDWD` | 32 | `VPERMB` + `VPDPBUSD` | 64 |
|
|
478
|
+
| Skylake, '15 | `VPSHUFB` + `VPMADDWD` | 64 | `VPSHUFB` + `VPMADDUBSW` | 64 |
|
|
479
|
+
| Haswell, '13 | `VPSHUFB` + `VPMADDWD` | 32 | `VPSHUFB` + `VPMADDUBSW` | 32 |
|
|
480
|
+
| __Arm__ | | | | |
|
|
481
|
+
| NEON+FP8DOT, '26 | → E5M2 + `FDOT` | 16 | → E4M3 + `FDOT` | 16 |
|
|
482
|
+
| NEON+DotProd, '19 | `VQTBL2` + `SMLAL` | 16 | `VQTBL2` + `SDOT` | 16 |
|
|
483
|
+
| NEON, '18 | → F16 + `FCVTL` + FMA | 16 | → F16 + `FCVTL` + FMA | 16 |
|
|
484
|
+
| __RISC-V__ | | | | |
|
|
485
|
+
| RVV | I16 gather LUT + `VWMACC` | 4–32 | U8 gather LUT + `VWMACC` | 4–32 |
|
|
486
486
|
|
|
487
487
|
> E3M2/E2M3 values map to exact integers via 32-entry LUTs (magnitudes up to 448 for E3M2, 120 for E2M3), enabling integer accumulation with no rounding error.
|
|
488
|
-
> On NEON
|
|
488
|
+
> On NEON+FP8DOT, E3M2 is first promoted to E5M2 and E2M3 to E4M3 before the hardware `FDOT` instruction.
|
|
489
489
|
> Sierra Forest and Alder Lake use native `VPDPBSSD` (signed×signed) and `VPDPBUSD` (unsigned×signed) respectively for E2M3.
|
|
490
490
|
|
|
491
491
|
E4M3 and E5M2 cannot use the integer path.
|
package/binding.gyp
CHANGED
|
@@ -196,7 +196,7 @@ NK_PUBLIC void nk_each_sum_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_
|
|
|
196
196
|
__m256 a_f32x8 = _mm256_cvtph_ps(a_f16x8);
|
|
197
197
|
__m256 b_f32x8 = _mm256_cvtph_ps(b_f16x8);
|
|
198
198
|
__m256 result_f32x8 = _mm256_add_ps(a_f32x8, b_f32x8);
|
|
199
|
-
__m128i result_f16x8 = _mm256_cvtps_ph(result_f32x8, _MM_FROUND_TO_NEAREST_INT
|
|
199
|
+
__m128i result_f16x8 = _mm256_cvtps_ph(result_f32x8, _MM_FROUND_TO_NEAREST_INT);
|
|
200
200
|
_mm_storeu_si128((__m128i *)(result + i), result_f16x8);
|
|
201
201
|
}
|
|
202
202
|
|
|
@@ -223,7 +223,7 @@ NK_PUBLIC void nk_each_scale_f16_haswell(nk_f16_t const *a, nk_size_t n, nk_f32_
|
|
|
223
223
|
__m128i a_f16x8 = _mm_loadu_si128((__m128i const *)(a + i));
|
|
224
224
|
__m256 a_f32x8 = _mm256_cvtph_ps(a_f16x8);
|
|
225
225
|
__m256 result_f32x8 = _mm256_fmadd_ps(a_f32x8, alpha_f32x8, beta_f32x8);
|
|
226
|
-
__m128i result_f16x8 = _mm256_cvtps_ph(result_f32x8, _MM_FROUND_TO_NEAREST_INT
|
|
226
|
+
__m128i result_f16x8 = _mm256_cvtps_ph(result_f32x8, _MM_FROUND_TO_NEAREST_INT);
|
|
227
227
|
_mm_storeu_si128((__m128i *)(result + i), result_f16x8);
|
|
228
228
|
}
|
|
229
229
|
|
|
@@ -271,7 +271,7 @@ NK_PUBLIC void nk_each_blend_f16_haswell( //
|
|
|
271
271
|
__m256 b_f32x8 = _mm256_cvtph_ps(b_f16x8);
|
|
272
272
|
__m256 a_scaled_f32x8 = _mm256_mul_ps(a_f32x8, alpha_f32x8);
|
|
273
273
|
__m256 result_f32x8 = _mm256_fmadd_ps(b_f32x8, beta_f32x8, a_scaled_f32x8);
|
|
274
|
-
__m128i result_f16x8 = _mm256_cvtps_ph(result_f32x8, _MM_FROUND_TO_NEAREST_INT
|
|
274
|
+
__m128i result_f16x8 = _mm256_cvtps_ph(result_f32x8, _MM_FROUND_TO_NEAREST_INT);
|
|
275
275
|
_mm_storeu_si128((__m128i *)(result + i), result_f16x8);
|
|
276
276
|
}
|
|
277
277
|
|
|
@@ -451,7 +451,7 @@ NK_PUBLIC void nk_each_fma_f16_haswell( //
|
|
|
451
451
|
__m256 ab_f32x8 = _mm256_mul_ps(a_f32x8, b_f32x8);
|
|
452
452
|
__m256 abc_f32x8 = _mm256_mul_ps(ab_f32x8, alpha_f32x8);
|
|
453
453
|
__m256 result_f32x8 = _mm256_fmadd_ps(c_f32x8, beta_f32x8, abc_f32x8);
|
|
454
|
-
__m128i result_f16x8 = _mm256_cvtps_ph(result_f32x8, _MM_FROUND_TO_NEAREST_INT
|
|
454
|
+
__m128i result_f16x8 = _mm256_cvtps_ph(result_f32x8, _MM_FROUND_TO_NEAREST_INT);
|
|
455
455
|
_mm_storeu_si128((__m128i *)(result + i), result_f16x8);
|
|
456
456
|
}
|
|
457
457
|
|
|
@@ -652,6 +652,46 @@ NK_PUBLIC nk_f64_t nk_maxsim_reduce_dot_f32_ssve_( //
|
|
|
652
652
|
return svaddv_f64(svptrue_b64(), accumulator_even_f64x) + svaddv_f64(svptrue_b64(), accumulator_odd_f64x);
|
|
653
653
|
}
|
|
654
654
|
|
|
655
|
+
/**
|
|
656
|
+
* Streaming-compatible angular distance accumulation from pre-reduced dot products
|
|
657
|
+
* and contiguous f64 norm arrays.
|
|
658
|
+
* Computes rsqrt via Newton-Raphson and accumulates `1 - dot / sqrt(||q||^2 * ||d||^2)`.
|
|
659
|
+
*/
|
|
660
|
+
NK_PUBLIC nk_f64_t nk_maxsim_angular_from_dots_ssve_( //
|
|
661
|
+
nk_f64_t const *dot_products, nk_size_t count, //
|
|
662
|
+
nk_f64_t const *query_norms_f64, nk_f64_t const *document_norms_f64) NK_STREAMING_ { //
|
|
663
|
+
|
|
664
|
+
nk_f64_t total_angular_distance_f64 = 0.0;
|
|
665
|
+
nk_size_t const vector_length = svcntd();
|
|
666
|
+
for (nk_size_t i = 0; i < count; i += vector_length) {
|
|
667
|
+
svbool_t predicate_b64x = svwhilelt_b64_u64(i, count);
|
|
668
|
+
svfloat64_t dot_products_f64x = svld1_f64(predicate_b64x, dot_products + i);
|
|
669
|
+
svfloat64_t query_norms_f64x = svld1_f64(predicate_b64x, query_norms_f64 + i);
|
|
670
|
+
svfloat64_t document_norms_f64x = svld1_f64(predicate_b64x, document_norms_f64 + i);
|
|
671
|
+
|
|
672
|
+
// norm_product = query_norm * document_norm
|
|
673
|
+
svfloat64_t norm_products_f64x = svmul_f64_x(predicate_b64x, query_norms_f64x, document_norms_f64x);
|
|
674
|
+
|
|
675
|
+
// Newton-Raphson rsqrt: estimate then two refinement steps
|
|
676
|
+
svfloat64_t rsqrt_f64x = svrsqrte_f64(norm_products_f64x);
|
|
677
|
+
rsqrt_f64x = svmul_f64_x(predicate_b64x, rsqrt_f64x,
|
|
678
|
+
svrsqrts_f64(svmul_f64_x(predicate_b64x, norm_products_f64x, rsqrt_f64x), rsqrt_f64x));
|
|
679
|
+
rsqrt_f64x = svmul_f64_x(predicate_b64x, rsqrt_f64x,
|
|
680
|
+
svrsqrts_f64(svmul_f64_x(predicate_b64x, norm_products_f64x, rsqrt_f64x), rsqrt_f64x));
|
|
681
|
+
|
|
682
|
+
// cosine = dot_product * rsqrt(norm_product), zeroed where norm <= 0
|
|
683
|
+
svbool_t positive_b64x = svcmpgt_f64(predicate_b64x, norm_products_f64x, svdup_n_f64(0.0));
|
|
684
|
+
svfloat64_t cosine_f64x = svmul_f64_z(positive_b64x, dot_products_f64x, rsqrt_f64x);
|
|
685
|
+
|
|
686
|
+
// angular_distance = max(0, 1 - cosine)
|
|
687
|
+
svfloat64_t angular_distance_f64x = svsub_f64_x(predicate_b64x, svdup_f64(1.0), cosine_f64x);
|
|
688
|
+
angular_distance_f64x = svmax_f64_x(predicate_b64x, angular_distance_f64x, svdup_f64(0.0));
|
|
689
|
+
|
|
690
|
+
total_angular_distance_f64 += svaddv_f64(predicate_b64x, angular_distance_f64x);
|
|
691
|
+
}
|
|
692
|
+
return total_angular_distance_f64;
|
|
693
|
+
}
|
|
694
|
+
|
|
655
695
|
/**
|
|
656
696
|
* MaxSim f32 kernel: i8 SMOPA screening + f32/f64 refinement + angular distance.
|
|
657
697
|
*
|
|
@@ -895,36 +935,34 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streami
|
|
|
895
935
|
svcvtlt_f64_f32_x(predicate_odd_b64x, document_values_3_f32x));
|
|
896
936
|
}
|
|
897
937
|
|
|
898
|
-
// Reduce accumulators and compute angular
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
: 0.0;
|
|
909
|
-
nk_f64_t angular_distance_f64 = 1.0 - cosine_f64;
|
|
910
|
-
if (angular_distance_f64 < 0.0) angular_distance_f64 = 0.0;
|
|
911
|
-
total_angular_distance_f64 += angular_distance_f64;
|
|
938
|
+
// Reduce SVE accumulators to scalars and compute angular distances
|
|
939
|
+
nk_f64_t dot_products_f64[4];
|
|
940
|
+
dot_products_f64[0] = svaddv_f64(svptrue_b64(), accumulator_0_f64x);
|
|
941
|
+
dot_products_f64[1] = svaddv_f64(svptrue_b64(), accumulator_1_f64x);
|
|
942
|
+
dot_products_f64[2] = svaddv_f64(svptrue_b64(), accumulator_2_f64x);
|
|
943
|
+
dot_products_f64[3] = svaddv_f64(svptrue_b64(), accumulator_3_f64x);
|
|
944
|
+
nk_f64_t batch_query_norms_f64[4], batch_document_norms_f64[4];
|
|
945
|
+
for (nk_size_t i = 0; i < 4; i++) {
|
|
946
|
+
batch_query_norms_f64[i] = (nk_f64_t)query_norms[row_start + row_batch_start + i];
|
|
947
|
+
batch_document_norms_f64[i] = (nk_f64_t)document_norms[best_document_indices[row_batch_start + i]];
|
|
912
948
|
}
|
|
949
|
+
total_angular_distance_f64 += nk_maxsim_angular_from_dots_ssve_(dot_products_f64, 4, batch_query_norms_f64,
|
|
950
|
+
batch_document_norms_f64);
|
|
913
951
|
}
|
|
914
952
|
|
|
915
|
-
// Remainder:
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
nk_f64_t
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
953
|
+
// Remainder: compute dot products then batch the angular distance
|
|
954
|
+
nk_size_t remainder_count = rows_remaining - row_batch_start;
|
|
955
|
+
if (remainder_count > 0) {
|
|
956
|
+
nk_f64_t remainder_dot_products_f64[3];
|
|
957
|
+
nk_f64_t remainder_query_norms_f64[3], remainder_document_norms_f64[3];
|
|
958
|
+
for (nk_size_t i = 0; i < remainder_count; i++) {
|
|
959
|
+
remainder_dot_products_f64[i] = nk_maxsim_reduce_dot_f32_ssve_(
|
|
960
|
+
query_original_ptrs[row_batch_start + i], document_original_ptrs[row_batch_start + i], depth);
|
|
961
|
+
remainder_query_norms_f64[i] = (nk_f64_t)query_norms[row_start + row_batch_start + i];
|
|
962
|
+
remainder_document_norms_f64[i] = (nk_f64_t)document_norms[best_document_indices[row_batch_start + i]];
|
|
963
|
+
}
|
|
964
|
+
total_angular_distance_f64 += nk_maxsim_angular_from_dots_ssve_(
|
|
965
|
+
remainder_dot_products_f64, remainder_count, remainder_query_norms_f64, remainder_document_norms_f64);
|
|
928
966
|
}
|
|
929
967
|
}
|
|
930
968
|
|
|
@@ -1,37 +1,23 @@
|
|
|
1
1
|
# Point Cloud Alignment in NumKong
|
|
2
2
|
|
|
3
|
-
NumKong implements
|
|
4
|
-
RMSD measures alignment quality, Kabsch finds the optimal rotation minimizing RMSD, and Umeyama extends Kabsch with uniform scaling.
|
|
5
|
-
Used in structural biology (protein alignment), robotics (point cloud registration), and computer graphics (mesh registration).
|
|
3
|
+
NumKong implements three algorithms for 3D point cloud comparison and alignment, used in structural biology (protein alignment), robotics (point cloud registration), and computer graphics (mesh registration).
|
|
6
4
|
|
|
7
|
-
|
|
5
|
+
RMSD measures raw point-pair deviation without centering or alignment:
|
|
8
6
|
|
|
9
7
|
$$
|
|
10
|
-
\
|
|
8
|
+
\text{RMSD} = \sqrt{\frac{1}{n}\sum \|a_i - b_i\|^2}
|
|
11
9
|
$$
|
|
12
10
|
|
|
13
|
-
|
|
11
|
+
Kabsch finds the optimal rotation $R$ that minimizes RMSD after centering both clouds at their centroids $\bar{a}$, $\bar{b}$, recovering $R$ from the SVD of the cross-covariance matrix $H$:
|
|
14
12
|
|
|
15
13
|
$$
|
|
16
|
-
H = \sum (a_i - \bar{a})(b_i - \bar{b})^T
|
|
14
|
+
H = \sum (a_i - \bar{a})(b_i - \bar{b})^T = U \Sigma V^T, \quad R = V U^T
|
|
17
15
|
$$
|
|
18
16
|
|
|
19
|
-
|
|
17
|
+
Umeyama extends Kabsch with a uniform scale factor $s$ derived from the singular values and source variance $\sigma_a^2$:
|
|
20
18
|
|
|
21
19
|
$$
|
|
22
|
-
|
|
23
|
-
$$
|
|
24
|
-
|
|
25
|
-
Umeyama scale factor:
|
|
26
|
-
|
|
27
|
-
$$
|
|
28
|
-
s = \frac{\text{tr}(\Sigma)}{n \cdot \sigma_a^2}
|
|
29
|
-
$$
|
|
30
|
-
|
|
31
|
-
RMSD after alignment:
|
|
32
|
-
|
|
33
|
-
$$
|
|
34
|
-
\text{RMSD} = \sqrt{\frac{1}{n}\sum \|s \cdot R(a_i - \bar{a}) - (b_i - \bar{b})\|^2}
|
|
20
|
+
s = \frac{\text{tr}(\Sigma)}{n \cdot \sigma_a^2}, \quad \text{RMSD} = \sqrt{\frac{1}{n}\sum \|s \cdot R(a_i - \bar{a}) - (b_i - \bar{b})\|^2}
|
|
35
21
|
$$
|
|
36
22
|
|
|
37
23
|
Reformulating as Python pseudocode:
|
|
@@ -189,25 +175,25 @@ Measured with Wasmtime v42 (Cranelift backend).
|
|
|
189
175
|
| Kernel | 256 | 1024 | 4096 |
|
|
190
176
|
| :-------------------------- | -----------------------: | -----------------------: | -----------------------: |
|
|
191
177
|
| __f64__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
192
|
-
| `nk_rmsd_f64_serial` |
|
|
178
|
+
| `nk_rmsd_f64_serial` | 279 mp/s, 0.5 ulp | 267 mp/s, 0.5 ulp | 279 mp/s, 0.5 ulp |
|
|
193
179
|
| `nk_kabsch_f64_serial` | 40.4 mp/s, 1.4 ulp | 47.3 mp/s, 2.6 ulp | 50.2 mp/s, 5.4 ulp |
|
|
194
180
|
| `nk_umeyama_f64_serial` | 34.5 mp/s, 1.0 ulp | 39.2 mp/s, 1.9 ulp | 41.6 mp/s, 3.7 ulp |
|
|
195
|
-
| `nk_rmsd_f64_neon` | 1,
|
|
181
|
+
| `nk_rmsd_f64_neon` | 1,776 mp/s, 0.4 ulp | 1,536 mp/s, 0.7 ulp | 2,037 mp/s, 1.3 ulp |
|
|
196
182
|
| `nk_kabsch_f64_neon` | 119 mp/s, 0.8 ulp | 222 mp/s, 1.3 ulp | 304 mp/s, 2.2 ulp |
|
|
197
183
|
| `nk_umeyama_f64_neon` | 115 mp/s, 0.4 ulp | 220 mp/s, 0.8 ulp | 296 mp/s, 1.6 ulp |
|
|
198
184
|
| __f32__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
199
|
-
| `nk_rmsd_f32_serial` |
|
|
185
|
+
| `nk_rmsd_f32_serial` | 264 mp/s, 0.5 ulp | 264 mp/s, 0.5 ulp | 261 mp/s, 0.5 ulp |
|
|
200
186
|
| `nk_kabsch_f32_serial` | 39.4 mp/s, 1.4 ulp | 46.0 mp/s, 2.7 ulp | 49.9 mp/s, 5.0 ulp |
|
|
201
187
|
| `nk_umeyama_f32_serial` | 33.6 mp/s, 0.9 ulp | 38.8 mp/s, 1.8 ulp | 41.4 mp/s, 3.5 ulp |
|
|
202
|
-
| `nk_rmsd_f32_neon` | 1,
|
|
188
|
+
| `nk_rmsd_f32_neon` | 1,912 mp/s, 1.5 ulp | 2,239 mp/s, 1.3 ulp | 1,966 mp/s, 4.8 ulp |
|
|
203
189
|
| `nk_kabsch_f32_neon` | 135 mp/s, 0.7 ulp | 288 mp/s, 0.9 ulp | 385 mp/s, 1.4 ulp |
|
|
204
190
|
| `nk_umeyama_f32_neon` | 130 mp/s, 0.3 ulp | 272 mp/s, 0.4 ulp | 367 mp/s, 0.8 ulp |
|
|
205
191
|
| __bf16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
206
|
-
| `nk_rmsd_bf16_neonbfdot` |
|
|
192
|
+
| `nk_rmsd_bf16_neonbfdot` | 3,728 mp/s, 0.4 ulp | 3,756 mp/s, 6.0 ulp | 3,769 mp/s, 10.0 ulp |
|
|
207
193
|
| `nk_kabsch_bf16_neonbfdot` | 180 mp/s, 0.7 ulp | 448 mp/s, 0.9 ulp | 726 mp/s, 1.3 ulp |
|
|
208
194
|
| `nk_umeyama_bf16_neonbfdot` | 176 mp/s, 0.2 ulp | 433 mp/s, 0.4 ulp | 705 mp/s, 0.8 ulp |
|
|
209
195
|
| __f16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
210
|
-
| `nk_rmsd_f16_neonhalf` | 2,
|
|
196
|
+
| `nk_rmsd_f16_neonhalf` | 2,998 mp/s, 0.4 ulp | 3,215 mp/s, 1.7 ulp | 3,216 mp/s, 4.6 ulp |
|
|
211
197
|
| `nk_kabsch_f16_neonhalf` | 178 mp/s, 0.9 ulp | 443 mp/s, 1.3 ulp | 711 mp/s, 2.4 ulp |
|
|
212
198
|
| `nk_umeyama_f16_neonhalf` | 175 mp/s, 0.4 ulp | 408 mp/s, 0.8 ulp | 620 mp/s, 1.5 ulp |
|
|
213
199
|
|
|
@@ -309,14 +309,13 @@ NK_INTERNAL nk_f64_t nk_transformed_ssd_f64_haswell_(nk_f64_t const *a, nk_f64_t
|
|
|
309
309
|
|
|
310
310
|
NK_PUBLIC void nk_rmsd_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
311
311
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
|
|
312
|
+
if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
|
|
313
|
+
if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
|
|
312
314
|
if (rotation)
|
|
313
315
|
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
|
|
314
316
|
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
315
317
|
if (scale) *scale = 1.0f;
|
|
316
318
|
|
|
317
|
-
__m256d sum_a_x_f64x4 = _mm256_setzero_pd(), sum_a_y_f64x4 = _mm256_setzero_pd();
|
|
318
|
-
__m256d sum_a_z_f64x4 = _mm256_setzero_pd(), sum_b_x_f64x4 = _mm256_setzero_pd();
|
|
319
|
-
__m256d sum_b_y_f64x4 = _mm256_setzero_pd(), sum_b_z_f64x4 = _mm256_setzero_pd();
|
|
320
319
|
__m256d sum_squared_f64x4 = _mm256_setzero_pd();
|
|
321
320
|
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
322
321
|
nk_size_t index = 0;
|
|
@@ -325,32 +324,19 @@ NK_PUBLIC void nk_rmsd_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size
|
|
|
325
324
|
nk_deinterleave_f32x8_haswell_(a + index * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8),
|
|
326
325
|
nk_deinterleave_f32x8_haswell_(b + index * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
|
|
327
326
|
|
|
328
|
-
__m256d
|
|
329
|
-
|
|
330
|
-
__m256d
|
|
331
|
-
|
|
332
|
-
__m256d
|
|
333
|
-
|
|
334
|
-
__m256d
|
|
335
|
-
|
|
336
|
-
__m256d
|
|
337
|
-
|
|
338
|
-
__m256d
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, _mm256_add_pd(a_x_low_f64x4, a_x_high_f64x4));
|
|
342
|
-
sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, _mm256_add_pd(a_y_low_f64x4, a_y_high_f64x4));
|
|
343
|
-
sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, _mm256_add_pd(a_z_low_f64x4, a_z_high_f64x4));
|
|
344
|
-
sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, _mm256_add_pd(b_x_low_f64x4, b_x_high_f64x4));
|
|
345
|
-
sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, _mm256_add_pd(b_y_low_f64x4, b_y_high_f64x4));
|
|
346
|
-
sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, _mm256_add_pd(b_z_low_f64x4, b_z_high_f64x4));
|
|
327
|
+
__m256d delta_x_low_f64x4 = _mm256_sub_pd(_mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8)),
|
|
328
|
+
_mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8)));
|
|
329
|
+
__m256d delta_x_high_f64x4 = _mm256_sub_pd(_mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1)),
|
|
330
|
+
_mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1)));
|
|
331
|
+
__m256d delta_y_low_f64x4 = _mm256_sub_pd(_mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8)),
|
|
332
|
+
_mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8)));
|
|
333
|
+
__m256d delta_y_high_f64x4 = _mm256_sub_pd(_mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1)),
|
|
334
|
+
_mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1)));
|
|
335
|
+
__m256d delta_z_low_f64x4 = _mm256_sub_pd(_mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8)),
|
|
336
|
+
_mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8)));
|
|
337
|
+
__m256d delta_z_high_f64x4 = _mm256_sub_pd(_mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1)),
|
|
338
|
+
_mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1)));
|
|
347
339
|
|
|
348
|
-
__m256d delta_x_low_f64x4 = _mm256_sub_pd(a_x_low_f64x4, b_x_low_f64x4);
|
|
349
|
-
__m256d delta_x_high_f64x4 = _mm256_sub_pd(a_x_high_f64x4, b_x_high_f64x4);
|
|
350
|
-
__m256d delta_y_low_f64x4 = _mm256_sub_pd(a_y_low_f64x4, b_y_low_f64x4);
|
|
351
|
-
__m256d delta_y_high_f64x4 = _mm256_sub_pd(a_y_high_f64x4, b_y_high_f64x4);
|
|
352
|
-
__m256d delta_z_low_f64x4 = _mm256_sub_pd(a_z_low_f64x4, b_z_low_f64x4);
|
|
353
|
-
__m256d delta_z_high_f64x4 = _mm256_sub_pd(a_z_high_f64x4, b_z_high_f64x4);
|
|
354
340
|
__m256d batch_sum_squared_f64x4 = _mm256_add_pd(_mm256_mul_pd(delta_x_low_f64x4, delta_x_low_f64x4),
|
|
355
341
|
_mm256_mul_pd(delta_x_high_f64x4, delta_x_high_f64x4));
|
|
356
342
|
batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_low_f64x4, delta_y_low_f64x4, batch_sum_squared_f64x4);
|
|
@@ -360,70 +346,38 @@ NK_PUBLIC void nk_rmsd_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size
|
|
|
360
346
|
sum_squared_f64x4 = _mm256_add_pd(sum_squared_f64x4, batch_sum_squared_f64x4);
|
|
361
347
|
}
|
|
362
348
|
|
|
363
|
-
nk_f64_t total_a_x = nk_reduce_add_f64x4_haswell_(sum_a_x_f64x4);
|
|
364
|
-
nk_f64_t total_a_y = nk_reduce_add_f64x4_haswell_(sum_a_y_f64x4);
|
|
365
|
-
nk_f64_t total_a_z = nk_reduce_add_f64x4_haswell_(sum_a_z_f64x4);
|
|
366
|
-
nk_f64_t total_b_x = nk_reduce_add_f64x4_haswell_(sum_b_x_f64x4);
|
|
367
|
-
nk_f64_t total_b_y = nk_reduce_add_f64x4_haswell_(sum_b_y_f64x4);
|
|
368
|
-
nk_f64_t total_b_z = nk_reduce_add_f64x4_haswell_(sum_b_z_f64x4);
|
|
369
349
|
nk_f64_t sum_squared = nk_reduce_add_f64x4_haswell_(sum_squared_f64x4);
|
|
370
350
|
|
|
371
351
|
for (; index < n; ++index) {
|
|
372
|
-
nk_f64_t
|
|
373
|
-
nk_f64_t
|
|
374
|
-
|
|
375
|
-
total_b_x += b_x, total_b_y += b_y, total_b_z += b_z;
|
|
376
|
-
nk_f64_t delta_x = a_x - b_x, delta_y = a_y - b_y, delta_z = a_z - b_z;
|
|
352
|
+
nk_f64_t delta_x = (nk_f64_t)a[index * 3 + 0] - (nk_f64_t)b[index * 3 + 0];
|
|
353
|
+
nk_f64_t delta_y = (nk_f64_t)a[index * 3 + 1] - (nk_f64_t)b[index * 3 + 1];
|
|
354
|
+
nk_f64_t delta_z = (nk_f64_t)a[index * 3 + 2] - (nk_f64_t)b[index * 3 + 2];
|
|
377
355
|
sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
|
|
378
356
|
}
|
|
379
357
|
|
|
380
|
-
|
|
381
|
-
nk_f64_t centroid_a_x = total_a_x * inv_n, centroid_a_y = total_a_y * inv_n, centroid_a_z = total_a_z * inv_n;
|
|
382
|
-
nk_f64_t centroid_b_x = total_b_x * inv_n, centroid_b_y = total_b_y * inv_n, centroid_b_z = total_b_z * inv_n;
|
|
383
|
-
if (a_centroid)
|
|
384
|
-
a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
|
|
385
|
-
a_centroid[2] = (nk_f32_t)centroid_a_z;
|
|
386
|
-
if (b_centroid)
|
|
387
|
-
b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
|
|
388
|
-
b_centroid[2] = (nk_f32_t)centroid_b_z;
|
|
389
|
-
|
|
390
|
-
nk_f64_t mean_delta_x = centroid_a_x - centroid_b_x, mean_delta_y = centroid_a_y - centroid_b_y,
|
|
391
|
-
mean_delta_z = centroid_a_z - centroid_b_z;
|
|
392
|
-
nk_f64_t mean_delta_squared = mean_delta_x * mean_delta_x + mean_delta_y * mean_delta_y +
|
|
393
|
-
mean_delta_z * mean_delta_z;
|
|
394
|
-
*result = nk_f64_sqrt_haswell(sum_squared * inv_n - mean_delta_squared);
|
|
358
|
+
*result = nk_f64_sqrt_haswell(sum_squared / (nk_f64_t)n);
|
|
395
359
|
}
|
|
396
360
|
|
|
397
361
|
NK_PUBLIC void nk_rmsd_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
|
|
398
362
|
nk_f64_t *b_centroid, nk_f64_t *rotation, nk_f64_t *scale, nk_f64_t *result) {
|
|
399
|
-
|
|
363
|
+
if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
|
|
364
|
+
if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
|
|
400
365
|
if (rotation)
|
|
401
366
|
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
|
|
402
367
|
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
403
368
|
if (scale) *scale = 1.0;
|
|
404
369
|
__m256d const zeros_f64x4 = _mm256_setzero_pd();
|
|
405
370
|
|
|
406
|
-
// Accumulators for centroids and squared differences
|
|
407
|
-
__m256d sum_a_x_f64x4 = zeros_f64x4, sum_a_y_f64x4 = zeros_f64x4, sum_a_z_f64x4 = zeros_f64x4;
|
|
408
|
-
__m256d sum_b_x_f64x4 = zeros_f64x4, sum_b_y_f64x4 = zeros_f64x4, sum_b_z_f64x4 = zeros_f64x4;
|
|
409
371
|
__m256d sum_squared_x_f64x4 = zeros_f64x4, sum_squared_y_f64x4 = zeros_f64x4, sum_squared_z_f64x4 = zeros_f64x4;
|
|
410
372
|
|
|
411
373
|
__m256d a_x_f64x4, a_y_f64x4, a_z_f64x4, b_x_f64x4, b_y_f64x4, b_z_f64x4;
|
|
412
374
|
nk_size_t i = 0;
|
|
413
375
|
|
|
414
|
-
// Main loop with
|
|
376
|
+
// Main loop with 2× unrolling
|
|
415
377
|
for (; i + 8 <= n; i += 8) {
|
|
416
|
-
// Iteration 0
|
|
417
378
|
nk_deinterleave_f64x4_haswell_(a + i * 3, &a_x_f64x4, &a_y_f64x4, &a_z_f64x4);
|
|
418
379
|
nk_deinterleave_f64x4_haswell_(b + i * 3, &b_x_f64x4, &b_y_f64x4, &b_z_f64x4);
|
|
419
380
|
|
|
420
|
-
sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, a_x_f64x4);
|
|
421
|
-
sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, a_y_f64x4);
|
|
422
|
-
sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, a_z_f64x4);
|
|
423
|
-
sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, b_x_f64x4);
|
|
424
|
-
sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, b_y_f64x4);
|
|
425
|
-
sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, b_z_f64x4);
|
|
426
|
-
|
|
427
381
|
__m256d delta_x_f64x4 = _mm256_sub_pd(a_x_f64x4, b_x_f64x4);
|
|
428
382
|
__m256d delta_y_f64x4 = _mm256_sub_pd(a_y_f64x4, b_y_f64x4);
|
|
429
383
|
__m256d delta_z_f64x4 = _mm256_sub_pd(a_z_f64x4, b_z_f64x4);
|
|
@@ -432,18 +386,10 @@ NK_PUBLIC void nk_rmsd_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size
|
|
|
432
386
|
sum_squared_y_f64x4 = _mm256_fmadd_pd(delta_y_f64x4, delta_y_f64x4, sum_squared_y_f64x4);
|
|
433
387
|
sum_squared_z_f64x4 = _mm256_fmadd_pd(delta_z_f64x4, delta_z_f64x4, sum_squared_z_f64x4);
|
|
434
388
|
|
|
435
|
-
// Iteration 1
|
|
436
389
|
__m256d a_x1_f64x4, a_y1_f64x4, a_z1_f64x4, b_x1_f64x4, b_y1_f64x4, b_z1_f64x4;
|
|
437
390
|
nk_deinterleave_f64x4_haswell_(a + (i + 4) * 3, &a_x1_f64x4, &a_y1_f64x4, &a_z1_f64x4);
|
|
438
391
|
nk_deinterleave_f64x4_haswell_(b + (i + 4) * 3, &b_x1_f64x4, &b_y1_f64x4, &b_z1_f64x4);
|
|
439
392
|
|
|
440
|
-
sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, a_x1_f64x4);
|
|
441
|
-
sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, a_y1_f64x4);
|
|
442
|
-
sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, a_z1_f64x4);
|
|
443
|
-
sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, b_x1_f64x4);
|
|
444
|
-
sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, b_y1_f64x4);
|
|
445
|
-
sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, b_z1_f64x4);
|
|
446
|
-
|
|
447
393
|
__m256d delta_x1_f64x4 = _mm256_sub_pd(a_x1_f64x4, b_x1_f64x4);
|
|
448
394
|
__m256d delta_y1_f64x4 = _mm256_sub_pd(a_y1_f64x4, b_y1_f64x4);
|
|
449
395
|
__m256d delta_z1_f64x4 = _mm256_sub_pd(a_z1_f64x4, b_z1_f64x4);
|
|
@@ -453,18 +399,10 @@ NK_PUBLIC void nk_rmsd_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size
|
|
|
453
399
|
sum_squared_z_f64x4 = _mm256_fmadd_pd(delta_z1_f64x4, delta_z1_f64x4, sum_squared_z_f64x4);
|
|
454
400
|
}
|
|
455
401
|
|
|
456
|
-
// Handle 4-point remainder
|
|
457
402
|
for (; i + 4 <= n; i += 4) {
|
|
458
403
|
nk_deinterleave_f64x4_haswell_(a + i * 3, &a_x_f64x4, &a_y_f64x4, &a_z_f64x4);
|
|
459
404
|
nk_deinterleave_f64x4_haswell_(b + i * 3, &b_x_f64x4, &b_y_f64x4, &b_z_f64x4);
|
|
460
405
|
|
|
461
|
-
sum_a_x_f64x4 = _mm256_add_pd(sum_a_x_f64x4, a_x_f64x4);
|
|
462
|
-
sum_a_y_f64x4 = _mm256_add_pd(sum_a_y_f64x4, a_y_f64x4);
|
|
463
|
-
sum_a_z_f64x4 = _mm256_add_pd(sum_a_z_f64x4, a_z_f64x4);
|
|
464
|
-
sum_b_x_f64x4 = _mm256_add_pd(sum_b_x_f64x4, b_x_f64x4);
|
|
465
|
-
sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, b_y_f64x4);
|
|
466
|
-
sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, b_z_f64x4);
|
|
467
|
-
|
|
468
406
|
__m256d delta_x_f64x4 = _mm256_sub_pd(a_x_f64x4, b_x_f64x4);
|
|
469
407
|
__m256d delta_y_f64x4 = _mm256_sub_pd(a_y_f64x4, b_y_f64x4);
|
|
470
408
|
__m256d delta_z_f64x4 = _mm256_sub_pd(a_z_f64x4, b_z_f64x4);
|
|
@@ -474,57 +412,22 @@ NK_PUBLIC void nk_rmsd_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size
|
|
|
474
412
|
sum_squared_z_f64x4 = _mm256_fmadd_pd(delta_z_f64x4, delta_z_f64x4, sum_squared_z_f64x4);
|
|
475
413
|
}
|
|
476
414
|
|
|
477
|
-
// Reduce vectors to scalars
|
|
478
|
-
nk_f64_t total_ax = nk_reduce_stable_f64x4_haswell_(sum_a_x_f64x4), total_ax_compensation = 0.0;
|
|
479
|
-
nk_f64_t total_ay = nk_reduce_stable_f64x4_haswell_(sum_a_y_f64x4), total_ay_compensation = 0.0;
|
|
480
|
-
nk_f64_t total_az = nk_reduce_stable_f64x4_haswell_(sum_a_z_f64x4), total_az_compensation = 0.0;
|
|
481
|
-
nk_f64_t total_bx = nk_reduce_stable_f64x4_haswell_(sum_b_x_f64x4), total_bx_compensation = 0.0;
|
|
482
|
-
nk_f64_t total_by = nk_reduce_stable_f64x4_haswell_(sum_b_y_f64x4), total_by_compensation = 0.0;
|
|
483
|
-
nk_f64_t total_bz = nk_reduce_stable_f64x4_haswell_(sum_b_z_f64x4), total_bz_compensation = 0.0;
|
|
484
415
|
nk_f64_t total_sq_x = nk_reduce_stable_f64x4_haswell_(sum_squared_x_f64x4), total_sq_x_compensation = 0.0;
|
|
485
416
|
nk_f64_t total_sq_y = nk_reduce_stable_f64x4_haswell_(sum_squared_y_f64x4), total_sq_y_compensation = 0.0;
|
|
486
417
|
nk_f64_t total_sq_z = nk_reduce_stable_f64x4_haswell_(sum_squared_z_f64x4), total_sq_z_compensation = 0.0;
|
|
487
418
|
|
|
488
|
-
// Scalar tail
|
|
489
419
|
for (; i < n; ++i) {
|
|
490
|
-
nk_f64_t
|
|
491
|
-
nk_f64_t
|
|
492
|
-
|
|
493
|
-
nk_accumulate_sum_f64_(&total_ay, &total_ay_compensation, ay);
|
|
494
|
-
nk_accumulate_sum_f64_(&total_az, &total_az_compensation, az);
|
|
495
|
-
nk_accumulate_sum_f64_(&total_bx, &total_bx_compensation, bx);
|
|
496
|
-
nk_accumulate_sum_f64_(&total_by, &total_by_compensation, by);
|
|
497
|
-
nk_accumulate_sum_f64_(&total_bz, &total_bz_compensation, bz);
|
|
498
|
-
nk_f64_t delta_x = ax - bx, delta_y = ay - by, delta_z = az - bz;
|
|
420
|
+
nk_f64_t delta_x = a[i * 3 + 0] - b[i * 3 + 0];
|
|
421
|
+
nk_f64_t delta_y = a[i * 3 + 1] - b[i * 3 + 1];
|
|
422
|
+
nk_f64_t delta_z = a[i * 3 + 2] - b[i * 3 + 2];
|
|
499
423
|
nk_accumulate_square_f64_(&total_sq_x, &total_sq_x_compensation, delta_x);
|
|
500
424
|
nk_accumulate_square_f64_(&total_sq_y, &total_sq_y_compensation, delta_y);
|
|
501
425
|
nk_accumulate_square_f64_(&total_sq_z, &total_sq_z_compensation, delta_z);
|
|
502
426
|
}
|
|
503
427
|
|
|
504
|
-
total_ax += total_ax_compensation, total_ay += total_ay_compensation, total_az += total_az_compensation;
|
|
505
|
-
total_bx += total_bx_compensation, total_by += total_by_compensation, total_bz += total_bz_compensation;
|
|
506
428
|
total_sq_x += total_sq_x_compensation, total_sq_y += total_sq_y_compensation, total_sq_z += total_sq_z_compensation;
|
|
507
429
|
|
|
508
|
-
|
|
509
|
-
nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
|
|
510
|
-
nk_f64_t centroid_a_x = total_ax * inv_n;
|
|
511
|
-
nk_f64_t centroid_a_y = total_ay * inv_n;
|
|
512
|
-
nk_f64_t centroid_a_z = total_az * inv_n;
|
|
513
|
-
nk_f64_t centroid_b_x = total_bx * inv_n;
|
|
514
|
-
nk_f64_t centroid_b_y = total_by * inv_n;
|
|
515
|
-
nk_f64_t centroid_b_z = total_bz * inv_n;
|
|
516
|
-
|
|
517
|
-
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
518
|
-
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
519
|
-
|
|
520
|
-
// Compute RMSD
|
|
521
|
-
nk_f64_t mean_diff_x = centroid_a_x - centroid_b_x;
|
|
522
|
-
nk_f64_t mean_diff_y = centroid_a_y - centroid_b_y;
|
|
523
|
-
nk_f64_t mean_diff_z = centroid_a_z - centroid_b_z;
|
|
524
|
-
nk_f64_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
|
|
525
|
-
nk_f64_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
|
|
526
|
-
|
|
527
|
-
*result = nk_f64_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
|
|
430
|
+
*result = nk_f64_sqrt_haswell((total_sq_x + total_sq_y + total_sq_z) / (nk_f64_t)n);
|
|
528
431
|
}
|
|
529
432
|
|
|
530
433
|
NK_PUBLIC void nk_kabsch_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|