numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -84,14 +84,14 @@
|
|
|
84
84
|
* rounding (notably 3x faster on Genoa than Ice Lake). VFPCLASS detects NaN/Inf inputs for special
|
|
85
85
|
* case handling. Division appears in tangent's final step but isn't on the critical path.
|
|
86
86
|
*
|
|
87
|
-
* Intrinsic Instruction
|
|
88
|
-
* _mm512_roundscale_ps VRNDSCALEPS (ZMM, ZMM, I8)
|
|
89
|
-
* _mm512_roundscale_pd VRNDSCALEPD (ZMM, ZMM, I8)
|
|
90
|
-
* _mm512_fpclass_ps_mask VFPCLASSPS (K, ZMM, I8)
|
|
91
|
-
* _mm512_fmadd_ps VFMADD231PS (ZMM, ZMM, ZMM)
|
|
92
|
-
* _mm256_fmadd_ps VFMADD231PS (YMM, YMM, YMM)
|
|
93
|
-
* _mm256_div_ps VDIVPS (YMM, YMM, YMM)
|
|
94
|
-
* _mm256_div_pd VDIVPD (YMM, YMM, YMM)
|
|
87
|
+
* Intrinsic Instruction Icelake Genoa
|
|
88
|
+
* _mm512_roundscale_ps VRNDSCALEPS (ZMM, ZMM, I8) 8cy @ p0+p0 3cy @ p23
|
|
89
|
+
* _mm512_roundscale_pd VRNDSCALEPD (ZMM, ZMM, I8) 8cy @ p0+p0 3cy @ p23
|
|
90
|
+
* _mm512_fpclass_ps_mask VFPCLASSPS (K, ZMM, I8) 3cy @ p5 5cy @ p01
|
|
91
|
+
* _mm512_fmadd_ps VFMADD231PS (ZMM, ZMM, ZMM) 4cy @ p0 4cy @ p01
|
|
92
|
+
* _mm256_fmadd_ps VFMADD231PS (YMM, YMM, YMM) 4cy @ p01 4cy @ p01
|
|
93
|
+
* _mm256_div_ps VDIVPS (YMM, YMM, YMM) ~11cy @ p0 ~11cy @ p01
|
|
94
|
+
* _mm256_div_pd VDIVPD (YMM, YMM, YMM) ~13cy @ p0 ~13cy @ p01
|
|
95
95
|
*
|
|
96
96
|
* @section arm_instructions Relevant ARM NEON/SVE Instructions
|
|
97
97
|
*
|
|
@@ -99,14 +99,14 @@
|
|
|
99
99
|
* fast rounding for range reduction. The 4-cycle FMA latency with 4 inst/cycle throughput allows
|
|
100
100
|
* excellent pipelining when processing multiple elements.
|
|
101
101
|
*
|
|
102
|
-
* Intrinsic
|
|
103
|
-
* vfmaq_f32
|
|
104
|
-
* vfmaq_f64
|
|
105
|
-
* vrndaq_f32
|
|
102
|
+
* Intrinsic Instruction M1 Firestorm Graviton 3 Graviton 4
|
|
103
|
+
* vfmaq_f32 FMLA.S (vec) 4cy @ V0123 4cy @ V0123 4cy @ V0123
|
|
104
|
+
* vfmaq_f64 FMLA.D (vec) 4cy @ V0123 4cy @ V0123 4cy @ V0123
|
|
105
|
+
* vrndaq_f32 FRINTA.S 2cy @ V0123 2cy @ V01 2cy @ V01
|
|
106
106
|
*
|
|
107
107
|
* @section references References
|
|
108
108
|
*
|
|
109
|
-
* - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
|
|
109
|
+
* - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html
|
|
110
110
|
* - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
|
|
111
111
|
*
|
|
112
112
|
*/
|
|
@@ -91,7 +91,7 @@ void atan(in_type_ const *in, std::size_t n, in_type_ *out) noexcept {
|
|
|
91
91
|
|
|
92
92
|
namespace ashvardanian::numkong {
|
|
93
93
|
|
|
94
|
-
#pragma region
|
|
94
|
+
#pragma region Tensor Trigonometric
|
|
95
95
|
|
|
96
96
|
/** @brief Elementwise sin into pre-allocated output. */
|
|
97
97
|
template <numeric_dtype value_type_, std::size_t max_rank_ = 8>
|
|
@@ -159,7 +159,7 @@ tensor<value_type_, allocator_type_, max_rank_> try_atan(tensor_view<value_type_
|
|
|
159
159
|
return result;
|
|
160
160
|
}
|
|
161
161
|
|
|
162
|
-
#pragma endregion
|
|
162
|
+
#pragma endregion Tensor Trigonometric
|
|
163
163
|
|
|
164
164
|
} // namespace ashvardanian::numkong
|
|
165
165
|
|
package/include/numkong/types.h
CHANGED
|
@@ -36,6 +36,29 @@
|
|
|
36
36
|
* @see https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1
|
|
37
37
|
* @see FP8 Formats for Deep Learning: https://arxiv.org/pdf/2209.05433
|
|
38
38
|
* @see ONNX Float8 Types: https://onnx.ai/onnx/technical/float8.html
|
|
39
|
+
*
|
|
40
|
+
* @section fp6_types FP6 Numeric Types
|
|
41
|
+
*
|
|
42
|
+
* The OCP Microscaling (MX) v1.0 specification defines two 6-bit floating-point formats
|
|
43
|
+
* for block-scaled quantization. Both are "FN" (finite-numeric): all bit patterns map
|
|
44
|
+
* to real numbers with no Inf or NaN codes. Stored byte-aligned with 2 bits of padding.
|
|
45
|
+
*
|
|
46
|
+
* Format Bias Sign Exp Mant Range Subnormals Infinity NaN Standard
|
|
47
|
+
* E2M3 1 1 2 3 ±7.5 14 of 64 ❌ No ❌ OCP MX v1.0
|
|
48
|
+
* E3M2 3 1 3 2 ±28 6 of 64 ❌ No ❌ OCP MX v1.0
|
|
49
|
+
*
|
|
50
|
+
* E2M3 favors mantissa precision (3 bits) for narrow dynamic range — ideal for activations.
|
|
51
|
+
* E3M2 favors exponent range (3 bits) for wider dynamic range — suited for weights.
|
|
52
|
+
* Both follow IEEE 754 subnormal rules: when exp=0, the implicit leading bit is 0,
|
|
53
|
+
* giving value = (-1)^s × 0.mmm × 2^(1-bias). This provides gradual underflow to zero.
|
|
54
|
+
*
|
|
55
|
+
* No hardware directly computes on FP6. On Arm with FEAT_FP8DOT4, E2M3 values can be
|
|
56
|
+
* losslessly promoted to E4M3 (same mantissa width, rebias exponent by +6) and E3M2 to
|
|
57
|
+
* E5M2 (same mantissa width, rebias exponent by +12), then fed to FDOT instructions.
|
|
58
|
+
* Subnormal values (exp=0) require normalization during this promotion.
|
|
59
|
+
*
|
|
60
|
+
* @see https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
|
61
|
+
* @see https://arxiv.org/abs/2401.14112 (FP6-LLM paper)
|
|
39
62
|
*/
|
|
40
63
|
#ifndef NK_TYPES_H
|
|
41
64
|
#define NK_TYPES_H
|
|
@@ -68,6 +91,9 @@
|
|
|
68
91
|
#if defined(__GNUC__) || defined(__clang__)
|
|
69
92
|
#define NK_PUBLIC __attribute__((unused)) inline static
|
|
70
93
|
#define NK_INTERNAL __attribute__((always_inline)) inline static
|
|
94
|
+
#elif defined(_MSC_VER)
|
|
95
|
+
#define NK_PUBLIC inline static
|
|
96
|
+
#define NK_INTERNAL __forceinline static
|
|
71
97
|
#else
|
|
72
98
|
#define NK_PUBLIC inline static
|
|
73
99
|
#define NK_INTERNAL inline static
|
|
@@ -85,6 +111,14 @@
|
|
|
85
111
|
#define NK_DYNAMIC NK_PUBLIC
|
|
86
112
|
#endif // NK_DYNAMIC_DISPATCH
|
|
87
113
|
|
|
114
|
+
// Vector union types use type punning by design (write as f16, read as f32, etc.).
|
|
115
|
+
// Without this, GCC at -O2 assumes strict aliasing and may optimize away valid accesses.
|
|
116
|
+
#if defined(__GNUC__) || defined(__clang__)
|
|
117
|
+
#define NK_MAY_ALIAS_ __attribute__((may_alias))
|
|
118
|
+
#else
|
|
119
|
+
#define NK_MAY_ALIAS_
|
|
120
|
+
#endif
|
|
121
|
+
|
|
88
122
|
// Allow SIMD kernels to redirect small inputs to serial implementations.
|
|
89
123
|
// Enabled by default for production use. Tests and benchmarks may disable
|
|
90
124
|
// this to isolate SIMD path behavior on small inputs.
|
|
@@ -93,6 +127,7 @@
|
|
|
93
127
|
#endif
|
|
94
128
|
|
|
95
129
|
// Compiling for Arm: NK_TARGET_ARM_
|
|
130
|
+
// https://arm-software.github.io/acle/main/acle.html
|
|
96
131
|
#if !defined(NK_TARGET_ARM_)
|
|
97
132
|
#if defined(__aarch64__) || defined(_M_ARM64)
|
|
98
133
|
#define NK_TARGET_ARM_ 1
|
|
@@ -102,6 +137,7 @@
|
|
|
102
137
|
#endif // !defined(NK_TARGET_ARM_)
|
|
103
138
|
|
|
104
139
|
// Compiling for x86: NK_TARGET_X86_
|
|
140
|
+
// https://www.intel.com/content/www/us/en/docs/dpcpp-cpp-compiler/developer-guide-reference/2024-2/additional-predefined-macros.html
|
|
105
141
|
#if !defined(NK_TARGET_X86_)
|
|
106
142
|
#if defined(__x86_64__) || defined(_M_X64)
|
|
107
143
|
#define NK_TARGET_X86_ 1
|
|
@@ -119,6 +155,24 @@
|
|
|
119
155
|
#endif // defined(__riscv) && (__riscv_xlen == 64)
|
|
120
156
|
#endif // !defined(NK_TARGET_RISCV_)
|
|
121
157
|
|
|
158
|
+
// Compiling for LoongArch: NK_TARGET_LOONGARCH_
|
|
159
|
+
#if !defined(NK_TARGET_LOONGARCH_)
|
|
160
|
+
#if defined(__loongarch__)
|
|
161
|
+
#define NK_TARGET_LOONGARCH_ 1
|
|
162
|
+
#else
|
|
163
|
+
#define NK_TARGET_LOONGARCH_ 0
|
|
164
|
+
#endif // defined(__loongarch__)
|
|
165
|
+
#endif // !defined(NK_TARGET_LOONGARCH_)
|
|
166
|
+
|
|
167
|
+
// Compiling for Power: NK_TARGET_POWER_
|
|
168
|
+
#if !defined(NK_TARGET_POWER_)
|
|
169
|
+
#if defined(__powerpc64__) || defined(__ppc64__) || defined(_ARCH_PPC64)
|
|
170
|
+
#define NK_TARGET_POWER_ 1
|
|
171
|
+
#else
|
|
172
|
+
#define NK_TARGET_POWER_ 0
|
|
173
|
+
#endif // defined(__powerpc64__) || defined(__ppc64__) || defined(_ARCH_PPC64)
|
|
174
|
+
#endif // !defined(NK_TARGET_POWER_)
|
|
175
|
+
|
|
122
176
|
// Compiling for WASM: NK_TARGET_WASM_
|
|
123
177
|
#if !defined(NK_TARGET_WASM_)
|
|
124
178
|
#if defined(__wasm__) || defined(__EMSCRIPTEN__)
|
|
@@ -191,56 +245,93 @@
|
|
|
191
245
|
#endif // defined(__riscv_zvbb) && (__riscv_zvbb > 0)
|
|
192
246
|
#endif // !defined(NK_TARGET_RVVBB) || ...
|
|
193
247
|
|
|
248
|
+
// Compiling for LoongArch LASX (256-bit SIMD): NK_TARGET_LOONGSONASX
|
|
249
|
+
// LASX provides 32 × 256-bit vector registers, widening integer multiply-accumulate,
|
|
250
|
+
// and f32-to-f64 conversion (xvfcvtl_d_s / xvfcvth_d_s) but no widening FMA.
|
|
251
|
+
#if !defined(NK_TARGET_LOONGSONASX) || (NK_TARGET_LOONGSONASX && !NK_TARGET_LOONGARCH_)
|
|
252
|
+
#if defined(__loongarch_asx)
|
|
253
|
+
#define NK_TARGET_LOONGSONASX 1
|
|
254
|
+
#else
|
|
255
|
+
#undef NK_TARGET_LOONGSONASX
|
|
256
|
+
#define NK_TARGET_LOONGSONASX 0
|
|
257
|
+
#endif // defined(__loongarch_asx)
|
|
258
|
+
#endif // !defined(NK_TARGET_LOONGSONASX) || ...
|
|
259
|
+
|
|
260
|
+
// Compiling for Power VSX (128-bit SIMD, POWER9+ baseline): NK_TARGET_POWERVSX
|
|
261
|
+
// VSX provides 64 × 128-bit registers, FMA (vec_madd), vec_msum (multiply-sum), hardware f16
|
|
262
|
+
// conversion (vec_extract_fp32_from_shorth/l), length-limited loads (vec_xl_len), per-byte
|
|
263
|
+
// popcount (vec_popcnt), and vec_cmpne. Requires POWER9 (ISA 3.0) or newer.
|
|
264
|
+
#if !defined(NK_TARGET_POWERVSX) || (NK_TARGET_POWERVSX && !NK_TARGET_POWER_)
|
|
265
|
+
#if defined(__VSX__) && defined(__POWER9_VECTOR__)
|
|
266
|
+
#define NK_TARGET_POWERVSX 1
|
|
267
|
+
#else
|
|
268
|
+
#undef NK_TARGET_POWERVSX
|
|
269
|
+
#define NK_TARGET_POWERVSX 0
|
|
270
|
+
#endif // defined(__VSX__)
|
|
271
|
+
#endif // !defined(NK_TARGET_POWERVSX) || ...
|
|
272
|
+
|
|
194
273
|
// Compiling for Arm: NK_TARGET_NEON
|
|
195
274
|
#if !defined(NK_TARGET_NEON) || (NK_TARGET_NEON && !NK_TARGET_ARM_)
|
|
196
|
-
#if defined(__ARM_NEON)
|
|
275
|
+
#if defined(__ARM_NEON) || (defined(_MSC_VER) && defined(_M_ARM64))
|
|
197
276
|
#define NK_TARGET_NEON 1
|
|
198
277
|
#else
|
|
199
278
|
#undef NK_TARGET_NEON
|
|
200
279
|
#define NK_TARGET_NEON 0
|
|
201
|
-
#endif // defined(__ARM_NEON)
|
|
280
|
+
#endif // defined(__ARM_NEON) || ...
|
|
202
281
|
#endif // !defined(NK_TARGET_NEON) || ...
|
|
203
282
|
|
|
204
|
-
// Compiling for Arm: NK_TARGET_NEONSDOT
|
|
283
|
+
// Compiling for Arm: NK_TARGET_NEONSDOT (FEAT_DotProd, optional from ARMv8.1, mandatory at ARMv8.4 with AdvSIMD)
|
|
205
284
|
#if !defined(NK_TARGET_NEONSDOT) || (NK_TARGET_NEONSDOT && !NK_TARGET_ARM_)
|
|
206
|
-
#if defined(
|
|
285
|
+
#if defined(__ARM_FEATURE_DOTPROD) || (defined(_MSC_VER) && defined(_M_ARM64) && __ARM_ARCH >= 804)
|
|
207
286
|
#define NK_TARGET_NEONSDOT 1
|
|
208
287
|
#else
|
|
209
288
|
#undef NK_TARGET_NEONSDOT
|
|
210
289
|
#define NK_TARGET_NEONSDOT 0
|
|
211
|
-
#endif
|
|
290
|
+
#endif
|
|
212
291
|
#endif // !defined(NK_TARGET_NEONSDOT) || ...
|
|
213
292
|
|
|
214
|
-
// Compiling for Arm: NK_TARGET_NEONHALF
|
|
293
|
+
// Compiling for Arm: NK_TARGET_NEONHALF (FEAT_FP16, optional from ARMv8.2, mandatory at ARMv9.0 with AdvSIMD)
|
|
215
294
|
#if !defined(NK_TARGET_NEONHALF) || (NK_TARGET_NEONHALF && !NK_TARGET_ARM_)
|
|
216
|
-
#if defined(
|
|
295
|
+
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || (defined(_MSC_VER) && defined(_M_ARM64) && __ARM_ARCH >= 802)
|
|
217
296
|
#define NK_TARGET_NEONHALF 1
|
|
218
297
|
#else
|
|
219
298
|
#undef NK_TARGET_NEONHALF
|
|
220
299
|
#define NK_TARGET_NEONHALF 0
|
|
221
|
-
#endif
|
|
300
|
+
#endif
|
|
222
301
|
#endif // !defined(NK_TARGET_NEONHALF) || ...
|
|
223
302
|
|
|
224
|
-
// Compiling for Arm: NK_TARGET_NEONFHM (FEAT_FHM
|
|
303
|
+
// Compiling for Arm: NK_TARGET_NEONFHM (FEAT_FHM, optional from ARMv8.1, mandatory at ARMv8.4 with FP16)
|
|
225
304
|
#if !defined(NK_TARGET_NEONFHM) || (NK_TARGET_NEONFHM && !NK_TARGET_ARM_)
|
|
226
|
-
#if defined(
|
|
305
|
+
#if defined(__ARM_FEATURE_FP16_FML) || (defined(_MSC_VER) && defined(_M_ARM64) && __ARM_ARCH >= 804)
|
|
227
306
|
#define NK_TARGET_NEONFHM 1
|
|
228
307
|
#else
|
|
229
308
|
#undef NK_TARGET_NEONFHM
|
|
230
309
|
#define NK_TARGET_NEONFHM 0
|
|
231
|
-
#endif
|
|
310
|
+
#endif
|
|
232
311
|
#endif // !defined(NK_TARGET_NEONFHM) || ...
|
|
233
312
|
|
|
234
|
-
// Compiling for Arm: NK_TARGET_NEONBFDOT
|
|
313
|
+
// Compiling for Arm: NK_TARGET_NEONBFDOT (FEAT_BF16, optional from ARMv8.2, mandatory at ARMv8.6 with FP)
|
|
235
314
|
#if !defined(NK_TARGET_NEONBFDOT) || (NK_TARGET_NEONBFDOT && !NK_TARGET_ARM_)
|
|
236
|
-
#if defined(
|
|
315
|
+
#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || (defined(_MSC_VER) && defined(_M_ARM64) && __ARM_ARCH >= 806)
|
|
237
316
|
#define NK_TARGET_NEONBFDOT 1
|
|
238
317
|
#else
|
|
239
318
|
#undef NK_TARGET_NEONBFDOT
|
|
240
319
|
#define NK_TARGET_NEONBFDOT 0
|
|
241
|
-
#endif
|
|
320
|
+
#endif
|
|
242
321
|
#endif // !defined(NK_TARGET_NEONBFDOT) || ...
|
|
243
322
|
|
|
323
|
+
// Compiling for Arm: NK_TARGET_NEONFP8 (NEON FP8 extensions, FEAT_FP8DOT4)
|
|
324
|
+
// ACLE macro __ARM_FEATURE_FP8DOT4 defined by GCC 15+ and Clang 21+ when +fp8dot4 is enabled.
|
|
325
|
+
// Older compilers lack mfloat8x16_t and the fp8dot4 target attribute entirely.
|
|
326
|
+
#if !defined(NK_TARGET_NEONFP8) || (NK_TARGET_NEONFP8 && !NK_TARGET_ARM_)
|
|
327
|
+
#if defined(__ARM_FEATURE_FP8DOT4)
|
|
328
|
+
#define NK_TARGET_NEONFP8 1
|
|
329
|
+
#else
|
|
330
|
+
#undef NK_TARGET_NEONFP8
|
|
331
|
+
#define NK_TARGET_NEONFP8 0
|
|
332
|
+
#endif // defined(__ARM_FEATURE_FP8DOT4)
|
|
333
|
+
#endif // !defined(NK_TARGET_NEONFP8) || ...
|
|
334
|
+
|
|
244
335
|
// Compiling for Arm: NK_TARGET_SVE
|
|
245
336
|
#if !defined(NK_TARGET_SVE) || (NK_TARGET_SVE && !NK_TARGET_ARM_)
|
|
246
337
|
#if defined(__ARM_FEATURE_SVE)
|
|
@@ -316,20 +407,26 @@
|
|
|
316
407
|
#endif // defined(__ARM_FEATURE_SME2)
|
|
317
408
|
#endif // !defined(NK_TARGET_SME2) || ...
|
|
318
409
|
|
|
410
|
+
// Compiling for Arm: NK_TARGET_SME2P1 (FEAT_SME2p1)
|
|
411
|
+
// ACLE macro: __ARM_FEATURE_SME2p1 (note lowercase 'p')
|
|
319
412
|
#if !defined(NK_TARGET_SME2P1) || (NK_TARGET_SME2P1 && !NK_TARGET_ARM_)
|
|
413
|
+
#if defined(__ARM_FEATURE_SME2p1)
|
|
414
|
+
#define NK_TARGET_SME2P1 1
|
|
415
|
+
#else
|
|
320
416
|
#undef NK_TARGET_SME2P1
|
|
321
417
|
#define NK_TARGET_SME2P1 0
|
|
322
|
-
#endif
|
|
418
|
+
#endif // defined(__ARM_FEATURE_SME2p1)
|
|
419
|
+
#endif // !defined(NK_TARGET_SME2P1) || ...
|
|
323
420
|
|
|
324
421
|
// AppleClang 17 exposes SME sub-features through `arm_sme.h` builtin aliases,
|
|
325
422
|
// not dedicated `__ARM_FEATURE_*` predefines for every matrix subtype.
|
|
326
423
|
#if !defined(NK_TARGET_SMEF64) || (NK_TARGET_SMEF64 && !NK_TARGET_ARM_)
|
|
327
|
-
#if defined(__has_builtin) && __has_builtin(__builtin_sme_svmopa_za64_f64_m)
|
|
424
|
+
#if defined(__ARM_FEATURE_SME_F64F64) || (defined(__has_builtin) && __has_builtin(__builtin_sme_svmopa_za64_f64_m))
|
|
328
425
|
#define NK_TARGET_SMEF64 1
|
|
329
426
|
#else
|
|
330
427
|
#undef NK_TARGET_SMEF64
|
|
331
428
|
#define NK_TARGET_SMEF64 0
|
|
332
|
-
#endif // defined(
|
|
429
|
+
#endif // defined(__ARM_FEATURE_SME_F64F64) || ...
|
|
333
430
|
#endif // !defined(NK_TARGET_SMEF64) || ...
|
|
334
431
|
|
|
335
432
|
#if !defined(NK_TARGET_SMEBI32) || (NK_TARGET_SMEBI32 && !NK_TARGET_ARM_)
|
|
@@ -342,7 +439,7 @@
|
|
|
342
439
|
#endif // !defined(NK_TARGET_SMEBI32) || ...
|
|
343
440
|
|
|
344
441
|
#if !defined(NK_TARGET_SMEHALF) || (NK_TARGET_SMEHALF && !NK_TARGET_ARM_)
|
|
345
|
-
#if defined(__has_builtin) && __has_builtin(__builtin_sme_svmopa_za32_f16_m)
|
|
442
|
+
#if defined(__ARM_FEATURE_SME_F16F16) || (defined(__has_builtin) && __has_builtin(__builtin_sme_svmopa_za32_f16_m))
|
|
346
443
|
#define NK_TARGET_SMEHALF 1
|
|
347
444
|
#else
|
|
348
445
|
#undef NK_TARGET_SMEHALF
|
|
@@ -368,10 +465,15 @@
|
|
|
368
465
|
#endif // defined(__has_builtin) && __has_builtin(__builtin_sme_svluti2_lane_zt_u8)
|
|
369
466
|
#endif // !defined(NK_TARGET_SMELUT2) || ...
|
|
370
467
|
|
|
468
|
+
// Compiling for Arm: NK_TARGET_SMEFA64 (FEAT_SME_FA64, full SVE2 in streaming mode)
|
|
371
469
|
#if !defined(NK_TARGET_SMEFA64) || (NK_TARGET_SMEFA64 && !NK_TARGET_ARM_)
|
|
470
|
+
#if defined(__ARM_FEATURE_SME_FA64)
|
|
471
|
+
#define NK_TARGET_SMEFA64 1
|
|
472
|
+
#else
|
|
372
473
|
#undef NK_TARGET_SMEFA64
|
|
373
474
|
#define NK_TARGET_SMEFA64 0
|
|
374
|
-
#endif
|
|
475
|
+
#endif // defined(__ARM_FEATURE_SME_FA64)
|
|
476
|
+
#endif // !defined(NK_TARGET_SMEFA64) || ...
|
|
375
477
|
|
|
376
478
|
// Compiling for x86: NK_TARGET_HASWELL
|
|
377
479
|
//
|
|
@@ -433,9 +535,22 @@
|
|
|
433
535
|
#else
|
|
434
536
|
#undef NK_TARGET_GENOA
|
|
435
537
|
#define NK_TARGET_GENOA 0
|
|
436
|
-
#endif
|
|
538
|
+
#endif // defined(__AVX512BF16__) || ...
|
|
437
539
|
#endif // !defined(NK_TARGET_GENOA) || ...
|
|
438
540
|
|
|
541
|
+
// Compiling for x86: NK_TARGET_DIAMOND (AVX10.2, Diamond Rapids)
|
|
542
|
+
// GCC 14+: defines __AVX10_2__ with -mavx10.2-512
|
|
543
|
+
// Clang 19+: defines __AVX10_2__ with -mavx10.2-512
|
|
544
|
+
// MSVC: defines __AVX10_VER__ >= 2 with /arch:AVX10.2 (VS 2026+, not yet released)
|
|
545
|
+
#if !defined(NK_TARGET_DIAMOND) || (NK_TARGET_DIAMOND && !NK_TARGET_X86_)
|
|
546
|
+
#if defined(__AVX10_2__) || (defined(__AVX10_VER__) && __AVX10_VER__ >= 2)
|
|
547
|
+
#define NK_TARGET_DIAMOND 1
|
|
548
|
+
#else
|
|
549
|
+
#undef NK_TARGET_DIAMOND
|
|
550
|
+
#define NK_TARGET_DIAMOND 0
|
|
551
|
+
#endif // defined(__AVX10_2__) || ...
|
|
552
|
+
#endif // !defined(NK_TARGET_DIAMOND) || ...
|
|
553
|
+
|
|
439
554
|
#if !defined(NK_TARGET_SAPPHIRE) || (NK_TARGET_SAPPHIRE && !NK_TARGET_X86_)
|
|
440
555
|
#if defined(__AVX512FP16__) || (defined(_MSC_VER) && _MSC_VER >= 1944)
|
|
441
556
|
#define NK_TARGET_SAPPHIRE 1
|
|
@@ -490,10 +605,10 @@
|
|
|
490
605
|
#endif
|
|
491
606
|
#endif // !defined(NK_TARGET_SIERRA) || ...
|
|
492
607
|
|
|
493
|
-
// Include the relevant intrinsics
|
|
608
|
+
// Include the relevant intrinsics headers
|
|
494
609
|
#if defined(_MSC_VER)
|
|
495
610
|
#include <intrin.h>
|
|
496
|
-
#
|
|
611
|
+
#endif
|
|
497
612
|
#if NK_TARGET_NEON
|
|
498
613
|
#include <arm_neon.h>
|
|
499
614
|
#endif
|
|
@@ -503,11 +618,20 @@
|
|
|
503
618
|
#if NK_TARGET_SME || NK_TARGET_SME2 || NK_TARGET_SMEBI32
|
|
504
619
|
#include <arm_sme.h>
|
|
505
620
|
#endif
|
|
506
|
-
#
|
|
621
|
+
#if NK_TARGET_HASWELL || NK_TARGET_SKYLAKE
|
|
507
622
|
#include <immintrin.h>
|
|
508
|
-
#
|
|
623
|
+
#endif
|
|
624
|
+
#if NK_TARGET_RVV
|
|
509
625
|
#include <riscv_vector.h>
|
|
510
|
-
#
|
|
626
|
+
#endif
|
|
627
|
+
#if NK_TARGET_LOONGSONASX
|
|
628
|
+
#include <lsxintrin.h> // `__m128i` for LSX SIMD
|
|
629
|
+
#include <lasxintrin.h> // `__m256i` for LASX SIMD
|
|
630
|
+
#endif
|
|
631
|
+
#if NK_TARGET_POWERVSX
|
|
632
|
+
#include <altivec.h>
|
|
633
|
+
#endif
|
|
634
|
+
#if NK_TARGET_V128RELAXED
|
|
511
635
|
#include <wasm_simd128.h>
|
|
512
636
|
#endif
|
|
513
637
|
|
|
@@ -516,11 +640,11 @@
|
|
|
516
640
|
#endif
|
|
517
641
|
|
|
518
642
|
#if !defined(NK_F32_DIVISION_EPSILON)
|
|
519
|
-
#define NK_F32_DIVISION_EPSILON (1e-
|
|
643
|
+
#define NK_F32_DIVISION_EPSILON (1e-7f)
|
|
520
644
|
#endif
|
|
521
645
|
|
|
522
646
|
#if !defined(NK_F16_DIVISION_EPSILON)
|
|
523
|
-
#define NK_F16_DIVISION_EPSILON (1e-
|
|
647
|
+
#define NK_F16_DIVISION_EPSILON (1e-3f)
|
|
524
648
|
#endif
|
|
525
649
|
|
|
526
650
|
/**
|
|
@@ -576,6 +700,27 @@
|
|
|
576
700
|
#endif
|
|
577
701
|
#endif
|
|
578
702
|
|
|
703
|
+
/* AltiVec defines `bool`, `vector`, and `pixel` as macros, which conflict with C++.
|
|
704
|
+
* We use `__vector` directly in our code, so undef the problematic macros.
|
|
705
|
+
*/
|
|
706
|
+
#if NK_TARGET_POWERVSX
|
|
707
|
+
#ifdef __cplusplus
|
|
708
|
+
#undef bool
|
|
709
|
+
#undef vector
|
|
710
|
+
#undef pixel
|
|
711
|
+
#endif
|
|
712
|
+
typedef __vector unsigned char nk_vu8x16_t;
|
|
713
|
+
typedef __vector unsigned short nk_vu16x8_t;
|
|
714
|
+
typedef __vector unsigned int nk_vu32x4_t;
|
|
715
|
+
typedef __vector unsigned long long nk_vu64x2_t;
|
|
716
|
+
typedef __vector signed char nk_vi8x16_t;
|
|
717
|
+
typedef __vector signed short nk_vi16x8_t;
|
|
718
|
+
typedef __vector signed int nk_vi32x4_t;
|
|
719
|
+
typedef __vector signed long long nk_vi64x2_t;
|
|
720
|
+
typedef __vector float nk_vf32x4_t;
|
|
721
|
+
typedef __vector double nk_vf64x2_t;
|
|
722
|
+
#endif // NK_TARGET_POWERVSX
|
|
723
|
+
|
|
579
724
|
/** Copy 16 bits (2 bytes) from source to destination */
|
|
580
725
|
#if defined(__GNUC__) || defined(__clang__)
|
|
581
726
|
#define nk_copy_bytes_(destination_ptr, source_ptr, count) __builtin_memcpy((destination_ptr), (source_ptr), count)
|
|
@@ -632,10 +777,16 @@ typedef unsigned char nk_e4m3_t;
|
|
|
632
777
|
* 122 of 248 finite values (49.2%) fall in [−1, +1]. */
|
|
633
778
|
typedef unsigned char nk_e5m2_t;
|
|
634
779
|
/** @brief 6-bit E2M3 micro-float (OCP MX v1.0): sign(1) + exponent(2) + mantissa(3), bias=1.
|
|
635
|
-
* Range: ±7.5, no infinities or NaN.
|
|
780
|
+
* Stored as 0b00SEEMMM with 2 bits of padding. Range: ±7.5, no infinities or NaN.
|
|
781
|
+
* 64 total codes: 48 normal, 14 subnormal (exp=0, mant≠0), 2 zeros (±0).
|
|
782
|
+
* 18 of 64 values (28.1%) fall in [−1, +1]. Subnormal values span [±0.125, ±0.875].
|
|
783
|
+
* Losslessly promotable to E4M3 by rebiasing exponent +6 (normals) or normalizing (subnormals). */
|
|
636
784
|
typedef unsigned char nk_e2m3_t;
|
|
637
785
|
/** @brief 6-bit E3M2 micro-float (OCP MX v1.0): sign(1) + exponent(3) + mantissa(2), bias=3.
|
|
638
|
-
*
|
|
786
|
+
* Stored as 0b00SEEEMM with 2 bits of padding. Range: ±28, no infinities or NaN.
|
|
787
|
+
* 64 total codes: 56 normal, 6 subnormal (exp=0, mant≠0), 2 zeros (±0).
|
|
788
|
+
* 26 of 64 values (40.6%) fall in [−1, +1]. Subnormal values span [±0.0625, ±0.1875].
|
|
789
|
+
* Losslessly promotable to E5M2 by rebiasing exponent +12 (normals) or normalizing (subnormals). */
|
|
639
790
|
typedef unsigned char nk_e3m2_t;
|
|
640
791
|
|
|
641
792
|
/** @brief Signed 8-bit integer. Range: [−128, +127]. */
|
|
@@ -670,7 +821,7 @@ typedef float nk_f32_t;
|
|
|
670
821
|
/** @brief Double-precision (64-bit) IEEE 754 float. sign(1) + exponent(11) + mantissa(52), bias=1023. */
|
|
671
822
|
typedef double nk_f64_t;
|
|
672
823
|
|
|
673
|
-
#if NK_TARGET_X86_ || NK_TARGET_ARM_ || NK_TARGET_RISCV_
|
|
824
|
+
#if NK_TARGET_X86_ || NK_TARGET_ARM_ || NK_TARGET_RISCV_ || NK_TARGET_POWER_ || NK_TARGET_LOONGARCH_
|
|
674
825
|
#define NK_IS_64BIT_ 1
|
|
675
826
|
#else
|
|
676
827
|
#define NK_IS_64BIT_ 0
|
|
@@ -712,11 +863,17 @@ typedef nk_f64_t nk_fmax_t;
|
|
|
712
863
|
#define NK_U8_MAX 255U
|
|
713
864
|
#define NK_U8_MIN 0x0U
|
|
714
865
|
|
|
715
|
-
#define
|
|
716
|
-
#define
|
|
866
|
+
#define NK_F16_MAX_AS_U16 0x7BFF // IEEE 754 binary16: +65504.0
|
|
867
|
+
#define NK_F16_MIN_AS_U16 0xFBFF // IEEE 754 binary16: -65504.0
|
|
717
868
|
|
|
718
|
-
#define
|
|
719
|
-
#define
|
|
869
|
+
#define NK_F16_MAX nk_u16_as_f16_(0x7BFF)
|
|
870
|
+
#define NK_F16_MIN nk_u16_as_f16_(0xFBFF)
|
|
871
|
+
|
|
872
|
+
#define NK_BF16_MAX_AS_U16 0x7F7F // BFloat16: ~+3.39e38
|
|
873
|
+
#define NK_BF16_MIN_AS_U16 0xFF7F // BFloat16: ~-3.39e38
|
|
874
|
+
|
|
875
|
+
#define NK_BF16_MAX nk_u16_as_bf16_(0x7F7F)
|
|
876
|
+
#define NK_BF16_MIN nk_u16_as_bf16_(0xFF7F)
|
|
720
877
|
|
|
721
878
|
#define NK_E4M3_MAX 0x7E // FP8 E4M3: +448.0
|
|
722
879
|
#define NK_E4M3_MIN 0xFE // FP8 E4M3: -448.0
|
|
@@ -842,7 +999,7 @@ NK_PUBLIC nk_size_t nk_dtype_bits(nk_dtype_t dtype) {
|
|
|
842
999
|
/** @brief Returns how many logical dimensions are packed into one storage value.
|
|
843
1000
|
* For sub-byte types multiple dimensions share a single byte container.
|
|
844
1001
|
* For byte-or-larger types this is always 1. */
|
|
845
|
-
NK_PUBLIC nk_size_t
|
|
1002
|
+
NK_PUBLIC nk_size_t nk_dimensions_per_value(nk_dtype_t dtype) {
|
|
846
1003
|
switch (dtype) {
|
|
847
1004
|
case nk_u1_k: return 8;
|
|
848
1005
|
case nk_i4_k: return 2;
|
|
@@ -975,7 +1132,7 @@ NK_STATIC_ASSERT(sizeof(nk_bf16_t) == 2, nk_bf16_t_must_be_2_bytes);
|
|
|
975
1132
|
#define nk_assign_from_to_(src, dest) (*(dest) = *(src))
|
|
976
1133
|
|
|
977
1134
|
/** @brief 16-bit union for f16/bf16/u16/i16 bit manipulation. */
|
|
978
|
-
typedef union {
|
|
1135
|
+
typedef union NK_MAY_ALIAS_ {
|
|
979
1136
|
nk_u16_t u;
|
|
980
1137
|
nk_i16_t i;
|
|
981
1138
|
nk_f16_t f;
|
|
@@ -983,14 +1140,14 @@ typedef union {
|
|
|
983
1140
|
} nk_fui16_t;
|
|
984
1141
|
|
|
985
1142
|
/** @brief 32-bit union for f32/u32/i32 bit manipulation. */
|
|
986
|
-
typedef union {
|
|
1143
|
+
typedef union NK_MAY_ALIAS_ {
|
|
987
1144
|
nk_u32_t u;
|
|
988
1145
|
nk_i32_t i;
|
|
989
1146
|
nk_f32_t f;
|
|
990
1147
|
} nk_fui32_t;
|
|
991
1148
|
|
|
992
1149
|
/** @brief 64-bit union for f64/u64/i64 bit manipulation. */
|
|
993
|
-
typedef union {
|
|
1150
|
+
typedef union NK_MAY_ALIAS_ {
|
|
994
1151
|
nk_u64_t u;
|
|
995
1152
|
nk_i64_t i;
|
|
996
1153
|
nk_f64_t f;
|
|
@@ -1021,7 +1178,7 @@ typedef struct {
|
|
|
1021
1178
|
} nk_f64c_t;
|
|
1022
1179
|
|
|
1023
1180
|
/** @brief Small 4-byte memory slice viewable as different types. */
|
|
1024
|
-
typedef union nk_b32_vec_t {
|
|
1181
|
+
typedef union NK_MAY_ALIAS_ nk_b32_vec_t {
|
|
1025
1182
|
nk_u32_t u32;
|
|
1026
1183
|
nk_i32_t i32;
|
|
1027
1184
|
nk_f32_t f32;
|
|
@@ -1034,7 +1191,7 @@ typedef union nk_b32_vec_t {
|
|
|
1034
1191
|
} nk_b32_vec_t;
|
|
1035
1192
|
|
|
1036
1193
|
/** @brief Small 8-byte memory slice viewable as different types. */
|
|
1037
|
-
typedef union nk_b64_vec_t {
|
|
1194
|
+
typedef union NK_MAY_ALIAS_ nk_b64_vec_t {
|
|
1038
1195
|
#if NK_TARGET_NEON
|
|
1039
1196
|
uint8x8_t u8x8;
|
|
1040
1197
|
uint16x4_t u16x4;
|
|
@@ -1061,8 +1218,8 @@ typedef union nk_b64_vec_t {
|
|
|
1061
1218
|
} nk_b64_vec_t;
|
|
1062
1219
|
|
|
1063
1220
|
/** @brief Small 16-byte memory slice viewable as different types. */
|
|
1064
|
-
typedef union nk_b128_vec_t {
|
|
1065
|
-
#if NK_TARGET_HASWELL
|
|
1221
|
+
typedef union NK_MAY_ALIAS_ nk_b128_vec_t {
|
|
1222
|
+
#if NK_TARGET_HASWELL || NK_TARGET_LOONGSONASX
|
|
1066
1223
|
__m128i xmm;
|
|
1067
1224
|
__m128d xmm_pd;
|
|
1068
1225
|
__m128 xmm_ps;
|
|
@@ -1082,6 +1239,22 @@ typedef union nk_b128_vec_t {
|
|
|
1082
1239
|
float32x4_t f32x4;
|
|
1083
1240
|
float64x2_t f64x2;
|
|
1084
1241
|
#endif
|
|
1242
|
+
#if NK_TARGET_NEONHALF
|
|
1243
|
+
float16x8_t f16x8;
|
|
1244
|
+
#endif
|
|
1245
|
+
#if NK_TARGET_POWERVSX
|
|
1246
|
+
nk_vu8x16_t vu8x16;
|
|
1247
|
+
nk_vu16x8_t vu16x8;
|
|
1248
|
+
nk_vu32x4_t vu32x4;
|
|
1249
|
+
nk_vu64x2_t vu64x2;
|
|
1250
|
+
nk_vi8x16_t vi8x16;
|
|
1251
|
+
nk_vi16x8_t vi16x8;
|
|
1252
|
+
nk_vi32x4_t vi32x4;
|
|
1253
|
+
nk_vi64x2_t vi64x2;
|
|
1254
|
+
nk_vf32x4_t vf32x4;
|
|
1255
|
+
nk_vf64x2_t vf64x2;
|
|
1256
|
+
#endif
|
|
1257
|
+
|
|
1085
1258
|
nk_u8_t u8s[16];
|
|
1086
1259
|
nk_u16_t u16s[8];
|
|
1087
1260
|
nk_u32_t u32s[4];
|
|
@@ -1101,8 +1274,8 @@ typedef union nk_b128_vec_t {
|
|
|
1101
1274
|
} nk_b128_vec_t;
|
|
1102
1275
|
|
|
1103
1276
|
/** @brief Small 32-byte memory slice viewable as different types. */
|
|
1104
|
-
typedef union nk_b256_vec_t {
|
|
1105
|
-
#if NK_TARGET_HASWELL
|
|
1277
|
+
typedef union NK_MAY_ALIAS_ nk_b256_vec_t {
|
|
1278
|
+
#if NK_TARGET_HASWELL || NK_TARGET_LOONGSONASX
|
|
1106
1279
|
__m256i ymm;
|
|
1107
1280
|
__m256d ymm_pd;
|
|
1108
1281
|
__m256 ymm_ps;
|
|
@@ -1123,6 +1296,19 @@ typedef union nk_b256_vec_t {
|
|
|
1123
1296
|
float32x4_t f32x4s[2];
|
|
1124
1297
|
float64x2_t f64x2s[2];
|
|
1125
1298
|
#endif
|
|
1299
|
+
#if NK_TARGET_POWERVSX
|
|
1300
|
+
nk_vu8x16_t vu8x16s[2];
|
|
1301
|
+
nk_vu16x8_t vu16x8s[2];
|
|
1302
|
+
nk_vu32x4_t vu32x4s[2];
|
|
1303
|
+
nk_vu64x2_t vu64x2s[2];
|
|
1304
|
+
nk_vi8x16_t vi8x16s[2];
|
|
1305
|
+
nk_vi16x8_t vi16x8s[2];
|
|
1306
|
+
nk_vi32x4_t vi32x4s[2];
|
|
1307
|
+
nk_vi64x2_t vi64x2s[2];
|
|
1308
|
+
nk_vf32x4_t vf32x4s[2];
|
|
1309
|
+
nk_vf64x2_t vf64x2s[2];
|
|
1310
|
+
#endif
|
|
1311
|
+
|
|
1126
1312
|
nk_u8_t u8s[32];
|
|
1127
1313
|
nk_u16_t u16s[16];
|
|
1128
1314
|
nk_u32_t u32s[8];
|
|
@@ -1148,7 +1334,7 @@ typedef union nk_b256_vec_t {
|
|
|
1148
1334
|
* of this is that the argument of such type is passed to functions using the calling convention of the first
|
|
1149
1335
|
* member of the union, which in our case is a register-based calling convention for SIMD types.
|
|
1150
1336
|
*/
|
|
1151
|
-
typedef union nk_b512_vec_t {
|
|
1337
|
+
typedef union NK_MAY_ALIAS_ nk_b512_vec_t {
|
|
1152
1338
|
#if NK_TARGET_SKYLAKE
|
|
1153
1339
|
__m512i zmm;
|
|
1154
1340
|
__m512d zmm_pd;
|
|
@@ -1353,17 +1539,28 @@ NK_INTERNAL nk_i8_t nk_i4x2_get_(nk_i4x2_t byte_val, int n) {
|
|
|
1353
1539
|
/** @brief Extract bit at position n (0-7) from packed u1x8 byte. */
|
|
1354
1540
|
NK_INTERNAL nk_u8_t nk_u1x8_get_(nk_u1x8_t byte_val, int n) { return (byte_val >> (n & 7)) & 1; }
|
|
1355
1541
|
|
|
1356
|
-
NK_INTERNAL nk_f16_t
|
|
1542
|
+
NK_INTERNAL nk_f16_t nk_u16_as_f16_(nk_u16_t bits) {
|
|
1357
1543
|
nk_fui16_t c;
|
|
1358
1544
|
c.u = bits;
|
|
1359
1545
|
return c.f;
|
|
1360
1546
|
}
|
|
1361
|
-
NK_INTERNAL
|
|
1547
|
+
NK_INTERNAL nk_u16_t nk_f16_as_u16_(nk_f16_t x) {
|
|
1548
|
+
nk_fui16_t c;
|
|
1549
|
+
c.f = x;
|
|
1550
|
+
return c.u;
|
|
1551
|
+
}
|
|
1552
|
+
NK_INTERNAL nk_bf16_t nk_u16_as_bf16_(nk_u16_t bits) {
|
|
1362
1553
|
nk_fui16_t c;
|
|
1363
1554
|
c.u = bits;
|
|
1364
1555
|
return c.bf;
|
|
1365
1556
|
}
|
|
1366
1557
|
|
|
1558
|
+
NK_INTERNAL void nk_f64_from_i64_(nk_i64_t const *src, nk_f64_t *dest) { *dest = (nk_f64_t)*src; }
|
|
1559
|
+
NK_INTERNAL void nk_f64_from_u64_(nk_u64_t const *src, nk_f64_t *dest) { *dest = (nk_f64_t)*src; }
|
|
1560
|
+
NK_INTERNAL void nk_f32_from_i32_(nk_i32_t const *src, nk_f32_t *dest) { *dest = (nk_f32_t)*src; }
|
|
1561
|
+
NK_INTERNAL void nk_f32_from_u32_(nk_u32_t const *src, nk_f32_t *dest) { *dest = (nk_f32_t)*src; }
|
|
1562
|
+
NK_INTERNAL void nk_f32_from_f64_(nk_f64_t const *src, nk_f32_t *dest) { *dest = (nk_f32_t)*src; }
|
|
1563
|
+
|
|
1367
1564
|
/** @brief E4M3: NaN when (raw & 0x7F) == 0x7F (two NaN values: 0x7F, 0xFF). */
|
|
1368
1565
|
NK_INTERNAL int nk_e4m3_is_nan_(nk_e4m3_t x) { return (x & 0x7F) == 0x7F; }
|
|
1369
1566
|
|
|
@@ -1372,10 +1569,51 @@ NK_INTERNAL int nk_e4m3_is_nan_(nk_e4m3_t x) { return (x & 0x7F) == 0x7F; }
|
|
|
1372
1569
|
NK_INTERNAL int nk_e5m2_is_nan_(nk_e5m2_t x) { return (x & 0x7F) > 0x7C; }
|
|
1373
1570
|
|
|
1374
1571
|
/** @brief F16: NaN when (raw & 0x7FFF) > 0x7C00. */
|
|
1375
|
-
NK_INTERNAL int nk_f16_is_nan_(
|
|
1572
|
+
NK_INTERNAL int nk_f16_is_nan_(nk_f16_t x) {
|
|
1573
|
+
nk_fui16_t x_fui;
|
|
1574
|
+
x_fui.f = x;
|
|
1575
|
+
return (x_fui.u & 0x7FFF) > 0x7C00;
|
|
1576
|
+
}
|
|
1376
1577
|
|
|
1377
1578
|
/** @brief BF16: NaN when (raw & 0x7FFF) > 0x7F80. */
|
|
1378
|
-
NK_INTERNAL int nk_bf16_is_nan_(
|
|
1579
|
+
NK_INTERNAL int nk_bf16_is_nan_(nk_bf16_t x) {
|
|
1580
|
+
nk_fui16_t x_fui;
|
|
1581
|
+
x_fui.bf = x;
|
|
1582
|
+
return (x_fui.u & 0x7FFF) > 0x7F80;
|
|
1583
|
+
}
|
|
1584
|
+
|
|
1585
|
+
/* Safe SVE vector-length queries usable from non-streaming context.
|
|
1586
|
+
* On Apple M4 (and other SME-only-SVE cores), SVE instructions like CNTW/CNTH/CNTB
|
|
1587
|
+
* trap with SIGILL outside streaming mode. These helpers bracket the query with
|
|
1588
|
+
* SMSTART SM / SMSTOP SM so the calling function's ABI is unchanged.
|
|
1589
|
+
* Inside `__arm_locally_streaming` functions the plain `svcntXX()` intrinsics are fine.
|
|
1590
|
+
*/
|
|
1591
|
+
#if NK_TARGET_ARM_ && NK_TARGET_SME
|
|
1592
|
+
/** @brief Streaming SVL byte-element count (SVL/8) via SMSTART SM bracket. */
|
|
1593
|
+
NK_INTERNAL nk_size_t nk_sme_cntb_(void) {
|
|
1594
|
+
nk_u64_t r;
|
|
1595
|
+
__asm__ __volatile__("smstart sm\n\t" "cntb %0\n\t" "smstop sm" : "=r"(r));
|
|
1596
|
+
return (nk_size_t)r;
|
|
1597
|
+
}
|
|
1598
|
+
/** @brief Streaming SVL half-element count (SVL/16) via SMSTART SM bracket. */
|
|
1599
|
+
NK_INTERNAL nk_size_t nk_sme_cnth_(void) {
|
|
1600
|
+
nk_u64_t r;
|
|
1601
|
+
__asm__ __volatile__("smstart sm\n\t" "cnth %0\n\t" "smstop sm" : "=r"(r));
|
|
1602
|
+
return (nk_size_t)r;
|
|
1603
|
+
}
|
|
1604
|
+
/** @brief Streaming SVL word-element count (SVL/32) via SMSTART SM bracket. */
|
|
1605
|
+
NK_INTERNAL nk_size_t nk_sme_cntw_(void) {
|
|
1606
|
+
nk_u64_t r;
|
|
1607
|
+
__asm__ __volatile__("smstart sm\n\t" "cntw %0\n\t" "smstop sm" : "=r"(r));
|
|
1608
|
+
return (nk_size_t)r;
|
|
1609
|
+
}
|
|
1610
|
+
/** @brief Streaming SVL double-element count (SVL/64) via SMSTART SM bracket. */
|
|
1611
|
+
NK_INTERNAL nk_size_t nk_sme_cntd_(void) {
|
|
1612
|
+
nk_u64_t r;
|
|
1613
|
+
__asm__ __volatile__("smstart sm\n\t" "cntd %0\n\t" "smstop sm" : "=r"(r));
|
|
1614
|
+
return (nk_size_t)r;
|
|
1615
|
+
}
|
|
1616
|
+
#endif
|
|
1379
1617
|
|
|
1380
1618
|
#ifdef __cplusplus
|
|
1381
1619
|
} // extern "C"
|