numkong 7.4.5 → 7.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. package/README.md +1 -0
  2. package/binding.gyp +99 -5
  3. package/c/dispatch_e5m2.c +23 -3
  4. package/c/dispatch_f16.c +23 -0
  5. package/c/numkong.c +0 -13
  6. package/include/numkong/attention/sme.h +34 -31
  7. package/include/numkong/capabilities.h +2 -15
  8. package/include/numkong/cast/README.md +3 -0
  9. package/include/numkong/cast/haswell.h +28 -64
  10. package/include/numkong/cast/neon.h +15 -0
  11. package/include/numkong/cast/serial.h +17 -0
  12. package/include/numkong/cast/skylake.h +67 -52
  13. package/include/numkong/cast.h +1 -0
  14. package/include/numkong/curved/smef64.h +82 -62
  15. package/include/numkong/dot/README.md +1 -0
  16. package/include/numkong/dot/haswell.h +92 -13
  17. package/include/numkong/dot/rvvbf16.h +1 -1
  18. package/include/numkong/dot/rvvhalf.h +1 -1
  19. package/include/numkong/dot/serial.h +15 -0
  20. package/include/numkong/dot/skylake.h +61 -14
  21. package/include/numkong/dot/sve.h +6 -5
  22. package/include/numkong/dot/svebfdot.h +2 -1
  23. package/include/numkong/dot/svehalf.h +6 -5
  24. package/include/numkong/dot/svesdot.h +3 -2
  25. package/include/numkong/dots/README.md +2 -0
  26. package/include/numkong/dots/graniteamx.h +1167 -0
  27. package/include/numkong/dots/haswell.h +28 -28
  28. package/include/numkong/dots/sapphireamx.h +1 -1
  29. package/include/numkong/dots/serial.h +33 -11
  30. package/include/numkong/dots/skylake.h +28 -23
  31. package/include/numkong/dots/sme.h +172 -140
  32. package/include/numkong/dots/smebi32.h +14 -11
  33. package/include/numkong/dots/smef64.h +31 -26
  34. package/include/numkong/dots.h +41 -3
  35. package/include/numkong/each/serial.h +39 -0
  36. package/include/numkong/geospatial/haswell.h +1 -1
  37. package/include/numkong/geospatial/neon.h +1 -1
  38. package/include/numkong/geospatial/serial.h +15 -4
  39. package/include/numkong/geospatial/skylake.h +1 -1
  40. package/include/numkong/maxsim/serial.h +15 -0
  41. package/include/numkong/maxsim/sme.h +34 -33
  42. package/include/numkong/mesh/README.md +50 -44
  43. package/include/numkong/mesh/genoa.h +462 -0
  44. package/include/numkong/mesh/haswell.h +806 -933
  45. package/include/numkong/mesh/neon.h +871 -943
  46. package/include/numkong/mesh/neonbfdot.h +382 -522
  47. package/include/numkong/mesh/neonfhm.h +676 -0
  48. package/include/numkong/mesh/rvv.h +404 -319
  49. package/include/numkong/mesh/serial.h +225 -161
  50. package/include/numkong/mesh/skylake.h +1029 -1585
  51. package/include/numkong/mesh/v128relaxed.h +403 -377
  52. package/include/numkong/mesh.h +38 -0
  53. package/include/numkong/reduce/neon.h +29 -0
  54. package/include/numkong/reduce/neonbfdot.h +2 -2
  55. package/include/numkong/reduce/neonfhm.h +4 -4
  56. package/include/numkong/reduce/serial.h +15 -1
  57. package/include/numkong/reduce/sve.h +52 -0
  58. package/include/numkong/reduce.h +4 -0
  59. package/include/numkong/set/sve.h +6 -5
  60. package/include/numkong/sets/smebi32.h +35 -30
  61. package/include/numkong/sparse/serial.h +17 -2
  62. package/include/numkong/sparse/sve2.h +3 -2
  63. package/include/numkong/spatial/genoa.h +0 -68
  64. package/include/numkong/spatial/haswell.h +98 -56
  65. package/include/numkong/spatial/serial.h +15 -0
  66. package/include/numkong/spatial/skylake.h +114 -54
  67. package/include/numkong/spatial/sve.h +7 -6
  68. package/include/numkong/spatial/svebfdot.h +7 -4
  69. package/include/numkong/spatial/svehalf.h +5 -4
  70. package/include/numkong/spatial/svesdot.h +9 -8
  71. package/include/numkong/spatial.h +0 -12
  72. package/include/numkong/spatials/graniteamx.h +301 -0
  73. package/include/numkong/spatials/serial.h +39 -0
  74. package/include/numkong/spatials/skylake.h +2 -2
  75. package/include/numkong/spatials/sme.h +391 -350
  76. package/include/numkong/spatials/smef64.h +79 -70
  77. package/include/numkong/spatials.h +54 -4
  78. package/include/numkong/tensor.hpp +107 -23
  79. package/include/numkong/types.h +59 -0
  80. package/javascript/dist/cjs/numkong.js +13 -0
  81. package/javascript/dist/esm/numkong.js +13 -0
  82. package/javascript/numkong.c +59 -14
  83. package/javascript/numkong.ts +13 -0
  84. package/package.json +7 -7
  85. package/probes/probe.js +2 -2
  86. package/wasm/numkong.wasm +0 -0
