numkong 7.5.0 → 7.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/binding.gyp +18 -0
  2. package/c/dispatch_e5m2.c +23 -3
  3. package/include/numkong/capabilities.h +1 -1
  4. package/include/numkong/cast/README.md +3 -0
  5. package/include/numkong/cast/haswell.h +28 -64
  6. package/include/numkong/cast/serial.h +17 -0
  7. package/include/numkong/cast/skylake.h +67 -52
  8. package/include/numkong/cast.h +1 -0
  9. package/include/numkong/dot/README.md +1 -0
  10. package/include/numkong/dot/haswell.h +92 -13
  11. package/include/numkong/dot/serial.h +15 -0
  12. package/include/numkong/dot/skylake.h +61 -14
  13. package/include/numkong/dots/README.md +2 -0
  14. package/include/numkong/dots/graniteamx.h +434 -0
  15. package/include/numkong/dots/haswell.h +28 -28
  16. package/include/numkong/dots/sapphireamx.h +1 -1
  17. package/include/numkong/dots/serial.h +23 -8
  18. package/include/numkong/dots/skylake.h +28 -23
  19. package/include/numkong/dots.h +12 -0
  20. package/include/numkong/each/serial.h +18 -1
  21. package/include/numkong/geospatial/serial.h +14 -3
  22. package/include/numkong/maxsim/serial.h +15 -0
  23. package/include/numkong/mesh/README.md +50 -44
  24. package/include/numkong/mesh/genoa.h +462 -0
  25. package/include/numkong/mesh/haswell.h +806 -933
  26. package/include/numkong/mesh/neon.h +871 -943
  27. package/include/numkong/mesh/neonbfdot.h +382 -522
  28. package/include/numkong/mesh/neonfhm.h +676 -0
  29. package/include/numkong/mesh/rvv.h +404 -319
  30. package/include/numkong/mesh/serial.h +204 -162
  31. package/include/numkong/mesh/skylake.h +1029 -1585
  32. package/include/numkong/mesh/v128relaxed.h +403 -377
  33. package/include/numkong/mesh.h +38 -0
  34. package/include/numkong/reduce/serial.h +15 -1
  35. package/include/numkong/sparse/serial.h +17 -2
  36. package/include/numkong/spatial/genoa.h +0 -68
  37. package/include/numkong/spatial/haswell.h +98 -56
  38. package/include/numkong/spatial/serial.h +15 -0
  39. package/include/numkong/spatial/skylake.h +114 -54
  40. package/include/numkong/spatial.h +0 -12
  41. package/include/numkong/spatials/graniteamx.h +128 -0
  42. package/include/numkong/spatials/serial.h +18 -1
  43. package/include/numkong/spatials/skylake.h +2 -2
  44. package/include/numkong/spatials.h +17 -0
  45. package/include/numkong/tensor.hpp +107 -23
  46. package/javascript/numkong.c +3 -2
  47. package/package.json +7 -7
  48. package/wasm/numkong.wasm +0 -0
@@ -266,6 +266,20 @@ NK_PUBLIC void nk_umeyama_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, n
266
266
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
267
267
  #endif // NK_TARGET_SKYLAKE
268
268
 
269
+ /* SIMD-powered backends for AVX512-BF16 CPUs of AMD Genoa / Intel Sapphire Rapids generation and newer.
270
+ */
271
+ #if NK_TARGET_GENOA
272
+ /** @copydoc nk_rmsd_bf16 */
273
+ NK_PUBLIC void nk_rmsd_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
274
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
275
+ /** @copydoc nk_kabsch_bf16 */
276
+ NK_PUBLIC void nk_kabsch_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
277
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
278
+ /** @copydoc nk_umeyama_bf16 */
279
+ NK_PUBLIC void nk_umeyama_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
280
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
281
+ #endif // NK_TARGET_GENOA
282
+
269
283
  /* SIMD-powered backends for AVX2 CPUs of Haswell generation and newer.
270
284
  */
271
285
  #if NK_TARGET_HASWELL
@@ -357,6 +371,20 @@ NK_PUBLIC void nk_umeyama_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
357
371
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
358
372
  #endif // NK_TARGET_NEONBFDOT
359
373
 
