numkong 7.4.5 → 7.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -0
- package/binding.gyp +99 -5
- package/c/dispatch_e5m2.c +23 -3
- package/c/dispatch_f16.c +23 -0
- package/c/numkong.c +0 -13
- package/include/numkong/attention/sme.h +34 -31
- package/include/numkong/capabilities.h +2 -15
- package/include/numkong/cast/README.md +3 -0
- package/include/numkong/cast/haswell.h +28 -64
- package/include/numkong/cast/neon.h +15 -0
- package/include/numkong/cast/serial.h +17 -0
- package/include/numkong/cast/skylake.h +67 -52
- package/include/numkong/cast.h +1 -0
- package/include/numkong/curved/smef64.h +82 -62
- package/include/numkong/dot/README.md +1 -0
- package/include/numkong/dot/haswell.h +92 -13
- package/include/numkong/dot/rvvbf16.h +1 -1
- package/include/numkong/dot/rvvhalf.h +1 -1
- package/include/numkong/dot/serial.h +15 -0
- package/include/numkong/dot/skylake.h +61 -14
- package/include/numkong/dot/sve.h +6 -5
- package/include/numkong/dot/svebfdot.h +2 -1
- package/include/numkong/dot/svehalf.h +6 -5
- package/include/numkong/dot/svesdot.h +3 -2
- package/include/numkong/dots/README.md +2 -0
- package/include/numkong/dots/graniteamx.h +1167 -0
- package/include/numkong/dots/haswell.h +28 -28
- package/include/numkong/dots/sapphireamx.h +1 -1
- package/include/numkong/dots/serial.h +33 -11
- package/include/numkong/dots/skylake.h +28 -23
- package/include/numkong/dots/sme.h +172 -140
- package/include/numkong/dots/smebi32.h +14 -11
- package/include/numkong/dots/smef64.h +31 -26
- package/include/numkong/dots.h +41 -3
- package/include/numkong/each/serial.h +39 -0
- package/include/numkong/geospatial/haswell.h +1 -1
- package/include/numkong/geospatial/neon.h +1 -1
- package/include/numkong/geospatial/serial.h +15 -4
- package/include/numkong/geospatial/skylake.h +1 -1
- package/include/numkong/maxsim/serial.h +15 -0
- package/include/numkong/maxsim/sme.h +34 -33
- package/include/numkong/mesh/README.md +50 -44
- package/include/numkong/mesh/genoa.h +462 -0
- package/include/numkong/mesh/haswell.h +806 -933
- package/include/numkong/mesh/neon.h +871 -943
- package/include/numkong/mesh/neonbfdot.h +382 -522
- package/include/numkong/mesh/neonfhm.h +676 -0
- package/include/numkong/mesh/rvv.h +404 -319
- package/include/numkong/mesh/serial.h +225 -161
- package/include/numkong/mesh/skylake.h +1029 -1585
- package/include/numkong/mesh/v128relaxed.h +403 -377
- package/include/numkong/mesh.h +38 -0
- package/include/numkong/reduce/neon.h +29 -0
- package/include/numkong/reduce/neonbfdot.h +2 -2
- package/include/numkong/reduce/neonfhm.h +4 -4
- package/include/numkong/reduce/serial.h +15 -1
- package/include/numkong/reduce/sve.h +52 -0
- package/include/numkong/reduce.h +4 -0
- package/include/numkong/set/sve.h +6 -5
- package/include/numkong/sets/smebi32.h +35 -30
- package/include/numkong/sparse/serial.h +17 -2
- package/include/numkong/sparse/sve2.h +3 -2
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +98 -56
- package/include/numkong/spatial/serial.h +15 -0
- package/include/numkong/spatial/skylake.h +114 -54
- package/include/numkong/spatial/sve.h +7 -6
- package/include/numkong/spatial/svebfdot.h +7 -4
- package/include/numkong/spatial/svehalf.h +5 -4
- package/include/numkong/spatial/svesdot.h +9 -8
- package/include/numkong/spatial.h +0 -12
- package/include/numkong/spatials/graniteamx.h +301 -0
- package/include/numkong/spatials/serial.h +39 -0
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +391 -350
- package/include/numkong/spatials/smef64.h +79 -70
- package/include/numkong/spatials.h +54 -4
- package/include/numkong/tensor.hpp +107 -23
- package/include/numkong/types.h +59 -0
- package/javascript/dist/cjs/numkong.js +13 -0
- package/javascript/dist/esm/numkong.js +13 -0
- package/javascript/numkong.c +59 -14
- package/javascript/numkong.ts +13 -0
- package/package.json +7 -7
- package/probes/probe.js +2 -2
- package/wasm/numkong.wasm +0 -0
package/include/numkong/mesh.h
CHANGED
|
@@ -266,6 +266,20 @@ NK_PUBLIC void nk_umeyama_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, n
|
|
|
266
266
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
267
267
|
#endif // NK_TARGET_SKYLAKE
|
|
268
268
|
|
|
269
|
+
/* SIMD-powered backends for AVX512-BF16 CPUs of AMD Genoa / Intel Sapphire Rapids generation and newer.
|
|
270
|
+
*/
|
|
271
|
+
#if NK_TARGET_GENOA
|
|
272
|
+
/** @copydoc nk_rmsd_bf16 */
|
|
273
|
+
NK_PUBLIC void nk_rmsd_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
274
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
275
|
+
/** @copydoc nk_kabsch_bf16 */
|
|
276
|
+
NK_PUBLIC void nk_kabsch_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
277
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
278
|
+
/** @copydoc nk_umeyama_bf16 */
|
|
279
|
+
NK_PUBLIC void nk_umeyama_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
280
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
281
|
+
#endif // NK_TARGET_GENOA
|
|
282
|
+
|
|
269
283
|
/* SIMD-powered backends for AVX2 CPUs of Haswell generation and newer.
|
|
270
284
|
*/
|
|
271
285
|
#if NK_TARGET_HASWELL
|
|
@@ -357,6 +371,20 @@ NK_PUBLIC void nk_umeyama_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
|
|
|
357
371
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
358
372
|
#endif // NK_TARGET_NEONBFDOT
|
|
359
373
|
|
|
374
|
+
/* SIMD-powered backends for Arm NEON FHM (FP16 widening FMA) CPUs.
|
|
375
|
+
*/
|
|
376
|
+
#if NK_TARGET_NEONFHM
|
|
377
|
+
/** @copydoc nk_rmsd_f16 */
|
|
378
|
+
NK_PUBLIC void nk_rmsd_f16_neonfhm(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
379
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
380
|
+
/** @copydoc nk_kabsch_f16 */
|
|
381
|
+
NK_PUBLIC void nk_kabsch_f16_neonfhm(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
382
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
383
|
+
/** @copydoc nk_umeyama_f16 */
|
|
384
|
+
NK_PUBLIC void nk_umeyama_f16_neonfhm(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
385
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
386
|
+
#endif // NK_TARGET_NEONFHM
|
|
387
|
+
|
|
360
388
|
#if NK_TARGET_RVV
|
|
361
389
|
/** @copydoc nk_rmsd_f32 */
|
|
362
390
|
NK_PUBLIC void nk_rmsd_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
@@ -454,8 +482,10 @@ NK_INTERNAL nk_dtype_t nk_mesh_transform_dtype(nk_dtype_t dtype) {
|
|
|
454
482
|
#include "numkong/mesh/serial.h"
|
|
455
483
|
#include "numkong/mesh/neon.h"
|
|
456
484
|
#include "numkong/mesh/neonbfdot.h"
|
|
485
|
+
#include "numkong/mesh/neonfhm.h"
|
|
457
486
|
#include "numkong/mesh/haswell.h"
|
|
458
487
|
#include "numkong/mesh/skylake.h"
|
|
488
|
+
#include "numkong/mesh/genoa.h"
|
|
459
489
|
#include "numkong/mesh/rvv.h"
|
|
460
490
|
#include "numkong/mesh/v128relaxed.h"
|
|
461
491
|
|
|
@@ -505,6 +535,8 @@ NK_PUBLIC void nk_rmsd_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk
|
|
|
505
535
|
nk_rmsd_f16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
506
536
|
#elif NK_TARGET_HASWELL
|
|
507
537
|
nk_rmsd_f16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
538
|
+
#elif NK_TARGET_NEONFHM
|
|
539
|
+
nk_rmsd_f16_neonfhm(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
508
540
|
#elif NK_TARGET_NEON
|
|
509
541
|
nk_rmsd_f16_neon(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
510
542
|
#elif NK_TARGET_RVV
|
|
@@ -517,6 +549,8 @@ NK_PUBLIC void nk_rmsd_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk
|
|
|
517
549
|
NK_PUBLIC void nk_rmsd_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
518
550
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
519
551
|
#if NK_TARGET_SKYLAKE
|
|
552
|
+
// Skylake f32-widen path wins on Intel where VDPBF16PS throughput matches FMA; on AMD Zen4+
|
|
553
|
+
// where VDPBF16PS is faster than FMA, users can call `nk_rmsd_bf16_genoa` directly.
|
|
520
554
|
nk_rmsd_bf16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
521
555
|
#elif NK_TARGET_HASWELL
|
|
522
556
|
nk_rmsd_bf16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
@@ -569,6 +603,8 @@ NK_PUBLIC void nk_kabsch_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
|
|
|
569
603
|
nk_kabsch_f16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
570
604
|
#elif NK_TARGET_HASWELL
|
|
571
605
|
nk_kabsch_f16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
606
|
+
#elif NK_TARGET_NEONFHM
|
|
607
|
+
nk_kabsch_f16_neonfhm(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
572
608
|
#elif NK_TARGET_NEON
|
|
573
609
|
nk_kabsch_f16_neon(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
574
610
|
#elif NK_TARGET_RVV
|
|
@@ -633,6 +669,8 @@ NK_PUBLIC void nk_umeyama_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
|
|
|
633
669
|
nk_umeyama_f16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
634
670
|
#elif NK_TARGET_HASWELL
|
|
635
671
|
nk_umeyama_f16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
672
|
+
#elif NK_TARGET_NEONFHM
|
|
673
|
+
nk_umeyama_f16_neonfhm(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
636
674
|
#elif NK_TARGET_NEON
|
|
637
675
|
nk_umeyama_f16_neon(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
638
676
|
#elif NK_TARGET_RVV
|
|
@@ -3936,6 +3936,35 @@ NK_PUBLIC void nk_reduce_moments_f16_neon( //
|
|
|
3936
3936
|
else nk_reduce_moments_f16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3937
3937
|
}
|
|
3938
3938
|
|
|
3939
|
+
NK_INTERNAL void nk_reduce_moments_u1_neon_contiguous_( //
|
|
3940
|
+
nk_u1x8_t const *data_ptr, nk_size_t count, //
|
|
3941
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
3942
|
+
nk_size_t byte_count = nk_size_divide_round_up_(count, NK_BITS_PER_BYTE);
|
|
3943
|
+
nk_u64_t sum = 0;
|
|
3944
|
+
nk_size_t idx = 0;
|
|
3945
|
+
// Each vcntq_u8 produces values 0-8 per lane; accumulate at u8 level
|
|
3946
|
+
// for up to 31 iterations (31 × 8 = 248, fits in u8) before widening.
|
|
3947
|
+
while (idx + 16 <= byte_count) {
|
|
3948
|
+
uint8x16_t popcount_u8x16 = vdupq_n_u8(0);
|
|
3949
|
+
for (nk_size_t cycle = 0; cycle < 31 && idx + 16 <= byte_count; ++cycle, idx += 16) {
|
|
3950
|
+
uint8x16_t data_u8x16 = vld1q_u8((nk_u8_t const *)data_ptr + idx);
|
|
3951
|
+
popcount_u8x16 = vaddq_u8(popcount_u8x16, vcntq_u8(data_u8x16));
|
|
3952
|
+
}
|
|
3953
|
+
sum += (nk_u64_t)vaddlvq_u8(popcount_u8x16);
|
|
3954
|
+
}
|
|
3955
|
+
for (; idx < byte_count; ++idx) sum += nk_u1x8_popcount_(((nk_u8_t const *)data_ptr)[idx]);
|
|
3956
|
+
*sum_ptr = sum, *sumsq_ptr = sum;
|
|
3957
|
+
}
|
|
3958
|
+
|
|
3959
|
+
NK_PUBLIC void nk_reduce_moments_u1_neon( //
|
|
3960
|
+
nk_u1x8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3961
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
3962
|
+
count = nk_size_round_up_to_multiple_(count, 8);
|
|
3963
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
3964
|
+
else if (stride_bytes == 1) nk_reduce_moments_u1_neon_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
3965
|
+
else nk_reduce_moments_u1_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3966
|
+
}
|
|
3967
|
+
|
|
3939
3968
|
#if defined(__clang__)
|
|
3940
3969
|
#pragma clang attribute pop
|
|
3941
3970
|
#elif defined(__GNUC__)
|
|
@@ -33,7 +33,7 @@ NK_INTERNAL void nk_reduce_moments_bf16_neonbfdot_contiguous_( //
|
|
|
33
33
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
34
34
|
|
|
35
35
|
// bf16 representation of 1.0 is 0x3F80 (same as upper 16 bits of f32 1.0)
|
|
36
|
-
bfloat16x8_t ones_bf16x8 = vreinterpretq_bf16_u16(
|
|
36
|
+
bfloat16x8_t ones_bf16x8 = vreinterpretq_bf16_u16(nk_u16x8_splat_(0x3F80));
|
|
37
37
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
38
38
|
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
39
39
|
nk_size_t idx = 0;
|
|
@@ -61,7 +61,7 @@ NK_INTERNAL void nk_reduce_moments_bf16_neonbfdot_strided_( //
|
|
|
61
61
|
nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
62
62
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
63
63
|
|
|
64
|
-
bfloat16x8_t ones_bf16x8 = vreinterpretq_bf16_u16(
|
|
64
|
+
bfloat16x8_t ones_bf16x8 = vreinterpretq_bf16_u16(nk_u16x8_splat_(0x3F80));
|
|
65
65
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
66
66
|
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
67
67
|
nk_size_t idx = 0;
|
|
@@ -34,7 +34,7 @@ NK_INTERNAL void nk_reduce_moments_e4m3_neonfhm_contiguous_( //
|
|
|
34
34
|
|
|
35
35
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
36
36
|
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
37
|
-
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(
|
|
37
|
+
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(nk_u16x8_splat_(0x3C00));
|
|
38
38
|
nk_size_t idx = 0;
|
|
39
39
|
|
|
40
40
|
for (; idx + 8 <= count; idx += 8) {
|
|
@@ -67,7 +67,7 @@ NK_INTERNAL void nk_reduce_moments_e4m3_neonfhm_strided_( //
|
|
|
67
67
|
|
|
68
68
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
69
69
|
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
70
|
-
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(
|
|
70
|
+
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(nk_u16x8_splat_(0x3C00));
|
|
71
71
|
nk_size_t idx = 0;
|
|
72
72
|
|
|
73
73
|
if (stride_elements == 2) {
|
|
@@ -159,7 +159,7 @@ NK_INTERNAL void nk_reduce_moments_e5m2_neonfhm_contiguous_( //
|
|
|
159
159
|
|
|
160
160
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
161
161
|
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
162
|
-
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(
|
|
162
|
+
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(nk_u16x8_splat_(0x3C00));
|
|
163
163
|
nk_size_t idx = 0;
|
|
164
164
|
|
|
165
165
|
for (; idx + 8 <= count; idx += 8) {
|
|
@@ -192,7 +192,7 @@ NK_INTERNAL void nk_reduce_moments_e5m2_neonfhm_strided_( //
|
|
|
192
192
|
|
|
193
193
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
194
194
|
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
195
|
-
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(
|
|
195
|
+
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(nk_u16x8_splat_(0x3C00));
|
|
196
196
|
nk_size_t idx = 0;
|
|
197
197
|
|
|
198
198
|
if (stride_elements == 2) {
|
|
@@ -14,7 +14,6 @@
|
|
|
14
14
|
#define NK_REDUCE_SERIAL_H
|
|
15
15
|
|
|
16
16
|
#include "numkong/types.h"
|
|
17
|
-
#include "numkong/scalar/serial.h"
|
|
18
17
|
#include "numkong/cast/serial.h"
|
|
19
18
|
#include "numkong/scalar/serial.h"
|
|
20
19
|
|
|
@@ -22,6 +21,15 @@
|
|
|
22
21
|
extern "C" {
|
|
23
22
|
#endif
|
|
24
23
|
|
|
24
|
+
/* Keep the serial instantiations below actually scalar, regardless of build type.
|
|
25
|
+
* See dots/serial.h for rationale. */
|
|
26
|
+
#if defined(__clang__)
|
|
27
|
+
#pragma clang attribute push(__attribute__((noinline)), apply_to = function)
|
|
28
|
+
#elif defined(__GNUC__)
|
|
29
|
+
#pragma GCC push_options
|
|
30
|
+
#pragma GCC optimize("no-tree-vectorize", "no-tree-slp-vectorize", "no-ipa-cp-clone", "no-inline")
|
|
31
|
+
#endif
|
|
32
|
+
|
|
25
33
|
NK_INTERNAL nk_f64_t nk_reduce_sum_f64_serial_(nk_f64_t const *values, nk_f64_t const *compensations, int count) {
|
|
26
34
|
nk_f64_t running_sum = 0, accumulated_error = 0;
|
|
27
35
|
for (int i = 0; i < count; i++) {
|
|
@@ -746,6 +754,12 @@ NK_PUBLIC void nk_reduce_minmax_u1_serial( //
|
|
|
746
754
|
*max_value_ptr = max_value, *max_index_ptr = max_idx;
|
|
747
755
|
}
|
|
748
756
|
|
|
757
|
+
#if defined(__clang__)
|
|
758
|
+
#pragma clang attribute pop
|
|
759
|
+
#elif defined(__GNUC__)
|
|
760
|
+
#pragma GCC pop_options
|
|
761
|
+
#endif
|
|
762
|
+
|
|
749
763
|
#if defined(__cplusplus)
|
|
750
764
|
} // extern "C"
|
|
751
765
|
#endif
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SVE horizontal reduction helpers with MSan unpoisoning.
|
|
3
|
+
* @file include/numkong/reduce/sve.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date April 12, 2026
|
|
6
|
+
*
|
|
7
|
+
* LLVM's MSan does not instrument ARM SVE intrinsics — `svaddv` moves data
|
|
8
|
+
* from vector to scalar registers via architecture-specific paths invisible
|
|
9
|
+
* to the compiler, causing false-positive uninitialized-value reports.
|
|
10
|
+
* These macros wrap the reduction and unpoison the scalar result.
|
|
11
|
+
*
|
|
12
|
+
* The `svaddv` intrinsic stays inside a macro so it expands in the caller's
|
|
13
|
+
* target context — SVE and SME streaming translation units carry incompatible
|
|
14
|
+
* target attributes. The unpoisoning runs on the already-reduced scalar, so it
|
|
15
|
+
* lives in a target-agnostic `NK_INTERNAL` helper called from the macro.
|
|
16
|
+
*
|
|
17
|
+
* @sa include/numkong/reduce.h
|
|
18
|
+
*/
|
|
19
|
+
#ifndef NK_REDUCE_SVE_H
|
|
20
|
+
#define NK_REDUCE_SVE_H
|
|
21
|
+
|
|
22
|
+
#if NK_TARGET_ARM64_
|
|
23
|
+
#if NK_TARGET_SVE || NK_TARGET_SVE2 || NK_TARGET_SME
|
|
24
|
+
|
|
25
|
+
#include "numkong/types.h"
|
|
26
|
+
|
|
27
|
+
NK_INTERNAL nk_f64_t nk_unpoison_f64_(nk_f64_t v) NK_STREAMING_COMPATIBLE_ {
|
|
28
|
+
nk_unpoison_(&v, sizeof(v));
|
|
29
|
+
return v;
|
|
30
|
+
}
|
|
31
|
+
NK_INTERNAL nk_f32_t nk_unpoison_f32_(nk_f32_t v) NK_STREAMING_COMPATIBLE_ {
|
|
32
|
+
nk_unpoison_(&v, sizeof(v));
|
|
33
|
+
return v;
|
|
34
|
+
}
|
|
35
|
+
NK_INTERNAL nk_u64_t nk_unpoison_u64_(nk_u64_t v) NK_STREAMING_COMPATIBLE_ {
|
|
36
|
+
nk_unpoison_(&v, sizeof(v));
|
|
37
|
+
return v;
|
|
38
|
+
}
|
|
39
|
+
NK_INTERNAL nk_i64_t nk_unpoison_i64_(nk_i64_t v) NK_STREAMING_COMPATIBLE_ {
|
|
40
|
+
nk_unpoison_(&v, sizeof(v));
|
|
41
|
+
return v;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
#define nk_svaddv_f64_(predicate, vector) nk_unpoison_f64_(svaddv_f64((predicate), (vector)))
|
|
45
|
+
#define nk_svaddv_f32_(predicate, vector) nk_unpoison_f32_(svaddv_f32((predicate), (vector)))
|
|
46
|
+
#define nk_svaddv_u32_(predicate, vector) nk_unpoison_u64_(svaddv_u32((predicate), (vector)))
|
|
47
|
+
#define nk_svaddv_s32_(predicate, vector) nk_unpoison_i64_(svaddv_s32((predicate), (vector)))
|
|
48
|
+
#define nk_svaddv_u8_(predicate, vector) nk_unpoison_u64_(svaddv_u8((predicate), (vector)))
|
|
49
|
+
|
|
50
|
+
#endif // NK_TARGET_SVE || NK_TARGET_SVE2 || NK_TARGET_SME
|
|
51
|
+
#endif // NK_TARGET_ARM64_
|
|
52
|
+
#endif // NK_REDUCE_SVE_H
|
package/include/numkong/reduce.h
CHANGED
|
@@ -389,6 +389,8 @@ NK_PUBLIC void nk_reduce_moments_i16_neon(nk_i16_t const *, nk_size_t, nk_size_t
|
|
|
389
389
|
/** @copydoc nk_reduce_moments_f64 */
|
|
390
390
|
NK_PUBLIC void nk_reduce_moments_u16_neon(nk_u16_t const *, nk_size_t, nk_size_t, nk_u64_t *, nk_u64_t *);
|
|
391
391
|
/** @copydoc nk_reduce_moments_f64 */
|
|
392
|
+
NK_PUBLIC void nk_reduce_moments_u1_neon(nk_u1x8_t const *, nk_size_t, nk_size_t, nk_u64_t *, nk_u64_t *);
|
|
393
|
+
/** @copydoc nk_reduce_moments_f64 */
|
|
392
394
|
NK_PUBLIC void nk_reduce_moments_i32_neon(nk_i32_t const *, nk_size_t, nk_size_t, nk_i64_t *, nk_u64_t *);
|
|
393
395
|
/** @copydoc nk_reduce_moments_f64 */
|
|
394
396
|
NK_PUBLIC void nk_reduce_moments_u32_neon(nk_u32_t const *, nk_size_t, nk_size_t, nk_u64_t *, nk_u64_t *);
|
|
@@ -1559,6 +1561,8 @@ NK_PUBLIC void nk_reduce_moments_u1(nk_u1x8_t const *d, nk_size_t n, nk_size_t s
|
|
|
1559
1561
|
nk_reduce_moments_u1_skylake(d, n, s, sum, sumsq);
|
|
1560
1562
|
#elif NK_TARGET_HASWELL
|
|
1561
1563
|
nk_reduce_moments_u1_haswell(d, n, s, sum, sumsq);
|
|
1564
|
+
#elif NK_TARGET_NEON
|
|
1565
|
+
nk_reduce_moments_u1_neon(d, n, s, sum, sumsq);
|
|
1562
1566
|
#else
|
|
1563
1567
|
nk_reduce_moments_u1_serial(d, n, s, sum, sumsq);
|
|
1564
1568
|
#endif
|
|
@@ -32,8 +32,9 @@
|
|
|
32
32
|
#if NK_TARGET_ARM64_
|
|
33
33
|
#if NK_TARGET_SVE
|
|
34
34
|
|
|
35
|
-
#include "numkong/types.h"
|
|
36
|
-
#include "numkong/
|
|
35
|
+
#include "numkong/types.h" // `nk_u1x8_t`
|
|
36
|
+
#include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
|
|
37
|
+
#include "numkong/set/neon.h" // `nk_hamming_u1_neon`
|
|
37
38
|
|
|
38
39
|
#if defined(__cplusplus)
|
|
39
40
|
extern "C" {
|
|
@@ -73,7 +74,7 @@ NK_PUBLIC void nk_hamming_u1_sve(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size
|
|
|
73
74
|
i += words_per_register;
|
|
74
75
|
++cycle;
|
|
75
76
|
} while (i < n_bytes && cycle < 31);
|
|
76
|
-
differences +=
|
|
77
|
+
differences += nk_svaddv_u8_(all_predicate_b8x, popcount_u8x);
|
|
77
78
|
popcount_u8x = svdup_n_u8(0);
|
|
78
79
|
cycle = 0; // Reset the cycle counter.
|
|
79
80
|
}
|
|
@@ -110,9 +111,9 @@ NK_PUBLIC void nk_jaccard_u1_sve(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size
|
|
|
110
111
|
i += words_per_register;
|
|
111
112
|
++cycle;
|
|
112
113
|
} while (i < n_bytes && cycle < 31);
|
|
113
|
-
intersection_count +=
|
|
114
|
+
intersection_count += nk_svaddv_u8_(all_predicate_b8x, intersection_popcount_u8x);
|
|
114
115
|
intersection_popcount_u8x = svdup_n_u8(0);
|
|
115
|
-
union_count +=
|
|
116
|
+
union_count += nk_svaddv_u8_(all_predicate_b8x, union_popcount_u8x);
|
|
116
117
|
union_popcount_u8x = svdup_n_u8(0);
|
|
117
118
|
cycle = 0; // Reset the cycle counter.
|
|
118
119
|
}
|
|
@@ -41,8 +41,9 @@
|
|
|
41
41
|
#include "numkong/types.h"
|
|
42
42
|
#include "numkong/set/serial.h"
|
|
43
43
|
#include "numkong/sets/serial.h"
|
|
44
|
-
#include "numkong/
|
|
45
|
-
#include "numkong/reduce.h"
|
|
44
|
+
#include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
|
|
45
|
+
#include "numkong/reduce/neon.h" // `nk_reduce_moments_u1_neon`
|
|
46
|
+
#include "numkong/dots/sme.h" // `nk_sme_zero_za32_*`
|
|
46
47
|
|
|
47
48
|
#if defined(__cplusplus)
|
|
48
49
|
extern "C" {
|
|
@@ -100,7 +101,7 @@ NK_PUBLIC nk_u32_t nk_sets_reduce_sumsq_u1_streaming_(nk_u1x8_t const *data, nk_
|
|
|
100
101
|
svbool_t predicate_b8x = svwhilelt_b8_u64(offset, n_bytes);
|
|
101
102
|
acc_u32x = svdot_u32(acc_u32x, svcnt_u8_z(predicate_b8x, svld1_u8(predicate_b8x, data + offset)), ones_u8x);
|
|
102
103
|
}
|
|
103
|
-
return (nk_u32_t)
|
|
104
|
+
return (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), acc_u32x);
|
|
104
105
|
}
|
|
105
106
|
|
|
106
107
|
#pragma region Hamming Distance
|
|
@@ -187,11 +188,9 @@ NK_PUBLIC void nk_dots_pack_u1_smebi32(nk_u1x8_t const *b, nk_size_t row_count,
|
|
|
187
188
|
// Compute per-row population counts
|
|
188
189
|
for (nk_size_t row = 0; row < row_count; row++) {
|
|
189
190
|
nk_u1x8_t const *src_row = (nk_u1x8_t const *)((char const *)b + row * b_stride_in_bytes);
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
norms_ptr[row] = (nk_u32_t)nk_local_sum_;
|
|
194
|
-
}
|
|
191
|
+
nk_u64_t nk_local_sum_, nk_local_sumsq_;
|
|
192
|
+
nk_reduce_moments_u1_neon(src_row, depth_bytes * 8, sizeof(nk_u1x8_t), &nk_local_sum_, &nk_local_sumsq_);
|
|
193
|
+
norms_ptr[row] = (nk_u32_t)nk_local_sum_;
|
|
195
194
|
}
|
|
196
195
|
}
|
|
197
196
|
|
|
@@ -203,9 +202,9 @@ NK_PUBLIC void nk_dots_pack_u1_smebi32(nk_u1x8_t const *b, nk_size_t row_count,
|
|
|
203
202
|
* Each ZA0.S batch covers 16 depth u32 steps (one full depth tile).
|
|
204
203
|
* BMOPA expansion=1 for u32: each u32 contributes 32 bits via XNOR+POPCNT.
|
|
205
204
|
*/
|
|
206
|
-
|
|
205
|
+
__arm_new("za") static void nk_hammings_packed_u1_smebi32_streaming_( //
|
|
207
206
|
nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
|
|
208
|
-
nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
207
|
+
nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) NK_STREAMING_ {
|
|
209
208
|
|
|
210
209
|
nk_sets_smebi32_packed_header_t const *header = (nk_sets_smebi32_packed_header_t const *)b_packed;
|
|
211
210
|
nk_size_t const row_tile_count_b = header->row_tile_count;
|
|
@@ -344,11 +343,13 @@ __arm_locally_streaming __arm_new("za") static void nk_hammings_packed_u1_smebi3
|
|
|
344
343
|
}
|
|
345
344
|
}
|
|
346
345
|
|
|
347
|
-
NK_PUBLIC void nk_hammings_packed_u1_smebi32(
|
|
348
|
-
|
|
349
|
-
|
|
346
|
+
NK_PUBLIC void nk_hammings_packed_u1_smebi32( //
|
|
347
|
+
nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
|
|
348
|
+
nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
349
|
+
nk_sme_start_streaming_();
|
|
350
350
|
nk_hammings_packed_u1_smebi32_streaming_(a, b_packed, c, row_count_a, row_count_b, depth_bits, a_stride_in_bytes,
|
|
351
351
|
c_stride_in_bytes);
|
|
352
|
+
nk_sme_stop_streaming_();
|
|
352
353
|
}
|
|
353
354
|
|
|
354
355
|
/**
|
|
@@ -357,9 +358,9 @@ NK_PUBLIC void nk_hammings_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_p
|
|
|
357
358
|
* ZA1-3.S = BMOPA accumulators (3 B column tiles in fast path).
|
|
358
359
|
* Mirrors the unpacked kernel nk_hammings_packed_u1_smebi32_streaming_ pattern.
|
|
359
360
|
*/
|
|
360
|
-
|
|
361
|
+
__arm_new("za") static void nk_hammings_symmetric_u1_smebi32_streaming_( //
|
|
361
362
|
nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
|
|
362
|
-
nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
363
|
+
nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
363
364
|
|
|
364
365
|
nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
|
|
365
366
|
nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
|
|
@@ -545,12 +546,13 @@ __arm_locally_streaming __arm_new("za") static void nk_hammings_symmetric_u1_sme
|
|
|
545
546
|
}
|
|
546
547
|
}
|
|
547
548
|
|
|
548
|
-
NK_PUBLIC void nk_hammings_symmetric_u1_smebi32(
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
549
|
+
NK_PUBLIC void nk_hammings_symmetric_u1_smebi32( //
|
|
550
|
+
nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
|
|
551
|
+
nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
552
|
+
nk_sme_start_streaming_();
|
|
552
553
|
nk_hammings_symmetric_u1_smebi32_streaming_(vectors, vectors_count, depth_bits, stride_in_bytes, result,
|
|
553
554
|
result_stride_in_bytes, row_start, row_count);
|
|
555
|
+
nk_sme_stop_streaming_();
|
|
554
556
|
}
|
|
555
557
|
|
|
556
558
|
#pragma endregion Hamming Distance
|
|
@@ -581,9 +583,9 @@ NK_PUBLIC void nk_hammings_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_siz
|
|
|
581
583
|
* union = (norm_a + norm_b + hamming) / 2
|
|
582
584
|
* jaccard = 1 - intersection / union (1.0 when union == 0)
|
|
583
585
|
*/
|
|
584
|
-
|
|
586
|
+
__arm_new("za") static void nk_jaccards_packed_u1_smebi32_streaming_( //
|
|
585
587
|
nk_u1x8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
|
|
586
|
-
nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
588
|
+
nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) NK_STREAMING_ {
|
|
587
589
|
|
|
588
590
|
nk_sets_smebi32_packed_header_t const *header = (nk_sets_smebi32_packed_header_t const *)b_packed;
|
|
589
591
|
nk_size_t const row_tile_count_b = header->row_tile_count;
|
|
@@ -796,11 +798,13 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_packed_u1_smebi3
|
|
|
796
798
|
}
|
|
797
799
|
}
|
|
798
800
|
|
|
799
|
-
NK_PUBLIC void nk_jaccards_packed_u1_smebi32(
|
|
800
|
-
|
|
801
|
-
|
|
801
|
+
NK_PUBLIC void nk_jaccards_packed_u1_smebi32( //
|
|
802
|
+
nk_u1x8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
|
|
803
|
+
nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
804
|
+
nk_sme_start_streaming_();
|
|
802
805
|
nk_jaccards_packed_u1_smebi32_streaming_(a, b_packed, c, row_count_a, row_count_b, depth_bits, a_stride_in_bytes,
|
|
803
806
|
c_stride_in_bytes);
|
|
807
|
+
nk_sme_stop_streaming_();
|
|
804
808
|
}
|
|
805
809
|
|
|
806
810
|
/**
|
|
@@ -808,9 +812,9 @@ NK_PUBLIC void nk_jaccards_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_p
|
|
|
808
812
|
* Fills upper triangle only (column_tile >= row_tile); caller sees result[i][j] for j >= i.
|
|
809
813
|
* Norms computed on-the-fly using streaming SVE popcount.
|
|
810
814
|
*/
|
|
811
|
-
|
|
815
|
+
__arm_new("za") static void nk_jaccards_symmetric_u1_smebi32_streaming_( //
|
|
812
816
|
nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
|
|
813
|
-
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
817
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
814
818
|
|
|
815
819
|
nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
|
|
816
820
|
nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
|
|
@@ -1104,12 +1108,13 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_symmetric_u1_sme
|
|
|
1104
1108
|
}
|
|
1105
1109
|
}
|
|
1106
1110
|
|
|
1107
|
-
NK_PUBLIC void nk_jaccards_symmetric_u1_smebi32(
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
+
NK_PUBLIC void nk_jaccards_symmetric_u1_smebi32( //
|
|
1112
|
+
nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
|
|
1113
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1114
|
+
nk_sme_start_streaming_();
|
|
1111
1115
|
nk_jaccards_symmetric_u1_smebi32_streaming_(vectors, vectors_count, depth_bits, stride_in_bytes, result,
|
|
1112
1116
|
result_stride_in_bytes, row_start, row_count);
|
|
1117
|
+
nk_sme_stop_streaming_();
|
|
1113
1118
|
}
|
|
1114
1119
|
|
|
1115
1120
|
#pragma endregion Jaccard Distance
|
|
@@ -17,7 +17,7 @@ extern "C" {
|
|
|
17
17
|
#endif
|
|
18
18
|
|
|
19
19
|
#define nk_define_sparse_intersect_(input_type) \
|
|
20
|
-
|
|
20
|
+
NK_INTERNAL nk_size_t nk_sparse_intersect_##input_type##_galloping_search_( \
|
|
21
21
|
nk_##input_type##_t const *array, nk_size_t start, nk_size_t length, nk_##input_type##_t val) { \
|
|
22
22
|
nk_size_t low = start; \
|
|
23
23
|
nk_size_t high = start + 1; \
|
|
@@ -32,7 +32,7 @@ extern "C" {
|
|
|
32
32
|
} \
|
|
33
33
|
return low; \
|
|
34
34
|
} \
|
|
35
|
-
|
|
35
|
+
NK_INTERNAL nk_size_t nk_sparse_intersect_##input_type##_linear_scan_( \
|
|
36
36
|
nk_##input_type##_t const *a, nk_##input_type##_t const *b, nk_size_t a_length, nk_size_t b_length, \
|
|
37
37
|
nk_##input_type##_t *result) { \
|
|
38
38
|
nk_size_t intersection_size = 0; \
|
|
@@ -103,6 +103,15 @@ extern "C" {
|
|
|
103
103
|
*product = weights_product; \
|
|
104
104
|
}
|
|
105
105
|
|
|
106
|
+
/* Keep the serial instantiations below actually scalar, regardless of build type.
|
|
107
|
+
* See dots/serial.h for rationale. */
|
|
108
|
+
#if defined(__clang__)
|
|
109
|
+
#pragma clang attribute push(__attribute__((noinline)), apply_to = function)
|
|
110
|
+
#elif defined(__GNUC__)
|
|
111
|
+
#pragma GCC push_options
|
|
112
|
+
#pragma GCC optimize("no-tree-vectorize", "no-tree-slp-vectorize", "no-ipa-cp-clone", "no-inline")
|
|
113
|
+
#endif
|
|
114
|
+
|
|
106
115
|
nk_define_sparse_intersect_(u16) // nk_sparse_intersect_u16_serial
|
|
107
116
|
nk_define_sparse_intersect_(u32) // nk_sparse_intersect_u32_serial
|
|
108
117
|
nk_define_sparse_intersect_(u64) // nk_sparse_intersect_u64_serial
|
|
@@ -110,6 +119,12 @@ nk_define_sparse_intersect_(u64) // nk_sparse_intersect_u64_serial
|
|
|
110
119
|
nk_define_sparse_dot_(u16, bf16, f32, nk_bf16_to_f32_serial) // nk_sparse_dot_u16bf16_serial
|
|
111
120
|
nk_define_sparse_dot_(u32, f32, f64, nk_assign_from_to_) // nk_sparse_dot_u32f32_serial
|
|
112
121
|
|
|
122
|
+
#if defined(__clang__)
|
|
123
|
+
#pragma clang attribute pop
|
|
124
|
+
#elif defined(__GNUC__)
|
|
125
|
+
#pragma GCC pop_options
|
|
126
|
+
#endif
|
|
127
|
+
|
|
113
128
|
#if defined(__cplusplus)
|
|
114
129
|
} // extern "C"
|
|
115
130
|
#endif
|
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
#if NK_TARGET_ARM64_
|
|
13
13
|
|
|
14
14
|
#include "numkong/types.h"
|
|
15
|
+
#include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
|
|
15
16
|
|
|
16
17
|
#if defined(__cplusplus)
|
|
17
18
|
extern "C" {
|
|
@@ -395,7 +396,7 @@ NK_PUBLIC void nk_sparse_dot_u32f32_sve2( //
|
|
|
395
396
|
a_idx += a_step;
|
|
396
397
|
b_idx += b_step;
|
|
397
398
|
}
|
|
398
|
-
*product =
|
|
399
|
+
*product = nk_svaddv_f64_(predicate_all_b64x, product_f64x);
|
|
399
400
|
}
|
|
400
401
|
|
|
401
402
|
#if defined(__clang__)
|
|
@@ -485,7 +486,7 @@ NK_PUBLIC void nk_sparse_dot_u16bf16_sve2( //
|
|
|
485
486
|
a_idx += a_step;
|
|
486
487
|
b_idx += b_step;
|
|
487
488
|
}
|
|
488
|
-
*product =
|
|
489
|
+
*product = nk_svaddv_f32_(svptrue_b32(), product_f32x);
|
|
489
490
|
}
|
|
490
491
|
|
|
491
492
|
#if defined(__clang__)
|
|
@@ -139,74 +139,6 @@ nk_angular_bf16_genoa_cycle:
|
|
|
139
139
|
*result = nk_angular_normalize_f32_haswell_(dot_product_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
140
140
|
}
|
|
141
141
|
|
|
142
|
-
NK_PUBLIC void nk_sqeuclidean_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
143
|
-
__m512 a_sq_f32x16 = _mm512_setzero_ps();
|
|
144
|
-
__m512 b_sq_f32x16 = _mm512_setzero_ps();
|
|
145
|
-
__m512 ab_f32x16 = _mm512_setzero_ps();
|
|
146
|
-
__m256i a_e5m2x32, b_e5m2x32;
|
|
147
|
-
|
|
148
|
-
nk_sqeuclidean_e5m2_genoa_cycle:
|
|
149
|
-
if (n < 32) {
|
|
150
|
-
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
|
|
151
|
-
a_e5m2x32 = _mm256_maskz_loadu_epi8(mask, a);
|
|
152
|
-
b_e5m2x32 = _mm256_maskz_loadu_epi8(mask, b);
|
|
153
|
-
n = 0;
|
|
154
|
-
}
|
|
155
|
-
else {
|
|
156
|
-
a_e5m2x32 = _mm256_loadu_epi8(a);
|
|
157
|
-
b_e5m2x32 = _mm256_loadu_epi8(b);
|
|
158
|
-
a += 32, b += 32, n -= 32;
|
|
159
|
-
}
|
|
160
|
-
__m512i a_bf16x32 = nk_e5m2x32_to_bf16x32_icelake_(a_e5m2x32);
|
|
161
|
-
__m512i b_bf16x32 = nk_e5m2x32_to_bf16x32_icelake_(b_e5m2x32);
|
|
162
|
-
a_sq_f32x16 = _mm512_dpbf16_ps(a_sq_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(a_bf16x32));
|
|
163
|
-
b_sq_f32x16 = _mm512_dpbf16_ps(b_sq_f32x16, nk_m512bh_from_m512i_(b_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
|
|
164
|
-
ab_f32x16 = _mm512_dpbf16_ps(ab_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
|
|
165
|
-
if (n) goto nk_sqeuclidean_e5m2_genoa_cycle;
|
|
166
|
-
|
|
167
|
-
// (a-b)² = a² + b² - 2ab
|
|
168
|
-
__m512 sum_sq_f32x16 = _mm512_add_ps(a_sq_f32x16, b_sq_f32x16);
|
|
169
|
-
*result = nk_reduce_add_f32x16_skylake_(_mm512_fnmadd_ps(_mm512_set1_ps(2.0f), ab_f32x16, sum_sq_f32x16));
|
|
170
|
-
}
|
|
171
|
-
|
|
172
|
-
NK_PUBLIC void nk_euclidean_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
173
|
-
nk_sqeuclidean_e5m2_genoa(a, b, n, result);
|
|
174
|
-
*result = nk_f32_sqrt_haswell(*result);
|
|
175
|
-
}
|
|
176
|
-
|
|
177
|
-
NK_PUBLIC void nk_angular_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
178
|
-
__m512 dot_f32x16 = _mm512_setzero_ps();
|
|
179
|
-
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
180
|
-
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
181
|
-
__m256i a_e5m2x32, b_e5m2x32;
|
|
182
|
-
|
|
183
|
-
nk_angular_e5m2_genoa_cycle:
|
|
184
|
-
if (n < 32) {
|
|
185
|
-
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
|
|
186
|
-
a_e5m2x32 = _mm256_maskz_loadu_epi8(mask, a);
|
|
187
|
-
b_e5m2x32 = _mm256_maskz_loadu_epi8(mask, b);
|
|
188
|
-
n = 0;
|
|
189
|
-
}
|
|
190
|
-
else {
|
|
191
|
-
a_e5m2x32 = _mm256_loadu_epi8(a);
|
|
192
|
-
b_e5m2x32 = _mm256_loadu_epi8(b);
|
|
193
|
-
a += 32, b += 32, n -= 32;
|
|
194
|
-
}
|
|
195
|
-
__m512i a_bf16x32 = nk_e5m2x32_to_bf16x32_icelake_(a_e5m2x32);
|
|
196
|
-
__m512i b_bf16x32 = nk_e5m2x32_to_bf16x32_icelake_(b_e5m2x32);
|
|
197
|
-
dot_f32x16 = _mm512_dpbf16_ps(dot_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
|
|
198
|
-
a_norm_sq_f32x16 = _mm512_dpbf16_ps(a_norm_sq_f32x16, nk_m512bh_from_m512i_(a_bf16x32),
|
|
199
|
-
nk_m512bh_from_m512i_(a_bf16x32));
|
|
200
|
-
b_norm_sq_f32x16 = _mm512_dpbf16_ps(b_norm_sq_f32x16, nk_m512bh_from_m512i_(b_bf16x32),
|
|
201
|
-
nk_m512bh_from_m512i_(b_bf16x32));
|
|
202
|
-
if (n) goto nk_angular_e5m2_genoa_cycle;
|
|
203
|
-
|
|
204
|
-
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
|
|
205
|
-
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a_norm_sq_f32x16);
|
|
206
|
-
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b_norm_sq_f32x16);
|
|
207
|
-
*result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
208
|
-
}
|
|
209
|
-
|
|
210
142
|
#if defined(__clang__)
|
|
211
143
|
#pragma clang attribute pop
|
|
212
144
|
#elif defined(__GNUC__)
|