numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -8,12 +8,11 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section spatial_skylake_instructions Key AVX-512 Spatial Instructions
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* _mm512_fmadd_ps
|
|
13
|
-
* _mm512_sub_ps
|
|
14
|
-
* _mm512_rsqrt14_ps
|
|
15
|
-
* _mm512_sqrt_ps
|
|
16
|
-
* _mm512_reduce_add_ps (sequence) ~8-10cy - -
|
|
11
|
+
* Intrinsic Instruction Skylake-X Genoa
|
|
12
|
+
* _mm512_fmadd_ps VFMADD132PS (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p01
|
|
13
|
+
* _mm512_sub_ps VSUBPS (ZMM, ZMM, ZMM) 4cy @ p05 3cy @ p23
|
|
14
|
+
* _mm512_rsqrt14_ps VRSQRT14PS (ZMM, ZMM) 7cy @ p0+p0+p05 5cy @ p01
|
|
15
|
+
* _mm512_sqrt_ps VSQRTPS (ZMM, ZMM) 20cy @ p0+p0+p05 15cy @ p01
|
|
17
16
|
*
|
|
18
17
|
* Distance computations benefit from Skylake-X's dual FMA units achieving 0.5cy throughput for
|
|
19
18
|
* fused multiply-add operations. VRSQRT14PS provides ~14-bit precision reciprocal square root;
|
|
@@ -43,21 +42,21 @@ extern "C" {
|
|
|
43
42
|
|
|
44
43
|
/** @brief Reciprocal square root of 16 floats with Newton-Raphson refinement (~28-bit precision). */
|
|
45
44
|
NK_INTERNAL __m512 nk_rsqrt_f32x16_skylake_(__m512 x) {
|
|
46
|
-
__m512
|
|
47
|
-
__m512
|
|
48
|
-
|
|
49
|
-
return _mm512_mul_ps(_mm512_mul_ps(_mm512_set1_ps(0.5f),
|
|
45
|
+
__m512 rsqrt_f32x16 = _mm512_rsqrt14_ps(x);
|
|
46
|
+
__m512 nr_f32x16 = _mm512_mul_ps(_mm512_mul_ps(x, rsqrt_f32x16), rsqrt_f32x16);
|
|
47
|
+
nr_f32x16 = _mm512_sub_ps(_mm512_set1_ps(3.0f), nr_f32x16);
|
|
48
|
+
return _mm512_mul_ps(_mm512_mul_ps(_mm512_set1_ps(0.5f), rsqrt_f32x16), nr_f32x16);
|
|
50
49
|
}
|
|
51
50
|
|
|
52
51
|
/** @brief Reciprocal square root of 8 doubles with Newton-Raphson refinement (~28-bit precision). */
|
|
53
52
|
NK_INTERNAL __m512d nk_rsqrt_f64x8_skylake_(__m512d x) {
|
|
54
|
-
__m512d
|
|
55
|
-
__m512d
|
|
56
|
-
|
|
57
|
-
return _mm512_mul_pd(_mm512_mul_pd(_mm512_set1_pd(0.5),
|
|
53
|
+
__m512d rsqrt_f64x8 = _mm512_rsqrt14_pd(x);
|
|
54
|
+
__m512d nr_f64x8 = _mm512_mul_pd(_mm512_mul_pd(x, rsqrt_f64x8), rsqrt_f64x8);
|
|
55
|
+
nr_f64x8 = _mm512_sub_pd(_mm512_set1_pd(3.0), nr_f64x8);
|
|
56
|
+
return _mm512_mul_pd(_mm512_mul_pd(_mm512_set1_pd(0.5), rsqrt_f64x8), nr_f64x8);
|
|
58
57
|
}
|
|
59
58
|
|
|
60
|
-
#pragma region
|
|
59
|
+
#pragma region F32 and F64 Floats
|
|
61
60
|
|
|
62
61
|
NK_PUBLIC void nk_sqeuclidean_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
63
62
|
// Upcast to f64 for higher precision accumulation
|
|
@@ -282,8 +281,8 @@ NK_INTERNAL void nk_euclidean_through_f64_from_dot_skylake_(nk_b128_vec_t dots,
|
|
|
282
281
|
results->xmm_ps = _mm256_cvtpd_ps(dist_f64x4);
|
|
283
282
|
}
|
|
284
283
|
|
|
285
|
-
#pragma endregion
|
|
286
|
-
#pragma region
|
|
284
|
+
#pragma endregion F32 and F64 Floats
|
|
285
|
+
#pragma region F16 and BF16 Floats
|
|
287
286
|
|
|
288
287
|
NK_PUBLIC void nk_sqeuclidean_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
289
288
|
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
@@ -348,22 +347,22 @@ nk_angular_f16_skylake_cycle:
|
|
|
348
347
|
|
|
349
348
|
NK_PUBLIC void nk_sqeuclidean_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
350
349
|
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
351
|
-
__m128i
|
|
350
|
+
__m128i a_e4m3_u8x16, b_e4m3_u8x16;
|
|
352
351
|
|
|
353
352
|
nk_sqeuclidean_e4m3_skylake_cycle:
|
|
354
353
|
if (n < 16) {
|
|
355
354
|
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
356
|
-
|
|
357
|
-
|
|
355
|
+
a_e4m3_u8x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
356
|
+
b_e4m3_u8x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
358
357
|
n = 0;
|
|
359
358
|
}
|
|
360
359
|
else {
|
|
361
|
-
|
|
362
|
-
|
|
360
|
+
a_e4m3_u8x16 = _mm_loadu_si128((__m128i const *)a);
|
|
361
|
+
b_e4m3_u8x16 = _mm_loadu_si128((__m128i const *)b);
|
|
363
362
|
a += 16, b += 16, n -= 16;
|
|
364
363
|
}
|
|
365
|
-
__m512 a_f32x16 = nk_e4m3x16_to_f32x16_skylake_(
|
|
366
|
-
__m512 b_f32x16 = nk_e4m3x16_to_f32x16_skylake_(
|
|
364
|
+
__m512 a_f32x16 = nk_e4m3x16_to_f32x16_skylake_(a_e4m3_u8x16);
|
|
365
|
+
__m512 b_f32x16 = nk_e4m3x16_to_f32x16_skylake_(b_e4m3_u8x16);
|
|
367
366
|
__m512 diff_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
|
|
368
367
|
sum_f32x16 = _mm512_fmadd_ps(diff_f32x16, diff_f32x16, sum_f32x16);
|
|
369
368
|
if (n) goto nk_sqeuclidean_e4m3_skylake_cycle;
|
|
@@ -380,22 +379,22 @@ NK_PUBLIC void nk_angular_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, n
|
|
|
380
379
|
__m512 dot_f32x16 = _mm512_setzero_ps();
|
|
381
380
|
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
382
381
|
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
383
|
-
__m128i
|
|
382
|
+
__m128i a_e4m3_u8x16, b_e4m3_u8x16;
|
|
384
383
|
|
|
385
384
|
nk_angular_e4m3_skylake_cycle:
|
|
386
385
|
if (n < 16) {
|
|
387
386
|
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
388
|
-
|
|
389
|
-
|
|
387
|
+
a_e4m3_u8x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
388
|
+
b_e4m3_u8x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
390
389
|
n = 0;
|
|
391
390
|
}
|
|
392
391
|
else {
|
|
393
|
-
|
|
394
|
-
|
|
392
|
+
a_e4m3_u8x16 = _mm_loadu_si128((__m128i const *)a);
|
|
393
|
+
b_e4m3_u8x16 = _mm_loadu_si128((__m128i const *)b);
|
|
395
394
|
a += 16, b += 16, n -= 16;
|
|
396
395
|
}
|
|
397
|
-
__m512 a_f32x16 = nk_e4m3x16_to_f32x16_skylake_(
|
|
398
|
-
__m512 b_f32x16 = nk_e4m3x16_to_f32x16_skylake_(
|
|
396
|
+
__m512 a_f32x16 = nk_e4m3x16_to_f32x16_skylake_(a_e4m3_u8x16);
|
|
397
|
+
__m512 b_f32x16 = nk_e4m3x16_to_f32x16_skylake_(b_e4m3_u8x16);
|
|
399
398
|
dot_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, dot_f32x16);
|
|
400
399
|
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, a_norm_sq_f32x16);
|
|
401
400
|
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, b_norm_sq_f32x16);
|
|
@@ -409,22 +408,22 @@ nk_angular_e4m3_skylake_cycle:
|
|
|
409
408
|
|
|
410
409
|
NK_PUBLIC void nk_sqeuclidean_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
411
410
|
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
412
|
-
__m128i
|
|
411
|
+
__m128i a_e5m2_u8x16, b_e5m2_u8x16;
|
|
413
412
|
|
|
414
413
|
nk_sqeuclidean_e5m2_skylake_cycle:
|
|
415
414
|
if (n < 16) {
|
|
416
415
|
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
417
|
-
|
|
418
|
-
|
|
416
|
+
a_e5m2_u8x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
417
|
+
b_e5m2_u8x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
419
418
|
n = 0;
|
|
420
419
|
}
|
|
421
420
|
else {
|
|
422
|
-
|
|
423
|
-
|
|
421
|
+
a_e5m2_u8x16 = _mm_loadu_si128((__m128i const *)a);
|
|
422
|
+
b_e5m2_u8x16 = _mm_loadu_si128((__m128i const *)b);
|
|
424
423
|
a += 16, b += 16, n -= 16;
|
|
425
424
|
}
|
|
426
|
-
__m512 a_f32x16 = nk_e5m2x16_to_f32x16_skylake_(
|
|
427
|
-
__m512 b_f32x16 = nk_e5m2x16_to_f32x16_skylake_(
|
|
425
|
+
__m512 a_f32x16 = nk_e5m2x16_to_f32x16_skylake_(a_e5m2_u8x16);
|
|
426
|
+
__m512 b_f32x16 = nk_e5m2x16_to_f32x16_skylake_(b_e5m2_u8x16);
|
|
428
427
|
__m512 diff_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
|
|
429
428
|
sum_f32x16 = _mm512_fmadd_ps(diff_f32x16, diff_f32x16, sum_f32x16);
|
|
430
429
|
if (n) goto nk_sqeuclidean_e5m2_skylake_cycle;
|
|
@@ -441,22 +440,22 @@ NK_PUBLIC void nk_angular_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, n
|
|
|
441
440
|
__m512 dot_f32x16 = _mm512_setzero_ps();
|
|
442
441
|
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
443
442
|
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
444
|
-
__m128i
|
|
443
|
+
__m128i a_e5m2_u8x16, b_e5m2_u8x16;
|
|
445
444
|
|
|
446
445
|
nk_angular_e5m2_skylake_cycle:
|
|
447
446
|
if (n < 16) {
|
|
448
447
|
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
449
|
-
|
|
450
|
-
|
|
448
|
+
a_e5m2_u8x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
449
|
+
b_e5m2_u8x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
451
450
|
n = 0;
|
|
452
451
|
}
|
|
453
452
|
else {
|
|
454
|
-
|
|
455
|
-
|
|
453
|
+
a_e5m2_u8x16 = _mm_loadu_si128((__m128i const *)a);
|
|
454
|
+
b_e5m2_u8x16 = _mm_loadu_si128((__m128i const *)b);
|
|
456
455
|
a += 16, b += 16, n -= 16;
|
|
457
456
|
}
|
|
458
|
-
__m512 a_f32x16 = nk_e5m2x16_to_f32x16_skylake_(
|
|
459
|
-
__m512 b_f32x16 = nk_e5m2x16_to_f32x16_skylake_(
|
|
457
|
+
__m512 a_f32x16 = nk_e5m2x16_to_f32x16_skylake_(a_e5m2_u8x16);
|
|
458
|
+
__m512 b_f32x16 = nk_e5m2x16_to_f32x16_skylake_(b_e5m2_u8x16);
|
|
460
459
|
dot_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, dot_f32x16);
|
|
461
460
|
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, a_norm_sq_f32x16);
|
|
462
461
|
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, b_norm_sq_f32x16);
|
|
@@ -470,22 +469,22 @@ nk_angular_e5m2_skylake_cycle:
|
|
|
470
469
|
|
|
471
470
|
NK_PUBLIC void nk_sqeuclidean_e2m3_skylake(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
472
471
|
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
473
|
-
__m128i
|
|
472
|
+
__m128i a_e2m3_u8x16, b_e2m3_u8x16;
|
|
474
473
|
|
|
475
474
|
nk_sqeuclidean_e2m3_skylake_cycle:
|
|
476
475
|
if (n < 16) {
|
|
477
476
|
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
478
|
-
|
|
479
|
-
|
|
477
|
+
a_e2m3_u8x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
478
|
+
b_e2m3_u8x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
480
479
|
n = 0;
|
|
481
480
|
}
|
|
482
481
|
else {
|
|
483
|
-
|
|
484
|
-
|
|
482
|
+
a_e2m3_u8x16 = _mm_loadu_si128((__m128i const *)a);
|
|
483
|
+
b_e2m3_u8x16 = _mm_loadu_si128((__m128i const *)b);
|
|
485
484
|
a += 16, b += 16, n -= 16;
|
|
486
485
|
}
|
|
487
|
-
__m512 a_f32x16 = nk_e2m3x16_to_f32x16_skylake_(
|
|
488
|
-
__m512 b_f32x16 = nk_e2m3x16_to_f32x16_skylake_(
|
|
486
|
+
__m512 a_f32x16 = nk_e2m3x16_to_f32x16_skylake_(a_e2m3_u8x16);
|
|
487
|
+
__m512 b_f32x16 = nk_e2m3x16_to_f32x16_skylake_(b_e2m3_u8x16);
|
|
489
488
|
__m512 diff_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
|
|
490
489
|
sum_f32x16 = _mm512_fmadd_ps(diff_f32x16, diff_f32x16, sum_f32x16);
|
|
491
490
|
if (n) goto nk_sqeuclidean_e2m3_skylake_cycle;
|
|
@@ -502,22 +501,22 @@ NK_PUBLIC void nk_angular_e2m3_skylake(nk_e2m3_t const *a, nk_e2m3_t const *b, n
|
|
|
502
501
|
__m512 dot_f32x16 = _mm512_setzero_ps();
|
|
503
502
|
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
504
503
|
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
505
|
-
__m128i
|
|
504
|
+
__m128i a_e2m3_u8x16, b_e2m3_u8x16;
|
|
506
505
|
|
|
507
506
|
nk_angular_e2m3_skylake_cycle:
|
|
508
507
|
if (n < 16) {
|
|
509
508
|
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
510
|
-
|
|
511
|
-
|
|
509
|
+
a_e2m3_u8x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
510
|
+
b_e2m3_u8x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
512
511
|
n = 0;
|
|
513
512
|
}
|
|
514
513
|
else {
|
|
515
|
-
|
|
516
|
-
|
|
514
|
+
a_e2m3_u8x16 = _mm_loadu_si128((__m128i const *)a);
|
|
515
|
+
b_e2m3_u8x16 = _mm_loadu_si128((__m128i const *)b);
|
|
517
516
|
a += 16, b += 16, n -= 16;
|
|
518
517
|
}
|
|
519
|
-
__m512 a_f32x16 = nk_e2m3x16_to_f32x16_skylake_(
|
|
520
|
-
__m512 b_f32x16 = nk_e2m3x16_to_f32x16_skylake_(
|
|
518
|
+
__m512 a_f32x16 = nk_e2m3x16_to_f32x16_skylake_(a_e2m3_u8x16);
|
|
519
|
+
__m512 b_f32x16 = nk_e2m3x16_to_f32x16_skylake_(b_e2m3_u8x16);
|
|
521
520
|
dot_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, dot_f32x16);
|
|
522
521
|
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, a_norm_sq_f32x16);
|
|
523
522
|
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, b_norm_sq_f32x16);
|
|
@@ -531,22 +530,22 @@ nk_angular_e2m3_skylake_cycle:
|
|
|
531
530
|
|
|
532
531
|
NK_PUBLIC void nk_sqeuclidean_e3m2_skylake(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
533
532
|
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
534
|
-
__m128i
|
|
533
|
+
__m128i a_e3m2_u8x16, b_e3m2_u8x16;
|
|
535
534
|
|
|
536
535
|
nk_sqeuclidean_e3m2_skylake_cycle:
|
|
537
536
|
if (n < 16) {
|
|
538
537
|
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
539
|
-
|
|
540
|
-
|
|
538
|
+
a_e3m2_u8x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
539
|
+
b_e3m2_u8x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
541
540
|
n = 0;
|
|
542
541
|
}
|
|
543
542
|
else {
|
|
544
|
-
|
|
545
|
-
|
|
543
|
+
a_e3m2_u8x16 = _mm_loadu_si128((__m128i const *)a);
|
|
544
|
+
b_e3m2_u8x16 = _mm_loadu_si128((__m128i const *)b);
|
|
546
545
|
a += 16, b += 16, n -= 16;
|
|
547
546
|
}
|
|
548
|
-
__m512 a_f32x16 = nk_e3m2x16_to_f32x16_skylake_(
|
|
549
|
-
__m512 b_f32x16 = nk_e3m2x16_to_f32x16_skylake_(
|
|
547
|
+
__m512 a_f32x16 = nk_e3m2x16_to_f32x16_skylake_(a_e3m2_u8x16);
|
|
548
|
+
__m512 b_f32x16 = nk_e3m2x16_to_f32x16_skylake_(b_e3m2_u8x16);
|
|
550
549
|
__m512 diff_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
|
|
551
550
|
sum_f32x16 = _mm512_fmadd_ps(diff_f32x16, diff_f32x16, sum_f32x16);
|
|
552
551
|
if (n) goto nk_sqeuclidean_e3m2_skylake_cycle;
|
|
@@ -563,22 +562,22 @@ NK_PUBLIC void nk_angular_e3m2_skylake(nk_e3m2_t const *a, nk_e3m2_t const *b, n
|
|
|
563
562
|
__m512 dot_f32x16 = _mm512_setzero_ps();
|
|
564
563
|
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
565
564
|
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
566
|
-
__m128i
|
|
565
|
+
__m128i a_e3m2_u8x16, b_e3m2_u8x16;
|
|
567
566
|
|
|
568
567
|
nk_angular_e3m2_skylake_cycle:
|
|
569
568
|
if (n < 16) {
|
|
570
569
|
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
571
|
-
|
|
572
|
-
|
|
570
|
+
a_e3m2_u8x16 = _mm_maskz_loadu_epi8(mask, a);
|
|
571
|
+
b_e3m2_u8x16 = _mm_maskz_loadu_epi8(mask, b);
|
|
573
572
|
n = 0;
|
|
574
573
|
}
|
|
575
574
|
else {
|
|
576
|
-
|
|
577
|
-
|
|
575
|
+
a_e3m2_u8x16 = _mm_loadu_si128((__m128i const *)a);
|
|
576
|
+
b_e3m2_u8x16 = _mm_loadu_si128((__m128i const *)b);
|
|
578
577
|
a += 16, b += 16, n -= 16;
|
|
579
578
|
}
|
|
580
|
-
__m512 a_f32x16 = nk_e3m2x16_to_f32x16_skylake_(
|
|
581
|
-
__m512 b_f32x16 = nk_e3m2x16_to_f32x16_skylake_(
|
|
579
|
+
__m512 a_f32x16 = nk_e3m2x16_to_f32x16_skylake_(a_e3m2_u8x16);
|
|
580
|
+
__m512 b_f32x16 = nk_e3m2x16_to_f32x16_skylake_(b_e3m2_u8x16);
|
|
582
581
|
dot_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, dot_f32x16);
|
|
583
582
|
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, a_norm_sq_f32x16);
|
|
584
583
|
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, b_norm_sq_f32x16);
|
|
@@ -600,7 +599,7 @@ nk_angular_e3m2_skylake_cycle:
|
|
|
600
599
|
} // extern "C"
|
|
601
600
|
#endif
|
|
602
601
|
|
|
603
|
-
#pragma endregion
|
|
602
|
+
#pragma endregion F16 and BF16 Floats
|
|
604
603
|
#endif // NK_TARGET_SKYLAKE
|
|
605
604
|
#endif // NK_TARGET_X86_
|
|
606
605
|
#endif // NK_SPATIAL_SKYLAKE_H
|
|
@@ -8,19 +8,19 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section spatial_sve_instructions ARM SVE Instructions
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* svld1_f32
|
|
13
|
-
* svsub_f32_x
|
|
14
|
-
* svmla_f32_x
|
|
15
|
-
* svaddv_f32
|
|
16
|
-
* svdupq_n_f32
|
|
17
|
-
* svwhilelt_b32
|
|
18
|
-
* svptrue_b32
|
|
19
|
-
* svcntw
|
|
20
|
-
* svld1_f64
|
|
21
|
-
* svsub_f64_x
|
|
22
|
-
* svmla_f64_x
|
|
23
|
-
* svaddv_f64
|
|
11
|
+
* Intrinsic Instruction V1
|
|
12
|
+
* svld1_f32 LD1W (Z.S, P/Z, [Xn]) 4-6cy @ 2p
|
|
13
|
+
* svsub_f32_x FSUB (Z.S, P/M, Z.S, Z.S) 3cy @ 2p
|
|
14
|
+
* svmla_f32_x FMLA (Z.S, P/M, Z.S, Z.S) 4cy @ 2p
|
|
15
|
+
* svaddv_f32 FADDV (S, P, Z.S) 6cy @ 1p
|
|
16
|
+
* svdupq_n_f32 DUP (Z.S, #imm) 1cy @ 2p
|
|
17
|
+
* svwhilelt_b32 WHILELT (P.S, Xn, Xm) 2cy @ 1p
|
|
18
|
+
* svptrue_b32 PTRUE (P.S, pattern) 1cy @ 2p
|
|
19
|
+
* svcntw CNTW (Xd) 1cy @ 2p
|
|
20
|
+
* svld1_f64 LD1D (Z.D, P/Z, [Xn]) 4-6cy @ 2p
|
|
21
|
+
* svsub_f64_x FSUB (Z.D, P/M, Z.D, Z.D) 3cy @ 2p
|
|
22
|
+
* svmla_f64_x FMLA (Z.D, P/M, Z.D, Z.D) 4cy @ 2p
|
|
23
|
+
* svaddv_f64 FADDV (D, P, Z.D) 6cy @ 1p
|
|
24
24
|
*
|
|
25
25
|
* SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
|
|
26
26
|
* and Apple M4+ use 128-bit. Code using svcntb() adapts automatically, but wider vectors
|
|
@@ -63,10 +63,10 @@ extern "C" {
|
|
|
63
63
|
* @param x Input vector (must be positive for meaningful results)
|
|
64
64
|
* @return Approximate 1/sqrt(x) with ~23-bit mantissa accuracy
|
|
65
65
|
*/
|
|
66
|
-
NK_INTERNAL svfloat32_t nk_rsqrt_f32x_sve_(svbool_t
|
|
66
|
+
NK_INTERNAL svfloat32_t nk_rsqrt_f32x_sve_(svbool_t predicate_b32x, svfloat32_t x) NK_STREAMING_COMPATIBLE_ {
|
|
67
67
|
svfloat32_t r = svrsqrte_f32(x);
|
|
68
|
-
r = svmul_f32_x(
|
|
69
|
-
r = svmul_f32_x(
|
|
68
|
+
r = svmul_f32_x(predicate_b32x, r, svrsqrts_f32(svmul_f32_x(predicate_b32x, x, r), r));
|
|
69
|
+
r = svmul_f32_x(predicate_b32x, r, svrsqrts_f32(svmul_f32_x(predicate_b32x, x, r), r));
|
|
70
70
|
return r;
|
|
71
71
|
}
|
|
72
72
|
|
|
@@ -79,29 +79,39 @@ NK_INTERNAL svfloat32_t nk_rsqrt_f32x_sve_(svbool_t predicate, svfloat32_t x) NK
|
|
|
79
79
|
* Marked `__arm_streaming_compatible` so the helper is callable from both streaming
|
|
80
80
|
* (SME) and non-streaming (SVE) contexts without mode transitions.
|
|
81
81
|
*
|
|
82
|
-
* @param
|
|
82
|
+
* @param predicate_b32x Active-lane mask
|
|
83
83
|
* @param x Input vector (must be positive for meaningful results)
|
|
84
84
|
* @return Approximate 1/sqrt(x) with ~52-bit mantissa accuracy
|
|
85
85
|
*/
|
|
86
|
-
NK_INTERNAL svfloat64_t nk_rsqrt_f64x_sve_(svbool_t
|
|
86
|
+
NK_INTERNAL svfloat64_t nk_rsqrt_f64x_sve_(svbool_t predicate_b64x, svfloat64_t x) NK_STREAMING_COMPATIBLE_ {
|
|
87
87
|
svfloat64_t r = svrsqrte_f64(x);
|
|
88
|
-
r = svmul_f64_x(
|
|
89
|
-
r = svmul_f64_x(
|
|
90
|
-
r = svmul_f64_x(
|
|
88
|
+
r = svmul_f64_x(predicate_b64x, r, svrsqrts_f64(svmul_f64_x(predicate_b64x, x, r), r));
|
|
89
|
+
r = svmul_f64_x(predicate_b64x, r, svrsqrts_f64(svmul_f64_x(predicate_b64x, x, r), r));
|
|
90
|
+
r = svmul_f64_x(predicate_b64x, r, svrsqrts_f64(svmul_f64_x(predicate_b64x, x, r), r));
|
|
91
91
|
return r;
|
|
92
92
|
}
|
|
93
93
|
|
|
94
94
|
NK_PUBLIC void nk_sqeuclidean_f32_sve(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
95
95
|
nk_size_t i = 0;
|
|
96
|
-
nk_size_t const vector_length = svcntd();
|
|
97
96
|
svfloat64_t dist_sq_f64x = svdupq_n_f64(0.0, 0.0);
|
|
98
|
-
for (; i < n; i +=
|
|
99
|
-
svbool_t
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
97
|
+
for (; i < n; i += svcntw()) {
|
|
98
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(i, n);
|
|
99
|
+
svfloat32_t a_f32x = svld1_f32(predicate_b32x, a + i);
|
|
100
|
+
svfloat32_t b_f32x = svld1_f32(predicate_b32x, b + i);
|
|
101
|
+
nk_size_t remaining = n - i < svcntw() ? n - i : svcntw();
|
|
102
|
+
|
|
103
|
+
// svcvt_f64_f32_x widens only even-indexed f32 elements; svext by 1 shifts odd into even.
|
|
104
|
+
svbool_t pred_even_b64x = svwhilelt_b64_u64(0u, (remaining + 1) / 2);
|
|
105
|
+
svfloat64_t a_even_f64x = svcvt_f64_f32_x(pred_even_b64x, a_f32x);
|
|
106
|
+
svfloat64_t b_even_f64x = svcvt_f64_f32_x(pred_even_b64x, b_f32x);
|
|
107
|
+
svfloat64_t diff_even_f64x = svsub_f64_x(pred_even_b64x, a_even_f64x, b_even_f64x);
|
|
108
|
+
dist_sq_f64x = svmla_f64_m(pred_even_b64x, dist_sq_f64x, diff_even_f64x, diff_even_f64x);
|
|
109
|
+
|
|
110
|
+
svbool_t pred_odd_b64x = svwhilelt_b64_u64(0u, remaining / 2);
|
|
111
|
+
svfloat64_t a_odd_f64x = svcvt_f64_f32_x(pred_odd_b64x, svext_f32(a_f32x, a_f32x, 1));
|
|
112
|
+
svfloat64_t b_odd_f64x = svcvt_f64_f32_x(pred_odd_b64x, svext_f32(b_f32x, b_f32x, 1));
|
|
113
|
+
svfloat64_t diff_odd_f64x = svsub_f64_x(pred_odd_b64x, a_odd_f64x, b_odd_f64x);
|
|
114
|
+
dist_sq_f64x = svmla_f64_m(pred_odd_b64x, dist_sq_f64x, diff_odd_f64x, diff_odd_f64x);
|
|
105
115
|
}
|
|
106
116
|
nk_f64_t dist_sq_f64 = svaddv_f64(svptrue_b64(), dist_sq_f64x);
|
|
107
117
|
*result = dist_sq_f64;
|
|
@@ -114,18 +124,29 @@ NK_PUBLIC void nk_euclidean_f32_sve(nk_f32_t const *a, nk_f32_t const *b, nk_siz
|
|
|
114
124
|
|
|
115
125
|
NK_PUBLIC void nk_angular_f32_sve(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
116
126
|
nk_size_t i = 0;
|
|
117
|
-
nk_size_t const vector_length = svcntd();
|
|
118
127
|
svfloat64_t ab_f64x = svdupq_n_f64(0.0, 0.0);
|
|
119
128
|
svfloat64_t a2_f64x = svdupq_n_f64(0.0, 0.0);
|
|
120
129
|
svfloat64_t b2_f64x = svdupq_n_f64(0.0, 0.0);
|
|
121
|
-
for (; i < n; i +=
|
|
122
|
-
svbool_t
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
130
|
+
for (; i < n; i += svcntw()) {
|
|
131
|
+
svbool_t predicate_b32x = svwhilelt_b32_u64(i, n);
|
|
132
|
+
svfloat32_t a_f32x = svld1_f32(predicate_b32x, a + i);
|
|
133
|
+
svfloat32_t b_f32x = svld1_f32(predicate_b32x, b + i);
|
|
134
|
+
nk_size_t remaining = n - i < svcntw() ? n - i : svcntw();
|
|
135
|
+
|
|
136
|
+
// svcvt_f64_f32_x widens only even-indexed f32 elements; svext by 1 shifts odd into even.
|
|
137
|
+
svbool_t pred_even_b64x = svwhilelt_b64_u64(0u, (remaining + 1) / 2);
|
|
138
|
+
svfloat64_t a_even_f64x = svcvt_f64_f32_x(pred_even_b64x, a_f32x);
|
|
139
|
+
svfloat64_t b_even_f64x = svcvt_f64_f32_x(pred_even_b64x, b_f32x);
|
|
140
|
+
ab_f64x = svmla_f64_m(pred_even_b64x, ab_f64x, a_even_f64x, b_even_f64x);
|
|
141
|
+
a2_f64x = svmla_f64_m(pred_even_b64x, a2_f64x, a_even_f64x, a_even_f64x);
|
|
142
|
+
b2_f64x = svmla_f64_m(pred_even_b64x, b2_f64x, b_even_f64x, b_even_f64x);
|
|
143
|
+
|
|
144
|
+
svbool_t pred_odd_b64x = svwhilelt_b64_u64(0u, remaining / 2);
|
|
145
|
+
svfloat64_t a_odd_f64x = svcvt_f64_f32_x(pred_odd_b64x, svext_f32(a_f32x, a_f32x, 1));
|
|
146
|
+
svfloat64_t b_odd_f64x = svcvt_f64_f32_x(pred_odd_b64x, svext_f32(b_f32x, b_f32x, 1));
|
|
147
|
+
ab_f64x = svmla_f64_m(pred_odd_b64x, ab_f64x, a_odd_f64x, b_odd_f64x);
|
|
148
|
+
a2_f64x = svmla_f64_m(pred_odd_b64x, a2_f64x, a_odd_f64x, a_odd_f64x);
|
|
149
|
+
b2_f64x = svmla_f64_m(pred_odd_b64x, b2_f64x, b_odd_f64x, b_odd_f64x);
|
|
129
150
|
}
|
|
130
151
|
|
|
131
152
|
nk_f64_t ab_f64 = svaddv_f64(svptrue_b64(), ab_f64x);
|
|
@@ -139,29 +160,29 @@ NK_PUBLIC void nk_sqeuclidean_f64_sve(nk_f64_t const *a, nk_f64_t const *b, nk_s
|
|
|
139
160
|
nk_size_t i = 0;
|
|
140
161
|
svfloat64_t sum_f64x = svdupq_n_f64(0.0, 0.0);
|
|
141
162
|
svfloat64_t compensation_f64x = svdupq_n_f64(0.0, 0.0);
|
|
142
|
-
svbool_t
|
|
163
|
+
svbool_t predicate_all_b64x = svptrue_b64();
|
|
143
164
|
do {
|
|
144
|
-
svbool_t
|
|
145
|
-
svfloat64_t a_f64x = svld1_f64(
|
|
146
|
-
svfloat64_t b_f64x = svld1_f64(
|
|
147
|
-
svfloat64_t diff_f64x = svsub_f64_x(
|
|
148
|
-
svfloat64_t diff_sq_f64x = svmul_f64_x(
|
|
165
|
+
svbool_t predicate_b64x = svwhilelt_b64_u64(i, n);
|
|
166
|
+
svfloat64_t a_f64x = svld1_f64(predicate_b64x, a + i);
|
|
167
|
+
svfloat64_t b_f64x = svld1_f64(predicate_b64x, b + i);
|
|
168
|
+
svfloat64_t diff_f64x = svsub_f64_x(predicate_b64x, a_f64x, b_f64x);
|
|
169
|
+
svfloat64_t diff_sq_f64x = svmul_f64_x(predicate_b64x, diff_f64x, diff_f64x);
|
|
149
170
|
// Neumaier: t = sum + x
|
|
150
|
-
svfloat64_t t_f64x =
|
|
151
|
-
svfloat64_t abs_sum_f64x = svabs_f64_x(
|
|
171
|
+
svfloat64_t t_f64x = svadd_f64_m(predicate_b64x, sum_f64x, diff_sq_f64x);
|
|
172
|
+
svfloat64_t abs_sum_f64x = svabs_f64_x(predicate_b64x, sum_f64x);
|
|
152
173
|
// diff_sq is already non-negative (it's a square), so svabs is unnecessary
|
|
153
|
-
svbool_t
|
|
174
|
+
svbool_t sum_ge_x_b64x = svcmpge_f64(predicate_b64x, abs_sum_f64x, diff_sq_f64x);
|
|
154
175
|
// When |sum| >= |x|: comp += (sum - t) + x; when |x| > |sum|: comp += (x - t) + sum
|
|
155
|
-
svfloat64_t comp_sum_large_f64x = svadd_f64_x(
|
|
176
|
+
svfloat64_t comp_sum_large_f64x = svadd_f64_x(predicate_b64x, svsub_f64_x(predicate_b64x, sum_f64x, t_f64x),
|
|
156
177
|
diff_sq_f64x);
|
|
157
|
-
svfloat64_t comp_x_large_f64x = svadd_f64_x(
|
|
178
|
+
svfloat64_t comp_x_large_f64x = svadd_f64_x(predicate_b64x, svsub_f64_x(predicate_b64x, diff_sq_f64x, t_f64x),
|
|
158
179
|
sum_f64x);
|
|
159
|
-
svfloat64_t comp_update_f64x = svsel_f64(
|
|
160
|
-
compensation_f64x =
|
|
180
|
+
svfloat64_t comp_update_f64x = svsel_f64(sum_ge_x_b64x, comp_sum_large_f64x, comp_x_large_f64x);
|
|
181
|
+
compensation_f64x = svadd_f64_m(predicate_b64x, compensation_f64x, comp_update_f64x);
|
|
161
182
|
sum_f64x = t_f64x;
|
|
162
183
|
i += svcntd();
|
|
163
184
|
} while (i < n);
|
|
164
|
-
*result = nk_dot_stable_sum_f64_sve_(
|
|
185
|
+
*result = nk_dot_stable_sum_f64_sve_(predicate_all_b64x, sum_f64x, compensation_f64x);
|
|
165
186
|
}
|
|
166
187
|
|
|
167
188
|
NK_PUBLIC void nk_euclidean_f64_sve(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
@@ -177,35 +198,35 @@ NK_PUBLIC void nk_angular_f64_sve(nk_f64_t const *a, nk_f64_t const *b, nk_size_
|
|
|
177
198
|
svfloat64_t ab_compensation_f64x = svdupq_n_f64(0.0, 0.0);
|
|
178
199
|
svfloat64_t a2_f64x = svdupq_n_f64(0.0, 0.0);
|
|
179
200
|
svfloat64_t b2_f64x = svdupq_n_f64(0.0, 0.0);
|
|
180
|
-
svbool_t
|
|
201
|
+
svbool_t predicate_all_b64x = svptrue_b64();
|
|
181
202
|
do {
|
|
182
|
-
svbool_t
|
|
183
|
-
svfloat64_t a_f64x = svld1_f64(
|
|
184
|
-
svfloat64_t b_f64x = svld1_f64(
|
|
203
|
+
svbool_t predicate_b64x = svwhilelt_b64_u64(i, n);
|
|
204
|
+
svfloat64_t a_f64x = svld1_f64(predicate_b64x, a + i);
|
|
205
|
+
svfloat64_t b_f64x = svld1_f64(predicate_b64x, b + i);
|
|
185
206
|
// TwoProd for ab: product = a*b, error = fma(a,b,-product) = -(product - a*b)
|
|
186
|
-
svfloat64_t product_f64x = svmul_f64_x(
|
|
187
|
-
svfloat64_t product_error_f64x = svneg_f64_x(
|
|
188
|
-
svnmls_f64_x(
|
|
207
|
+
svfloat64_t product_f64x = svmul_f64_x(predicate_b64x, a_f64x, b_f64x);
|
|
208
|
+
svfloat64_t product_error_f64x = svneg_f64_x(predicate_b64x,
|
|
209
|
+
svnmls_f64_x(predicate_b64x, product_f64x, a_f64x, b_f64x));
|
|
189
210
|
// TwoSum: (tentative_sum, sum_error) = TwoSum(sum, product)
|
|
190
|
-
svfloat64_t tentative_sum_f64x =
|
|
191
|
-
svfloat64_t virtual_addend_f64x = svsub_f64_x(
|
|
211
|
+
svfloat64_t tentative_sum_f64x = svadd_f64_m(predicate_b64x, ab_sum_f64x, product_f64x);
|
|
212
|
+
svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_b64x, tentative_sum_f64x, ab_sum_f64x);
|
|
192
213
|
svfloat64_t sum_error_f64x = svadd_f64_x(
|
|
193
|
-
|
|
194
|
-
svsub_f64_x(
|
|
195
|
-
svsub_f64_x(
|
|
196
|
-
svsub_f64_x(
|
|
214
|
+
predicate_b64x,
|
|
215
|
+
svsub_f64_x(predicate_b64x, ab_sum_f64x,
|
|
216
|
+
svsub_f64_x(predicate_b64x, tentative_sum_f64x, virtual_addend_f64x)),
|
|
217
|
+
svsub_f64_x(predicate_b64x, product_f64x, virtual_addend_f64x));
|
|
197
218
|
ab_sum_f64x = tentative_sum_f64x;
|
|
198
|
-
ab_compensation_f64x =
|
|
199
|
-
svadd_f64_x(
|
|
219
|
+
ab_compensation_f64x = svadd_f64_m(predicate_b64x, ab_compensation_f64x,
|
|
220
|
+
svadd_f64_x(predicate_b64x, sum_error_f64x, product_error_f64x));
|
|
200
221
|
// Simple FMA for self-products (no cancellation)
|
|
201
|
-
a2_f64x =
|
|
202
|
-
b2_f64x =
|
|
222
|
+
a2_f64x = svmla_f64_m(predicate_b64x, a2_f64x, a_f64x, a_f64x);
|
|
223
|
+
b2_f64x = svmla_f64_m(predicate_b64x, b2_f64x, b_f64x, b_f64x);
|
|
203
224
|
i += svcntd();
|
|
204
225
|
} while (i < n);
|
|
205
226
|
|
|
206
|
-
nk_f64_t ab_f64 = nk_dot_stable_sum_f64_sve_(
|
|
207
|
-
nk_f64_t a2_f64 = svaddv_f64(
|
|
208
|
-
nk_f64_t b2_f64 = svaddv_f64(
|
|
227
|
+
nk_f64_t ab_f64 = nk_dot_stable_sum_f64_sve_(predicate_all_b64x, ab_sum_f64x, ab_compensation_f64x);
|
|
228
|
+
nk_f64_t a2_f64 = svaddv_f64(predicate_all_b64x, a2_f64x);
|
|
229
|
+
nk_f64_t b2_f64 = svaddv_f64(predicate_all_b64x, b2_f64x);
|
|
209
230
|
*result = nk_angular_normalize_f64_neon_(ab_f64, a2_f64, b2_f64);
|
|
210
231
|
}
|
|
211
232
|
|
|
@@ -8,19 +8,19 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section spatial_svebfdot_instructions ARM SVE+BF16 Instructions
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* svld1_bf16
|
|
13
|
-
* svld1_u16
|
|
14
|
-
* svbfdot_f32
|
|
15
|
-
* svmla_f32_x
|
|
16
|
-
* svsub_f32_x
|
|
17
|
-
* svaddv_f32
|
|
18
|
-
* svunpklo_u32
|
|
19
|
-
* svunpkhi_u32
|
|
20
|
-
* svlsl_n_u32_x
|
|
21
|
-
* svwhilelt_b16
|
|
22
|
-
* svwhilelt_b32
|
|
23
|
-
* svcnth
|
|
11
|
+
* Intrinsic Instruction V1
|
|
12
|
+
* svld1_bf16 LD1H (Z.H, P/Z, [Xn]) 4-6cy @ 2p
|
|
13
|
+
* svld1_u16 LD1H (Z.H, P/Z, [Xn]) 4-6cy @ 2p
|
|
14
|
+
* svbfdot_f32 BFDOT (Z.S, Z.H, Z.H) 4cy @ 2p
|
|
15
|
+
* svmla_f32_x FMLA (Z.S, P/M, Z.S, Z.S) 4cy @ 2p
|
|
16
|
+
* svsub_f32_x FSUB (Z.S, P/M, Z.S, Z.S) 3cy @ 2p
|
|
17
|
+
* svaddv_f32 FADDV (S, P, Z.S) 6cy @ 1p
|
|
18
|
+
* svunpklo_u32 UUNPKLO (Z.S, Z.H) 2cy @ 2p
|
|
19
|
+
* svunpkhi_u32 UUNPKHI (Z.S, Z.H) 2cy @ 2p
|
|
20
|
+
* svlsl_n_u32_x LSL (Z.S, P/M, Z.S, #imm) 2cy @ 2p
|
|
21
|
+
* svwhilelt_b16 WHILELT (P.H, Xn, Xm) 2cy @ 1p
|
|
22
|
+
* svwhilelt_b32 WHILELT (P.S, Xn, Xm) 2cy @ 1p
|
|
23
|
+
* svcnth CNTH (Xd) 1cy @ 2p
|
|
24
24
|
*
|
|
25
25
|
* SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
|
|
26
26
|
* and Apple M4+ use 128-bit. Code using svcntb() adapts automatically, but wider vectors
|
|
@@ -57,22 +57,22 @@ NK_PUBLIC void nk_sqeuclidean_bf16_svebfdot(nk_bf16_t const *a_enum, nk_bf16_t c
|
|
|
57
57
|
nk_u16_t const *a = (nk_u16_t const *)(a_enum);
|
|
58
58
|
nk_u16_t const *b = (nk_u16_t const *)(b_enum);
|
|
59
59
|
do {
|
|
60
|
-
svbool_t
|
|
61
|
-
svuint16_t a_u16x = svld1_u16(
|
|
62
|
-
svuint16_t b_u16x = svld1_u16(
|
|
60
|
+
svbool_t predicate_b16x = svwhilelt_b16_u64(i, n);
|
|
61
|
+
svuint16_t a_u16x = svld1_u16(predicate_b16x, a + i);
|
|
62
|
+
svuint16_t b_u16x = svld1_u16(predicate_b16x, b + i);
|
|
63
63
|
|
|
64
64
|
// There is no `bf16` subtraction in SVE, so we need to convert to `u32` and shift.
|
|
65
|
-
svbool_t
|
|
66
|
-
svbool_t
|
|
67
|
-
svfloat32_t a_low_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(
|
|
68
|
-
svfloat32_t a_high_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(
|
|
69
|
-
svfloat32_t b_low_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(
|
|
70
|
-
svfloat32_t b_high_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(
|
|
65
|
+
svbool_t predicate_low_b32x = svwhilelt_b32_u64(i, n);
|
|
66
|
+
svbool_t predicate_high_b32x = svwhilelt_b32_u64(i + svcnth() / 2, n);
|
|
67
|
+
svfloat32_t a_low_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(predicate_low_b32x, svunpklo_u32(a_u16x), 16));
|
|
68
|
+
svfloat32_t a_high_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(predicate_high_b32x, svunpkhi_u32(a_u16x), 16));
|
|
69
|
+
svfloat32_t b_low_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(predicate_low_b32x, svunpklo_u32(b_u16x), 16));
|
|
70
|
+
svfloat32_t b_high_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(predicate_high_b32x, svunpkhi_u32(b_u16x), 16));
|
|
71
71
|
|
|
72
|
-
svfloat32_t a_minus_b_low_f32x = svsub_f32_x(
|
|
73
|
-
svfloat32_t a_minus_b_high_f32x = svsub_f32_x(
|
|
74
|
-
d2_low_f32x =
|
|
75
|
-
d2_high_f32x =
|
|
72
|
+
svfloat32_t a_minus_b_low_f32x = svsub_f32_x(predicate_low_b32x, a_low_f32x, b_low_f32x);
|
|
73
|
+
svfloat32_t a_minus_b_high_f32x = svsub_f32_x(predicate_high_b32x, a_high_f32x, b_high_f32x);
|
|
74
|
+
d2_low_f32x = svmla_f32_m(predicate_low_b32x, d2_low_f32x, a_minus_b_low_f32x, a_minus_b_low_f32x);
|
|
75
|
+
d2_high_f32x = svmla_f32_m(predicate_high_b32x, d2_high_f32x, a_minus_b_high_f32x, a_minus_b_high_f32x);
|
|
76
76
|
i += svcnth();
|
|
77
77
|
} while (i < n);
|
|
78
78
|
nk_f32_t d2 = svaddv_f32(svptrue_b32(), d2_low_f32x) + svaddv_f32(svptrue_b32(), d2_high_f32x);
|
|
@@ -92,9 +92,9 @@ NK_PUBLIC void nk_angular_bf16_svebfdot(nk_bf16_t const *a_enum, nk_bf16_t const
|
|
|
92
92
|
nk_bf16_for_arm_simd_t const *a = (nk_bf16_for_arm_simd_t const *)(a_enum);
|
|
93
93
|
nk_bf16_for_arm_simd_t const *b = (nk_bf16_for_arm_simd_t const *)(b_enum);
|
|
94
94
|
do {
|
|
95
|
-
svbool_t
|
|
96
|
-
svbfloat16_t a_bf16x = svld1_bf16(
|
|
97
|
-
svbfloat16_t b_bf16x = svld1_bf16(
|
|
95
|
+
svbool_t predicate_b16x = svwhilelt_b16_u64(i, n);
|
|
96
|
+
svbfloat16_t a_bf16x = svld1_bf16(predicate_b16x, a + i);
|
|
97
|
+
svbfloat16_t b_bf16x = svld1_bf16(predicate_b16x, b + i);
|
|
98
98
|
ab_f32x = svbfdot_f32(ab_f32x, a_bf16x, b_bf16x);
|
|
99
99
|
a2_f32x = svbfdot_f32(a2_f32x, a_bf16x, a_bf16x);
|
|
100
100
|
b2_f32x = svbfdot_f32(b2_f32x, b_bf16x, b_bf16x);
|