374
+ /* SIMD-powered backends for Arm NEON FHM (FP16 widening FMA) CPUs.
375
+ */
376
+ #if NK_TARGET_NEONFHM
377
+ /** @copydoc nk_rmsd_f16 */
378
+ NK_PUBLIC void nk_rmsd_f16_neonfhm(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
379
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
380
+ /** @copydoc nk_kabsch_f16 */
381
+ NK_PUBLIC void nk_kabsch_f16_neonfhm(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
382
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
383
+ /** @copydoc nk_umeyama_f16 */
384
+ NK_PUBLIC void nk_umeyama_f16_neonfhm(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
385
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
386
+ #endif // NK_TARGET_NEONFHM
387
+
360
388
  #if NK_TARGET_RVV
361
389
  /** @copydoc nk_rmsd_f32 */
362
390
  NK_PUBLIC void nk_rmsd_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
@@ -454,8 +482,10 @@ NK_INTERNAL nk_dtype_t nk_mesh_transform_dtype(nk_dtype_t dtype) {
454
482
  #include "numkong/mesh/serial.h"
455
483
  #include "numkong/mesh/neon.h"
456
484
  #include "numkong/mesh/neonbfdot.h"
485
+ #include "numkong/mesh/neonfhm.h"
457
486
  #include "numkong/mesh/haswell.h"
458
487
  #include "numkong/mesh/skylake.h"
488
+ #include "numkong/mesh/genoa.h"
459
489
  #include "numkong/mesh/rvv.h"
460
490
  #include "numkong/mesh/v128relaxed.h"
461
491
 
@@ -505,6 +535,8 @@ NK_PUBLIC void nk_rmsd_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk
505
535
  nk_rmsd_f16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
506
536
  #elif NK_TARGET_HASWELL
507
537
  nk_rmsd_f16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
538
+ #elif NK_TARGET_NEONFHM
539
+ nk_rmsd_f16_neonfhm(a, b, n, a_centroid, b_centroid, rotation, scale, result);
508
540
  #elif NK_TARGET_NEON
509
541
  nk_rmsd_f16_neon(a, b, n, a_centroid, b_centroid, rotation, scale, result);
510
542
  #elif NK_TARGET_RVV
@@ -517,6 +549,8 @@ NK_PUBLIC void nk_rmsd_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk
517
549
  NK_PUBLIC void nk_rmsd_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
518
550
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
519
551
  #if NK_TARGET_SKYLAKE
552
+ // Skylake f32-widen path wins on Intel where VDPBF16PS throughput matches FMA; on AMD Zen4+
553
+ // where VDPBF16PS is faster than FMA, users can call `nk_rmsd_bf16_genoa` directly.
520
554
  nk_rmsd_bf16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
521
555
  #elif NK_TARGET_HASWELL
522
556
  nk_rmsd_bf16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
@@ -569,6 +603,8 @@ NK_PUBLIC void nk_kabsch_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
569
603
  nk_kabsch_f16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
570
604
  #elif NK_TARGET_HASWELL
571
605
  nk_kabsch_f16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
606
+ #elif NK_TARGET_NEONFHM
607
+ nk_kabsch_f16_neonfhm(a, b, n, a_centroid, b_centroid, rotation, scale, result);
572
608
  #elif NK_TARGET_NEON
573
609
  nk_kabsch_f16_neon(a, b, n, a_centroid, b_centroid, rotation, scale, result);
574
610
  #elif NK_TARGET_RVV
@@ -633,6 +669,8 @@ NK_PUBLIC void nk_umeyama_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
633
669
  nk_umeyama_f16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
634
670
  #elif NK_TARGET_HASWELL
635
671
  nk_umeyama_f16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
672
+ #elif NK_TARGET_NEONFHM
673
+ nk_umeyama_f16_neonfhm(a, b, n, a_centroid, b_centroid, rotation, scale, result);
636
674
  #elif NK_TARGET_NEON
637
675
  nk_umeyama_f16_neon(a, b, n, a_centroid, b_centroid, rotation, scale, result);
638
676
  #elif NK_TARGET_RVV
@@ -14,7 +14,6 @@
14
14
  #define NK_REDUCE_SERIAL_H
15
15
 
16
16
  #include "numkong/types.h"
17
- #include "numkong/scalar/serial.h"
18
17
  #include "numkong/cast/serial.h"
19
18
  #include "numkong/scalar/serial.h"
20
19
 
@@ -22,6 +21,15 @@
22
21
  extern "C" {
23
22
  #endif
24
23
 
24
+ /* Keep the serial instantiations below actually scalar, regardless of build type.
25
+ * See dots/serial.h for rationale. */
26
+ #if defined(__clang__)
27
+ #pragma clang attribute push(__attribute__((noinline)), apply_to = function)
28
+ #elif defined(__GNUC__)
29
+ #pragma GCC push_options
30
+ #pragma GCC optimize("no-tree-vectorize", "no-tree-slp-vectorize", "no-ipa-cp-clone", "no-inline")
31
+ #endif
32
+
25
33
  NK_INTERNAL nk_f64_t nk_reduce_sum_f64_serial_(nk_f64_t const *values, nk_f64_t const *compensations, int count) {
26
34
  nk_f64_t running_sum = 0, accumulated_error = 0;
27
35
  for (int i = 0; i < count; i++) {
@@ -746,6 +754,12 @@ NK_PUBLIC void nk_reduce_minmax_u1_serial( //
746
754
  *max_value_ptr = max_value, *max_index_ptr = max_idx;
747
755
  }
748
756
 
757
+ #if defined(__clang__)
758
+ #pragma clang attribute pop
759
+ #elif defined(__GNUC__)
760
+ #pragma GCC pop_options
761
+ #endif
762
+
749
763
  #if defined(__cplusplus)
750
764
  } // extern "C"
751
765
  #endif
@@ -17,7 +17,7 @@ extern "C" {
17
17
  #endif
18
18
 
19
19
  #define nk_define_sparse_intersect_(input_type) \
20
- NK_PUBLIC nk_size_t nk_sparse_intersect_##input_type##_galloping_search_( \
20
+ NK_INTERNAL nk_size_t nk_sparse_intersect_##input_type##_galloping_search_( \
21
21
  nk_##input_type##_t const *array, nk_size_t start, nk_size_t length, nk_##input_type##_t val) { \
22
22
  nk_size_t low = start; \
23
23
  nk_size_t high = start + 1; \
@@ -32,7 +32,7 @@ extern "C" {
32
32
  } \
33
33
  return low; \
34
34
  } \
35
- NK_PUBLIC nk_size_t nk_sparse_intersect_##input_type##_linear_scan_( \
35
+ NK_INTERNAL nk_size_t nk_sparse_intersect_##input_type##_linear_scan_( \
36
36
  nk_##input_type##_t const *a, nk_##input_type##_t const *b, nk_size_t a_length, nk_size_t b_length, \
37
37
  nk_##input_type##_t *result) { \
38
38
  nk_size_t intersection_size = 0; \
@@ -103,6 +103,15 @@ extern "C" {
103
103
  *product = weights_product; \
104
104
  }
105
105
 
106
+ /* Keep the serial instantiations below actually scalar, regardless of build type.
107
+ * See dots/serial.h for rationale. */
108
+ #if defined(__clang__)
109
+ #pragma clang attribute push(__attribute__((noinline)), apply_to = function)
110
+ #elif defined(__GNUC__)
111
+ #pragma GCC push_options
112
+ #pragma GCC optimize("no-tree-vectorize", "no-tree-slp-vectorize", "no-ipa-cp-clone", "no-inline")
113
+ #endif
114
+
106
115
  nk_define_sparse_intersect_(u16) // nk_sparse_intersect_u16_serial
107
116
  nk_define_sparse_intersect_(u32) // nk_sparse_intersect_u32_serial
108
117
  nk_define_sparse_intersect_(u64) // nk_sparse_intersect_u64_serial
@@ -110,6 +119,12 @@ nk_define_sparse_intersect_(u64) // nk_sparse_intersect_u64_serial
110
119
  nk_define_sparse_dot_(u16, bf16, f32, nk_bf16_to_f32_serial) // nk_sparse_dot_u16bf16_serial
111
120
  nk_define_sparse_dot_(u32, f32, f64, nk_assign_from_to_) // nk_sparse_dot_u32f32_serial
112
121
 
122
+ #if defined(__clang__)
123
+ #pragma clang attribute pop
124
+ #elif defined(__GNUC__)
125
+ #pragma GCC pop_options
126
+ #endif
127
+
113
128
  #if defined(__cplusplus)
114
129
  } // extern "C"
115
130
  #endif
@@ -139,74 +139,6 @@ nk_angular_bf16_genoa_cycle:
139
139
  *result = nk_angular_normalize_f32_haswell_(dot_product_f32, a_norm_sq_f32, b_norm_sq_f32);
140
140
  }
141
141
 
142
- NK_PUBLIC void nk_sqeuclidean_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
143
- __m512 a_sq_f32x16 = _mm512_setzero_ps();
144
- __m512 b_sq_f32x16 = _mm512_setzero_ps();
145
- __m512 ab_f32x16 = _mm512_setzero_ps();
146
- __m256i a_e5m2x32, b_e5m2x32;
147
-
148
- nk_sqeuclidean_e5m2_genoa_cycle:
149
- if (n < 32) {
150
- __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
151
- a_e5m2x32 = _mm256_maskz_loadu_epi8(mask, a);
152
- b_e5m2x32 = _mm256_maskz_loadu_epi8(mask, b);
153
- n = 0;
154
- }
155
- else {
156
- a_e5m2x32 = _mm256_loadu_epi8(a);
157
- b_e5m2x32 = _mm256_loadu_epi8(b);
158
- a += 32, b += 32, n -= 32;
159
- }
160
- __m512i a_bf16x32 = nk_e5m2x32_to_bf16x32_icelake_(a_e5m2x32);
161
- __m512i b_bf16x32 = nk_e5m2x32_to_bf16x32_icelake_(b_e5m2x32);
162
- a_sq_f32x16 = _mm512_dpbf16_ps(a_sq_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(a_bf16x32));
163
- b_sq_f32x16 = _mm512_dpbf16_ps(b_sq_f32x16, nk_m512bh_from_m512i_(b_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
164
- ab_f32x16 = _mm512_dpbf16_ps(ab_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
165
- if (n) goto nk_sqeuclidean_e5m2_genoa_cycle;
166
-
167
- // (a-b)² = a² + b² - 2ab
168
- __m512 sum_sq_f32x16 = _mm512_add_ps(a_sq_f32x16, b_sq_f32x16);
169
- *result = nk_reduce_add_f32x16_skylake_(_mm512_fnmadd_ps(_mm512_set1_ps(2.0f), ab_f32x16, sum_sq_f32x16));
170
- }
171
-
172
- NK_PUBLIC void nk_euclidean_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
173
- nk_sqeuclidean_e5m2_genoa(a, b, n, result);
174
- *result = nk_f32_sqrt_haswell(*result);
175
- }
176
-
177
- NK_PUBLIC void nk_angular_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
178
- __m512 dot_f32x16 = _mm512_setzero_ps();
179
- __m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
180
- __m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
181
- __m256i a_e5m2x32, b_e5m2x32;
182
-
183
- nk_angular_e5m2_genoa_cycle:
184
- if (n < 32) {
185
- __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
186
- a_e5m2x32 = _mm256_maskz_loadu_epi8(mask, a);
187
- b_e5m2x32 = _mm256_maskz_loadu_epi8(mask, b);
188
- n = 0;
189
- }
190
- else {
191
- a_e5m2x32 = _mm256_loadu_epi8(a);
192
- b_e5m2x32 = _mm256_loadu_epi8(b);
193
- a += 32, b += 32, n -= 32;
194
- }
195
- __m512i a_bf16x32 = nk_e5m2x32_to_bf16x32_icelake_(a_e5m2x32);
196
- __m512i b_bf16x32 = nk_e5m2x32_to_bf16x32_icelake_(b_e5m2x32);
197
- dot_f32x16 = _mm512_dpbf16_ps(dot_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
198
- a_norm_sq_f32x16 = _mm512_dpbf16_ps(a_norm_sq_f32x16, nk_m512bh_from_m512i_(a_bf16x32),
199
- nk_m512bh_from_m512i_(a_bf16x32));
200
- b_norm_sq_f32x16 = _mm512_dpbf16_ps(b_norm_sq_f32x16, nk_m512bh_from_m512i_(b_bf16x32),
201
- nk_m512bh_from_m512i_(b_bf16x32));
202
- if (n) goto nk_angular_e5m2_genoa_cycle;
203
-
204
- nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
205
- nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a_norm_sq_f32x16);
206
- nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b_norm_sq_f32x16);
207
- *result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
208
- }
209
-
210
142
  #if defined(__clang__)
211
143
  #pragma clang attribute pop
212
144
  #elif defined(__GNUC__)
@@ -840,28 +840,37 @@ nk_angular_e3m2_haswell_cycle:
840
840
  }
841
841
 
842
842
  NK_PUBLIC void nk_sqeuclidean_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
843
- __m256 distance_sq_f32x8 = _mm256_setzero_ps();
843
+ // E4M3 has no free widen shift, so we call the Giesen-based 8-lane cast helper
844
+ // twice per 16-lane iter and run with two F32 accumulators to break the FMA chain.
845
+ __m256 first_acc_f32x8 = _mm256_setzero_ps();
846
+ __m256 second_acc_f32x8 = _mm256_setzero_ps();
847
+ __m128i a_u8x16, b_u8x16;
844
848
 
845
849
  nk_sqeuclidean_e4m3_haswell_cycle:
846
- if (n < 8) {
850
+ if (n < 16) {
847
851
  nk_b128_vec_t a_vec, b_vec;
848
852
  nk_partial_load_b8x16_serial_(a, &a_vec, n);
849
853
  nk_partial_load_b8x16_serial_(b, &b_vec, n);
850
- __m256 a_f32x8 = nk_e4m3x8_to_f32x8_haswell_(a_vec.xmm);
851
- __m256 b_f32x8 = nk_e4m3x8_to_f32x8_haswell_(b_vec.xmm);
852
- __m256 diff_f32x8 = _mm256_sub_ps(a_f32x8, b_f32x8);
853
- distance_sq_f32x8 = _mm256_fmadd_ps(diff_f32x8, diff_f32x8, distance_sq_f32x8);
854
+ a_u8x16 = a_vec.xmm;
855
+ b_u8x16 = b_vec.xmm;
856
+ n = 0;
854
857
  }
855
858
  else {
856
- __m256 a_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)a));
857
- __m256 b_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)b));
858
- __m256 diff_f32x8 = _mm256_sub_ps(a_f32x8, b_f32x8);
859
- distance_sq_f32x8 = _mm256_fmadd_ps(diff_f32x8, diff_f32x8, distance_sq_f32x8);
860
- n -= 8, a += 8, b += 8;
861
- goto nk_sqeuclidean_e4m3_haswell_cycle;
859
+ a_u8x16 = _mm_loadu_si128((__m128i const *)a);
860
+ b_u8x16 = _mm_loadu_si128((__m128i const *)b);
861
+ a += 16, b += 16, n -= 16;
862
862
  }