@@ -266,6 +266,20 @@ NK_PUBLIC void nk_umeyama_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, n
266
266
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
267
267
  #endif // NK_TARGET_SKYLAKE
268
268
 
269
+ /* SIMD-powered backends for AVX512-BF16 CPUs of AMD Genoa / Intel Sapphire Rapids generation and newer.
270
+ */
271
+ #if NK_TARGET_GENOA
272
+ /** @copydoc nk_rmsd_bf16 */
273
+ NK_PUBLIC void nk_rmsd_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
274
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
275
+ /** @copydoc nk_kabsch_bf16 */
276
+ NK_PUBLIC void nk_kabsch_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
277
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
278
+ /** @copydoc nk_umeyama_bf16 */
279
+ NK_PUBLIC void nk_umeyama_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
280
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
281
+ #endif // NK_TARGET_GENOA
282
+
269
283
  /* SIMD-powered backends for AVX2 CPUs of Haswell generation and newer.
270
284
  */
271
285
  #if NK_TARGET_HASWELL
@@ -357,6 +371,20 @@ NK_PUBLIC void nk_umeyama_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
357
371
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
358
372
  #endif // NK_TARGET_NEONBFDOT
359
373
 
374
+ /* SIMD-powered backends for Arm NEON FHM (FP16 widening FMA) CPUs.
375
+ */
376
+ #if NK_TARGET_NEONFHM
377
+ /** @copydoc nk_rmsd_f16 */
378
+ NK_PUBLIC void nk_rmsd_f16_neonfhm(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
379
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
380
+ /** @copydoc nk_kabsch_f16 */
381
+ NK_PUBLIC void nk_kabsch_f16_neonfhm(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
382
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
383
+ /** @copydoc nk_umeyama_f16 */
384
+ NK_PUBLIC void nk_umeyama_f16_neonfhm(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
385
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
386
+ #endif // NK_TARGET_NEONFHM
387
+
360
388
  #if NK_TARGET_RVV
361
389
  /** @copydoc nk_rmsd_f32 */
362
390
  NK_PUBLIC void nk_rmsd_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
@@ -454,8 +482,10 @@ NK_INTERNAL nk_dtype_t nk_mesh_transform_dtype(nk_dtype_t dtype) {
454
482
  #include "numkong/mesh/serial.h"
455
483
  #include "numkong/mesh/neon.h"
456
484
  #include "numkong/mesh/neonbfdot.h"
485
+ #include "numkong/mesh/neonfhm.h"
457
486
  #include "numkong/mesh/haswell.h"
458
487
  #include "numkong/mesh/skylake.h"
488
+ #include "numkong/mesh/genoa.h"
459
489
  #include "numkong/mesh/rvv.h"
460
490
  #include "numkong/mesh/v128relaxed.h"
461
491
 
@@ -505,6 +535,8 @@ NK_PUBLIC void nk_rmsd_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk
505
535
  nk_rmsd_f16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
506
536
  #elif NK_TARGET_HASWELL
507
537
  nk_rmsd_f16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
538
+ #elif NK_TARGET_NEONFHM
539
+ nk_rmsd_f16_neonfhm(a, b, n, a_centroid, b_centroid, rotation, scale, result);
508
540
  #elif NK_TARGET_NEON
509
541
  nk_rmsd_f16_neon(a, b, n, a_centroid, b_centroid, rotation, scale, result);
510
542
  #elif NK_TARGET_RVV
@@ -517,6 +549,8 @@ NK_PUBLIC void nk_rmsd_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk
517
549
  NK_PUBLIC void nk_rmsd_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
518
550
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
519
551
  #if NK_TARGET_SKYLAKE
552
+ // Skylake f32-widen path wins on Intel where VDPBF16PS throughput matches FMA; on AMD Zen4+
553
+ // where VDPBF16PS is faster than FMA, users can call `nk_rmsd_bf16_genoa` directly.
520
554
  nk_rmsd_bf16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
521
555
  #elif NK_TARGET_HASWELL
522
556
  nk_rmsd_bf16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
@@ -569,6 +603,8 @@ NK_PUBLIC void nk_kabsch_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
569
603
  nk_kabsch_f16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
570
604
  #elif NK_TARGET_HASWELL
571
605
  nk_kabsch_f16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
606
+ #elif NK_TARGET_NEONFHM
607
+ nk_kabsch_f16_neonfhm(a, b, n, a_centroid, b_centroid, rotation, scale, result);
572
608
  #elif NK_TARGET_NEON
573
609
  nk_kabsch_f16_neon(a, b, n, a_centroid, b_centroid, rotation, scale, result);
574
610
  #elif NK_TARGET_RVV
@@ -633,6 +669,8 @@ NK_PUBLIC void nk_umeyama_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
633
669
  nk_umeyama_f16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
634
670
  #elif NK_TARGET_HASWELL
635
671
  nk_umeyama_f16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
672
+ #elif NK_TARGET_NEONFHM
673
+ nk_umeyama_f16_neonfhm(a, b, n, a_centroid, b_centroid, rotation, scale, result);
636
674
  #elif NK_TARGET_NEON
637
675
  nk_umeyama_f16_neon(a, b, n, a_centroid, b_centroid, rotation, scale, result);
638
676
  #elif NK_TARGET_RVV
@@ -3936,6 +3936,35 @@ NK_PUBLIC void nk_reduce_moments_f16_neon( //
3936
3936
  else nk_reduce_moments_f16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
3937
3937
  }
3938
3938
 
3939
+ NK_INTERNAL void nk_reduce_moments_u1_neon_contiguous_( //
3940
+ nk_u1x8_t const *data_ptr, nk_size_t count, //
3941
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
3942
+ nk_size_t byte_count = nk_size_divide_round_up_(count, NK_BITS_PER_BYTE);
3943
+ nk_u64_t sum = 0;
3944
+ nk_size_t idx = 0;
3945
+ // Each vcntq_u8 produces values 0-8 per lane; accumulate at u8 level
3946
+ // for up to 31 iterations (31 × 8 = 248, fits in u8) before widening.
3947
+ while (idx + 16 <= byte_count) {
3948
+ uint8x16_t popcount_u8x16 = vdupq_n_u8(0);
3949
+ for (nk_size_t cycle = 0; cycle < 31 && idx + 16 <= byte_count; ++cycle, idx += 16) {
3950
+ uint8x16_t data_u8x16 = vld1q_u8((nk_u8_t const *)data_ptr + idx);
3951
+ popcount_u8x16 = vaddq_u8(popcount_u8x16, vcntq_u8(data_u8x16));
3952
+ }
3953
+ sum += (nk_u64_t)vaddlvq_u8(popcount_u8x16);
3954
+ }
3955
+ for (; idx < byte_count; ++idx) sum += nk_u1x8_popcount_(((nk_u8_t const *)data_ptr)[idx]);
3956
+ *sum_ptr = sum, *sumsq_ptr = sum;
3957
+ }
3958
+
3959
+ NK_PUBLIC void nk_reduce_moments_u1_neon( //
3960
+ nk_u1x8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
3961
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
3962
+ count = nk_size_round_up_to_multiple_(count, 8);
3963
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
3964
+ else if (stride_bytes == 1) nk_reduce_moments_u1_neon_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
3965
+ else nk_reduce_moments_u1_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
3966
+ }
3967
+
3939
3968
  #if defined(__clang__)
3940
3969
  #pragma clang attribute pop
3941
3970
  #elif defined(__GNUC__)
@@ -33,7 +33,7 @@ NK_INTERNAL void nk_reduce_moments_bf16_neonbfdot_contiguous_( //
33
33
  nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
34
34
 
35
35
  // bf16 representation of 1.0 is 0x3F80 (same as upper 16 bits of f32 1.0)
36
- bfloat16x8_t ones_bf16x8 = vreinterpretq_bf16_u16(vdupq_n_u16(0x3F80));
36
+ bfloat16x8_t ones_bf16x8 = vreinterpretq_bf16_u16(nk_u16x8_splat_(0x3F80));
37
37
  float32x4_t sum_f32x4 = vdupq_n_f32(0);
38
38
  float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
39
39
  nk_size_t idx = 0;
@@ -61,7 +61,7 @@ NK_INTERNAL void nk_reduce_moments_bf16_neonbfdot_strided_( //
61
61
  nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
62
62
  nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
63
63
 
64
- bfloat16x8_t ones_bf16x8 = vreinterpretq_bf16_u16(vdupq_n_u16(0x3F80));
64
+ bfloat16x8_t ones_bf16x8 = vreinterpretq_bf16_u16(nk_u16x8_splat_(0x3F80));
65
65
  float32x4_t sum_f32x4 = vdupq_n_f32(0);
66
66
  float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
67
67
  nk_size_t idx = 0;
@@ -34,7 +34,7 @@ NK_INTERNAL void nk_reduce_moments_e4m3_neonfhm_contiguous_( //
34
34
 
35
35
  float32x4_t sum_f32x4 = vdupq_n_f32(0);
36
36
  float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
37
- float16x8_t ones_f16x8 = vreinterpretq_f16_u16(vdupq_n_u16(0x3C00));
37
+ float16x8_t ones_f16x8 = vreinterpretq_f16_u16(nk_u16x8_splat_(0x3C00));
38
38
  nk_size_t idx = 0;
39
39
 
40
40
  for (; idx + 8 <= count; idx += 8) {
@@ -67,7 +67,7 @@ NK_INTERNAL void nk_reduce_moments_e4m3_neonfhm_strided_( //
67
67
 
68
68
  float32x4_t sum_f32x4 = vdupq_n_f32(0);
69
69
  float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
70
- float16x8_t ones_f16x8 = vreinterpretq_f16_u16(vdupq_n_u16(0x3C00));
70
+ float16x8_t ones_f16x8 = vreinterpretq_f16_u16(nk_u16x8_splat_(0x3C00));
71
71
  nk_size_t idx = 0;
72
72
 
73
73
  if (stride_elements == 2) {
@@ -159,7 +159,7 @@ NK_INTERNAL void nk_reduce_moments_e5m2_neonfhm_contiguous_( //
159
159
 
160
160
  float32x4_t sum_f32x4 = vdupq_n_f32(0);
161
161
  float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
162
- float16x8_t ones_f16x8 = vreinterpretq_f16_u16(vdupq_n_u16(0x3C00));
162
+ float16x8_t ones_f16x8 = vreinterpretq_f16_u16(nk_u16x8_splat_(0x3C00));
163
163
  nk_size_t idx = 0;
164
164
 
165
165
  for (; idx + 8 <= count; idx += 8) {
@@ -192,7 +192,7 @@ NK_INTERNAL void nk_reduce_moments_e5m2_neonfhm_strided_( //
192
192
 
193
193
  float32x4_t sum_f32x4 = vdupq_n_f32(0);
194
194
  float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
195
- float16x8_t ones_f16x8 = vreinterpretq_f16_u16(vdupq_n_u16(0x3C00));
195
+ float16x8_t ones_f16x8 = vreinterpretq_f16_u16(nk_u16x8_splat_(0x3C00));
196
196
  nk_size_t idx = 0;
197
197
 
198
198
  if (stride_elements == 2) {
@@ -14,7 +14,6 @@
14
14
  #define NK_REDUCE_SERIAL_H
15
15
 
16
16
  #include "numkong/types.h"
17
- #include "numkong/scalar/serial.h"
18
17
  #include "numkong/cast/serial.h"
19
18
  #include "numkong/scalar/serial.h"
20
19
 
@@ -22,6 +21,15 @@
22
21
  extern "C" {
23
22
  #endif
24
23
 
24
+ /* Keep the serial instantiations below actually scalar, regardless of build type.
25
+ * See dots/serial.h for rationale. */
26
+ #if defined(__clang__)
27
+ #pragma clang attribute push(__attribute__((noinline)), apply_to = function)
28
+ #elif defined(__GNUC__)
29
+ #pragma GCC push_options
30
+ #pragma GCC optimize("no-tree-vectorize", "no-tree-slp-vectorize", "no-ipa-cp-clone", "no-inline")
31
+ #endif
32
+
25
33
  NK_INTERNAL nk_f64_t nk_reduce_sum_f64_serial_(nk_f64_t const *values, nk_f64_t const *compensations, int count) {
26
34
  nk_f64_t running_sum = 0, accumulated_error = 0;
27
35
  for (int i = 0; i < count; i++) {
@@ -746,6 +754,12 @@ NK_PUBLIC void nk_reduce_minmax_u1_serial( //
746
754
  *max_value_ptr = max_value, *max_index_ptr = max_idx;
747
755
  }
748
756
 
757
+ #if defined(__clang__)
758
+ #pragma clang attribute pop
759
+ #elif defined(__GNUC__)
760
+ #pragma GCC pop_options
761
+ #endif
762
+
749
763
  #if defined(__cplusplus)
750
764
  } // extern "C"
751
765
  #endif
@@ -0,0 +1,52 @@
1
+ /**
2
+ * @brief SVE horizontal reduction helpers with MSan unpoisoning.
3
+ * @file include/numkong/reduce/sve.h
4
+ * @author Ash Vardanian
5
+ * @date April 12, 2026
6
+ *
7
+ * LLVM's MSan does not instrument ARM SVE intrinsics — `svaddv` moves data
8
+ * from vector to scalar registers via architecture-specific paths invisible
9
+ * to the compiler, causing false-positive uninitialized-value reports.
10
+ * These macros wrap the reduction and unpoison the scalar result.
11
+ *
12
+ * The `svaddv` intrinsic stays inside a macro so it expands in the caller's
13
+ * target context — SVE and SME streaming translation units carry incompatible
14
+ * target attributes. The unpoisoning runs on the already-reduced scalar, so it
15
+ * lives in a target-agnostic `NK_INTERNAL` helper called from the macro.
16
+ *
17
+ * @sa include/numkong/reduce.h
18
+ */
19
+ #ifndef NK_REDUCE_SVE_H
20
+ #define NK_REDUCE_SVE_H
21
+
22
+ #if NK_TARGET_ARM64_
23
+ #if NK_TARGET_SVE || NK_TARGET_SVE2 || NK_TARGET_SME
24
+
25
+ #include "numkong/types.h"
26
+
27
+ NK_INTERNAL nk_f64_t nk_unpoison_f64_(nk_f64_t v) NK_STREAMING_COMPATIBLE_ {
28
+ nk_unpoison_(&v, sizeof(v));
29
+ return v;
30
+ }
31
+ NK_INTERNAL nk_f32_t nk_unpoison_f32_(nk_f32_t v) NK_STREAMING_COMPATIBLE_ {
32
+ nk_unpoison_(&v, sizeof(v));
33
+ return v;
34
+ }
35
+ NK_INTERNAL nk_u64_t nk_unpoison_u64_(nk_u64_t v) NK_STREAMING_COMPATIBLE_ {
36
+ nk_unpoison_(&v, sizeof(v));
37
+ return v;
38
+ }
39
+ NK_INTERNAL nk_i64_t nk_unpoison_i64_(nk_i64_t v) NK_STREAMING_COMPATIBLE_ {
40
+ nk_unpoison_(&v, sizeof(v));
41
+ return v;
42
+ }
43
+
44
+ #define nk_svaddv_f64_(predicate, vector) nk_unpoison_f64_(svaddv_f64((predicate), (vector)))
45
+ #define nk_svaddv_f32_(predicate, vector) nk_unpoison_f32_(svaddv_f32((predicate), (vector)))
46
+ #define nk_svaddv_u32_(predicate, vector) nk_unpoison_u64_(svaddv_u32((predicate), (vector)))
47
+ #define nk_svaddv_s32_(predicate, vector) nk_unpoison_i64_(svaddv_s32((predicate), (vector)))
48
+ #define nk_svaddv_u8_(predicate, vector) nk_unpoison_u64_(svaddv_u8((predicate), (vector)))
49
+
50
+ #endif // NK_TARGET_SVE || NK_TARGET_SVE2 || NK_TARGET_SME
51
+ #endif // NK_TARGET_ARM64_
52
+ #endif // NK_REDUCE_SVE_H
@@ -389,6 +389,8 @@ NK_PUBLIC void nk_reduce_moments_i16_neon(nk_i16_t const *, nk_size_t, nk_size_t
389
389
  /** @copydoc nk_reduce_moments_f64 */
390
390
  NK_PUBLIC void nk_reduce_moments_u16_neon(nk_u16_t const *, nk_size_t, nk_size_t, nk_u64_t *, nk_u64_t *);
391
391
  /** @copydoc nk_reduce_moments_f64 */
392
+ NK_PUBLIC void nk_reduce_moments_u1_neon(nk_u1x8_t const *, nk_size_t, nk_size_t, nk_u64_t *, nk_u64_t *);
393
+ /** @copydoc nk_reduce_moments_f64 */
392
394
  NK_PUBLIC void nk_reduce_moments_i32_neon(nk_i32_t const *, nk_size_t, nk_size_t, nk_i64_t *, nk_u64_t *);
393
395
  /** @copydoc nk_reduce_moments_f64 */
394
396
  NK_PUBLIC void nk_reduce_moments_u32_neon(nk_u32_t const *, nk_size_t, nk_size_t, nk_u64_t *, nk_u64_t *);
@@ -1559,6 +1561,8 @@ NK_PUBLIC void nk_reduce_moments_u1(nk_u1x8_t const *d, nk_size_t n, nk_size_t s
1559
1561
  nk_reduce_moments_u1_skylake(d, n, s, sum, sumsq);
1560
1562
  #elif NK_TARGET_HASWELL
1561
1563
  nk_reduce_moments_u1_haswell(d, n, s, sum, sumsq);
1564
+ #elif NK_TARGET_NEON
1565
+ nk_reduce_moments_u1_neon(d, n, s, sum, sumsq);
1562
1566
  #else
1563
1567
  nk_reduce_moments_u1_serial(d, n, s, sum, sumsq);
1564
1568
  #endif
@@ -32,8 +32,9 @@
32
32
  #if NK_TARGET_ARM64_
33
33
  #if NK_TARGET_SVE
34
34
 
35
- #include "numkong/types.h" // `nk_u1x8_t`
36
- #include "numkong/set/neon.h" // `nk_hamming_u1_neon`
35
+ #include "numkong/types.h" // `nk_u1x8_t`
36
+ #include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
37
+ #include "numkong/set/neon.h" // `nk_hamming_u1_neon`
37
38
 
38
39
  #if defined(__cplusplus)
39
40
  extern "C" {
@@ -73,7 +74,7 @@ NK_PUBLIC void nk_hamming_u1_sve(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size
73
74
  i += words_per_register;
74
75
  ++cycle;
75
76
  } while (i < n_bytes && cycle < 31);
76
- differences += svaddv_u8(all_predicate_b8x, popcount_u8x);
77
+ differences += nk_svaddv_u8_(all_predicate_b8x, popcount_u8x);
77
78
  popcount_u8x = svdup_n_u8(0);
78
79
  cycle = 0; // Reset the cycle counter.
79
80
  }
@@ -110,9 +111,9 @@ NK_PUBLIC void nk_jaccard_u1_sve(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size
110
111
  i += words_per_register;
111
112
  ++cycle;
112
113
  } while (i < n_bytes && cycle < 31);
113
- intersection_count += svaddv_u8(all_predicate_b8x, intersection_popcount_u8x);
114
+ intersection_count += nk_svaddv_u8_(all_predicate_b8x, intersection_popcount_u8x);
114
115
  intersection_popcount_u8x = svdup_n_u8(0);
115
- union_count += svaddv_u8(all_predicate_b8x, union_popcount_u8x);
116
+ union_count += nk_svaddv_u8_(all_predicate_b8x, union_popcount_u8x);
116
117
  union_popcount_u8x = svdup_n_u8(0);
117
118
  cycle = 0; // Reset the cycle counter.
118
119
  }
@@ -41,8 +41,9 @@
41
41
  #include "numkong/types.h"
42
42
  #include "numkong/set/serial.h"
43
43
  #include "numkong/sets/serial.h"
44
- #include "numkong/dots/sme.h" // `nk_sme_zero_za32_*` constants
45
- #include "numkong/reduce.h" // `nk_reduce_moments_u1`
44
+ #include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
45
+ #include "numkong/reduce/neon.h" // `nk_reduce_moments_u1_neon`
46
+ #include "numkong/dots/sme.h" // `nk_sme_zero_za32_*`
46
47
 
47
48
  #if defined(__cplusplus)
48
49
  extern "C" {
@@ -100,7 +101,7 @@ NK_PUBLIC nk_u32_t nk_sets_reduce_sumsq_u1_streaming_(nk_u1x8_t const *data, nk_
100
101
  svbool_t predicate_b8x = svwhilelt_b8_u64(offset, n_bytes);
101
102
  acc_u32x = svdot_u32(acc_u32x, svcnt_u8_z(predicate_b8x, svld1_u8(predicate_b8x, data + offset)), ones_u8x);
102
103
  }
103
- return (nk_u32_t)svaddv_u32(svptrue_b32(), acc_u32x);
104
+ return (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), acc_u32x);
104
105
  }
105
106
 
106
107
  #pragma region Hamming Distance
@@ -187,11 +188,9 @@ NK_PUBLIC void nk_dots_pack_u1_smebi32(nk_u1x8_t const *b, nk_size_t row_count,
187
188
  // Compute per-row population counts
188
189
  for (nk_size_t row = 0; row < row_count; row++) {
189
190
  nk_u1x8_t const *src_row = (nk_u1x8_t const *)((char const *)b + row * b_stride_in_bytes);
190
- {
191
- nk_u64_t nk_local_sum_, nk_local_sumsq_;
192
- nk_reduce_moments_u1(src_row, depth_bytes * 8, sizeof(nk_u1x8_t), &nk_local_sum_, &nk_local_sumsq_);
193
- norms_ptr[row] = (nk_u32_t)nk_local_sum_;
194
- }
191
+ nk_u64_t nk_local_sum_, nk_local_sumsq_;
192
+ nk_reduce_moments_u1_neon(src_row, depth_bytes * 8, sizeof(nk_u1x8_t), &nk_local_sum_, &nk_local_sumsq_);
193
+ norms_ptr[row] = (nk_u32_t)nk_local_sum_;
195
194
  }
196
195
  }
197
196
 
@@ -203,9 +202,9 @@ NK_PUBLIC void nk_dots_pack_u1_smebi32(nk_u1x8_t const *b, nk_size_t row_count,
203
202
  * Each ZA0.S batch covers 16 depth u32 steps (one full depth tile).
204
203
  * BMOPA expansion=1 for u32: each u32 contributes 32 bits via XNOR+POPCNT.
205
204
  */
206
- __arm_locally_streaming __arm_new("za") static void nk_hammings_packed_u1_smebi32_streaming_(
205
+ __arm_new("za") static void nk_hammings_packed_u1_smebi32_streaming_( //
207
206
  nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
208
- nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
207
+ nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) NK_STREAMING_ {
209
208
 
210
209
  nk_sets_smebi32_packed_header_t const *header = (nk_sets_smebi32_packed_header_t const *)b_packed;
211
210
  nk_size_t const row_tile_count_b = header->row_tile_count;
@@ -344,11 +343,13 @@ __arm_locally_streaming __arm_new("za") static void nk_hammings_packed_u1_smebi3
344
343
  }
345
344
  }
346
345
 
347
- NK_PUBLIC void nk_hammings_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c,
348
- nk_size_t row_count_a, nk_size_t row_count_b, nk_size_t depth_bits,
349
- nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
346
+ NK_PUBLIC void nk_hammings_packed_u1_smebi32( //
347
+ nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
348
+ nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
349
+ nk_sme_start_streaming_();
350
350
  nk_hammings_packed_u1_smebi32_streaming_(a, b_packed, c, row_count_a, row_count_b, depth_bits, a_stride_in_bytes,
351
351
  c_stride_in_bytes);
352
+ nk_sme_stop_streaming_();
352
353
  }
353
354
 
354
355
  /**
@@ -357,9 +358,9 @@ NK_PUBLIC void nk_hammings_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_p
357
358
  * ZA1-3.S = BMOPA accumulators (3 B column tiles in fast path).
358
359
  * Mirrors the unpacked kernel nk_hammings_packed_u1_smebi32_streaming_ pattern.
359
360
  */
360
- __arm_locally_streaming __arm_new("za") static void nk_hammings_symmetric_u1_smebi32_streaming_(
361
+ __arm_new("za") static void nk_hammings_symmetric_u1_smebi32_streaming_( //
361
362
  nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
362
- nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
363
+ nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
363
364
 
364
365
  nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
365
366
  nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
@@ -545,12 +546,13 @@ __arm_locally_streaming __arm_new("za") static void nk_hammings_symmetric_u1_sme
545
546
  }
546
547
  }
547
548
 
548
- NK_PUBLIC void nk_hammings_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits,
549
- nk_size_t stride_in_bytes, nk_u32_t *result,
550
- nk_size_t result_stride_in_bytes, nk_size_t row_start,
551
- nk_size_t row_count) {
549
+ NK_PUBLIC void nk_hammings_symmetric_u1_smebi32( //
550
+ nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
551
+ nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
552
+ nk_sme_start_streaming_();
552
553
  nk_hammings_symmetric_u1_smebi32_streaming_(vectors, vectors_count, depth_bits, stride_in_bytes, result,
553
554
  result_stride_in_bytes, row_start, row_count);
555
+ nk_sme_stop_streaming_();
554
556
  }
555
557
 
556
558
  #pragma endregion Hamming Distance
@@ -581,9 +583,9 @@ NK_PUBLIC void nk_hammings_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_siz
581
583
  * union = (norm_a + norm_b + hamming) / 2
582
584
  * jaccard = 1 - intersection / union (1.0 when union == 0)
583
585
  */
584
- __arm_locally_streaming __arm_new("za") static void nk_jaccards_packed_u1_smebi32_streaming_(
586
+ __arm_new("za") static void nk_jaccards_packed_u1_smebi32_streaming_( //
585
587
  nk_u1x8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
586
- nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
588
+ nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) NK_STREAMING_ {
587
589
 
588
590
  nk_sets_smebi32_packed_header_t const *header = (nk_sets_smebi32_packed_header_t const *)b_packed;
589
591
  nk_size_t const row_tile_count_b = header->row_tile_count;
@@ -796,11 +798,13 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_packed_u1_smebi3
796
798
  }
797
799
  }
798
800
 
799
- NK_PUBLIC void nk_jaccards_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_packed, nk_f32_t *c,
800
- nk_size_t row_count_a, nk_size_t row_count_b, nk_size_t depth_bits,
801
- nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
801
+ NK_PUBLIC void nk_jaccards_packed_u1_smebi32( //
802
+ nk_u1x8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
803
+ nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
804
+ nk_sme_start_streaming_();
802
805
  nk_jaccards_packed_u1_smebi32_streaming_(a, b_packed, c, row_count_a, row_count_b, depth_bits, a_stride_in_bytes,
803
806
  c_stride_in_bytes);
807
+ nk_sme_stop_streaming_();
804
808
  }
805
809
 
806
810
  /**
@@ -808,9 +812,9 @@ NK_PUBLIC void nk_jaccards_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_p
808
812
  * Fills upper triangle only (column_tile >= row_tile); caller sees result[i][j] for j >= i.
809
813
  * Norms computed on-the-fly using streaming SVE popcount.
810
814
  */
811
- __arm_locally_streaming __arm_new("za") static void nk_jaccards_symmetric_u1_smebi32_streaming_(
815
+ __arm_new("za") static void nk_jaccards_symmetric_u1_smebi32_streaming_( //
812
816
  nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
813
- nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
817
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
814
818
 
815
819
  nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
816
820
  nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
@@ -1104,12 +1108,13 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_symmetric_u1_sme
1104
1108
  }
1105
1109
  }
1106
1110
 
1107
- NK_PUBLIC void nk_jaccards_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits,
1108
- nk_size_t stride_in_bytes, nk_f32_t *result,
1109
- nk_size_t result_stride_in_bytes, nk_size_t row_start,
1110
- nk_size_t row_count) {
1111
+ NK_PUBLIC void nk_jaccards_symmetric_u1_smebi32( //
1112
+ nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
1113
+ nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
1114
+ nk_sme_start_streaming_();
1111
1115
  nk_jaccards_symmetric_u1_smebi32_streaming_(vectors, vectors_count, depth_bits, stride_in_bytes, result,
1112
1116
  result_stride_in_bytes, row_start, row_count);
1117
+ nk_sme_stop_streaming_();
1113
1118
  }
1114
1119
 
1115
1120
  #pragma endregion Jaccard Distance
@@ -17,7 +17,7 @@ extern "C" {
17
17
  #endif
18
18
 
19
19
  #define nk_define_sparse_intersect_(input_type) \
20
- NK_PUBLIC nk_size_t nk_sparse_intersect_##input_type##_galloping_search_( \
20
+ NK_INTERNAL nk_size_t nk_sparse_intersect_##input_type##_galloping_search_( \
21
21
  nk_##input_type##_t const *array, nk_size_t start, nk_size_t length, nk_##input_type##_t val) { \
22
22
  nk_size_t low = start; \
23
23
  nk_size_t high = start + 1; \
@@ -32,7 +32,7 @@ extern "C" {
32
32
  } \
33
33
  return low; \
34
34
  } \
35
- NK_PUBLIC nk_size_t nk_sparse_intersect_##input_type##_linear_scan_( \
35
+ NK_INTERNAL nk_size_t nk_sparse_intersect_##input_type##_linear_scan_( \
36
36
  nk_##input_type##_t const *a, nk_##input_type##_t const *b, nk_size_t a_length, nk_size_t b_length, \
37
37
  nk_##input_type##_t *result) { \
38
38
  nk_size_t intersection_size = 0; \
@@ -103,6 +103,15 @@ extern "C" {
103
103
  *product = weights_product; \
104
104
  }
105
105
 
106
+ /* Keep the serial instantiations below actually scalar, regardless of build type.
107
+ * See dots/serial.h for rationale. */
108
+ #if defined(__clang__)
109
+ #pragma clang attribute push(__attribute__((noinline)), apply_to = function)
110
+ #elif defined(__GNUC__)
111
+ #pragma GCC push_options
112
+ #pragma GCC optimize("no-tree-vectorize", "no-tree-slp-vectorize", "no-ipa-cp-clone", "no-inline")
113
+ #endif
114
+
106
115
  nk_define_sparse_intersect_(u16) // nk_sparse_intersect_u16_serial
107
116
  nk_define_sparse_intersect_(u32) // nk_sparse_intersect_u32_serial
108
117
  nk_define_sparse_intersect_(u64) // nk_sparse_intersect_u64_serial
@@ -110,6 +119,12 @@ nk_define_sparse_intersect_(u64) // nk_sparse_intersect_u64_serial
110
119
  nk_define_sparse_dot_(u16, bf16, f32, nk_bf16_to_f32_serial) // nk_sparse_dot_u16bf16_serial
111
120
  nk_define_sparse_dot_(u32, f32, f64, nk_assign_from_to_) // nk_sparse_dot_u32f32_serial
112
121
 
122
+ #if defined(__clang__)
123
+ #pragma clang attribute pop
124
+ #elif defined(__GNUC__)
125
+ #pragma GCC pop_options
126
+ #endif
127
+
113
128
  #if defined(__cplusplus)
114
129
  } // extern "C"
115
130
  #endif
@@ -12,6 +12,7 @@
12
12
  #if NK_TARGET_ARM64_
13
13
 
14
14
  #include "numkong/types.h"
15
+ #include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
15
16
 
16
17
  #if defined(__cplusplus)
17
18
  extern "C" {
@@ -395,7 +396,7 @@ NK_PUBLIC void nk_sparse_dot_u32f32_sve2( //
395
396
  a_idx += a_step;
396
397
  b_idx += b_step;
397
398
  }
398
- *product = svaddv_f64(predicate_all_b64x, product_f64x);
399
+ *product = nk_svaddv_f64_(predicate_all_b64x, product_f64x);
399
400
  }
400
401
 
401
402
  #if defined(__clang__)
@@ -485,7 +486,7 @@ NK_PUBLIC void nk_sparse_dot_u16bf16_sve2( //
485
486
  a_idx += a_step;
486
487
  b_idx += b_step;
487
488
  }
488
- *product = svaddv_f32(svptrue_b32(), product_f32x);
489
+ *product = nk_svaddv_f32_(svptrue_b32(), product_f32x);
489
490
  }
490
491
 
491
492
  #if defined(__clang__)
@@ -139,74 +139,6 @@ nk_angular_bf16_genoa_cycle:
139
139
  *result = nk_angular_normalize_f32_haswell_(dot_product_f32, a_norm_sq_f32, b_norm_sq_f32);
140
140
  }
141
141
 
142
- NK_PUBLIC void nk_sqeuclidean_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
143
- __m512 a_sq_f32x16 = _mm512_setzero_ps();
144
- __m512 b_sq_f32x16 = _mm512_setzero_ps();
145
- __m512 ab_f32x16 = _mm512_setzero_ps();
146
- __m256i a_e5m2x32, b_e5m2x32;
147
-
148
- nk_sqeuclidean_e5m2_genoa_cycle:
149
- if (n < 32) {
150
- __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
151
- a_e5m2x32 = _mm256_maskz_loadu_epi8(mask, a);
152
- b_e5m2x32 = _mm256_maskz_loadu_epi8(mask, b);
153
- n = 0;
154
- }
155
- else {
156
- a_e5m2x32 = _mm256_loadu_epi8(a);
157
- b_e5m2x32 = _mm256_loadu_epi8(b);
158
- a += 32, b += 32, n -= 32;
159
- }
160
- __m512i a_bf16x32 = nk_e5m2x32_to_bf16x32_icelake_(a_e5m2x32);
161
- __m512i b_bf16x32 = nk_e5m2x32_to_bf16x32_icelake_(b_e5m2x32);
162
- a_sq_f32x16 = _mm512_dpbf16_ps(a_sq_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(a_bf16x32));
163
- b_sq_f32x16 = _mm512_dpbf16_ps(b_sq_f32x16, nk_m512bh_from_m512i_(b_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
164
- ab_f32x16 = _mm512_dpbf16_ps(ab_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
165
- if (n) goto nk_sqeuclidean_e5m2_genoa_cycle;
166
-
167
- // (a-b)² = a² + b² - 2ab
168
- __m512 sum_sq_f32x16 = _mm512_add_ps(a_sq_f32x16, b_sq_f32x16);
169
- *result = nk_reduce_add_f32x16_skylake_(_mm512_fnmadd_ps(_mm512_set1_ps(2.0f), ab_f32x16, sum_sq_f32x16));
170
- }
171
-
172
- NK_PUBLIC void nk_euclidean_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
173
- nk_sqeuclidean_e5m2_genoa(a, b, n, result);
174
- *result = nk_f32_sqrt_haswell(*result);
175
- }
176
-
177
- NK_PUBLIC void nk_angular_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
178
- __m512 dot_f32x16 = _mm512_setzero_ps();
179
- __m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
180
- __m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
181
- __m256i a_e5m2x32, b_e5m2x32;
182
-
183
- nk_angular_e5m2_genoa_cycle:
184
- if (n < 32) {
185
- __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
186
- a_e5m2x32 = _mm256_maskz_loadu_epi8(mask, a);
187
- b_e5m2x32 = _mm256_maskz_loadu_epi8(mask, b);
188
- n = 0;
189
- }
190
- else {
191
- a_e5m2x32 = _mm256_loadu_epi8(a);
192
- b_e5m2x32 = _mm256_loadu_epi8(b);
193
- a += 32, b += 32, n -= 32;
194
- }
195
- __m512i a_bf16x32 = nk_e5m2x32_to_bf16x32_icelake_(a_e5m2x32);
196
- __m512i b_bf16x32 = nk_e5m2x32_to_bf16x32_icelake_(b_e5m2x32);
197
- dot_f32x16 = _mm512_dpbf16_ps(dot_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
198
- a_norm_sq_f32x16 = _mm512_dpbf16_ps(a_norm_sq_f32x16, nk_m512bh_from_m512i_(a_bf16x32),
199
- nk_m512bh_from_m512i_(a_bf16x32));
200
- b_norm_sq_f32x16 = _mm512_dpbf16_ps(b_norm_sq_f32x16, nk_m512bh_from_m512i_(b_bf16x32),
201
- nk_m512bh_from_m512i_(b_bf16x32));
202
- if (n) goto nk_angular_e5m2_genoa_cycle;
203
-
204
- nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
205
- nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a_norm_sq_f32x16);
206
- nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b_norm_sq_f32x16);
207
- *result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
208
- }
209
-
210
142
  #if defined(__clang__)
211
143
  #pragma clang attribute pop
212
144
  #elif defined(__GNUC__)