numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Dot Products for NEON FP8DOT4.
|
|
3
|
+
* @file include/numkong/dot/neonfp8.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 23, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dot.h
|
|
8
|
+
*
|
|
9
|
+
* @section dot_neonfp8_instructions ARM NEON FP8DOT4 Instructions (FEAT_FP8DOT4)
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction V1
|
|
12
|
+
* vdotq_f32_mf8 FDOT (V.4S, V.16B, V.16B) 4cy @ 2p
|
|
13
|
+
* vld1q_u8 LD1 (V.16B) 4cy @ 2p
|
|
14
|
+
* vaddvq_f32 FADDP+FADDP (V.4S) 4cy @ 1p
|
|
15
|
+
* vpaddq_f32 FADDP (V.4S, V.4S, V.4S) 2cy @ 2p
|
|
16
|
+
*
|
|
17
|
+
* FEAT_FP8DOT4 adds NEON FDOT instructions that take two 128-bit vectors of FP8 (E4M3 or E5M2),
|
|
18
|
+
* perform 4-way multiply-accumulate into FP32 per lane. Each FDOT processes 16 FP8 elements
|
|
19
|
+
* into 4 FP32 accumulators. The FP8 format is selected by the FPMR register.
|
|
20
|
+
*
|
|
21
|
+
* FP6 types (E2M3, E3M2) are losslessly promoted to FP8 (E4M3, E5M2) by rebiasing the exponent.
|
|
22
|
+
* Normal values: magnitude += 48. Subnormal values (exp=0): 8-entry or 4-entry TBL lookup.
|
|
23
|
+
*
|
|
24
|
+
* @section dot_neonfp8_stateful Stateful Streaming Logic
|
|
25
|
+
*
|
|
26
|
+
* Defines stateful init/update/finalize helpers for tiled GEMM via the dots/ macros:
|
|
27
|
+
* - nk_dot_e4m3x16_state_neonfp8_t, nk_dot_e5m2x16_state_neonfp8_t
|
|
28
|
+
* - nk_dot_e2m3x16_state_neonfp8_t, nk_dot_e3m2x16_state_neonfp8_t
|
|
29
|
+
*/
|
|
30
|
+
#ifndef NK_DOT_NEONFP8_H
|
|
31
|
+
#define NK_DOT_NEONFP8_H
|
|
32
|
+
|
|
33
|
+
#if NK_TARGET_ARM_
|
|
34
|
+
#if NK_TARGET_NEONFP8
|
|
35
|
+
|
|
36
|
+
#include "numkong/types.h"
|
|
37
|
+
#include "numkong/cast/serial.h" // `nk_partial_load_b8x16_serial_`
|
|
38
|
+
|
|
39
|
+
/** @brief FPM immediate for E4M3 × E4M3 dot products: src1=E4M3(1), src2=E4M3(1). */
|
|
40
|
+
#define NK_FPM_E4M3_ ((fpm_t)((1ull << 0) | (1ull << 3)))
|
|
41
|
+
/** @brief FPM immediate for E5M2 × E5M2 dot products: src1=E5M2(0), src2=E5M2(0). */
|
|
42
|
+
#define NK_FPM_E5M2_ ((fpm_t)0)
|
|
43
|
+
|
|
44
|
+
#if defined(__cplusplus)
|
|
45
|
+
extern "C" {
|
|
46
|
+
#endif
|
|
47
|
+
|
|
48
|
+
#if defined(__clang__)
|
|
49
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8-a+simd+fp8dot4"))), apply_to = function)
|
|
50
|
+
#elif defined(__GNUC__)
|
|
51
|
+
#pragma GCC push_options
|
|
52
|
+
#pragma GCC target("arch=armv8-a+simd+fp8dot4")
|
|
53
|
+
#endif
|
|
54
|
+
|
|
55
|
+
/**
|
|
56
|
+
* @brief Convert 16 E2M3 bytes (0b00SEEMMM) to E4M3 bytes (0bSEEEEMMM).
|
|
57
|
+
*
|
|
58
|
+
* Normal values (exp>0, mag>=8): rebias exponent by +6 → magnitude += 48.
|
|
59
|
+
* Subnormal values (exp=0, mag<8): 8-entry TBL lookup for normalization.
|
|
60
|
+
* Zero (mag=0): maps to E4M3 zero. Sign moved from bit 5 to bit 7.
|
|
61
|
+
*/
|
|
62
|
+
NK_INTERNAL uint8x16_t nk_e2m3x16_to_e4m3x16_neonfp8_(uint8x16_t raw_u8x16) {
|
|
63
|
+
uint8x16_t sign_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x20));
|
|
64
|
+
uint8x16_t mag_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
|
|
65
|
+
|
|
66
|
+
// Normal path: rebias exponent by +6 → add 48 to magnitude
|
|
67
|
+
uint8x16_t normal_mag_u8x16 = vaddq_u8(mag_u8x16, vdupq_n_u8(48));
|
|
68
|
+
|
|
69
|
+
// Subnormal path: 8-entry LUT for mag 0-7
|
|
70
|
+
// 0→0, 1→32, 2→40, 3→44, 4→48, 5→50, 6→52, 7→54
|
|
71
|
+
uint8x16_t sub_lut_u8x16 = vcombine_u8(vcreate_u8(0x363432302c282000ull), vcreate_u8(0));
|
|
72
|
+
uint8x16_t sub_mag_u8x16 = vqtbl1q_u8(sub_lut_u8x16, mag_u8x16);
|
|
73
|
+
|
|
74
|
+
// Select: subnormal (mag < 8) uses LUT, normal uses +48
|
|
75
|
+
uint8x16_t is_normal_u8x16 = vcgeq_u8(mag_u8x16, vdupq_n_u8(8));
|
|
76
|
+
uint8x16_t result_mag_u8x16 = vbslq_u8(is_normal_u8x16, normal_mag_u8x16, sub_mag_u8x16);
|
|
77
|
+
|
|
78
|
+
// Move sign from bit 5 to bit 7
|
|
79
|
+
uint8x16_t sign_shifted_u8x16 = vshlq_n_u8(sign_u8x16, 2);
|
|
80
|
+
return vorrq_u8(sign_shifted_u8x16, result_mag_u8x16);
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
/**
|
|
84
|
+
* @brief Convert 16 E3M2 bytes (0b00SEEEMM) to E5M2 bytes (0bSEEEEEMM).
|
|
85
|
+
*
|
|
86
|
+
* Normal values (exp>0, mag>=4): rebias exponent by +12 → magnitude += 48.
|
|
87
|
+
* Subnormal values (exp=0, mag<4): 4-entry TBL lookup for normalization.
|
|
88
|
+
* Zero (mag=0): maps to E5M2 zero. Sign moved from bit 5 to bit 7.
|
|
89
|
+
*/
|
|
90
|
+
NK_INTERNAL uint8x16_t nk_e3m2x16_to_e5m2x16_neonfp8_(uint8x16_t raw_u8x16) {
|
|
91
|
+
uint8x16_t sign_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x20));
|
|
92
|
+
uint8x16_t mag_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
|
|
93
|
+
|
|
94
|
+
// Normal path: rebias exponent by +12 → add 48 to magnitude
|
|
95
|
+
uint8x16_t normal_mag_u8x16 = vaddq_u8(mag_u8x16, vdupq_n_u8(48));
|
|
96
|
+
|
|
97
|
+
// Subnormal path: 4-entry LUT for mag 0-3
|
|
98
|
+
// 0→0, 1→44, 2→48, 3→50
|
|
99
|
+
uint8x16_t sub_lut_u8x16 = vcombine_u8(vcreate_u8(0x0000000032302c00ull), vcreate_u8(0));
|
|
100
|
+
uint8x16_t sub_mag_u8x16 = vqtbl1q_u8(sub_lut_u8x16, mag_u8x16);
|
|
101
|
+
|
|
102
|
+
// Select: subnormal (mag < 4) uses LUT, normal uses +48
|
|
103
|
+
uint8x16_t is_normal_u8x16 = vcgeq_u8(mag_u8x16, vdupq_n_u8(4));
|
|
104
|
+
uint8x16_t result_mag_u8x16 = vbslq_u8(is_normal_u8x16, normal_mag_u8x16, sub_mag_u8x16);
|
|
105
|
+
|
|
106
|
+
// Move sign from bit 5 to bit 7
|
|
107
|
+
uint8x16_t sign_shifted_u8x16 = vshlq_n_u8(sign_u8x16, 2);
|
|
108
|
+
return vorrq_u8(sign_shifted_u8x16, result_mag_u8x16);
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
NK_PUBLIC void nk_dot_e4m3_neonfp8(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars, nk_size_t count_scalars,
|
|
112
|
+
nk_f32_t *result) {
|
|
113
|
+
mfloat8x16_t a_mf8x16, b_mf8x16;
|
|
114
|
+
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
115
|
+
nk_dot_e4m3_neonfp8_cycle:
|
|
116
|
+
if (count_scalars < 16) {
|
|
117
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
118
|
+
nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
|
|
119
|
+
nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
|
|
120
|
+
a_mf8x16 = vreinterpretq_mf8_u8(a_vec.u8x16);
|
|
121
|
+
b_mf8x16 = vreinterpretq_mf8_u8(b_vec.u8x16);
|
|
122
|
+
count_scalars = 0;
|
|
123
|
+
}
|
|
124
|
+
else {
|
|
125
|
+
a_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)a_scalars));
|
|
126
|
+
b_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)b_scalars));
|
|
127
|
+
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
128
|
+
}
|
|
129
|
+
sum_f32x4 = vdotq_f32_mf8_fpm(sum_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E4M3_);
|
|
130
|
+
if (count_scalars) goto nk_dot_e4m3_neonfp8_cycle;
|
|
131
|
+
*result = vaddvq_f32(sum_f32x4);
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
NK_PUBLIC void nk_dot_e5m2_neonfp8(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
|
|
135
|
+
nk_f32_t *result) {
|
|
136
|
+
mfloat8x16_t a_mf8x16, b_mf8x16;
|
|
137
|
+
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
138
|
+
nk_dot_e5m2_neonfp8_cycle:
|
|
139
|
+
if (count_scalars < 16) {
|
|
140
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
141
|
+
nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
|
|
142
|
+
nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
|
|
143
|
+
a_mf8x16 = vreinterpretq_mf8_u8(a_vec.u8x16);
|
|
144
|
+
b_mf8x16 = vreinterpretq_mf8_u8(b_vec.u8x16);
|
|
145
|
+
count_scalars = 0;
|
|
146
|
+
}
|
|
147
|
+
else {
|
|
148
|
+
a_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)a_scalars));
|
|
149
|
+
b_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)b_scalars));
|
|
150
|
+
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
151
|
+
}
|
|
152
|
+
sum_f32x4 = vdotq_f32_mf8_fpm(sum_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E5M2_);
|
|
153
|
+
if (count_scalars) goto nk_dot_e5m2_neonfp8_cycle;
|
|
154
|
+
*result = vaddvq_f32(sum_f32x4);
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
NK_PUBLIC void nk_dot_e2m3_neonfp8(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
|
|
158
|
+
nk_f32_t *result) {
|
|
159
|
+
mfloat8x16_t a_mf8x16, b_mf8x16;
|
|
160
|
+
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
161
|
+
nk_dot_e2m3_neonfp8_cycle:
|
|
162
|
+
if (count_scalars < 16) {
|
|
163
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
164
|
+
nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
|
|
165
|
+
nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
|
|
166
|
+
a_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(a_vec.u8x16));
|
|
167
|
+
b_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(b_vec.u8x16));
|
|
168
|
+
count_scalars = 0;
|
|
169
|
+
}
|
|
170
|
+
else {
|
|
171
|
+
a_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(vld1q_u8((nk_u8_t const *)a_scalars)));
|
|
172
|
+
b_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(vld1q_u8((nk_u8_t const *)b_scalars)));
|
|
173
|
+
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
174
|
+
}
|
|
175
|
+
sum_f32x4 = vdotq_f32_mf8_fpm(sum_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E4M3_);
|
|
176
|
+
if (count_scalars) goto nk_dot_e2m3_neonfp8_cycle;
|
|
177
|
+
*result = vaddvq_f32(sum_f32x4);
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
NK_PUBLIC void nk_dot_e3m2_neonfp8(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars, nk_size_t count_scalars,
|
|
181
|
+
nk_f32_t *result) {
|
|
182
|
+
mfloat8x16_t a_mf8x16, b_mf8x16;
|
|
183
|
+
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
184
|
+
nk_dot_e3m2_neonfp8_cycle:
|
|
185
|
+
if (count_scalars < 16) {
|
|
186
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
187
|
+
nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
|
|
188
|
+
nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
|
|
189
|
+
a_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(a_vec.u8x16));
|
|
190
|
+
b_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(b_vec.u8x16));
|
|
191
|
+
count_scalars = 0;
|
|
192
|
+
}
|
|
193
|
+
else {
|
|
194
|
+
a_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(vld1q_u8((nk_u8_t const *)a_scalars)));
|
|
195
|
+
b_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(vld1q_u8((nk_u8_t const *)b_scalars)));
|
|
196
|
+
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
197
|
+
}
|
|
198
|
+
sum_f32x4 = vdotq_f32_mf8_fpm(sum_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E5M2_);
|
|
199
|
+
if (count_scalars) goto nk_dot_e3m2_neonfp8_cycle;
|
|
200
|
+
*result = vaddvq_f32(sum_f32x4);
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
typedef struct nk_dot_e4m3x16_state_neonfp8_t {
|
|
204
|
+
float32x4_t sum_f32x4;
|
|
205
|
+
} nk_dot_e4m3x16_state_neonfp8_t;
|
|
206
|
+
|
|
207
|
+
NK_INTERNAL void nk_dot_e4m3x16_init_neonfp8(nk_dot_e4m3x16_state_neonfp8_t *state) {
|
|
208
|
+
state->sum_f32x4 = vdupq_n_f32(0);
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
NK_INTERNAL void nk_dot_e4m3x16_update_neonfp8(nk_dot_e4m3x16_state_neonfp8_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
212
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
213
|
+
nk_unused_(depth_offset);
|
|
214
|
+
nk_unused_(active_dimensions);
|
|
215
|
+
mfloat8x16_t a_mf8x16 = vreinterpretq_mf8_u8(a.u8x16);
|
|
216
|
+
mfloat8x16_t b_mf8x16 = vreinterpretq_mf8_u8(b.u8x16);
|
|
217
|
+
state->sum_f32x4 = vdotq_f32_mf8_fpm(state->sum_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E4M3_);
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
NK_INTERNAL void nk_dot_e4m3x16_finalize_neonfp8( //
|
|
221
|
+
nk_dot_e4m3x16_state_neonfp8_t const *state_a, nk_dot_e4m3x16_state_neonfp8_t const *state_b, //
|
|
222
|
+
nk_dot_e4m3x16_state_neonfp8_t const *state_c, nk_dot_e4m3x16_state_neonfp8_t const *state_d, //
|
|
223
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
224
|
+
nk_unused_(total_dimensions);
|
|
225
|
+
float32x4_t ab_f32x4 = vpaddq_f32(state_a->sum_f32x4, state_b->sum_f32x4);
|
|
226
|
+
float32x4_t cd_f32x4 = vpaddq_f32(state_c->sum_f32x4, state_d->sum_f32x4);
|
|
227
|
+
result->f32x4 = vpaddq_f32(ab_f32x4, cd_f32x4);
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
typedef struct nk_dot_e5m2x16_state_neonfp8_t {
|
|
231
|
+
float32x4_t sum_f32x4;
|
|
232
|
+
} nk_dot_e5m2x16_state_neonfp8_t;
|
|
233
|
+
|
|
234
|
+
NK_INTERNAL void nk_dot_e5m2x16_init_neonfp8(nk_dot_e5m2x16_state_neonfp8_t *state) {
|
|
235
|
+
state->sum_f32x4 = vdupq_n_f32(0);
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
NK_INTERNAL void nk_dot_e5m2x16_update_neonfp8(nk_dot_e5m2x16_state_neonfp8_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
239
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
240
|
+
nk_unused_(depth_offset);
|
|
241
|
+
nk_unused_(active_dimensions);
|
|
242
|
+
mfloat8x16_t a_mf8x16 = vreinterpretq_mf8_u8(a.u8x16);
|
|
243
|
+
mfloat8x16_t b_mf8x16 = vreinterpretq_mf8_u8(b.u8x16);
|
|
244
|
+
state->sum_f32x4 = vdotq_f32_mf8_fpm(state->sum_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E5M2_);
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
NK_INTERNAL void nk_dot_e5m2x16_finalize_neonfp8( //
|
|
248
|
+
nk_dot_e5m2x16_state_neonfp8_t const *state_a, nk_dot_e5m2x16_state_neonfp8_t const *state_b, //
|
|
249
|
+
nk_dot_e5m2x16_state_neonfp8_t const *state_c, nk_dot_e5m2x16_state_neonfp8_t const *state_d, //
|
|
250
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
251
|
+
nk_unused_(total_dimensions);
|
|
252
|
+
float32x4_t ab_f32x4 = vpaddq_f32(state_a->sum_f32x4, state_b->sum_f32x4);
|
|
253
|
+
float32x4_t cd_f32x4 = vpaddq_f32(state_c->sum_f32x4, state_d->sum_f32x4);
|
|
254
|
+
result->f32x4 = vpaddq_f32(ab_f32x4, cd_f32x4);
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
typedef struct nk_dot_e2m3x16_state_neonfp8_t {
|
|
258
|
+
float32x4_t sum_f32x4;
|
|
259
|
+
} nk_dot_e2m3x16_state_neonfp8_t;
|
|
260
|
+
|
|
261
|
+
NK_INTERNAL void nk_dot_e2m3x16_init_neonfp8(nk_dot_e2m3x16_state_neonfp8_t *state) {
|
|
262
|
+
state->sum_f32x4 = vdupq_n_f32(0);
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
NK_INTERNAL void nk_dot_e2m3x16_update_neonfp8(nk_dot_e2m3x16_state_neonfp8_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
266
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
267
|
+
nk_unused_(depth_offset);
|
|
268
|
+
nk_unused_(active_dimensions);
|
|
269
|
+
mfloat8x16_t a_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(a.u8x16));
|
|
270
|
+
mfloat8x16_t b_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(b.u8x16));
|
|
271
|
+
state->sum_f32x4 = vdotq_f32_mf8_fpm(state->sum_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E4M3_);
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
NK_INTERNAL void nk_dot_e2m3x16_finalize_neonfp8( //
|
|
275
|
+
nk_dot_e2m3x16_state_neonfp8_t const *state_a, nk_dot_e2m3x16_state_neonfp8_t const *state_b, //
|
|
276
|
+
nk_dot_e2m3x16_state_neonfp8_t const *state_c, nk_dot_e2m3x16_state_neonfp8_t const *state_d, //
|
|
277
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
278
|
+
nk_unused_(total_dimensions);
|
|
279
|
+
float32x4_t ab_f32x4 = vpaddq_f32(state_a->sum_f32x4, state_b->sum_f32x4);
|
|
280
|
+
float32x4_t cd_f32x4 = vpaddq_f32(state_c->sum_f32x4, state_d->sum_f32x4);
|
|
281
|
+
result->f32x4 = vpaddq_f32(ab_f32x4, cd_f32x4);
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
typedef struct nk_dot_e3m2x16_state_neonfp8_t {
|
|
285
|
+
float32x4_t sum_f32x4;
|
|
286
|
+
} nk_dot_e3m2x16_state_neonfp8_t;
|
|
287
|
+
|
|
288
|
+
NK_INTERNAL void nk_dot_e3m2x16_init_neonfp8(nk_dot_e3m2x16_state_neonfp8_t *state) {
|
|
289
|
+
state->sum_f32x4 = vdupq_n_f32(0);
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
NK_INTERNAL void nk_dot_e3m2x16_update_neonfp8(nk_dot_e3m2x16_state_neonfp8_t *state, nk_b128_vec_t a, nk_b128_vec_t b,
|
|
293
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
294
|
+
nk_unused_(depth_offset);
|
|
295
|
+
nk_unused_(active_dimensions);
|
|
296
|
+
mfloat8x16_t a_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(a.u8x16));
|
|
297
|
+
mfloat8x16_t b_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(b.u8x16));
|
|
298
|
+
state->sum_f32x4 = vdotq_f32_mf8_fpm(state->sum_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E5M2_);
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
NK_INTERNAL void nk_dot_e3m2x16_finalize_neonfp8( //
|
|
302
|
+
nk_dot_e3m2x16_state_neonfp8_t const *state_a, nk_dot_e3m2x16_state_neonfp8_t const *state_b, //
|
|
303
|
+
nk_dot_e3m2x16_state_neonfp8_t const *state_c, nk_dot_e3m2x16_state_neonfp8_t const *state_d, //
|
|
304
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
305
|
+
nk_unused_(total_dimensions);
|
|
306
|
+
float32x4_t ab_f32x4 = vpaddq_f32(state_a->sum_f32x4, state_b->sum_f32x4);
|
|
307
|
+
float32x4_t cd_f32x4 = vpaddq_f32(state_c->sum_f32x4, state_d->sum_f32x4);
|
|
308
|
+
result->f32x4 = vpaddq_f32(ab_f32x4, cd_f32x4);
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
#if defined(__clang__)
|
|
312
|
+
#pragma clang attribute pop
|
|
313
|
+
#elif defined(__GNUC__)
|
|
314
|
+
#pragma GCC pop_options
|
|
315
|
+
#endif
|
|
316
|
+
|
|
317
|
+
#if defined(__cplusplus)
|
|
318
|
+
} // extern "C"
|
|
319
|
+
#endif
|
|
320
|
+
|
|
321
|
+
#endif // NK_TARGET_NEONFP8
|
|
322
|
+
#endif // NK_TARGET_ARM_
|
|
323
|
+
#endif // NK_DOT_NEONFP8_H
|