863
+ __m256 a_low_f32x8 = nk_e4m3x8_to_f32x8_haswell_(a_u8x16);
864
+ __m256 a_high_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_unpackhi_epi64(a_u8x16, a_u8x16));
865
+ __m256 b_low_f32x8 = nk_e4m3x8_to_f32x8_haswell_(b_u8x16);
866
+ __m256 b_high_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_unpackhi_epi64(b_u8x16, b_u8x16));
867
+ __m256 diff_low_f32x8 = _mm256_sub_ps(a_low_f32x8, b_low_f32x8);
868
+ __m256 diff_high_f32x8 = _mm256_sub_ps(a_high_f32x8, b_high_f32x8);
869
+ first_acc_f32x8 = _mm256_fmadd_ps(diff_low_f32x8, diff_low_f32x8, first_acc_f32x8);
870
+ second_acc_f32x8 = _mm256_fmadd_ps(diff_high_f32x8, diff_high_f32x8, second_acc_f32x8);
871
+ if (n) goto nk_sqeuclidean_e4m3_haswell_cycle;
863
872
 
864
- *result = nk_reduce_add_f32x8_haswell_(distance_sq_f32x8);
873
+ *result = nk_reduce_add_f32x8_haswell_(_mm256_add_ps(first_acc_f32x8, second_acc_f32x8));
865
874
  }
