numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -8,13 +8,13 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section spatial_alder_instructions AVX-VNNI Instructions Performance
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* _mm256_dpbusd_epi32
|
|
13
|
-
* _mm256_sad_epu8
|
|
14
|
-
* _mm256_xor_si256
|
|
15
|
-
* _mm256_add_epi64
|
|
16
|
-
* _mm_rsqrt_ps
|
|
17
|
-
* _mm_sqrt_ss
|
|
11
|
+
* Intrinsic Instruction Alder Lake Raptor Lake
|
|
12
|
+
* _mm256_dpbusd_epi32 VPDPBUSD (YMM, YMM, YMM) 4cy @ p05 4cy @ p05
|
|
13
|
+
* _mm256_sad_epu8 VPSADBW (YMM, YMM, YMM) 3cy @ p5 3cy @ p5
|
|
14
|
+
* _mm256_xor_si256 VPXOR (YMM, YMM, YMM) 1cy @ p015 1cy @ p015
|
|
15
|
+
* _mm256_add_epi64 VPADDQ (YMM, YMM, YMM) 1cy @ p015 1cy @ p015
|
|
16
|
+
* _mm_rsqrt_ps VRSQRTPS (XMM, XMM) 5cy @ p0 5cy @ p0
|
|
17
|
+
* _mm_sqrt_ss VSQRTSS (XMM, XMM, XMM) 12cy @ p0 12cy @ p0
|
|
18
18
|
*
|
|
19
19
|
* All spatial kernels use the dpbusd norm-decomposition approach:
|
|
20
20
|
* ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b)
|
|
@@ -102,7 +102,8 @@ NK_PUBLIC void nk_angular_i8_alder(nk_i8_t const *a, nk_i8_t const *b, nk_size_t
|
|
|
102
102
|
b_norm_sq_i32 += b_element_i32 * b_element_i32;
|
|
103
103
|
}
|
|
104
104
|
|
|
105
|
-
*result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32,
|
|
105
|
+
*result = nk_angular_normalize_f32_haswell_((nk_f32_t)dot_product_i32, (nk_f32_t)a_norm_sq_i32,
|
|
106
|
+
(nk_f32_t)b_norm_sq_i32);
|
|
106
107
|
}
|
|
107
108
|
|
|
108
109
|
NK_PUBLIC void nk_sqeuclidean_i8_alder(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
@@ -254,7 +255,8 @@ NK_PUBLIC void nk_angular_u8_alder(nk_u8_t const *a, nk_u8_t const *b, nk_size_t
|
|
|
254
255
|
b_norm_sq_i32 += b_element_i32 * b_element_i32;
|
|
255
256
|
}
|
|
256
257
|
|
|
257
|
-
*result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32,
|
|
258
|
+
*result = nk_angular_normalize_f32_haswell_((nk_f32_t)dot_product_i32, (nk_f32_t)a_norm_sq_i32,
|
|
259
|
+
(nk_f32_t)b_norm_sq_i32);
|
|
258
260
|
}
|
|
259
261
|
|
|
260
262
|
NK_PUBLIC void nk_angular_e2m3_alder(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
|
|
@@ -265,10 +267,10 @@ NK_PUBLIC void nk_angular_e2m3_alder(nk_e2m3_t const *a_scalars, nk_e2m3_t const
|
|
|
265
267
|
// then normalize: angular = 1 - dot / sqrt(||a||^2 * ||b||^2).
|
|
266
268
|
// Final division by 256.0f for dot and norms cancels in the ratio.
|
|
267
269
|
//
|
|
268
|
-
__m256i const
|
|
269
|
-
|
|
270
|
-
__m256i const
|
|
271
|
-
|
|
270
|
+
__m256i const lut_low_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28, 26,
|
|
271
|
+
24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
272
|
+
__m256i const lut_high_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
273
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
272
274
|
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
273
275
|
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
274
276
|
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
@@ -296,16 +298,16 @@ nk_angular_e2m3_alder_cycle:
|
|
|
296
298
|
// Decode a: extract magnitude, dual-VPSHUFB LUT
|
|
297
299
|
__m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
|
|
298
300
|
__m256i a_shuffle_idx = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
299
|
-
__m256i
|
|
300
|
-
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(
|
|
301
|
-
_mm256_shuffle_epi8(
|
|
301
|
+
__m256i a_high_sel = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
|
|
302
|
+
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, a_shuffle_idx),
|
|
303
|
+
_mm256_shuffle_epi8(lut_high_u8x32, a_shuffle_idx), a_high_sel);
|
|
302
304
|
|
|
303
305
|
// Decode b: same LUT decode
|
|
304
306
|
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
|
|
305
307
|
__m256i b_shuffle_idx = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
306
|
-
__m256i
|
|
307
|
-
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(
|
|
308
|
-
_mm256_shuffle_epi8(
|
|
308
|
+
__m256i b_high_sel = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
|
|
309
|
+
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, b_shuffle_idx),
|
|
310
|
+
_mm256_shuffle_epi8(lut_high_u8x32, b_shuffle_idx), b_high_sel);
|
|
309
311
|
|
|
310
312
|
// Dot product with sign: combined sign from (a XOR b) & 0x20
|
|
311
313
|
__m256i sign_combined = _mm256_and_si256(_mm256_xor_si256(a_e2m3_u8x32, b_e2m3_u8x32), sign_mask_u8x32);
|
|
@@ -325,7 +327,7 @@ nk_angular_e2m3_alder_cycle:
|
|
|
325
327
|
nk_i32_t a_norm_i32 = nk_reduce_add_i32x8_haswell_(a_norm_i32x8);
|
|
326
328
|
nk_i32_t b_norm_i32 = nk_reduce_add_i32x8_haswell_(b_norm_i32x8);
|
|
327
329
|
// The 256.0f factor cancels in the angular normalization ratio
|
|
328
|
-
*result = nk_angular_normalize_f32_haswell_(dot_i32, a_norm_i32, b_norm_i32);
|
|
330
|
+
*result = nk_angular_normalize_f32_haswell_((nk_f32_t)dot_i32, (nk_f32_t)a_norm_i32, (nk_f32_t)b_norm_i32);
|
|
329
331
|
}
|
|
330
332
|
|
|
331
333
|
NK_PUBLIC void nk_sqeuclidean_e2m3_alder(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars,
|
|
@@ -334,10 +336,10 @@ NK_PUBLIC void nk_sqeuclidean_e2m3_alder(nk_e2m3_t const *a_scalars, nk_e2m3_t c
|
|
|
334
336
|
// ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b)
|
|
335
337
|
// Each value × 16 is exact integer, so result = integer_result / 256.0f
|
|
336
338
|
//
|
|
337
|
-
__m256i const
|
|
338
|
-
|
|
339
|
-
__m256i const
|
|
340
|
-
|
|
339
|
+
__m256i const lut_low_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28, 26,
|
|
340
|
+
24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
341
|
+
__m256i const lut_high_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
342
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
341
343
|
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
342
344
|
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
343
345
|
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
@@ -365,15 +367,15 @@ nk_sqeuclidean_e2m3_alder_cycle:
|
|
|
365
367
|
// Decode a and b magnitudes via LUT
|
|
366
368
|
__m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
|
|
367
369
|
__m256i a_shuffle_idx = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
368
|
-
__m256i
|
|
369
|
-
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(
|
|
370
|
-
_mm256_shuffle_epi8(
|
|
370
|
+
__m256i a_high_sel = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
|
|
371
|
+
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, a_shuffle_idx),
|
|
372
|
+
_mm256_shuffle_epi8(lut_high_u8x32, a_shuffle_idx), a_high_sel);
|
|
371
373
|
|
|
372
374
|
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
|
|
373
375
|
__m256i b_shuffle_idx = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
374
|
-
__m256i
|
|
375
|
-
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(
|
|
376
|
-
_mm256_shuffle_epi8(
|
|
376
|
+
__m256i b_high_sel = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
|
|
377
|
+
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, b_shuffle_idx),
|
|
378
|
+
_mm256_shuffle_epi8(lut_high_u8x32, b_shuffle_idx), b_high_sel);
|
|
377
379
|
|
|
378
380
|
// Signed dot product: combined sign from (a XOR b) & 0x20
|
|
379
381
|
__m256i sign_combined = _mm256_and_si256(_mm256_xor_si256(a_e2m3_u8x32, b_e2m3_u8x32), sign_mask_u8x32);
|
|
@@ -405,10 +407,10 @@ NK_PUBLIC void nk_angular_e3m2_alder(nk_e3m2_t const *a_scalars, nk_e3m2_t const
|
|
|
405
407
|
// Every e3m2 value × 16 is an exact integer (max magnitude 448), requiring i16.
|
|
406
408
|
// VPDPWSSD replaces Haswell's VPMADDWD + VPADDD, saving one instruction per accumulation.
|
|
407
409
|
//
|
|
408
|
-
__m256i const
|
|
410
|
+
__m256i const lut_low_byte_first_u8x32 = _mm256_set_epi8( //
|
|
409
411
|
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
|
|
410
412
|
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
|
|
411
|
-
__m256i const
|
|
413
|
+
__m256i const lut_low_byte_second_u8x32 = _mm256_set_epi8( //
|
|
412
414
|
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
413
415
|
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32);
|
|
414
416
|
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
@@ -443,46 +445,48 @@ nk_angular_e3m2_alder_cycle:
|
|
|
443
445
|
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e3m2_u8x32, magnitude_mask_u8x32);
|
|
444
446
|
__m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
445
447
|
__m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
446
|
-
__m256i
|
|
447
|
-
|
|
448
|
-
__m256i
|
|
449
|
-
|
|
448
|
+
__m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
|
|
449
|
+
half_select_u8x32);
|
|
450
|
+
__m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
|
|
451
|
+
half_select_u8x32);
|
|
450
452
|
|
|
451
453
|
// Dual VPSHUFB: lookup low bytes in both halves, blend based on bit 4
|
|
452
|
-
__m256i
|
|
453
|
-
_mm256_shuffle_epi8(
|
|
454
|
-
|
|
455
|
-
__m256i
|
|
456
|
-
_mm256_shuffle_epi8(
|
|
457
|
-
|
|
454
|
+
__m256i a_low_byte_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_byte_first_u8x32, a_shuffle_index_u8x32),
|
|
455
|
+
_mm256_shuffle_epi8(lut_low_byte_second_u8x32, a_shuffle_index_u8x32),
|
|
456
|
+
a_high_select_u8x32);
|
|
457
|
+
__m256i b_low_byte_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_byte_first_u8x32, b_shuffle_index_u8x32),
|
|
458
|
+
_mm256_shuffle_epi8(lut_low_byte_second_u8x32, b_shuffle_index_u8x32),
|
|
459
|
+
b_high_select_u8x32);
|
|
458
460
|
|
|
459
461
|
// High byte: 1 iff magnitude >= 28 (signed compare safe: 27 < 128)
|
|
460
|
-
__m256i
|
|
461
|
-
|
|
462
|
+
__m256i a_high_byte_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(a_magnitude_u8x32, high_threshold_u8x32),
|
|
463
|
+
ones_u8x32);
|
|
464
|
+
__m256i b_high_byte_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(b_magnitude_u8x32, high_threshold_u8x32),
|
|
465
|
+
ones_u8x32);
|
|
462
466
|
|
|
463
467
|
// Interleave low and high bytes into i16 (little-endian: low byte first)
|
|
464
|
-
__m256i
|
|
465
|
-
__m256i
|
|
466
|
-
__m256i
|
|
467
|
-
__m256i
|
|
468
|
+
__m256i a_low_i16x16 = _mm256_unpacklo_epi8(a_low_byte_u8x32, a_high_byte_u8x32);
|
|
469
|
+
__m256i a_high_i16x16 = _mm256_unpackhi_epi8(a_low_byte_u8x32, a_high_byte_u8x32);
|
|
470
|
+
__m256i b_low_i16x16 = _mm256_unpacklo_epi8(b_low_byte_u8x32, b_high_byte_u8x32);
|
|
471
|
+
__m256i b_high_i16x16 = _mm256_unpackhi_epi8(b_low_byte_u8x32, b_high_byte_u8x32);
|
|
468
472
|
|
|
469
473
|
// Combined sign: (a ^ b) & 0x20, widen to i16 via unpack, create +1/-1 sign vector
|
|
470
474
|
__m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e3m2_u8x32, b_e3m2_u8x32), sign_mask_u8x32);
|
|
471
475
|
__m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(sign_combined_u8x32, sign_mask_u8x32);
|
|
472
|
-
__m256i
|
|
473
|
-
__m256i
|
|
474
|
-
__m256i
|
|
475
|
-
__m256i
|
|
476
|
-
__m256i
|
|
477
|
-
__m256i
|
|
476
|
+
__m256i negate_low_i16x16 = _mm256_unpacklo_epi8(negate_mask_u8x32, negate_mask_u8x32);
|
|
477
|
+
__m256i negate_high_i16x16 = _mm256_unpackhi_epi8(negate_mask_u8x32, negate_mask_u8x32);
|
|
478
|
+
__m256i sign_low_i16x16 = _mm256_or_si256(negate_low_i16x16, ones_i16x16);
|
|
479
|
+
__m256i sign_high_i16x16 = _mm256_or_si256(negate_high_i16x16, ones_i16x16);
|
|
480
|
+
__m256i b_signed_low_i16x16 = _mm256_sign_epi16(b_low_i16x16, sign_low_i16x16);
|
|
481
|
+
__m256i b_signed_high_i16x16 = _mm256_sign_epi16(b_high_i16x16, sign_high_i16x16);
|
|
478
482
|
|
|
479
483
|
// VPDPWSSD: i16×i16→i32 fused dot-product-accumulate (replaces VPMADDWD + VPADDD)
|
|
480
|
-
dot_i32x8 = _mm256_dpwssd_avx_epi32(dot_i32x8,
|
|
481
|
-
dot_i32x8 = _mm256_dpwssd_avx_epi32(dot_i32x8,
|
|
482
|
-
a_norm_i32x8 = _mm256_dpwssd_avx_epi32(a_norm_i32x8,
|
|
483
|
-
a_norm_i32x8 = _mm256_dpwssd_avx_epi32(a_norm_i32x8,
|
|
484
|
-
b_norm_i32x8 = _mm256_dpwssd_avx_epi32(b_norm_i32x8,
|
|
485
|
-
b_norm_i32x8 = _mm256_dpwssd_avx_epi32(b_norm_i32x8,
|
|
484
|
+
dot_i32x8 = _mm256_dpwssd_avx_epi32(dot_i32x8, a_low_i16x16, b_signed_low_i16x16);
|
|
485
|
+
dot_i32x8 = _mm256_dpwssd_avx_epi32(dot_i32x8, a_high_i16x16, b_signed_high_i16x16);
|
|
486
|
+
a_norm_i32x8 = _mm256_dpwssd_avx_epi32(a_norm_i32x8, a_low_i16x16, a_low_i16x16);
|
|
487
|
+
a_norm_i32x8 = _mm256_dpwssd_avx_epi32(a_norm_i32x8, a_high_i16x16, a_high_i16x16);
|
|
488
|
+
b_norm_i32x8 = _mm256_dpwssd_avx_epi32(b_norm_i32x8, b_low_i16x16, b_low_i16x16);
|
|
489
|
+
b_norm_i32x8 = _mm256_dpwssd_avx_epi32(b_norm_i32x8, b_high_i16x16, b_high_i16x16);
|
|
486
490
|
|
|
487
491
|
if (count_scalars) goto nk_angular_e3m2_alder_cycle;
|
|
488
492
|
|
|
@@ -490,19 +494,19 @@ nk_angular_e3m2_alder_cycle:
|
|
|
490
494
|
nk_i32_t a_norm_i32 = nk_reduce_add_i32x8_haswell_(a_norm_i32x8);
|
|
491
495
|
nk_i32_t b_norm_i32 = nk_reduce_add_i32x8_haswell_(b_norm_i32x8);
|
|
492
496
|
// The 256.0f factor cancels in the angular normalization ratio
|
|
493
|
-
*result = nk_angular_normalize_f32_haswell_(dot_i32, a_norm_i32, b_norm_i32);
|
|
497
|
+
*result = nk_angular_normalize_f32_haswell_((nk_f32_t)dot_i32, (nk_f32_t)a_norm_i32, (nk_f32_t)b_norm_i32);
|
|
494
498
|
}
|
|
495
499
|
|
|
496
500
|
NK_PUBLIC void nk_sqeuclidean_e3m2_alder(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars,
|
|
497
501
|
nk_size_t count_scalars, nk_f32_t *result) {
|
|
498
|
-
// Squared Euclidean distance for e3m2
|
|
499
|
-
//
|
|
500
|
-
//
|
|
502
|
+
// Squared Euclidean distance for e3m2 via direct difference squaring.
|
|
503
|
+
// Computes Σ(a_i − b_i)² using signed i16 subtraction + VPMADDWD self-multiply.
|
|
504
|
+
// 2 VPMADDWDs per 32 elements (one per i16 half). 0 ULP.
|
|
501
505
|
//
|
|
502
|
-
__m256i const
|
|
506
|
+
__m256i const lut_low_byte_first_u8x32 = _mm256_set_epi8( //
|
|
503
507
|
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
|
|
504
508
|
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
|
|
505
|
-
__m256i const
|
|
509
|
+
__m256i const lut_low_byte_second_u8x32 = _mm256_set_epi8( //
|
|
506
510
|
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
507
511
|
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32);
|
|
508
512
|
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
@@ -512,9 +516,7 @@ NK_PUBLIC void nk_sqeuclidean_e3m2_alder(nk_e3m2_t const *a_scalars, nk_e3m2_t c
|
|
|
512
516
|
__m256i const high_threshold_u8x32 = _mm256_set1_epi8(27);
|
|
513
517
|
__m256i const ones_u8x32 = _mm256_set1_epi8(1);
|
|
514
518
|
__m256i const ones_i16x16 = _mm256_set1_epi16(1);
|
|
515
|
-
__m256i
|
|
516
|
-
__m256i a_norm_i32x8 = _mm256_setzero_si256();
|
|
517
|
-
__m256i b_norm_i32x8 = _mm256_setzero_si256();
|
|
519
|
+
__m256i sum_i32x8 = _mm256_setzero_si256();
|
|
518
520
|
__m256i a_e3m2_u8x32, b_e3m2_u8x32;
|
|
519
521
|
|
|
520
522
|
nk_sqeuclidean_e3m2_alder_cycle:
|
|
@@ -532,59 +534,54 @@ nk_sqeuclidean_e3m2_alder_cycle:
|
|
|
532
534
|
a_scalars += 32, b_scalars += 32, count_scalars -= 32;
|
|
533
535
|
}
|
|
534
536
|
|
|
535
|
-
//
|
|
537
|
+
// Decode both to unsigned i16 via dual-VPSHUFB + interleave
|
|
536
538
|
__m256i a_magnitude_u8x32 = _mm256_and_si256(a_e3m2_u8x32, magnitude_mask_u8x32);
|
|
537
539
|
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e3m2_u8x32, magnitude_mask_u8x32);
|
|
538
540
|
__m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
539
541
|
__m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
540
|
-
__m256i
|
|
541
|
-
|
|
542
|
-
__m256i
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
__m256i
|
|
555
|
-
__m256i
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
__m256i
|
|
561
|
-
__m256i
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
__m256i
|
|
565
|
-
|
|
566
|
-
__m256i
|
|
567
|
-
|
|
568
|
-
__m256i
|
|
569
|
-
|
|
570
|
-
__m256i
|
|
571
|
-
__m256i
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
542
|
+
__m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
|
|
543
|
+
half_select_u8x32);
|
|
544
|
+
__m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
|
|
545
|
+
half_select_u8x32);
|
|
546
|
+
__m256i a_low_bytes_u8x32 = _mm256_blendv_epi8(
|
|
547
|
+
_mm256_shuffle_epi8(lut_low_byte_first_u8x32, a_shuffle_index_u8x32),
|
|
548
|
+
_mm256_shuffle_epi8(lut_low_byte_second_u8x32, a_shuffle_index_u8x32), a_high_select_u8x32);
|
|
549
|
+
__m256i b_low_bytes_u8x32 = _mm256_blendv_epi8(
|
|
550
|
+
_mm256_shuffle_epi8(lut_low_byte_first_u8x32, b_shuffle_index_u8x32),
|
|
551
|
+
_mm256_shuffle_epi8(lut_low_byte_second_u8x32, b_shuffle_index_u8x32), b_high_select_u8x32);
|
|
552
|
+
__m256i a_high_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(a_magnitude_u8x32, high_threshold_u8x32),
|
|
553
|
+
ones_u8x32);
|
|
554
|
+
__m256i b_high_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(b_magnitude_u8x32, high_threshold_u8x32),
|
|
555
|
+
ones_u8x32);
|
|
556
|
+
__m256i a_low_i16x16 = _mm256_unpacklo_epi8(a_low_bytes_u8x32, a_high_bytes_u8x32);
|
|
557
|
+
__m256i a_high_i16x16 = _mm256_unpackhi_epi8(a_low_bytes_u8x32, a_high_bytes_u8x32);
|
|
558
|
+
__m256i b_low_i16x16 = _mm256_unpacklo_epi8(b_low_bytes_u8x32, b_high_bytes_u8x32);
|
|
559
|
+
__m256i b_high_i16x16 = _mm256_unpackhi_epi8(b_low_bytes_u8x32, b_high_bytes_u8x32);
|
|
560
|
+
|
|
561
|
+
// Apply signs individually
|
|
562
|
+
__m256i a_negative_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_e3m2_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
563
|
+
__m256i b_negative_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_e3m2_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
564
|
+
__m256i a_sign_low_i16x16 = _mm256_or_si256(_mm256_unpacklo_epi8(a_negative_mask_u8x32, a_negative_mask_u8x32),
|
|
565
|
+
ones_i16x16);
|
|
566
|
+
__m256i a_sign_high_i16x16 = _mm256_or_si256(_mm256_unpackhi_epi8(a_negative_mask_u8x32, a_negative_mask_u8x32),
|
|
567
|
+
ones_i16x16);
|
|
568
|
+
__m256i b_sign_low_i16x16 = _mm256_or_si256(_mm256_unpacklo_epi8(b_negative_mask_u8x32, b_negative_mask_u8x32),
|
|
569
|
+
ones_i16x16);
|
|
570
|
+
__m256i b_sign_high_i16x16 = _mm256_or_si256(_mm256_unpackhi_epi8(b_negative_mask_u8x32, b_negative_mask_u8x32),
|
|
571
|
+
ones_i16x16);
|
|
572
|
+
__m256i a_signed_low_i16x16 = _mm256_sign_epi16(a_low_i16x16, a_sign_low_i16x16);
|
|
573
|
+
__m256i a_signed_high_i16x16 = _mm256_sign_epi16(a_high_i16x16, a_sign_high_i16x16);
|
|
574
|
+
__m256i b_signed_low_i16x16 = _mm256_sign_epi16(b_low_i16x16, b_sign_low_i16x16);
|
|
575
|
+
__m256i b_signed_high_i16x16 = _mm256_sign_epi16(b_high_i16x16, b_sign_high_i16x16);
|
|
576
|
+
|
|
577
|
+
// Direct difference squaring: (a−b)² via VPMADDWD
|
|
578
|
+
__m256i diff_low_i16x16 = _mm256_sub_epi16(a_signed_low_i16x16, b_signed_low_i16x16);
|
|
579
|
+
__m256i diff_high_i16x16 = _mm256_sub_epi16(a_signed_high_i16x16, b_signed_high_i16x16);
|
|
580
|
+
sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(diff_low_i16x16, diff_low_i16x16));
|
|
581
|
+
sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(diff_high_i16x16, diff_high_i16x16));
|
|
580
582
|
|
|
581
583
|
if (count_scalars) goto nk_sqeuclidean_e3m2_alder_cycle;
|
|
582
|
-
|
|
583
|
-
nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(dot_i32x8);
|
|
584
|
-
nk_i32_t a_norm_i32 = nk_reduce_add_i32x8_haswell_(a_norm_i32x8);
|
|
585
|
-
nk_i32_t b_norm_i32 = nk_reduce_add_i32x8_haswell_(b_norm_i32x8);
|
|
586
|
-
// ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b), scaled by 256
|
|
587
|
-
*result = (nk_f32_t)(a_norm_i32 + b_norm_i32 - 2 * dot_i32) / 256.0f;
|
|
584
|
+
*result = (nk_f32_t)nk_reduce_add_i32x8_haswell_(sum_i32x8) / 256.0f;
|
|
588
585
|
}
|
|
589
586
|
|
|
590
587
|
NK_PUBLIC void nk_euclidean_e3m2_alder(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Spatial Similarity Measures for Diamond Rapids.
|
|
3
|
+
* @file include/numkong/spatial/diamond.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 23, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/spatial.h
|
|
8
|
+
*
|
|
9
|
+
* For L2 distance, uses the identity: (a−b)² = a² + b² − 2 × a × b,
|
|
10
|
+
* with VCVTHF82PH/VCVTBF82PH for 1-instruction FP8→FP16 conversion and
|
|
11
|
+
* VDPPHPS for FP16-pair dot products accumulating into FP32.
|
|
12
|
+
*/
|
|
13
|
+
#ifndef NK_SPATIAL_DIAMOND_H
|
|
14
|
+
#define NK_SPATIAL_DIAMOND_H
|
|
15
|
+
|
|
16
|
+
#if NK_TARGET_X86_
|
|
17
|
+
#if NK_TARGET_DIAMOND
|
|
18
|
+
|
|
19
|
+
#include "numkong/types.h"
|
|
20
|
+
#include "numkong/spatial/haswell.h" // `nk_angular_normalize_f32_haswell_`, `nk_f32_sqrt_haswell`
|
|
21
|
+
#include "numkong/reduce/skylake.h" // `nk_reduce_add_f32x16_skylake_`
|
|
22
|
+
|
|
23
|
+
#if defined(__cplusplus)
|
|
24
|
+
extern "C" {
|
|
25
|
+
#endif
|
|
26
|
+
|
|
27
|
+
#if defined(__clang__)
|
|
28
|
+
#pragma clang attribute push( \
|
|
29
|
+
__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,avx10.2-512,f16c,fma,bmi,bmi2"))), \
|
|
30
|
+
apply_to = function)
|
|
31
|
+
#elif defined(__GNUC__)
|
|
32
|
+
#pragma GCC push_options
|
|
33
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "avx10.2-512", "f16c", "fma", \
|
|
34
|
+
"bmi", "bmi2")
|
|
35
|
+
#endif
|
|
36
|
+
|
|
37
|
+
NK_PUBLIC void nk_sqeuclidean_e4m3_diamond(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
38
|
+
__m512 a_sq_f32x16 = _mm512_setzero_ps();
|
|
39
|
+
__m512 b_sq_f32x16 = _mm512_setzero_ps();
|
|
40
|
+
__m512 ab_f32x16 = _mm512_setzero_ps();
|
|
41
|
+
__m256i a_e4m3x32, b_e4m3x32;
|
|
42
|
+
|
|
43
|
+
nk_sqeuclidean_e4m3_diamond_cycle:
|
|
44
|
+
if (n < 32) {
|
|
45
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
|
|
46
|
+
a_e4m3x32 = _mm256_maskz_loadu_epi8(mask, a);
|
|
47
|
+
b_e4m3x32 = _mm256_maskz_loadu_epi8(mask, b);
|
|
48
|
+
n = 0;
|
|
49
|
+
}
|
|
50
|
+
else {
|
|
51
|
+
a_e4m3x32 = _mm256_loadu_epi8(a);
|
|
52
|
+
b_e4m3x32 = _mm256_loadu_epi8(b);
|
|
53
|
+
a += 32, b += 32, n -= 32;
|
|
54
|
+
}
|
|
55
|
+
__m512h a_f16x32 = _mm512_cvthf8_ph(a_e4m3x32);
|
|
56
|
+
__m512h b_f16x32 = _mm512_cvthf8_ph(b_e4m3x32);
|
|
57
|
+
a_sq_f32x16 = _mm512_dpph_ps(a_sq_f32x16, a_f16x32, a_f16x32);
|
|
58
|
+
b_sq_f32x16 = _mm512_dpph_ps(b_sq_f32x16, b_f16x32, b_f16x32);
|
|
59
|
+
ab_f32x16 = _mm512_dpph_ps(ab_f32x16, a_f16x32, b_f16x32);
|
|
60
|
+
if (n) goto nk_sqeuclidean_e4m3_diamond_cycle;
|
|
61
|
+
|
|
62
|
+
__m512 sum_sq_f32x16 = _mm512_add_ps(a_sq_f32x16, b_sq_f32x16);
|
|
63
|
+
*result = nk_reduce_add_f32x16_skylake_(_mm512_fnmadd_ps(_mm512_set1_ps(2.0f), ab_f32x16, sum_sq_f32x16));
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
NK_PUBLIC void nk_euclidean_e4m3_diamond(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
67
|
+
nk_sqeuclidean_e4m3_diamond(a, b, n, result);
|
|
68
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
NK_PUBLIC void nk_angular_e4m3_diamond(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
72
|
+
__m512 dot_f32x16 = _mm512_setzero_ps();
|
|
73
|
+
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
74
|
+
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
75
|
+
__m256i a_e4m3x32, b_e4m3x32;
|
|
76
|
+
|
|
77
|
+
nk_angular_e4m3_diamond_cycle:
|
|
78
|
+
if (n < 32) {
|
|
79
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
|
|
80
|
+
a_e4m3x32 = _mm256_maskz_loadu_epi8(mask, a);
|
|
81
|
+
b_e4m3x32 = _mm256_maskz_loadu_epi8(mask, b);
|
|
82
|
+
n = 0;
|
|
83
|
+
}
|
|
84
|
+
else {
|
|
85
|
+
a_e4m3x32 = _mm256_loadu_epi8(a);
|
|
86
|
+
b_e4m3x32 = _mm256_loadu_epi8(b);
|
|
87
|
+
a += 32, b += 32, n -= 32;
|
|
88
|
+
}
|
|
89
|
+
__m512h a_f16x32 = _mm512_cvthf8_ph(a_e4m3x32);
|
|
90
|
+
__m512h b_f16x32 = _mm512_cvthf8_ph(b_e4m3x32);
|
|
91
|
+
dot_f32x16 = _mm512_dpph_ps(dot_f32x16, a_f16x32, b_f16x32);
|
|
92
|
+
a_norm_sq_f32x16 = _mm512_dpph_ps(a_norm_sq_f32x16, a_f16x32, a_f16x32);
|
|
93
|
+
b_norm_sq_f32x16 = _mm512_dpph_ps(b_norm_sq_f32x16, b_f16x32, b_f16x32);
|
|
94
|
+
if (n) goto nk_angular_e4m3_diamond_cycle;
|
|
95
|
+
|
|
96
|
+
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
|
|
97
|
+
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a_norm_sq_f32x16);
|
|
98
|
+
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b_norm_sq_f32x16);
|
|
99
|
+
*result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
NK_PUBLIC void nk_sqeuclidean_e5m2_diamond(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
103
|
+
__m512 a_sq_f32x16 = _mm512_setzero_ps();
|
|
104
|
+
__m512 b_sq_f32x16 = _mm512_setzero_ps();
|
|
105
|
+
__m512 ab_f32x16 = _mm512_setzero_ps();
|
|
106
|
+
__m256i a_e5m2x32, b_e5m2x32;
|
|
107
|
+
|
|
108
|
+
nk_sqeuclidean_e5m2_diamond_cycle:
|
|
109
|
+
if (n < 32) {
|
|
110
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
|
|
111
|
+
a_e5m2x32 = _mm256_maskz_loadu_epi8(mask, a);
|
|
112
|
+
b_e5m2x32 = _mm256_maskz_loadu_epi8(mask, b);
|
|
113
|
+
n = 0;
|
|
114
|
+
}
|
|
115
|
+
else {
|
|
116
|
+
a_e5m2x32 = _mm256_loadu_epi8(a);
|
|
117
|
+
b_e5m2x32 = _mm256_loadu_epi8(b);
|
|
118
|
+
a += 32, b += 32, n -= 32;
|
|
119
|
+
}
|
|
120
|
+
__m512h a_f16x32 = _mm512_cvtbf8_ph(a_e5m2x32);
|
|
121
|
+
__m512h b_f16x32 = _mm512_cvtbf8_ph(b_e5m2x32);
|
|
122
|
+
a_sq_f32x16 = _mm512_dpph_ps(a_sq_f32x16, a_f16x32, a_f16x32);
|
|
123
|
+
b_sq_f32x16 = _mm512_dpph_ps(b_sq_f32x16, b_f16x32, b_f16x32);
|
|
124
|
+
ab_f32x16 = _mm512_dpph_ps(ab_f32x16, a_f16x32, b_f16x32);
|
|
125
|
+
if (n) goto nk_sqeuclidean_e5m2_diamond_cycle;
|
|
126
|
+
|
|
127
|
+
__m512 sum_sq_f32x16 = _mm512_add_ps(a_sq_f32x16, b_sq_f32x16);
|
|
128
|
+
*result = nk_reduce_add_f32x16_skylake_(_mm512_fnmadd_ps(_mm512_set1_ps(2.0f), ab_f32x16, sum_sq_f32x16));
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
NK_PUBLIC void nk_euclidean_e5m2_diamond(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
132
|
+
nk_sqeuclidean_e5m2_diamond(a, b, n, result);
|
|
133
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
NK_PUBLIC void nk_angular_e5m2_diamond(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
137
|
+
__m512 dot_f32x16 = _mm512_setzero_ps();
|
|
138
|
+
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
139
|
+
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
140
|
+
__m256i a_e5m2x32, b_e5m2x32;
|
|
141
|
+
|
|
142
|
+
nk_angular_e5m2_diamond_cycle:
|
|
143
|
+
if (n < 32) {
|
|
144
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
|
|
145
|
+
a_e5m2x32 = _mm256_maskz_loadu_epi8(mask, a);
|
|
146
|
+
b_e5m2x32 = _mm256_maskz_loadu_epi8(mask, b);
|
|
147
|
+
n = 0;
|
|
148
|
+
}
|
|
149
|
+
else {
|
|
150
|
+
a_e5m2x32 = _mm256_loadu_epi8(a);
|
|
151
|
+
b_e5m2x32 = _mm256_loadu_epi8(b);
|
|
152
|
+
a += 32, b += 32, n -= 32;
|
|
153
|
+
}
|
|
154
|
+
__m512h a_f16x32 = _mm512_cvtbf8_ph(a_e5m2x32);
|
|
155
|
+
__m512h b_f16x32 = _mm512_cvtbf8_ph(b_e5m2x32);
|
|
156
|
+
dot_f32x16 = _mm512_dpph_ps(dot_f32x16, a_f16x32, b_f16x32);
|
|
157
|
+
a_norm_sq_f32x16 = _mm512_dpph_ps(a_norm_sq_f32x16, a_f16x32, a_f16x32);
|
|
158
|
+
b_norm_sq_f32x16 = _mm512_dpph_ps(b_norm_sq_f32x16, b_f16x32, b_f16x32);
|
|
159
|
+
if (n) goto nk_angular_e5m2_diamond_cycle;
|
|
160
|
+
|
|
161
|
+
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
|
|
162
|
+
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a_norm_sq_f32x16);
|
|
163
|
+
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b_norm_sq_f32x16);
|
|
164
|
+
*result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
NK_PUBLIC void nk_sqeuclidean_f16_diamond(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
168
|
+
__m512 a_sq_f32x16 = _mm512_setzero_ps();
|
|
169
|
+
__m512 b_sq_f32x16 = _mm512_setzero_ps();
|
|
170
|
+
__m512 ab_f32x16 = _mm512_setzero_ps();
|
|
171
|
+
__m512h a_f16x32, b_f16x32;
|
|
172
|
+
|
|
173
|
+
nk_sqeuclidean_f16_diamond_cycle:
|
|
174
|
+
if (n < 32) {
|
|
175
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
|
|
176
|
+
a_f16x32 = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a));
|
|
177
|
+
b_f16x32 = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b));
|
|
178
|
+
n = 0;
|
|
179
|
+
}
|
|
180
|
+
else {
|
|
181
|
+
a_f16x32 = _mm512_castsi512_ph(_mm512_loadu_epi16(a));
|
|
182
|
+
b_f16x32 = _mm512_castsi512_ph(_mm512_loadu_epi16(b));
|
|
183
|
+
a += 32, b += 32, n -= 32;
|
|
184
|
+
}
|
|
185
|
+
a_sq_f32x16 = _mm512_dpph_ps(a_sq_f32x16, a_f16x32, a_f16x32);
|
|
186
|
+
b_sq_f32x16 = _mm512_dpph_ps(b_sq_f32x16, b_f16x32, b_f16x32);
|
|
187
|
+
ab_f32x16 = _mm512_dpph_ps(ab_f32x16, a_f16x32, b_f16x32);
|
|
188
|
+
if (n) goto nk_sqeuclidean_f16_diamond_cycle;
|
|
189
|
+
|
|
190
|
+
__m512 sum_sq_f32x16 = _mm512_add_ps(a_sq_f32x16, b_sq_f32x16);
|
|
191
|
+
*result = nk_reduce_add_f32x16_skylake_(_mm512_fnmadd_ps(_mm512_set1_ps(2.0f), ab_f32x16, sum_sq_f32x16));
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
NK_PUBLIC void nk_euclidean_f16_diamond(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
195
|
+
nk_sqeuclidean_f16_diamond(a, b, n, result);
|
|
196
|
+
*result = nk_f32_sqrt_haswell(*result);
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
NK_PUBLIC void nk_angular_f16_diamond(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
200
|
+
__m512 dot_f32x16 = _mm512_setzero_ps();
|
|
201
|
+
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
202
|
+
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
203
|
+
__m512h a_f16x32, b_f16x32;
|
|
204
|
+
|
|
205
|
+
nk_angular_f16_diamond_cycle:
|
|
206
|
+
if (n < 32) {
|
|
207
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
|
|
208
|
+
a_f16x32 = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a));
|
|
209
|
+
b_f16x32 = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b));
|
|
210
|
+
n = 0;
|
|
211
|
+
}
|
|
212
|
+
else {
|
|
213
|
+
a_f16x32 = _mm512_castsi512_ph(_mm512_loadu_epi16(a));
|
|
214
|
+
b_f16x32 = _mm512_castsi512_ph(_mm512_loadu_epi16(b));
|
|
215
|
+
a += 32, b += 32, n -= 32;
|
|
216
|
+
}
|
|
217
|
+
dot_f32x16 = _mm512_dpph_ps(dot_f32x16, a_f16x32, b_f16x32);
|
|
218
|
+
a_norm_sq_f32x16 = _mm512_dpph_ps(a_norm_sq_f32x16, a_f16x32, a_f16x32);
|
|
219
|
+
b_norm_sq_f32x16 = _mm512_dpph_ps(b_norm_sq_f32x16, b_f16x32, b_f16x32);
|
|
220
|
+
if (n) goto nk_angular_f16_diamond_cycle;
|
|
221
|
+
|
|
222
|
+
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
|
|
223
|
+
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a_norm_sq_f32x16);
|
|
224
|
+
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b_norm_sq_f32x16);
|
|
225
|
+
*result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
#if defined(__clang__)
|
|
229
|
+
#pragma clang attribute pop
|
|
230
|
+
#elif defined(__GNUC__)
|
|
231
|
+
#pragma GCC pop_options
|
|
232
|
+
#endif
|
|
233
|
+
|
|
234
|
+
#if defined(__cplusplus)
|
|
235
|
+
} // extern "C"
|
|
236
|
+
#endif
|
|
237
|
+
|
|
238
|
+
#endif // NK_TARGET_DIAMOND
|
|
239
|
+
#endif // NK_TARGET_X86_
|
|
240
|
+
#endif // NK_SPATIAL_DIAMOND_H
|