numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -10,15 +10,18 @@
|
|
|
10
10
|
*
|
|
11
11
|
* Key NEON instructions for dot products:
|
|
12
12
|
*
|
|
13
|
-
* Intrinsic
|
|
14
|
-
*
|
|
15
|
-
*
|
|
16
|
-
*
|
|
17
|
-
* vmulq_f32
|
|
18
|
-
*
|
|
19
|
-
*
|
|
20
|
-
*
|
|
21
|
-
*
|
|
13
|
+
* Intrinsic Instruction A76 M5
|
|
14
|
+
* vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy @ 2p 3cy @ 4p
|
|
15
|
+
* vfmaq_f64 FMLA (V.2D, V.2D, V.2D) 4cy @ 2p 4cy @ 4p
|
|
16
|
+
* vfmsq_f64 FMLS (V.2D, V.2D, V.2D) 4cy @ 2p 4cy @ 4p
|
|
17
|
+
* vmulq_f32 FMUL (V.4S, V.4S, V.4S) 3cy @ 2p 3cy @ 4p
|
|
18
|
+
* vmulq_f64 FMUL (V.2D, V.2D, V.2D) 3cy @ 2p 3cy @ 4p
|
|
19
|
+
* vaddvq_f32 FADDP+FADDP (reduce) 5cy @ 1p 8cy @ 1p
|
|
20
|
+
* vaddvq_f64 FADDP (V.2D to scalar) 3cy @ 1p 3cy @ 1p
|
|
21
|
+
* vpaddq_f32 FADDP (V.4S, V.4S, V.4S) 2cy @ 2p 3cy @ 4p
|
|
22
|
+
* vpaddq_f64 FADDP (V.2D, V.2D, V.2D) 2cy @ 2p 3cy @ 4p
|
|
23
|
+
* vcvt_f64_f32 FCVTL (V.2D, V.2S) 3cy @ 2p 3cy @ 2p
|
|
24
|
+
* vld2_f32 LD2 ({Vt.2S, Vt2.2S}, [Xn]) 4cy @ 1p 4cy @ 1p
|
|
22
25
|
*
|
|
23
26
|
* FMA throughput doubles on cores with 4 SIMD pipes (Apple M4+, Graviton3+, Oryon), but
|
|
24
27
|
* horizontal reductions remain at 1/cy on all cores and become the main bottleneck.
|
|
@@ -118,21 +121,25 @@ NK_INTERNAL nk_f64_t nk_dot_stable_sum_f64x2_neon_(float64x2_t sum_f64x2, float6
|
|
|
118
121
|
return tentative_sum + (lower_error + upper_error + rounding_error);
|
|
119
122
|
}
|
|
120
123
|
|
|
121
|
-
#pragma region
|
|
124
|
+
#pragma region F32 and F64 Floats
|
|
122
125
|
|
|
123
126
|
NK_PUBLIC void nk_dot_f32_neon(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
|
|
124
127
|
nk_f64_t *result) {
|
|
125
|
-
// Upcast f32 to f64
|
|
126
|
-
float64x2_t
|
|
128
|
+
// Upcast f32 to f64 via FCVTL/FCVTL2, two independent FMA chains for ILP
|
|
129
|
+
float64x2_t sum_low_f64x2 = vdupq_n_f64(0);
|
|
130
|
+
float64x2_t sum_high_f64x2 = vdupq_n_f64(0);
|
|
127
131
|
nk_size_t idx_scalars = 0;
|
|
128
|
-
for (; idx_scalars +
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
float64x2_t
|
|
132
|
-
float64x2_t
|
|
133
|
-
|
|
132
|
+
for (; idx_scalars + 4 <= count_scalars; idx_scalars += 4) {
|
|
133
|
+
float32x4_t a_f32x4 = vld1q_f32(a_scalars + idx_scalars);
|
|
134
|
+
float32x4_t b_f32x4 = vld1q_f32(b_scalars + idx_scalars);
|
|
135
|
+
float64x2_t a_low_f64x2 = vcvt_f64_f32(vget_low_f32(a_f32x4));
|
|
136
|
+
float64x2_t a_high_f64x2 = vcvt_high_f64_f32(a_f32x4);
|
|
137
|
+
float64x2_t b_low_f64x2 = vcvt_f64_f32(vget_low_f32(b_f32x4));
|
|
138
|
+
float64x2_t b_high_f64x2 = vcvt_high_f64_f32(b_f32x4);
|
|
139
|
+
sum_low_f64x2 = vfmaq_f64(sum_low_f64x2, a_low_f64x2, b_low_f64x2);
|
|
140
|
+
sum_high_f64x2 = vfmaq_f64(sum_high_f64x2, a_high_f64x2, b_high_f64x2);
|
|
134
141
|
}
|
|
135
|
-
nk_f64_t sum_f64 = vaddvq_f64(
|
|
142
|
+
nk_f64_t sum_f64 = vaddvq_f64(vaddq_f64(sum_low_f64x2, sum_high_f64x2));
|
|
136
143
|
for (; idx_scalars < count_scalars; ++idx_scalars)
|
|
137
144
|
sum_f64 += (nk_f64_t)a_scalars[idx_scalars] * (nk_f64_t)b_scalars[idx_scalars];
|
|
138
145
|
*result = sum_f64;
|
|
@@ -243,10 +250,10 @@ NK_INTERNAL void nk_dot_f32x2_finalize_neon(
|
|
|
243
250
|
nk_dot_f32x2_state_neon_t const *state_c, nk_dot_f32x2_state_neon_t const *state_d, //
|
|
244
251
|
nk_size_t total_dimensions, nk_b256_vec_t *result) {
|
|
245
252
|
nk_unused_(total_dimensions);
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
result->f64s[
|
|
249
|
-
result->f64s[
|
|
253
|
+
float64x2_t ab_f64x2 = vpaddq_f64(state_a->sum_f64x2, state_b->sum_f64x2);
|
|
254
|
+
float64x2_t cd_f64x2 = vpaddq_f64(state_c->sum_f64x2, state_d->sum_f64x2);
|
|
255
|
+
vst1q_f64(&result->f64s[0], ab_f64x2);
|
|
256
|
+
vst1q_f64(&result->f64s[2], cd_f64x2);
|
|
250
257
|
}
|
|
251
258
|
|
|
252
259
|
NK_PUBLIC void nk_dot_f64_neon(nk_f64_t const *a_scalars, nk_f64_t const *b_scalars, nk_size_t count_scalars,
|
|
@@ -302,11 +309,11 @@ nk_dot_f64c_neon_cycle:
|
|
|
302
309
|
nk_b128_vec_t a_tail, b_tail;
|
|
303
310
|
nk_partial_load_b64x2_serial_(a_pairs, &a_tail, count_pairs * 2);
|
|
304
311
|
nk_partial_load_b64x2_serial_(b_pairs, &b_tail, count_pairs * 2);
|
|
305
|
-
float64x2_t
|
|
306
|
-
a_real_f64x2 = vzip1q_f64(a_tail.f64x2,
|
|
307
|
-
a_imag_f64x2 = vzip2q_f64(a_tail.f64x2,
|
|
308
|
-
b_real_f64x2 = vzip1q_f64(b_tail.f64x2,
|
|
309
|
-
b_imag_f64x2 = vzip2q_f64(b_tail.f64x2,
|
|
312
|
+
float64x2_t zeros_f64x2 = vdupq_n_f64(0);
|
|
313
|
+
a_real_f64x2 = vzip1q_f64(a_tail.f64x2, zeros_f64x2);
|
|
314
|
+
a_imag_f64x2 = vzip2q_f64(a_tail.f64x2, zeros_f64x2);
|
|
315
|
+
b_real_f64x2 = vzip1q_f64(b_tail.f64x2, zeros_f64x2);
|
|
316
|
+
b_imag_f64x2 = vzip2q_f64(b_tail.f64x2, zeros_f64x2);
|
|
310
317
|
count_pairs = 0;
|
|
311
318
|
}
|
|
312
319
|
else {
|
|
@@ -385,11 +392,11 @@ nk_vdot_f64c_neon_cycle:
|
|
|
385
392
|
nk_b128_vec_t a_tail, b_tail;
|
|
386
393
|
nk_partial_load_b64x2_serial_(a_pairs, &a_tail, count_pairs * 2);
|
|
387
394
|
nk_partial_load_b64x2_serial_(b_pairs, &b_tail, count_pairs * 2);
|
|
388
|
-
float64x2_t
|
|
389
|
-
a_real_f64x2 = vzip1q_f64(a_tail.f64x2,
|
|
390
|
-
a_imag_f64x2 = vzip2q_f64(a_tail.f64x2,
|
|
391
|
-
b_real_f64x2 = vzip1q_f64(b_tail.f64x2,
|
|
392
|
-
b_imag_f64x2 = vzip2q_f64(b_tail.f64x2,
|
|
395
|
+
float64x2_t zeros_f64x2 = vdupq_n_f64(0);
|
|
396
|
+
a_real_f64x2 = vzip1q_f64(a_tail.f64x2, zeros_f64x2);
|
|
397
|
+
a_imag_f64x2 = vzip2q_f64(a_tail.f64x2, zeros_f64x2);
|
|
398
|
+
b_real_f64x2 = vzip1q_f64(b_tail.f64x2, zeros_f64x2);
|
|
399
|
+
b_imag_f64x2 = vzip2q_f64(b_tail.f64x2, zeros_f64x2);
|
|
393
400
|
count_pairs = 0;
|
|
394
401
|
}
|
|
395
402
|
else {
|
|
@@ -505,9 +512,9 @@ NK_INTERNAL void nk_dot_f64x2_finalize_neon(
|
|
|
505
512
|
result->f64s[3] = nk_dot_stable_sum_f64x2_neon_(state_d->sum_f64x2, state_d->compensation_f64x2);
|
|
506
513
|
}
|
|
507
514
|
|
|
508
|
-
#pragma endregion
|
|
515
|
+
#pragma endregion F32 and F64 Floats
|
|
509
516
|
|
|
510
|
-
#pragma region
|
|
517
|
+
#pragma region F16 and BF16 Floats
|
|
511
518
|
|
|
512
519
|
NK_PUBLIC void nk_dot_bf16_neon(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
|
|
513
520
|
nk_f32_t *result) {
|
|
@@ -528,9 +535,9 @@ nk_dot_bf16_neon_cycle:
|
|
|
528
535
|
a_scalars += 8, b_scalars += 8, count_scalars -= 8;
|
|
529
536
|
}
|
|
530
537
|
float32x4_t a_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(a_u16x8), 16));
|
|
531
|
-
float32x4_t a_high_f32x4 = vreinterpretq_f32_u32(
|
|
538
|
+
float32x4_t a_high_f32x4 = vreinterpretq_f32_u32(vshll_high_n_u16(a_u16x8, 16));
|
|
532
539
|
float32x4_t b_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b_u16x8), 16));
|
|
533
|
-
float32x4_t b_high_f32x4 = vreinterpretq_f32_u32(
|
|
540
|
+
float32x4_t b_high_f32x4 = vreinterpretq_f32_u32(vshll_high_n_u16(b_u16x8, 16));
|
|
534
541
|
sum_f32x4 = vfmaq_f32(sum_f32x4, a_low_f32x4, b_low_f32x4);
|
|
535
542
|
sum_f32x4 = vfmaq_f32(sum_f32x4, a_high_f32x4, b_high_f32x4);
|
|
536
543
|
if (count_scalars) goto nk_dot_bf16_neon_cycle;
|
|
@@ -555,9 +562,9 @@ NK_INTERNAL void nk_dot_bf16x8_update_neon(nk_dot_bf16x8_state_neon_t *state, nk
|
|
|
555
562
|
nk_unused_(active_dimensions);
|
|
556
563
|
// Convert bf16 to f32 via USHLL shift-16 (low and high halves)
|
|
557
564
|
float32x4_t a_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(a.u16x8), 16));
|
|
558
|
-
float32x4_t a_high_f32x4 = vreinterpretq_f32_u32(
|
|
565
|
+
float32x4_t a_high_f32x4 = vreinterpretq_f32_u32(vshll_high_n_u16(a.u16x8, 16));
|
|
559
566
|
float32x4_t b_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b.u16x8), 16));
|
|
560
|
-
float32x4_t b_high_f32x4 = vreinterpretq_f32_u32(
|
|
567
|
+
float32x4_t b_high_f32x4 = vreinterpretq_f32_u32(vshll_high_n_u16(b.u16x8, 16));
|
|
561
568
|
state->sum_f32x4 = vfmaq_f32(state->sum_f32x4, a_low_f32x4, b_low_f32x4);
|
|
562
569
|
state->sum_f32x4 = vfmaq_f32(state->sum_f32x4, a_high_f32x4, b_high_f32x4);
|
|
563
570
|
}
|
|
@@ -567,10 +574,9 @@ NK_INTERNAL void nk_dot_bf16x8_finalize_neon(
|
|
|
567
574
|
nk_dot_bf16x8_state_neon_t const *state_c, nk_dot_bf16x8_state_neon_t const *state_d, //
|
|
568
575
|
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
569
576
|
nk_unused_(total_dimensions);
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
result->
|
|
573
|
-
result->f32s[3] = vaddvq_f32(state_d->sum_f32x4);
|
|
577
|
+
float32x4_t ab_f32x4 = vpaddq_f32(state_a->sum_f32x4, state_b->sum_f32x4);
|
|
578
|
+
float32x4_t cd_f32x4 = vpaddq_f32(state_c->sum_f32x4, state_d->sum_f32x4);
|
|
579
|
+
result->f32x4 = vpaddq_f32(ab_f32x4, cd_f32x4);
|
|
574
580
|
}
|
|
575
581
|
|
|
576
582
|
NK_PUBLIC void nk_dot_f16_neon(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
|
|
@@ -591,10 +597,12 @@ nk_dot_f16_neon_cycle:
|
|
|
591
597
|
b_u16x8 = vld1q_u16((nk_u16_t const *)b_scalars);
|
|
592
598
|
a_scalars += 8, b_scalars += 8, count_scalars -= 8;
|
|
593
599
|
}
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
float32x4_t
|
|
597
|
-
float32x4_t
|
|
600
|
+
float16x8_t a_f16x8 = vreinterpretq_f16_u16(a_u16x8);
|
|
601
|
+
float16x8_t b_f16x8 = vreinterpretq_f16_u16(b_u16x8);
|
|
602
|
+
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
603
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
604
|
+
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
605
|
+
float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
598
606
|
sum_f32x4 = vfmaq_f32(sum_f32x4, a_low_f32x4, b_low_f32x4);
|
|
599
607
|
sum_f32x4 = vfmaq_f32(sum_f32x4, a_high_f32x4, b_high_f32x4);
|
|
600
608
|
if (count_scalars) goto nk_dot_f16_neon_cycle;
|
|
@@ -604,8 +612,8 @@ nk_dot_f16_neon_cycle:
|
|
|
604
612
|
/**
|
|
605
613
|
* @brief Running state for 128-bit dot accumulation over f16 scalars on plain NEON.
|
|
606
614
|
*
|
|
607
|
-
* Processes 8 f16 values at a time (128 bits), converting to f32 via
|
|
608
|
-
*
|
|
615
|
+
* Processes 8 f16 values at a time (128 bits), converting to f32 via FCVTL
|
|
616
|
+
* for accumulation without requiring the ARMv8.2-A FP16 arithmetic extension.
|
|
609
617
|
*/
|
|
610
618
|
typedef struct nk_dot_f16x8_state_neon_t {
|
|
611
619
|
float32x4_t sum_f32x4;
|
|
@@ -617,11 +625,13 @@ NK_INTERNAL void nk_dot_f16x8_update_neon(nk_dot_f16x8_state_neon_t *state, nk_b
|
|
|
617
625
|
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
618
626
|
nk_unused_(depth_offset);
|
|
619
627
|
nk_unused_(active_dimensions);
|
|
620
|
-
// Convert f16 to f32 via
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
float32x4_t
|
|
624
|
-
float32x4_t
|
|
628
|
+
// Convert f16 to f32 via FCVTL / FCVTL2 (low and high halves)
|
|
629
|
+
float16x8_t a_f16x8 = vreinterpretq_f16_u16(a.u16x8);
|
|
630
|
+
float16x8_t b_f16x8 = vreinterpretq_f16_u16(b.u16x8);
|
|
631
|
+
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
632
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
633
|
+
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
634
|
+
float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
625
635
|
state->sum_f32x4 = vfmaq_f32(state->sum_f32x4, a_low_f32x4, b_low_f32x4);
|
|
626
636
|
state->sum_f32x4 = vfmaq_f32(state->sum_f32x4, a_high_f32x4, b_high_f32x4);
|
|
627
637
|
}
|
|
@@ -631,10 +641,9 @@ NK_INTERNAL void nk_dot_f16x8_finalize_neon(
|
|
|
631
641
|
nk_dot_f16x8_state_neon_t const *state_c, nk_dot_f16x8_state_neon_t const *state_d, //
|
|
632
642
|
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
633
643
|
nk_unused_(total_dimensions);
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
result->
|
|
637
|
-
result->f32s[3] = vaddvq_f32(state_d->sum_f32x4);
|
|
644
|
+
float32x4_t ab_f32x4 = vpaddq_f32(state_a->sum_f32x4, state_b->sum_f32x4);
|
|
645
|
+
float32x4_t cd_f32x4 = vpaddq_f32(state_c->sum_f32x4, state_d->sum_f32x4);
|
|
646
|
+
result->f32x4 = vpaddq_f32(ab_f32x4, cd_f32x4);
|
|
638
647
|
}
|
|
639
648
|
|
|
640
649
|
NK_PUBLIC void nk_dot_e4m3_neon(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
|
|
@@ -656,9 +665,9 @@ nk_dot_e4m3_neon_cycle:
|
|
|
656
665
|
a_scalars += 8, b_scalars += 8, count_scalars -= 8;
|
|
657
666
|
}
|
|
658
667
|
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
659
|
-
float32x4_t a_high_f32x4 =
|
|
668
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
660
669
|
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
661
|
-
float32x4_t b_high_f32x4 =
|
|
670
|
+
float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
662
671
|
sum_f32x4 = vfmaq_f32(sum_f32x4, a_low_f32x4, b_low_f32x4);
|
|
663
672
|
sum_f32x4 = vfmaq_f32(sum_f32x4, a_high_f32x4, b_high_f32x4);
|
|
664
673
|
if (count_scalars) goto nk_dot_e4m3_neon_cycle;
|
|
@@ -684,9 +693,9 @@ nk_dot_e5m2_neon_cycle:
|
|
|
684
693
|
a_scalars += 8, b_scalars += 8, count_scalars -= 8;
|
|
685
694
|
}
|
|
686
695
|
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
687
|
-
float32x4_t a_high_f32x4 =
|
|
696
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
688
697
|
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
689
|
-
float32x4_t b_high_f32x4 =
|
|
698
|
+
float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
690
699
|
sum_f32x4 = vfmaq_f32(sum_f32x4, a_low_f32x4, b_low_f32x4);
|
|
691
700
|
sum_f32x4 = vfmaq_f32(sum_f32x4, a_high_f32x4, b_high_f32x4);
|
|
692
701
|
if (count_scalars) goto nk_dot_e5m2_neon_cycle;
|
|
@@ -713,12 +722,10 @@ nk_dot_e2m3_neon_cycle:
|
|
|
713
722
|
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
714
723
|
}
|
|
715
724
|
sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_f32_f16(vget_low_f16(a_low_f16x8)), vcvt_f32_f16(vget_low_f16(b_low_f16x8)));
|
|
716
|
-
sum_f32x4 = vfmaq_f32(sum_f32x4,
|
|
717
|
-
vcvt_f32_f16(vget_high_f16(b_low_f16x8)));
|
|
725
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_high_f32_f16(a_low_f16x8), vcvt_high_f32_f16(b_low_f16x8));
|
|
718
726
|
sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_f32_f16(vget_low_f16(a_high_f16x8)),
|
|
719
727
|
vcvt_f32_f16(vget_low_f16(b_high_f16x8)));
|
|
720
|
-
sum_f32x4 = vfmaq_f32(sum_f32x4,
|
|
721
|
-
vcvt_f32_f16(vget_high_f16(b_high_f16x8)));
|
|
728
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_high_f32_f16(a_high_f16x8), vcvt_high_f32_f16(b_high_f16x8));
|
|
722
729
|
if (count_scalars) goto nk_dot_e2m3_neon_cycle;
|
|
723
730
|
*result = vaddvq_f32(sum_f32x4);
|
|
724
731
|
}
|
|
@@ -743,19 +750,17 @@ nk_dot_e3m2_neon_cycle:
|
|
|
743
750
|
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
744
751
|
}
|
|
745
752
|
sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_f32_f16(vget_low_f16(a_low_f16x8)), vcvt_f32_f16(vget_low_f16(b_low_f16x8)));
|
|
746
|
-
sum_f32x4 = vfmaq_f32(sum_f32x4,
|
|
747
|
-
vcvt_f32_f16(vget_high_f16(b_low_f16x8)));
|
|
753
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_high_f32_f16(a_low_f16x8), vcvt_high_f32_f16(b_low_f16x8));
|
|
748
754
|
sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_f32_f16(vget_low_f16(a_high_f16x8)),
|
|
749
755
|
vcvt_f32_f16(vget_low_f16(b_high_f16x8)));
|
|
750
|
-
sum_f32x4 = vfmaq_f32(sum_f32x4,
|
|
751
|
-
vcvt_f32_f16(vget_high_f16(b_high_f16x8)));
|
|
756
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, vcvt_high_f32_f16(a_high_f16x8), vcvt_high_f32_f16(b_high_f16x8));
|
|
752
757
|
if (count_scalars) goto nk_dot_e3m2_neon_cycle;
|
|
753
758
|
*result = vaddvq_f32(sum_f32x4);
|
|
754
759
|
}
|
|
755
760
|
|
|
756
|
-
#pragma endregion
|
|
761
|
+
#pragma endregion F16 and BF16 Floats
|
|
757
762
|
|
|
758
|
-
#pragma region
|
|
763
|
+
#pragma region Binary
|
|
759
764
|
|
|
760
765
|
NK_PUBLIC void nk_dot_u1_neon(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
|
|
761
766
|
nk_size_t n_bytes = nk_size_divide_round_up_(n_bits, NK_BITS_PER_BYTE);
|
|
@@ -801,7 +806,53 @@ NK_INTERNAL void nk_dot_u1x128_finalize_neon( //
|
|
|
801
806
|
result->u32x4 = vpaddq_u32(ab_sum_u32x4, cd_sum_u32x4);
|
|
802
807
|
}
|
|
803
808
|
|
|
804
|
-
#pragma endregion
|
|
809
|
+
#pragma endregion Binary
|
|
810
|
+
|
|
811
|
+
NK_PUBLIC void nk_dot_f16c_neon(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
|
|
812
|
+
nk_f32c_t *result) {
|
|
813
|
+
float32x4_t sum_real_f32x4 = vdupq_n_f32(0);
|
|
814
|
+
float32x4_t sum_imag_f32x4 = vdupq_n_f32(0);
|
|
815
|
+
while (count_pairs >= 4) {
|
|
816
|
+
int16x4x2_t a_i16x4x2 = vld2_s16((short *)a_pairs);
|
|
817
|
+
int16x4x2_t b_i16x4x2 = vld2_s16((short *)b_pairs);
|
|
818
|
+
float32x4_t a_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(a_i16x4x2.val[0]));
|
|
819
|
+
float32x4_t a_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(a_i16x4x2.val[1]));
|
|
820
|
+
float32x4_t b_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[0]));
|
|
821
|
+
float32x4_t b_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[1]));
|
|
822
|
+
sum_real_f32x4 = vfmaq_f32(sum_real_f32x4, a_real_f32x4, b_real_f32x4);
|
|
823
|
+
sum_real_f32x4 = vfmsq_f32(sum_real_f32x4, a_imag_f32x4, b_imag_f32x4);
|
|
824
|
+
sum_imag_f32x4 = vfmaq_f32(sum_imag_f32x4, a_real_f32x4, b_imag_f32x4);
|
|
825
|
+
sum_imag_f32x4 = vfmaq_f32(sum_imag_f32x4, a_imag_f32x4, b_real_f32x4);
|
|
826
|
+
count_pairs -= 4, a_pairs += 4, b_pairs += 4;
|
|
827
|
+
}
|
|
828
|
+
nk_f32c_t tail_result;
|
|
829
|
+
nk_dot_f16c_serial(a_pairs, b_pairs, count_pairs, &tail_result);
|
|
830
|
+
result->real = tail_result.real + vaddvq_f32(sum_real_f32x4);
|
|
831
|
+
result->imag = tail_result.imag + vaddvq_f32(sum_imag_f32x4);
|
|
832
|
+
}
|
|
833
|
+
|
|
834
|
+
NK_PUBLIC void nk_vdot_f16c_neon(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
|
|
835
|
+
nk_f32c_t *result) {
|
|
836
|
+
float32x4_t sum_real_f32x4 = vdupq_n_f32(0);
|
|
837
|
+
float32x4_t sum_imag_f32x4 = vdupq_n_f32(0);
|
|
838
|
+
while (count_pairs >= 4) {
|
|
839
|
+
int16x4x2_t a_i16x4x2 = vld2_s16((short *)a_pairs);
|
|
840
|
+
int16x4x2_t b_i16x4x2 = vld2_s16((short *)b_pairs);
|
|
841
|
+
float32x4_t a_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(a_i16x4x2.val[0]));
|
|
842
|
+
float32x4_t a_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(a_i16x4x2.val[1]));
|
|
843
|
+
float32x4_t b_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[0]));
|
|
844
|
+
float32x4_t b_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[1]));
|
|
845
|
+
sum_real_f32x4 = vfmaq_f32(sum_real_f32x4, a_real_f32x4, b_real_f32x4);
|
|
846
|
+
sum_real_f32x4 = vfmaq_f32(sum_real_f32x4, a_imag_f32x4, b_imag_f32x4);
|
|
847
|
+
sum_imag_f32x4 = vfmaq_f32(sum_imag_f32x4, a_real_f32x4, b_imag_f32x4);
|
|
848
|
+
sum_imag_f32x4 = vfmsq_f32(sum_imag_f32x4, a_imag_f32x4, b_real_f32x4);
|
|
849
|
+
count_pairs -= 4, a_pairs += 4, b_pairs += 4;
|
|
850
|
+
}
|
|
851
|
+
nk_f32c_t tail_result;
|
|
852
|
+
nk_vdot_f16c_serial(a_pairs, b_pairs, count_pairs, &tail_result);
|
|
853
|
+
result->real = tail_result.real + vaddvq_f32(sum_real_f32x4);
|
|
854
|
+
result->imag = tail_result.imag + vaddvq_f32(sum_imag_f32x4);
|
|
855
|
+
}
|
|
805
856
|
|
|
806
857
|
#if defined(__clang__)
|
|
807
858
|
#pragma clang attribute pop
|
|
@@ -8,14 +8,14 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section dot_neonbfdot_instructions ARM NEON BF16 Instructions (ARMv8.6-BF16)
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
*
|
|
13
|
-
*
|
|
14
|
-
*
|
|
15
|
-
*
|
|
16
|
-
*
|
|
17
|
-
* vfmaq_f32
|
|
18
|
-
* vfmsq_f32
|
|
11
|
+
* Intrinsic Instruction A76 M5
|
|
12
|
+
* vbfdotq_f32 BFDOT (V.4S, V.8H, V.8H) 3cy @ 2p 2cy @ 1p
|
|
13
|
+
* vcvt_f32_bf16 BFCVTN (V.4H, V.4S) 3cy @ 2p 3cy @ 4p
|
|
14
|
+
* vld1q_bf16 LD1 (V.8H) 4cy @ 2p 4cy @ 3p
|
|
15
|
+
* vaddvq_f32 FADDP+FADDP (V.4S) 4cy @ 1p 8cy @ 1p
|
|
16
|
+
* vpaddq_f32 FADDP (V.4S, V.4S, V.4S) 2cy @ 2p 3cy @ 4p
|
|
17
|
+
* vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy @ 2p 3cy @ 4p
|
|
18
|
+
* vfmsq_f32 FMLS (V.4S, V.4S, V.4S) 4cy @ 2p 3cy @ 4p
|
|
19
19
|
*
|
|
20
20
|
* The ARMv8.6-BF16 extension provides the BFDOT instruction for accelerated BF16 dot products,
|
|
21
21
|
* targeting machine learning inference workloads. BF16 trades mantissa precision (7 bits vs 10 in
|
|
@@ -223,10 +223,9 @@ NK_INTERNAL void nk_dot_bf16x8_finalize_neonbfdot(
|
|
|
223
223
|
nk_dot_bf16x8_state_neonbfdot_t const *state_c, nk_dot_bf16x8_state_neonbfdot_t const *state_d, //
|
|
224
224
|
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
225
225
|
nk_unused_(total_dimensions);
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
result->
|
|
229
|
-
result->f32s[3] = vaddvq_f32(state_d->sum_f32x4);
|
|
226
|
+
float32x4_t ab_f32x4 = vpaddq_f32(state_a->sum_f32x4, state_b->sum_f32x4);
|
|
227
|
+
float32x4_t cd_f32x4 = vpaddq_f32(state_c->sum_f32x4, state_d->sum_f32x4);
|
|
228
|
+
result->f32x4 = vpaddq_f32(ab_f32x4, cd_f32x4);
|
|
230
229
|
}
|
|
231
230
|
|
|
232
231
|
#if defined(__clang__)
|
|
@@ -8,14 +8,15 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section dot_neonfhm_instructions ARM NEON FP16 Matrix Instructions (ARMv8.4-FHM)
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
*
|
|
13
|
-
*
|
|
14
|
-
*
|
|
15
|
-
*
|
|
16
|
-
*
|
|
17
|
-
*
|
|
18
|
-
*
|
|
11
|
+
* Intrinsic Instruction A76 M5
|
|
12
|
+
* vfmlalq_low_f16 FMLAL (V.4S, V.8H, V.8H) 4cy @ 2p 4cy @ 4p
|
|
13
|
+
* vfmlalq_high_f16 FMLAL2 (V.4S, V.8H, V.8H) 4cy @ 2p 4cy @ 4p
|
|
14
|
+
* vfmlslq_low_f16 FMLSL (V.4S, V.8H, V.8H) 4cy @ 2p 4cy @ 4p
|
|
15
|
+
* vfmlslq_high_f16 FMLSL2 (V.4S, V.8H, V.8H) 4cy @ 2p 4cy @ 4p
|
|
16
|
+
* vld1q_f16 LD1 (V.8H) 4cy @ 2p 4cy @ 3p
|
|
17
|
+
* vaddvq_f32 FADDP+FADDP (V.4S) 4cy @ 1p 8cy @ 1p
|
|
18
|
+
* vpaddq_f32 FADDP (V.4S, V.4S, V.4S) 2cy @ 2p 3cy @ 4p
|
|
19
|
+
* vshll_n_u8 SHLL (V.8H, V.8B, #8) 2cy @ 2p 2cy @ 4p
|
|
19
20
|
*
|
|
20
21
|
* The ARMv8.4-FHM extension (FEAT_FHM) provides FMLAL/FMLSL instructions that fuse FP16 to FP32
|
|
21
22
|
* widening with multiply-accumulate in a single operation. FMLAL executes as a single fused op
|
|
@@ -90,8 +91,8 @@ nk_dot_f16_neonfhm_cycle:
|
|
|
90
91
|
count_scalars = 0;
|
|
91
92
|
}
|
|
92
93
|
else {
|
|
93
|
-
a_f16x8 =
|
|
94
|
-
b_f16x8 =
|
|
94
|
+
a_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)(a_scalars)));
|
|
95
|
+
b_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)(b_scalars)));
|
|
95
96
|
a_scalars += 8, b_scalars += 8, count_scalars -= 8;
|
|
96
97
|
}
|
|
97
98
|
// FMLAL: widening multiply-accumulate fp16 → f32
|
|
@@ -124,10 +125,9 @@ NK_INTERNAL void nk_dot_f16x8_finalize_neonfhm(
|
|
|
124
125
|
nk_dot_f16x8_state_neonfhm_t const *state_c, nk_dot_f16x8_state_neonfhm_t const *state_d, //
|
|
125
126
|
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
126
127
|
nk_unused_(total_dimensions);
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
result->
|
|
130
|
-
result->f32s[3] = vaddvq_f32(state_d->sum_f32x4);
|
|
128
|
+
float32x4_t ab_f32x4 = vpaddq_f32(state_a->sum_f32x4, state_b->sum_f32x4);
|
|
129
|
+
float32x4_t cd_f32x4 = vpaddq_f32(state_c->sum_f32x4, state_d->sum_f32x4);
|
|
130
|
+
result->f32x4 = vpaddq_f32(ab_f32x4, cd_f32x4);
|
|
131
131
|
}
|
|
132
132
|
|
|
133
133
|
NK_PUBLIC void nk_dot_f16c_neonfhm(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
|
|
@@ -220,58 +220,58 @@ NK_PUBLIC void nk_vdot_f16c_neonfhm(nk_f16c_t const *a_pairs, nk_f16c_t const *b
|
|
|
220
220
|
|
|
221
221
|
NK_PUBLIC void nk_dot_e4m3_neonfhm(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
|
|
222
222
|
nk_f32_t *result) {
|
|
223
|
-
float16x8_t
|
|
223
|
+
float16x8_t a_low_f16x8, a_high_f16x8, b_low_f16x8, b_high_f16x8;
|
|
224
224
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
225
225
|
nk_dot_e4m3_neonfhm_cycle:
|
|
226
226
|
if (count_scalars < 16) {
|
|
227
227
|
nk_b128_vec_t a_vec, b_vec;
|
|
228
228
|
nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
|
|
229
229
|
nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
|
|
230
|
-
nk_e4m3x16_to_f16x8x2_neon_(a_vec.u8x16, &
|
|
231
|
-
nk_e4m3x16_to_f16x8x2_neon_(b_vec.u8x16, &
|
|
230
|
+
nk_e4m3x16_to_f16x8x2_neon_(a_vec.u8x16, &a_low_f16x8, &a_high_f16x8);
|
|
231
|
+
nk_e4m3x16_to_f16x8x2_neon_(b_vec.u8x16, &b_low_f16x8, &b_high_f16x8);
|
|
232
232
|
count_scalars = 0;
|
|
233
233
|
}
|
|
234
234
|
else {
|
|
235
|
-
nk_e4m3x16_to_f16x8x2_neon_(vld1q_u8(a_scalars), &
|
|
236
|
-
nk_e4m3x16_to_f16x8x2_neon_(vld1q_u8(b_scalars), &
|
|
235
|
+
nk_e4m3x16_to_f16x8x2_neon_(vld1q_u8(a_scalars), &a_low_f16x8, &a_high_f16x8);
|
|
236
|
+
nk_e4m3x16_to_f16x8x2_neon_(vld1q_u8(b_scalars), &b_low_f16x8, &b_high_f16x8);
|
|
237
237
|
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
238
238
|
}
|
|
239
|
-
sum_f32x4 = vfmlalq_low_f16(sum_f32x4,
|
|
240
|
-
sum_f32x4 = vfmlalq_high_f16(sum_f32x4,
|
|
241
|
-
sum_f32x4 = vfmlalq_low_f16(sum_f32x4,
|
|
242
|
-
sum_f32x4 = vfmlalq_high_f16(sum_f32x4,
|
|
239
|
+
sum_f32x4 = vfmlalq_low_f16(sum_f32x4, a_low_f16x8, b_low_f16x8);
|
|
240
|
+
sum_f32x4 = vfmlalq_high_f16(sum_f32x4, a_low_f16x8, b_low_f16x8);
|
|
241
|
+
sum_f32x4 = vfmlalq_low_f16(sum_f32x4, a_high_f16x8, b_high_f16x8);
|
|
242
|
+
sum_f32x4 = vfmlalq_high_f16(sum_f32x4, a_high_f16x8, b_high_f16x8);
|
|
243
243
|
if (count_scalars) goto nk_dot_e4m3_neonfhm_cycle;
|
|
244
244
|
*result = vaddvq_f32(sum_f32x4);
|
|
245
245
|
}
|
|
246
246
|
|
|
247
247
|
NK_PUBLIC void nk_dot_e5m2_neonfhm(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
|
|
248
248
|
nk_f32_t *result) {
|
|
249
|
-
float16x8_t
|
|
249
|
+
float16x8_t a_low_f16x8, a_high_f16x8, b_low_f16x8, b_high_f16x8;
|
|
250
250
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
251
251
|
nk_dot_e5m2_neonfhm_cycle:
|
|
252
252
|
if (count_scalars < 16) {
|
|
253
253
|
nk_b128_vec_t a_vec, b_vec;
|
|
254
254
|
nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
|
|
255
255
|
nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
256
|
+
a_low_f16x8 = vreinterpretq_f16_u16(vshll_n_u8(vget_low_u8(a_vec.u8x16), 8));
|
|
257
|
+
a_high_f16x8 = vreinterpretq_f16_u16(vshll_high_n_u8(a_vec.u8x16, 8));
|
|
258
|
+
b_low_f16x8 = vreinterpretq_f16_u16(vshll_n_u8(vget_low_u8(b_vec.u8x16), 8));
|
|
259
|
+
b_high_f16x8 = vreinterpretq_f16_u16(vshll_high_n_u8(b_vec.u8x16, 8));
|
|
260
260
|
count_scalars = 0;
|
|
261
261
|
}
|
|
262
262
|
else {
|
|
263
263
|
uint8x16_t a_u8x16 = vld1q_u8(a_scalars);
|
|
264
264
|
uint8x16_t b_u8x16 = vld1q_u8(b_scalars);
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
265
|
+
a_low_f16x8 = vreinterpretq_f16_u16(vshll_n_u8(vget_low_u8(a_u8x16), 8));
|
|
266
|
+
a_high_f16x8 = vreinterpretq_f16_u16(vshll_high_n_u8(a_u8x16, 8));
|
|
267
|
+
b_low_f16x8 = vreinterpretq_f16_u16(vshll_n_u8(vget_low_u8(b_u8x16), 8));
|
|
268
|
+
b_high_f16x8 = vreinterpretq_f16_u16(vshll_high_n_u8(b_u8x16, 8));
|
|
269
269
|
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
270
270
|
}
|
|
271
|
-
sum_f32x4 = vfmlalq_low_f16(sum_f32x4,
|
|
272
|
-
sum_f32x4 = vfmlalq_high_f16(sum_f32x4,
|
|
273
|
-
sum_f32x4 = vfmlalq_low_f16(sum_f32x4,
|
|
274
|
-
sum_f32x4 = vfmlalq_high_f16(sum_f32x4,
|
|
271
|
+
sum_f32x4 = vfmlalq_low_f16(sum_f32x4, a_low_f16x8, b_low_f16x8);
|
|
272
|
+
sum_f32x4 = vfmlalq_high_f16(sum_f32x4, a_low_f16x8, b_low_f16x8);
|
|
273
|
+
sum_f32x4 = vfmlalq_low_f16(sum_f32x4, a_high_f16x8, b_high_f16x8);
|
|
274
|
+
sum_f32x4 = vfmlalq_high_f16(sum_f32x4, a_high_f16x8, b_high_f16x8);
|
|
275
275
|
if (count_scalars) goto nk_dot_e5m2_neonfhm_cycle;
|
|
276
276
|
*result = vaddvq_f32(sum_f32x4);
|
|
277
277
|
}
|
|
@@ -304,10 +304,9 @@ NK_INTERNAL void nk_dot_e4m3x16_finalize_neonfhm(
|
|
|
304
304
|
nk_dot_e4m3x16_state_neonfhm_t const *state_c, nk_dot_e4m3x16_state_neonfhm_t const *state_d, //
|
|
305
305
|
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
306
306
|
nk_unused_(total_dimensions);
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
result->
|
|
310
|
-
result->f32s[3] = vaddvq_f32(state_d->sum_f32x4);
|
|
307
|
+
float32x4_t ab_f32x4 = vpaddq_f32(state_a->sum_f32x4, state_b->sum_f32x4);
|
|
308
|
+
float32x4_t cd_f32x4 = vpaddq_f32(state_c->sum_f32x4, state_d->sum_f32x4);
|
|
309
|
+
result->f32x4 = vpaddq_f32(ab_f32x4, cd_f32x4);
|
|
311
310
|
}
|
|
312
311
|
|
|
313
312
|
typedef struct nk_dot_e5m2x16_state_neonfhm_t {
|
|
@@ -324,9 +323,9 @@ NK_INTERNAL void nk_dot_e5m2x16_update_neonfhm(nk_dot_e5m2x16_state_neonfhm_t *s
|
|
|
324
323
|
nk_unused_(active_dimensions);
|
|
325
324
|
// Convert e5m2 → f16 via SHLL: widen u8→u16 and shift left 8 in one instruction
|
|
326
325
|
float16x8_t a_low_f16x8 = vreinterpretq_f16_u16(vshll_n_u8(vget_low_u8(a.u8x16), 8));
|
|
327
|
-
float16x8_t a_high_f16x8 = vreinterpretq_f16_u16(
|
|
326
|
+
float16x8_t a_high_f16x8 = vreinterpretq_f16_u16(vshll_high_n_u8(a.u8x16, 8));
|
|
328
327
|
float16x8_t b_low_f16x8 = vreinterpretq_f16_u16(vshll_n_u8(vget_low_u8(b.u8x16), 8));
|
|
329
|
-
float16x8_t b_high_f16x8 = vreinterpretq_f16_u16(
|
|
328
|
+
float16x8_t b_high_f16x8 = vreinterpretq_f16_u16(vshll_high_n_u8(b.u8x16, 8));
|
|
330
329
|
// FMLAL: widening multiply-accumulate fp16 → f32
|
|
331
330
|
state->sum_f32x4 = vfmlalq_low_f16(state->sum_f32x4, a_low_f16x8, b_low_f16x8);
|
|
332
331
|
state->sum_f32x4 = vfmlalq_high_f16(state->sum_f32x4, a_low_f16x8, b_low_f16x8);
|
|
@@ -339,10 +338,9 @@ NK_INTERNAL void nk_dot_e5m2x16_finalize_neonfhm(
|
|
|
339
338
|
nk_dot_e5m2x16_state_neonfhm_t const *state_c, nk_dot_e5m2x16_state_neonfhm_t const *state_d, //
|
|
340
339
|
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
341
340
|
nk_unused_(total_dimensions);
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
result->
|
|
345
|
-
result->f32s[3] = vaddvq_f32(state_d->sum_f32x4);
|
|
341
|
+
float32x4_t ab_f32x4 = vpaddq_f32(state_a->sum_f32x4, state_b->sum_f32x4);
|
|
342
|
+
float32x4_t cd_f32x4 = vpaddq_f32(state_c->sum_f32x4, state_d->sum_f32x4);
|
|
343
|
+
result->f32x4 = vpaddq_f32(ab_f32x4, cd_f32x4);
|
|
346
344
|
}
|
|
347
345
|
|
|
348
346
|
#if defined(__clang__)
|