866
875
 
867
876
  NK_PUBLIC void nk_euclidean_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
@@ -873,27 +882,33 @@ NK_PUBLIC void nk_angular_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, n
873
882
  __m256 dot_product_f32x8 = _mm256_setzero_ps();
874
883
  __m256 a_norm_sq_f32x8 = _mm256_setzero_ps();
875
884
  __m256 b_norm_sq_f32x8 = _mm256_setzero_ps();
885
+ __m128i a_u8x16, b_u8x16;
876
886
 
877
887
  nk_angular_e4m3_haswell_cycle:
878
- if (n < 8) {
888
+ if (n < 16) {
879
889
  nk_b128_vec_t a_vec, b_vec;
880
890
  nk_partial_load_b8x16_serial_(a, &a_vec, n);
881
891
  nk_partial_load_b8x16_serial_(b, &b_vec, n);
882
- __m256 a_f32x8 = nk_e4m3x8_to_f32x8_haswell_(a_vec.xmm);
883
- __m256 b_f32x8 = nk_e4m3x8_to_f32x8_haswell_(b_vec.xmm);
884
- dot_product_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, dot_product_f32x8);
885
- a_norm_sq_f32x8 = _mm256_fmadd_ps(a_f32x8, a_f32x8, a_norm_sq_f32x8);
886
- b_norm_sq_f32x8 = _mm256_fmadd_ps(b_f32x8, b_f32x8, b_norm_sq_f32x8);
892
+ a_u8x16 = a_vec.xmm;
893
+ b_u8x16 = b_vec.xmm;
894
+ n = 0;
887
895
  }
888
896
  else {
889
- __m256 a_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)a));
890
- __m256 b_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)b));
891
- dot_product_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, dot_product_f32x8);
892
- a_norm_sq_f32x8 = _mm256_fmadd_ps(a_f32x8, a_f32x8, a_norm_sq_f32x8);
893
- b_norm_sq_f32x8 = _mm256_fmadd_ps(b_f32x8, b_f32x8, b_norm_sq_f32x8);
894
- n -= 8, a += 8, b += 8;
895
- goto nk_angular_e4m3_haswell_cycle;
896
- }
897
+ a_u8x16 = _mm_loadu_si128((__m128i const *)a);
898
+ b_u8x16 = _mm_loadu_si128((__m128i const *)b);
899
+ a += 16, b += 16, n -= 16;
900
+ }
901
+ __m256 a_low_f32x8 = nk_e4m3x8_to_f32x8_haswell_(a_u8x16);
902
+ __m256 a_high_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_unpackhi_epi64(a_u8x16, a_u8x16));
903
+ __m256 b_low_f32x8 = nk_e4m3x8_to_f32x8_haswell_(b_u8x16);
904
+ __m256 b_high_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_unpackhi_epi64(b_u8x16, b_u8x16));
905
+ dot_product_f32x8 = _mm256_fmadd_ps(a_low_f32x8, b_low_f32x8, dot_product_f32x8);
906
+ dot_product_f32x8 = _mm256_fmadd_ps(a_high_f32x8, b_high_f32x8, dot_product_f32x8);
907
+ a_norm_sq_f32x8 = _mm256_fmadd_ps(a_low_f32x8, a_low_f32x8, a_norm_sq_f32x8);
908
+ a_norm_sq_f32x8 = _mm256_fmadd_ps(a_high_f32x8, a_high_f32x8, a_norm_sq_f32x8);
909
+ b_norm_sq_f32x8 = _mm256_fmadd_ps(b_low_f32x8, b_low_f32x8, b_norm_sq_f32x8);
910
+ b_norm_sq_f32x8 = _mm256_fmadd_ps(b_high_f32x8, b_high_f32x8, b_norm_sq_f32x8);
911
+ if (n) goto nk_angular_e4m3_haswell_cycle;
897
912
 
898
913
  nk_f32_t dot_product_f32 = nk_reduce_add_f32x8_haswell_(dot_product_f32x8);
899
914
  nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x8_haswell_(a_norm_sq_f32x8);
@@ -902,28 +917,44 @@ nk_angular_e4m3_haswell_cycle:
902
917
  }
903
918
 
904
919
  NK_PUBLIC void nk_sqeuclidean_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
905
- __m256 distance_sq_f32x8 = _mm256_setzero_ps();
920
+ // E5M2 shares F16's exponent bias (15): `byte << 8` equals the matching F16 encoding.
921
+ // `vpunpck*bw` against zero is the free widen+shift: zero byte in low half of each
922
+ // 16-bit lane, E5M2 byte in high half. Per-128-bit-lane scrambled; commutative sum
923
+ // reduction is invariant under that.
924
+ __m256 first_acc_f32x8 = _mm256_setzero_ps();
925
+ __m256 second_acc_f32x8 = _mm256_setzero_ps();
926
+ __m128i const zero_u8x16 = _mm_setzero_si128();
927
+ __m128i a_u8x16, b_u8x16;
906
928
 
907
929
  nk_sqeuclidean_e5m2_haswell_cycle:
908
- if (n < 8) {
930
+ if (n < 16) {
909
931
  nk_b128_vec_t a_vec, b_vec;
910
932
  nk_partial_load_b8x16_serial_(a, &a_vec, n);
911
933
  nk_partial_load_b8x16_serial_(b, &b_vec, n);
912
- __m256 a_f32x8 = nk_e5m2x8_to_f32x8_haswell_(a_vec.xmm);
913
- __m256 b_f32x8 = nk_e5m2x8_to_f32x8_haswell_(b_vec.xmm);
914
- __m256 diff_f32x8 = _mm256_sub_ps(a_f32x8, b_f32x8);
915
- distance_sq_f32x8 = _mm256_fmadd_ps(diff_f32x8, diff_f32x8, distance_sq_f32x8);
934
+ a_u8x16 = a_vec.xmm;
935
+ b_u8x16 = b_vec.xmm;
936
+ n = 0;
916
937
  }
917
938
  else {
918
- __m256 a_f32x8 = nk_e5m2x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)a));
919
- __m256 b_f32x8 = nk_e5m2x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)b));
920
- __m256 diff_f32x8 = _mm256_sub_ps(a_f32x8, b_f32x8);
921
- distance_sq_f32x8 = _mm256_fmadd_ps(diff_f32x8, diff_f32x8, distance_sq_f32x8);
922
- n -= 8, a += 8, b += 8;
923
- goto nk_sqeuclidean_e5m2_haswell_cycle;
924
- }
925
-
926
- *result = nk_reduce_add_f32x8_haswell_(distance_sq_f32x8);
939
+ a_u8x16 = _mm_loadu_si128((__m128i const *)a);
940
+ b_u8x16 = _mm_loadu_si128((__m128i const *)b);
941
+ a += 16, b += 16, n -= 16;
942
+ }
943
+ __m128i a_even_f16x8 = _mm_unpacklo_epi8(zero_u8x16, a_u8x16);
944
+ __m128i a_odd_f16x8 = _mm_unpackhi_epi8(zero_u8x16, a_u8x16);
945
+ __m128i b_even_f16x8 = _mm_unpacklo_epi8(zero_u8x16, b_u8x16);
946
+ __m128i b_odd_f16x8 = _mm_unpackhi_epi8(zero_u8x16, b_u8x16);
947
+ __m256 a_first_f32x8 = _mm256_cvtph_ps(a_even_f16x8);
948
+ __m256 a_second_f32x8 = _mm256_cvtph_ps(a_odd_f16x8);
949
+ __m256 b_first_f32x8 = _mm256_cvtph_ps(b_even_f16x8);
950
+ __m256 b_second_f32x8 = _mm256_cvtph_ps(b_odd_f16x8);
951
+ __m256 diff_first_f32x8 = _mm256_sub_ps(a_first_f32x8, b_first_f32x8);
952
+ __m256 diff_second_f32x8 = _mm256_sub_ps(a_second_f32x8, b_second_f32x8);
953
+ first_acc_f32x8 = _mm256_fmadd_ps(diff_first_f32x8, diff_first_f32x8, first_acc_f32x8);
954
+ second_acc_f32x8 = _mm256_fmadd_ps(diff_second_f32x8, diff_second_f32x8, second_acc_f32x8);
955
+ if (n) goto nk_sqeuclidean_e5m2_haswell_cycle;
956
+
957
+ *result = nk_reduce_add_f32x8_haswell_(_mm256_add_ps(first_acc_f32x8, second_acc_f32x8));
927
958
  }
928
959
 
929
960
  NK_PUBLIC void nk_euclidean_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
@@ -935,27 +966,38 @@ NK_PUBLIC void nk_angular_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, n
935
966
  __m256 dot_product_f32x8 = _mm256_setzero_ps();
936
967
  __m256 a_norm_sq_f32x8 = _mm256_setzero_ps();
937
968
  __m256 b_norm_sq_f32x8 = _mm256_setzero_ps();
969
+ __m128i const zero_u8x16 = _mm_setzero_si128();
970
+ __m128i a_u8x16, b_u8x16;
938
971
 
939
972
  nk_angular_e5m2_haswell_cycle:
940
- if (n < 8) {
973
+ if (n < 16) {
941
974
  nk_b128_vec_t a_vec, b_vec;
942
975
  nk_partial_load_b8x16_serial_(a, &a_vec, n);
943
976
  nk_partial_load_b8x16_serial_(b, &b_vec, n);
944
- __m256 a_f32x8 = nk_e5m2x8_to_f32x8_haswell_(a_vec.xmm);
945
- __m256 b_f32x8 = nk_e5m2x8_to_f32x8_haswell_(b_vec.xmm);
946
- dot_product_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, dot_product_f32x8);
947
- a_norm_sq_f32x8 = _mm256_fmadd_ps(a_f32x8, a_f32x8, a_norm_sq_f32x8);
948
- b_norm_sq_f32x8 = _mm256_fmadd_ps(b_f32x8, b_f32x8, b_norm_sq_f32x8);
977
+ a_u8x16 = a_vec.xmm;
978
+ b_u8x16 = b_vec.xmm;
979
+ n = 0;
949
980
  }
950
981
  else {
951
- __m256 a_f32x8 = nk_e5m2x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)a));
952
- __m256 b_f32x8 = nk_e5m2x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)b));
953
- dot_product_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, dot_product_f32x8);
954
- a_norm_sq_f32x8 = _mm256_fmadd_ps(a_f32x8, a_f32x8, a_norm_sq_f32x8);
955
- b_norm_sq_f32x8 = _mm256_fmadd_ps(b_f32x8, b_f32x8, b_norm_sq_f32x8);
956
- n -= 8, a += 8, b += 8;
957
- goto nk_angular_e5m2_haswell_cycle;
958
- }
982
+ a_u8x16 = _mm_loadu_si128((__m128i const *)a);
983
+ b_u8x16 = _mm_loadu_si128((__m128i const *)b);
984
+ a += 16, b += 16, n -= 16;
985
+ }
986
+ __m128i a_even_f16x8 = _mm_unpacklo_epi8(zero_u8x16, a_u8x16);
987
+ __m128i a_odd_f16x8 = _mm_unpackhi_epi8(zero_u8x16, a_u8x16);
988
+ __m128i b_even_f16x8 = _mm_unpacklo_epi8(zero_u8x16, b_u8x16);
989
+ __m128i b_odd_f16x8 = _mm_unpackhi_epi8(zero_u8x16, b_u8x16);
990
+ __m256 a_first_f32x8 = _mm256_cvtph_ps(a_even_f16x8);
991
+ __m256 a_second_f32x8 = _mm256_cvtph_ps(a_odd_f16x8);
992
+ __m256 b_first_f32x8 = _mm256_cvtph_ps(b_even_f16x8);
993
+ __m256 b_second_f32x8 = _mm256_cvtph_ps(b_odd_f16x8);
994
+ dot_product_f32x8 = _mm256_fmadd_ps(a_first_f32x8, b_first_f32x8, dot_product_f32x8);
995
+ dot_product_f32x8 = _mm256_fmadd_ps(a_second_f32x8, b_second_f32x8, dot_product_f32x8);
996
+ a_norm_sq_f32x8 = _mm256_fmadd_ps(a_first_f32x8, a_first_f32x8, a_norm_sq_f32x8);
997
+ a_norm_sq_f32x8 = _mm256_fmadd_ps(a_second_f32x8, a_second_f32x8, a_norm_sq_f32x8);
998
+ b_norm_sq_f32x8 = _mm256_fmadd_ps(b_first_f32x8, b_first_f32x8, b_norm_sq_f32x8);
999
+ b_norm_sq_f32x8 = _mm256_fmadd_ps(b_second_f32x8, b_second_f32x8, b_norm_sq_f32x8);
1000
+ if (n) goto nk_angular_e5m2_haswell_cycle;
959
1001
 
960
1002
  nk_f32_t dot_product_f32 = nk_reduce_add_f32x8_haswell_(dot_product_f32x8);
961
1003
  nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x8_haswell_(a_norm_sq_f32x8);
@@ -108,6 +108,15 @@ extern "C" {
108
108
  } \
109
109
  }
110
110
 
111
+ /* Keep the serial instantiations below actually scalar, regardless of build type.
112
+ * See dots/serial.h for rationale. */
113
+ #if defined(__clang__)
114
+ #pragma clang attribute push(__attribute__((noinline)), apply_to = function)
115
+ #elif defined(__GNUC__)
116
+ #pragma GCC push_options
117
+ #pragma GCC optimize("no-tree-vectorize", "no-tree-slp-vectorize", "no-ipa-cp-clone", "no-inline")
118
+ #endif
119
+
111
120
  nk_define_angular_(f64, f64, f64, nk_assign_from_to_, nk_f64_rsqrt_serial) // nk_angular_f64_serial
112
121
  nk_define_sqeuclidean_(f64, f64, f64, nk_assign_from_to_) // nk_sqeuclidean_f64_serial
113
122
  nk_define_euclidean_(f64, f64, f64, f64, nk_assign_from_to_, nk_f64_sqrt_serial) // nk_euclidean_f64_serial
@@ -340,6 +349,12 @@ NK_INTERNAL void nk_euclidean_through_u32_from_dot_serial_(nk_b128_vec_t dots, n
340
349
  }
341
350
  }
342
351
 
352
+ #if defined(__clang__)
353
+ #pragma clang attribute pop
354
+ #elif defined(__GNUC__)
355
+ #pragma GCC pop_options
356
+ #endif
357
+
343
358
  #if defined(__cplusplus)
344
359
  } // extern "C"
345
360
  #endif