numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,1384 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief Shared definitions for the NumKong library.
|
|
3
|
+
* @file include/numkong/types.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date October 2, 2023
|
|
6
|
+
*
|
|
7
|
+
* Defines:
|
|
8
|
+
*
|
|
9
|
+
* - Sized aliases for numeric types, like: `nk_i32_t` and `nk_f64_t`.
|
|
10
|
+
* - Macros for internal compiler/hardware checks, like: `NK_TARGET_ARM_`.
|
|
11
|
+
* - Macros for feature controls, like: `NK_TARGET_NEON`
|
|
12
|
+
*
|
|
13
|
+
* @section fp8_types FP8 Numeric Types
|
|
14
|
+
*
|
|
15
|
+
* There are several variants of 8-bit floating point types supported by different industry memebers
|
|
16
|
+
* with different hardware support. None are part of the IEEE 754 standard, but some are part of the
|
|
17
|
+
* Open Compute Project (OCP) 8-bit Floating Point Specification (OFP8):
|
|
18
|
+
*
|
|
19
|
+
* Format Bias Sign Exp Mant Range Infinity NaN Standard
|
|
20
|
+
* E4M3FN 7 1 4 3 ±448 ❌ No Only 0x7F/0xFF OCP, NVIDIA, ONNX
|
|
21
|
+
* E5M2 15 1 5 2 ±57344 ✅ Yes (0x7C/0xFC) 0x7D-7F, 0xFD-FF OCP, IEEE-like
|
|
22
|
+
* E4M3FNUZ 8 1 4 3 ±240 ❌ No 0x80 only GraphCore, ONNX
|
|
23
|
+
* E5M2FNUZ 16 1 5 2 ±57344 ❌ No 0x80 only GraphCore, ONNX
|
|
24
|
+
*
|
|
25
|
+
* In currently available and soon incoming harware, only two series of models prioritze FNUZ over OCP:
|
|
26
|
+
*
|
|
27
|
+
* - GraphCore IPUs were the original platform proposing FNUZ
|
|
28
|
+
* - AMD MI300 series based on CDNA3 implements FNUZ, but not OCP
|
|
29
|
+
* - AMD MI350+ series based on CDNA4 switch to OCP and remove FNUZ
|
|
30
|
+
* - NVIDIA Hopper and Blackwell only support E4M3FN, E5M2
|
|
31
|
+
* - Intel AVX10.2 defines HF8 (E4M3FN) and BF8 (E5M2) - OCP-aligned
|
|
32
|
+
* - Arm implements E4M3 (meaning E4M3FN) and E5M2 with a shared `__mfp8` type and a `FPMR` format selector
|
|
33
|
+
*
|
|
34
|
+
* For brevety, across NumKong, "E4M3" implies "E4M3FN".
|
|
35
|
+
*
|
|
36
|
+
* @see https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1
|
|
37
|
+
* @see FP8 Formats for Deep Learning: https://arxiv.org/pdf/2209.05433
|
|
38
|
+
* @see ONNX Float8 Types: https://onnx.ai/onnx/technical/float8.html
|
|
39
|
+
*/
|
|
40
|
+
#ifndef NK_TYPES_H
|
|
41
|
+
#define NK_TYPES_H
|
|
42
|
+
|
|
43
|
+
// On Linux, `_GNU_SOURCE` must be defined before any system headers
|
|
44
|
+
// to expose `syscall` and other GNU extensions when C extensions are disabled.
|
|
45
|
+
#if defined(__linux__) && !defined(_GNU_SOURCE)
|
|
46
|
+
#define _GNU_SOURCE
|
|
47
|
+
#endif
|
|
48
|
+
|
|
49
|
+
// Inferring target OS: Windows, macOS, Linux, or FreeBSD
|
|
50
|
+
#if defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(__NT__)
|
|
51
|
+
#define NK_DEFINED_WINDOWS_ 1
|
|
52
|
+
#elif defined(__APPLE__) && defined(__MACH__)
|
|
53
|
+
#define NK_DEFINED_APPLE_ 1
|
|
54
|
+
#elif defined(__linux__)
|
|
55
|
+
#define NK_DEFINED_LINUX_ 1
|
|
56
|
+
#elif defined(__FreeBSD__)
|
|
57
|
+
#define NK_DEFINED_FREEBSD_ 1
|
|
58
|
+
#endif
|
|
59
|
+
|
|
60
|
+
// Annotation for the public API symbols:
|
|
61
|
+
//
|
|
62
|
+
// - `NK_PUBLIC` is used for functions that are part of the public API.
|
|
63
|
+
// - `NK_INTERNAL` is used for internal helper functions with unstable APIs.
|
|
64
|
+
// - `NK_DYNAMIC` is used for functions that are part of the public API, but are dispatched at runtime.
|
|
65
|
+
//
|
|
66
|
+
// On GCC we mark the functions as `nonnull` informing that none of the arguments can be `NULL`.
|
|
67
|
+
// Marking with `pure` and `const` isn't possible as outputting to a pointer is a "side effect".
|
|
68
|
+
#if defined(__GNUC__) || defined(__clang__)
|
|
69
|
+
#define NK_PUBLIC __attribute__((unused)) inline static
|
|
70
|
+
#define NK_INTERNAL __attribute__((always_inline)) inline static
|
|
71
|
+
#else
|
|
72
|
+
#define NK_PUBLIC inline static
|
|
73
|
+
#define NK_INTERNAL inline static
|
|
74
|
+
#endif // defined(__GNUC__) || defined(__clang__)
|
|
75
|
+
|
|
76
|
+
#if NK_DYNAMIC_DISPATCH
|
|
77
|
+
#if defined(_WIN32) || defined(__CYGWIN__)
|
|
78
|
+
#define NK_DYNAMIC __declspec(dllexport)
|
|
79
|
+
#elif defined(__GNUC__) || defined(__clang__)
|
|
80
|
+
#define NK_DYNAMIC __attribute__((visibility("default")))
|
|
81
|
+
#else
|
|
82
|
+
#define NK_DYNAMIC NK_PUBLIC
|
|
83
|
+
#endif
|
|
84
|
+
#else
|
|
85
|
+
#define NK_DYNAMIC NK_PUBLIC
|
|
86
|
+
#endif // NK_DYNAMIC_DISPATCH
|
|
87
|
+
|
|
88
|
+
// Allow SIMD kernels to redirect small inputs to serial implementations.
|
|
89
|
+
// Enabled by default for production use. Tests and benchmarks may disable
|
|
90
|
+
// this to isolate SIMD path behavior on small inputs.
|
|
91
|
+
#if !defined(NK_ALLOW_ISA_REDIRECT)
|
|
92
|
+
#define NK_ALLOW_ISA_REDIRECT 1
|
|
93
|
+
#endif
|
|
94
|
+
|
|
95
|
+
// Compiling for Arm: NK_TARGET_ARM_
|
|
96
|
+
#if !defined(NK_TARGET_ARM_)
|
|
97
|
+
#if defined(__aarch64__) || defined(_M_ARM64)
|
|
98
|
+
#define NK_TARGET_ARM_ 1
|
|
99
|
+
#else
|
|
100
|
+
#define NK_TARGET_ARM_ 0
|
|
101
|
+
#endif // defined(__aarch64__) || defined(_M_ARM64)
|
|
102
|
+
#endif // !defined(NK_TARGET_ARM_)
|
|
103
|
+
|
|
104
|
+
// Compiling for x86: NK_TARGET_X86_
|
|
105
|
+
#if !defined(NK_TARGET_X86_)
|
|
106
|
+
#if defined(__x86_64__) || defined(_M_X64)
|
|
107
|
+
#define NK_TARGET_X86_ 1
|
|
108
|
+
#else
|
|
109
|
+
#define NK_TARGET_X86_ 0
|
|
110
|
+
#endif // defined(__x86_64__) || defined(_M_X64)
|
|
111
|
+
#endif // !defined(NK_TARGET_X86_)
|
|
112
|
+
|
|
113
|
+
// Compiling for RISC-V: NK_TARGET_RISCV_
|
|
114
|
+
#if !defined(NK_TARGET_RISCV_)
|
|
115
|
+
#if defined(__riscv) && (__riscv_xlen == 64)
|
|
116
|
+
#define NK_TARGET_RISCV_ 1
|
|
117
|
+
#else
|
|
118
|
+
#define NK_TARGET_RISCV_ 0
|
|
119
|
+
#endif // defined(__riscv) && (__riscv_xlen == 64)
|
|
120
|
+
#endif // !defined(NK_TARGET_RISCV_)
|
|
121
|
+
|
|
122
|
+
// Compiling for WASM: NK_TARGET_WASM_
|
|
123
|
+
#if !defined(NK_TARGET_WASM_)
|
|
124
|
+
#if defined(__wasm__) || defined(__EMSCRIPTEN__)
|
|
125
|
+
#define NK_TARGET_WASM_ 1
|
|
126
|
+
#else
|
|
127
|
+
#define NK_TARGET_WASM_ 0
|
|
128
|
+
#endif
|
|
129
|
+
#endif // !defined(NK_TARGET_WASM_)
|
|
130
|
+
|
|
131
|
+
// WASI hosted mode: NK_DEFINED_WASI_
|
|
132
|
+
// When NK_WASI_HOSTED=ON in CMake, this is predefined to 1 so the library
|
|
133
|
+
// imports capability probes (nk_has_v128, nk_has_relaxed) from the host.
|
|
134
|
+
// Standalone runtimes (Wasmer, Wasmtime CLI) cannot supply those imports,
|
|
135
|
+
// so the default for plain __wasi__ builds is 0 (compile-time detection).
|
|
136
|
+
#if !defined(NK_DEFINED_WASI_)
|
|
137
|
+
#define NK_DEFINED_WASI_ 0
|
|
138
|
+
#endif // !defined(NK_DEFINED_WASI_)
|
|
139
|
+
|
|
140
|
+
// Compiling for WASM with Relaxed SIMD: NK_TARGET_V128RELAXED
|
|
141
|
+
// Requires -mrelaxed-simd for FMA instructions (f32x4.relaxed_madd, f64x2.relaxed_madd)
|
|
142
|
+
#if !defined(NK_TARGET_V128RELAXED) || (NK_TARGET_V128RELAXED && !NK_TARGET_WASM_)
|
|
143
|
+
#if defined(__wasm_relaxed_simd__)
|
|
144
|
+
#define NK_TARGET_V128RELAXED 1
|
|
145
|
+
#else
|
|
146
|
+
#undef NK_TARGET_V128RELAXED
|
|
147
|
+
#define NK_TARGET_V128RELAXED 0
|
|
148
|
+
#endif
|
|
149
|
+
#endif // !defined(NK_TARGET_V128RELAXED) || ...
|
|
150
|
+
|
|
151
|
+
// Compiling for RISC-V Vector: NK_TARGET_RVV
|
|
152
|
+
#if !defined(NK_TARGET_RVV) || (NK_TARGET_RVV && !NK_TARGET_RISCV_)
|
|
153
|
+
#if defined(__riscv_v) && (__riscv_v >= 1000000)
|
|
154
|
+
#define NK_TARGET_RVV 1
|
|
155
|
+
#else
|
|
156
|
+
#undef NK_TARGET_RVV
|
|
157
|
+
#define NK_TARGET_RVV 0
|
|
158
|
+
#endif // defined(__riscv_v) && (__riscv_v >= 1000000)
|
|
159
|
+
#endif // !defined(NK_TARGET_RVV) || ...
|
|
160
|
+
|
|
161
|
+
// Compiling for RISC-V Vector with Zvfh (f16): NK_TARGET_RVVHALF
|
|
162
|
+
// Requires GCC 14+ or Clang 18+ for full intrinsic support
|
|
163
|
+
#if !defined(NK_TARGET_RVVHALF) || (NK_TARGET_RVVHALF && !NK_TARGET_RVV)
|
|
164
|
+
#if defined(__riscv_zvfh) && (__riscv_zvfh > 0)
|
|
165
|
+
#define NK_TARGET_RVVHALF 1
|
|
166
|
+
#else
|
|
167
|
+
#undef NK_TARGET_RVVHALF
|
|
168
|
+
#define NK_TARGET_RVVHALF 0
|
|
169
|
+
#endif // defined(__riscv_zvfh) && (__riscv_zvfh > 0)
|
|
170
|
+
#endif // !defined(NK_TARGET_RVVHALF) || ...
|
|
171
|
+
|
|
172
|
+
// Compiling for RISC-V Vector with Zvfbfwma (bf16 widening FMA): NK_TARGET_RVVBF16
|
|
173
|
+
// Requires GCC 14+ or Clang 18+ for full intrinsic support
|
|
174
|
+
#if !defined(NK_TARGET_RVVBF16) || (NK_TARGET_RVVBF16 && !NK_TARGET_RVV)
|
|
175
|
+
#if defined(__riscv_zvfbfwma) && (__riscv_zvfbfwma > 0)
|
|
176
|
+
#define NK_TARGET_RVVBF16 1
|
|
177
|
+
#else
|
|
178
|
+
#undef NK_TARGET_RVVBF16
|
|
179
|
+
#define NK_TARGET_RVVBF16 0
|
|
180
|
+
#endif // defined(__riscv_zvfbfwma) && (__riscv_zvfbfwma > 0)
|
|
181
|
+
#endif // !defined(NK_TARGET_RVVBF16) || ...
|
|
182
|
+
|
|
183
|
+
// Compiling for RISC-V Vector with Zvbb (basic bit-manipulation): NK_TARGET_RVVBB
|
|
184
|
+
// Provides vcpop.v (per-element popcount), vclz.v, vctz.v, vbrev.v, vrol.v, vror.v
|
|
185
|
+
#if !defined(NK_TARGET_RVVBB) || (NK_TARGET_RVVBB && !NK_TARGET_RVV)
|
|
186
|
+
#if defined(__riscv_zvbb) && (__riscv_zvbb > 0)
|
|
187
|
+
#define NK_TARGET_RVVBB 1
|
|
188
|
+
#else
|
|
189
|
+
#undef NK_TARGET_RVVBB
|
|
190
|
+
#define NK_TARGET_RVVBB 0
|
|
191
|
+
#endif // defined(__riscv_zvbb) && (__riscv_zvbb > 0)
|
|
192
|
+
#endif // !defined(NK_TARGET_RVVBB) || ...
|
|
193
|
+
|
|
194
|
+
// Compiling for Arm: NK_TARGET_NEON
|
|
195
|
+
#if !defined(NK_TARGET_NEON) || (NK_TARGET_NEON && !NK_TARGET_ARM_)
|
|
196
|
+
#if defined(__ARM_NEON)
|
|
197
|
+
#define NK_TARGET_NEON 1
|
|
198
|
+
#else
|
|
199
|
+
#undef NK_TARGET_NEON
|
|
200
|
+
#define NK_TARGET_NEON 0
|
|
201
|
+
#endif // defined(__ARM_NEON)
|
|
202
|
+
#endif // !defined(NK_TARGET_NEON) || ...
|
|
203
|
+
|
|
204
|
+
// Compiling for Arm: NK_TARGET_NEONSDOT
|
|
205
|
+
#if !defined(NK_TARGET_NEONSDOT) || (NK_TARGET_NEONSDOT && !NK_TARGET_ARM_)
|
|
206
|
+
#if defined(__ARM_NEON)
|
|
207
|
+
#define NK_TARGET_NEONSDOT 1
|
|
208
|
+
#else
|
|
209
|
+
#undef NK_TARGET_NEONSDOT
|
|
210
|
+
#define NK_TARGET_NEONSDOT 0
|
|
211
|
+
#endif // defined(__ARM_NEON)
|
|
212
|
+
#endif // !defined(NK_TARGET_NEONSDOT) || ...
|
|
213
|
+
|
|
214
|
+
// Compiling for Arm: NK_TARGET_NEONHALF
|
|
215
|
+
#if !defined(NK_TARGET_NEONHALF) || (NK_TARGET_NEONHALF && !NK_TARGET_ARM_)
|
|
216
|
+
#if defined(__ARM_NEON)
|
|
217
|
+
#define NK_TARGET_NEONHALF 1
|
|
218
|
+
#else
|
|
219
|
+
#undef NK_TARGET_NEONHALF
|
|
220
|
+
#define NK_TARGET_NEONHALF 0
|
|
221
|
+
#endif // defined(__ARM_NEON)
|
|
222
|
+
#endif // !defined(NK_TARGET_NEONHALF) || ...
|
|
223
|
+
|
|
224
|
+
// Compiling for Arm: NK_TARGET_NEONFHM (FEAT_FHM - FMLAL/FMLSL widening ops)
|
|
225
|
+
#if !defined(NK_TARGET_NEONFHM) || (NK_TARGET_NEONFHM && !NK_TARGET_ARM_)
|
|
226
|
+
#if defined(__ARM_NEON)
|
|
227
|
+
#define NK_TARGET_NEONFHM 1
|
|
228
|
+
#else
|
|
229
|
+
#undef NK_TARGET_NEONFHM
|
|
230
|
+
#define NK_TARGET_NEONFHM 0
|
|
231
|
+
#endif // defined(__ARM_NEON)
|
|
232
|
+
#endif // !defined(NK_TARGET_NEONFHM) || ...
|
|
233
|
+
|
|
234
|
+
// Compiling for Arm: NK_TARGET_NEONBFDOT
|
|
235
|
+
#if !defined(NK_TARGET_NEONBFDOT) || (NK_TARGET_NEONBFDOT && !NK_TARGET_ARM_)
|
|
236
|
+
#if defined(__ARM_NEON)
|
|
237
|
+
#define NK_TARGET_NEONBFDOT 1
|
|
238
|
+
#else
|
|
239
|
+
#undef NK_TARGET_NEONBFDOT
|
|
240
|
+
#define NK_TARGET_NEONBFDOT 0
|
|
241
|
+
#endif // defined(__ARM_NEON)
|
|
242
|
+
#endif // !defined(NK_TARGET_NEONBFDOT) || ...
|
|
243
|
+
|
|
244
|
+
// Compiling for Arm: NK_TARGET_SVE
|
|
245
|
+
#if !defined(NK_TARGET_SVE) || (NK_TARGET_SVE && !NK_TARGET_ARM_)
|
|
246
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
247
|
+
#define NK_TARGET_SVE 1
|
|
248
|
+
#else
|
|
249
|
+
#undef NK_TARGET_SVE
|
|
250
|
+
#define NK_TARGET_SVE 0
|
|
251
|
+
#endif // defined(__ARM_FEATURE_SVE)
|
|
252
|
+
#endif // !defined(NK_TARGET_SVE) || ...
|
|
253
|
+
|
|
254
|
+
// Compiling for Arm: NK_TARGET_SVESDOT
|
|
255
|
+
#if !defined(NK_TARGET_SVESDOT) || (NK_TARGET_SVESDOT && !NK_TARGET_ARM_)
|
|
256
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
257
|
+
#define NK_TARGET_SVESDOT 1
|
|
258
|
+
#else
|
|
259
|
+
#undef NK_TARGET_SVESDOT
|
|
260
|
+
#define NK_TARGET_SVESDOT 0
|
|
261
|
+
#endif // defined(__ARM_FEATURE_SVE)
|
|
262
|
+
#endif // !defined(NK_TARGET_SVESDOT) || ...
|
|
263
|
+
|
|
264
|
+
// Compiling for Arm: NK_TARGET_SVEHALF
|
|
265
|
+
#if !defined(NK_TARGET_SVEHALF) || (NK_TARGET_SVEHALF && !NK_TARGET_ARM_)
|
|
266
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
267
|
+
#define NK_TARGET_SVEHALF 1
|
|
268
|
+
#else
|
|
269
|
+
#undef NK_TARGET_SVEHALF
|
|
270
|
+
#define NK_TARGET_SVEHALF 0
|
|
271
|
+
#endif // defined(__ARM_FEATURE_SVE)
|
|
272
|
+
#endif // !defined(NK_TARGET_SVEHALF) || ...
|
|
273
|
+
|
|
274
|
+
// Compiling for Arm: NK_TARGET_SVEBFDOT
|
|
275
|
+
#if !defined(NK_TARGET_SVEBFDOT) || (NK_TARGET_SVEBFDOT && !NK_TARGET_ARM_)
|
|
276
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
277
|
+
#define NK_TARGET_SVEBFDOT 1
|
|
278
|
+
#else
|
|
279
|
+
#undef NK_TARGET_SVEBFDOT
|
|
280
|
+
#define NK_TARGET_SVEBFDOT 0
|
|
281
|
+
#endif // defined(__ARM_FEATURE_SVE)
|
|
282
|
+
#endif // !defined(NK_TARGET_SVEBFDOT) || ...
|
|
283
|
+
|
|
284
|
+
// Compiling for Arm: NK_TARGET_SVE2
|
|
285
|
+
#if !defined(NK_TARGET_SVE2) || (NK_TARGET_SVE2 && !NK_TARGET_ARM_)
|
|
286
|
+
#if defined(__ARM_FEATURE_SVE2)
|
|
287
|
+
#define NK_TARGET_SVE2 1
|
|
288
|
+
#else
|
|
289
|
+
#undef NK_TARGET_SVE2
|
|
290
|
+
#define NK_TARGET_SVE2 0
|
|
291
|
+
#endif // defined(__ARM_FEATURE_SVE2)
|
|
292
|
+
#endif // !defined(NK_TARGET_SVE2) || ...
|
|
293
|
+
|
|
294
|
+
// Compiling for Arm: NK_TARGET_SVE2P1
|
|
295
|
+
#if !defined(NK_TARGET_SVE2P1) || (NK_TARGET_SVE2P1 && !NK_TARGET_ARM_)
|
|
296
|
+
#undef NK_TARGET_SVE2P1
|
|
297
|
+
#define NK_TARGET_SVE2P1 0
|
|
298
|
+
#endif // !defined(NK_TARGET_SVE2P1) || ...
|
|
299
|
+
|
|
300
|
+
// Compiling for Arm: NK_TARGET_SME (Scalable Matrix Extension)
|
|
301
|
+
#if !defined(NK_TARGET_SME) || (NK_TARGET_SME && !NK_TARGET_ARM_)
|
|
302
|
+
#if defined(__ARM_FEATURE_SME)
|
|
303
|
+
#define NK_TARGET_SME 1
|
|
304
|
+
#else
|
|
305
|
+
#undef NK_TARGET_SME
|
|
306
|
+
#define NK_TARGET_SME 0
|
|
307
|
+
#endif // defined(__ARM_FEATURE_SME)
|
|
308
|
+
#endif // !defined(NK_TARGET_SME) || ...
|
|
309
|
+
|
|
310
|
+
#if !defined(NK_TARGET_SME2) || (NK_TARGET_SME2 && !NK_TARGET_ARM_)
|
|
311
|
+
#if defined(__ARM_FEATURE_SME2)
|
|
312
|
+
#define NK_TARGET_SME2 1
|
|
313
|
+
#else
|
|
314
|
+
#undef NK_TARGET_SME2
|
|
315
|
+
#define NK_TARGET_SME2 0
|
|
316
|
+
#endif // defined(__ARM_FEATURE_SME2)
|
|
317
|
+
#endif // !defined(NK_TARGET_SME2) || ...
|
|
318
|
+
|
|
319
|
+
#if !defined(NK_TARGET_SME2P1) || (NK_TARGET_SME2P1 && !NK_TARGET_ARM_)
|
|
320
|
+
#undef NK_TARGET_SME2P1
|
|
321
|
+
#define NK_TARGET_SME2P1 0
|
|
322
|
+
#endif
|
|
323
|
+
|
|
324
|
+
// AppleClang 17 exposes SME sub-features through `arm_sme.h` builtin aliases,
|
|
325
|
+
// not dedicated `__ARM_FEATURE_*` predefines for every matrix subtype.
|
|
326
|
+
#if !defined(NK_TARGET_SMEF64) || (NK_TARGET_SMEF64 && !NK_TARGET_ARM_)
|
|
327
|
+
#if defined(__has_builtin) && __has_builtin(__builtin_sme_svmopa_za64_f64_m)
|
|
328
|
+
#define NK_TARGET_SMEF64 1
|
|
329
|
+
#else
|
|
330
|
+
#undef NK_TARGET_SMEF64
|
|
331
|
+
#define NK_TARGET_SMEF64 0
|
|
332
|
+
#endif // defined(__has_builtin) && __has_builtin(__builtin_sme_svmopa_za64_f64_m)
|
|
333
|
+
#endif // !defined(NK_TARGET_SMEF64) || ...
|
|
334
|
+
|
|
335
|
+
#if !defined(NK_TARGET_SMEBI32) || (NK_TARGET_SMEBI32 && !NK_TARGET_ARM_)
|
|
336
|
+
#if defined(__has_builtin) && __has_builtin(__builtin_sme_svbmopa_za32_u32_m)
|
|
337
|
+
#define NK_TARGET_SMEBI32 1
|
|
338
|
+
#else
|
|
339
|
+
#undef NK_TARGET_SMEBI32
|
|
340
|
+
#define NK_TARGET_SMEBI32 0
|
|
341
|
+
#endif // defined(__has_builtin) && __has_builtin(__builtin_sme_svbmopa_za32_u32_m)
|
|
342
|
+
#endif // !defined(NK_TARGET_SMEBI32) || ...
|
|
343
|
+
|
|
344
|
+
#if !defined(NK_TARGET_SMEHALF) || (NK_TARGET_SMEHALF && !NK_TARGET_ARM_)
|
|
345
|
+
#if defined(__has_builtin) && __has_builtin(__builtin_sme_svmopa_za32_f16_m)
|
|
346
|
+
#define NK_TARGET_SMEHALF 1
|
|
347
|
+
#else
|
|
348
|
+
#undef NK_TARGET_SMEHALF
|
|
349
|
+
#define NK_TARGET_SMEHALF 0
|
|
350
|
+
#endif // defined(__has_builtin) && __has_builtin(__builtin_sme_svmopa_za32_f16_m)
|
|
351
|
+
#endif // !defined(NK_TARGET_SMEHALF) || ...
|
|
352
|
+
|
|
353
|
+
#if !defined(NK_TARGET_SMEBF16) || (NK_TARGET_SMEBF16 && !NK_TARGET_ARM_)
|
|
354
|
+
#if defined(__has_builtin) && __has_builtin(__builtin_sme_svmopa_za32_bf16_m)
|
|
355
|
+
#define NK_TARGET_SMEBF16 1
|
|
356
|
+
#else
|
|
357
|
+
#undef NK_TARGET_SMEBF16
|
|
358
|
+
#define NK_TARGET_SMEBF16 0
|
|
359
|
+
#endif // defined(__has_builtin) && __has_builtin(__builtin_sme_svmopa_za32_bf16_m)
|
|
360
|
+
#endif // !defined(NK_TARGET_SMEBF16) || ...
|
|
361
|
+
|
|
362
|
+
#if !defined(NK_TARGET_SMELUT2) || (NK_TARGET_SMELUT2 && !NK_TARGET_ARM_)
|
|
363
|
+
#if defined(__has_builtin) && __has_builtin(__builtin_sme_svluti2_lane_zt_u8)
|
|
364
|
+
#define NK_TARGET_SMELUT2 1
|
|
365
|
+
#else
|
|
366
|
+
#undef NK_TARGET_SMELUT2
|
|
367
|
+
#define NK_TARGET_SMELUT2 0
|
|
368
|
+
#endif // defined(__has_builtin) && __has_builtin(__builtin_sme_svluti2_lane_zt_u8)
|
|
369
|
+
#endif // !defined(NK_TARGET_SMELUT2) || ...
|
|
370
|
+
|
|
371
|
+
#if !defined(NK_TARGET_SMEFA64) || (NK_TARGET_SMEFA64 && !NK_TARGET_ARM_)
|
|
372
|
+
#undef NK_TARGET_SMEFA64
|
|
373
|
+
#define NK_TARGET_SMEFA64 0
|
|
374
|
+
#endif
|
|
375
|
+
|
|
376
|
+
// Compiling for x86: NK_TARGET_HASWELL
|
|
377
|
+
//
|
|
378
|
+
// Starting with Ivy Bridge, Intel supports the `F16C` extensions for fast half-precision
|
|
379
|
+
// to single-precision floating-point conversions. On AMD those instructions
|
|
380
|
+
// are supported on all CPUs starting with Jaguar 2009.
|
|
381
|
+
// Starting with Sandy Bridge, Intel adds basic AVX support in their CPUs and in 2013
|
|
382
|
+
// extends it with AVX2 in the Haswell generation. Moreover, Haswell adds FMA support.
|
|
383
|
+
//
|
|
384
|
+
// On MSVC, most GCC-style ISA macros are unavailable. MSVC defines __AVX__, __AVX2__,
|
|
385
|
+
// __AVX512F/BW/CD/DQ/VL__, and __AVX10_VER__, but NOT __AVXVNNI__, __AVX512VNNI__,
|
|
386
|
+
// __AVX512BF16__, __AVX512FP16__, __AMX_*__, etc.
|
|
387
|
+
// Instead, MSVC makes all intrinsics available once the toolset version supports them,
|
|
388
|
+
// without requiring `/arch:AVX512`. We gate on _MSC_VER to auto-enable targets:
|
|
389
|
+
// - _MSC_VER >= 1900 (VS 2015+): AVX2/FMA/F16C (Haswell)
|
|
390
|
+
// - _MSC_VER >= 1920 (VS 2019+): AVX-512 base (Skylake, Icelake), AVX-VNNI (Alder)
|
|
391
|
+
// - _MSC_VER >= 1944 (VS 2022 17.14+): BF16, FP16, VP2INTERSECT, VNNI-INT8 (Sierra), AMX
|
|
392
|
+
#if !defined(NK_TARGET_HASWELL) || (NK_TARGET_HASWELL && !NK_TARGET_X86_)
|
|
393
|
+
#if (defined(__AVX2__) && defined(__FMA__) && defined(__F16C__)) || (defined(_MSC_VER) && _MSC_VER >= 1900)
|
|
394
|
+
#define NK_TARGET_HASWELL 1
|
|
395
|
+
#else
|
|
396
|
+
#undef NK_TARGET_HASWELL
|
|
397
|
+
#define NK_TARGET_HASWELL 0
|
|
398
|
+
#endif // defined(__AVX2__)
|
|
399
|
+
#endif // !defined(NK_TARGET_HASWELL) || ...
|
|
400
|
+
|
|
401
|
+
// Compiling for x86: NK_TARGET_SKYLAKE, NK_TARGET_ICELAKE, NK_TARGET_GENOA,
|
|
402
|
+
// NK_TARGET_SAPPHIRE, NK_TARGET_TURIN, NK_TARGET_SIERRA
|
|
403
|
+
//
|
|
404
|
+
// To list all available macros for x86, take a recent compiler, like GCC 12 and run:
|
|
405
|
+
// gcc-12 -march=sapphirerapids -dM -E - < /dev/null | egrep "SSE|AVX" | sort
|
|
406
|
+
// On Arm machines you may want to check for other flags:
|
|
407
|
+
// gcc-12 -march=native -dM -E - < /dev/null | egrep "NEON|SVE|FP16|FMA" | sort
|
|
408
|
+
#if !defined(NK_TARGET_SKYLAKE) || (NK_TARGET_SKYLAKE && !NK_TARGET_X86_)
|
|
409
|
+
#if (defined(__AVX512F__) && defined(__AVX512CD__) && defined(__AVX512VL__) && defined(__AVX512DQ__) && \
|
|
410
|
+
defined(__AVX512BW__)) || \
|
|
411
|
+
(defined(_MSC_VER) && _MSC_VER >= 1920)
|
|
412
|
+
#define NK_TARGET_SKYLAKE 1
|
|
413
|
+
#else
|
|
414
|
+
#undef NK_TARGET_SKYLAKE
|
|
415
|
+
#define NK_TARGET_SKYLAKE 0
|
|
416
|
+
#endif
|
|
417
|
+
#endif // !defined(NK_TARGET_SKYLAKE) || ...
|
|
418
|
+
|
|
419
|
+
#if !defined(NK_TARGET_ICELAKE) || (NK_TARGET_ICELAKE && !NK_TARGET_X86_)
|
|
420
|
+
#if (defined(__AVX512VNNI__) && defined(__AVX512IFMA__) && defined(__AVX512BITALG__) && defined(__AVX512VBMI__) && \
|
|
421
|
+
defined(__AVX512VBMI2__) && defined(__AVX512VPOPCNTDQ__)) || \
|
|
422
|
+
(defined(_MSC_VER) && _MSC_VER >= 1920)
|
|
423
|
+
#define NK_TARGET_ICELAKE 1
|
|
424
|
+
#else
|
|
425
|
+
#undef NK_TARGET_ICELAKE
|
|
426
|
+
#define NK_TARGET_ICELAKE 0
|
|
427
|
+
#endif
|
|
428
|
+
#endif // !defined(NK_TARGET_ICELAKE) || ...
|
|
429
|
+
|
|
430
|
+
#if !defined(NK_TARGET_GENOA) || (NK_TARGET_GENOA && !NK_TARGET_X86_)
|
|
431
|
+
#if defined(__AVX512BF16__) || (defined(_MSC_VER) && _MSC_VER >= 1944)
|
|
432
|
+
#define NK_TARGET_GENOA 1
|
|
433
|
+
#else
|
|
434
|
+
#undef NK_TARGET_GENOA
|
|
435
|
+
#define NK_TARGET_GENOA 0
|
|
436
|
+
#endif
|
|
437
|
+
#endif // !defined(NK_TARGET_GENOA) || ...
|
|
438
|
+
|
|
439
|
+
#if !defined(NK_TARGET_SAPPHIRE) || (NK_TARGET_SAPPHIRE && !NK_TARGET_X86_)
|
|
440
|
+
#if defined(__AVX512FP16__) || (defined(_MSC_VER) && _MSC_VER >= 1944)
|
|
441
|
+
#define NK_TARGET_SAPPHIRE 1
|
|
442
|
+
#else
|
|
443
|
+
#undef NK_TARGET_SAPPHIRE
|
|
444
|
+
#define NK_TARGET_SAPPHIRE 0
|
|
445
|
+
#endif
|
|
446
|
+
#endif // !defined(NK_TARGET_SAPPHIRE) || ...
|
|
447
|
+
|
|
448
|
+
#if !defined(NK_TARGET_SAPPHIREAMX) || (NK_TARGET_SAPPHIREAMX && !NK_TARGET_X86_)
|
|
449
|
+
#if (defined(__AMX_TILE__) && defined(__AMX_BF16__) && defined(__AMX_INT8__)) || (defined(_MSC_VER) && _MSC_VER >= 1944)
|
|
450
|
+
#define NK_TARGET_SAPPHIREAMX 1
|
|
451
|
+
#else
|
|
452
|
+
#undef NK_TARGET_SAPPHIREAMX
|
|
453
|
+
#define NK_TARGET_SAPPHIREAMX 0
|
|
454
|
+
#endif
|
|
455
|
+
#endif // !defined(NK_TARGET_SAPPHIREAMX) || ...
|
|
456
|
+
|
|
457
|
+
#if !defined(NK_TARGET_GRANITEAMX) || (NK_TARGET_GRANITEAMX && !NK_TARGET_X86_)
|
|
458
|
+
#if (defined(__AMX_TILE__) && defined(__AMX_FP16__)) || (defined(_MSC_VER) && _MSC_VER >= 1944)
|
|
459
|
+
#define NK_TARGET_GRANITEAMX 1
|
|
460
|
+
#else
|
|
461
|
+
#undef NK_TARGET_GRANITEAMX
|
|
462
|
+
#define NK_TARGET_GRANITEAMX 0
|
|
463
|
+
#endif
|
|
464
|
+
#endif // !defined(NK_TARGET_GRANITEAMX) || ...
|
|
465
|
+
|
|
466
|
+
#if !defined(NK_TARGET_TURIN) || (NK_TARGET_TURIN && !NK_TARGET_X86_)
|
|
467
|
+
#if defined(__AVX512VP2INTERSECT__) || (defined(_MSC_VER) && _MSC_VER >= 1944)
|
|
468
|
+
#define NK_TARGET_TURIN 1
|
|
469
|
+
#else
|
|
470
|
+
#undef NK_TARGET_TURIN
|
|
471
|
+
#define NK_TARGET_TURIN 0
|
|
472
|
+
#endif
|
|
473
|
+
#endif // !defined(NK_TARGET_TURIN) || ...
|
|
474
|
+
|
|
475
|
+
#if !defined(NK_TARGET_ALDER) || (NK_TARGET_ALDER && !NK_TARGET_X86_)
|
|
476
|
+
#if defined(__AVXVNNI__) || (defined(_MSC_VER) && _MSC_VER >= 1920)
|
|
477
|
+
#define NK_TARGET_ALDER 1
|
|
478
|
+
#else
|
|
479
|
+
#undef NK_TARGET_ALDER
|
|
480
|
+
#define NK_TARGET_ALDER 0
|
|
481
|
+
#endif
|
|
482
|
+
#endif // !defined(NK_TARGET_ALDER) || ...
|
|
483
|
+
|
|
484
|
+
#if !defined(NK_TARGET_SIERRA) || (NK_TARGET_SIERRA && !NK_TARGET_X86_)
|
|
485
|
+
#if defined(__AVXVNNIINT8__) || (defined(_MSC_VER) && _MSC_VER >= 1944)
|
|
486
|
+
#define NK_TARGET_SIERRA 1
|
|
487
|
+
#else
|
|
488
|
+
#undef NK_TARGET_SIERRA
|
|
489
|
+
#define NK_TARGET_SIERRA 0
|
|
490
|
+
#endif
|
|
491
|
+
#endif // !defined(NK_TARGET_SIERRA) || ...
|
|
492
|
+
|
|
493
|
+
// Include the relevant intrinsics file - different for different OSes and ISAs
|
|
494
|
+
#if defined(_MSC_VER)
|
|
495
|
+
#include <intrin.h>
|
|
496
|
+
#elif NK_TARGET_ARM_
|
|
497
|
+
#if NK_TARGET_NEON
|
|
498
|
+
#include <arm_neon.h>
|
|
499
|
+
#endif
|
|
500
|
+
#if NK_TARGET_SVE || NK_TARGET_SVE2
|
|
501
|
+
#include <arm_sve.h>
|
|
502
|
+
#endif
|
|
503
|
+
#if NK_TARGET_SME || NK_TARGET_SME2 || NK_TARGET_SMEBI32
|
|
504
|
+
#include <arm_sme.h>
|
|
505
|
+
#endif
|
|
506
|
+
#elif NK_TARGET_HASWELL || NK_TARGET_SKYLAKE
|
|
507
|
+
#include <immintrin.h>
|
|
508
|
+
#elif NK_TARGET_RVV
|
|
509
|
+
#include <riscv_vector.h>
|
|
510
|
+
#elif NK_TARGET_V128RELAXED
|
|
511
|
+
#include <wasm_simd128.h>
|
|
512
|
+
#endif
|
|
513
|
+
|
|
514
|
+
#if !defined(NK_F64_DIVISION_EPSILON)
|
|
515
|
+
#define NK_F64_DIVISION_EPSILON (1e-15)
|
|
516
|
+
#endif
|
|
517
|
+
|
|
518
|
+
#if !defined(NK_F32_DIVISION_EPSILON)
|
|
519
|
+
#define NK_F32_DIVISION_EPSILON (1e-7)
|
|
520
|
+
#endif
|
|
521
|
+
|
|
522
|
+
#if !defined(NK_F16_DIVISION_EPSILON)
|
|
523
|
+
#define NK_F16_DIVISION_EPSILON (1e-3)
|
|
524
|
+
#endif
|
|
525
|
+
|
|
526
|
+
/**
|
|
527
|
+
* @brief The compile-time constant defining the capacity of `nk_tensor_position_t`.
|
|
528
|
+
* Matches `PyBUF_MAX_NDIM` by default.
|
|
529
|
+
*/
|
|
530
|
+
#if !defined(NK_TENSOR_MAX_RANK)
|
|
531
|
+
#define NK_TENSOR_MAX_RANK (64)
|
|
532
|
+
#endif
|
|
533
|
+
|
|
534
|
+
/**
|
|
535
|
+
* @brief Aligns a variable to a 64-byte boundary using compiler extensions for
|
|
536
|
+
* compatibility with C 99, as `alignas(64)` is only available in C 11 or C++.
|
|
537
|
+
* Used internally and recommended for external users.
|
|
538
|
+
*/
|
|
539
|
+
#if defined(_MSC_VER)
|
|
540
|
+
#define NK_ALIGN64 __declspec(align(64))
|
|
541
|
+
#elif defined(__GNUC__) || defined(__clang__)
|
|
542
|
+
#define NK_ALIGN64 __attribute__((aligned(64)))
|
|
543
|
+
#endif
|
|
544
|
+
|
|
545
|
+
/**
|
|
546
|
+
* ARM Streaming attributes (require SME-capable compiler: GCC 14+, Clang 16+).
|
|
547
|
+
* NK_STREAMING_ marks functions that require streaming SVE mode (e.g. FCVTLT).
|
|
548
|
+
* NK_STREAMING_COMPATIBLE_ marks helpers callable from both streaming and non-streaming mode.
|
|
549
|
+
*/
|
|
550
|
+
#if NK_TARGET_ARM_ && NK_TARGET_SME
|
|
551
|
+
#define NK_STREAMING_ __arm_streaming
|
|
552
|
+
#define NK_STREAMING_COMPATIBLE_ __arm_streaming_compatible
|
|
553
|
+
#else
|
|
554
|
+
#define NK_STREAMING_
|
|
555
|
+
#define NK_STREAMING_COMPATIBLE_
|
|
556
|
+
#endif
|
|
557
|
+
|
|
558
|
+
/**
|
|
559
|
+
* @brief Portable casts between SIMD vector types.
|
|
560
|
+
* MSVC typedefs `__m512bh`, `__m512h`, `__m256bh` as aliases for `__m512i`/`__m256i`,
|
|
561
|
+
* but rejects C-style casts between them. GCC/Clang define them as distinct types.
|
|
562
|
+
*/
|
|
563
|
+
#if NK_TARGET_X86_
|
|
564
|
+
#if defined(_MSC_VER)
|
|
565
|
+
#define nk_m512bh_from_m512i_(x) (x)
|
|
566
|
+
#define nk_m512h_from_m512i_(x) (x)
|
|
567
|
+
#define nk_m512i_from_m512h_(x) (x)
|
|
568
|
+
#define nk_m256bh_from_m256i_(x) (x)
|
|
569
|
+
#define nk_m256i_from_m256bh_(x) (x)
|
|
570
|
+
#else
|
|
571
|
+
#define nk_m512bh_from_m512i_(x) ((__m512bh)(x))
|
|
572
|
+
#define nk_m512h_from_m512i_(x) ((__m512h)(x))
|
|
573
|
+
#define nk_m512i_from_m512h_(x) ((__m512i)(x))
|
|
574
|
+
#define nk_m256bh_from_m256i_(x) ((__m256bh)(x))
|
|
575
|
+
#define nk_m256i_from_m256bh_(x) ((__m256i)(x))
|
|
576
|
+
#endif
|
|
577
|
+
#endif
|
|
578
|
+
|
|
579
|
+
/** Copy 16 bits (2 bytes) from source to destination */
|
|
580
|
+
#if defined(__GNUC__) || defined(__clang__)
|
|
581
|
+
#define nk_copy_bytes_(destination_ptr, source_ptr, count) __builtin_memcpy((destination_ptr), (source_ptr), count)
|
|
582
|
+
#else
|
|
583
|
+
#include <string.h> // `memcpy`
|
|
584
|
+
#define nk_copy_bytes_(destination_ptr, source_ptr, count) memcpy((destination_ptr), (source_ptr), count)
|
|
585
|
+
#endif
|
|
586
|
+
|
|
587
|
+
/** Macro to mark unused parameters (cleaner than (void)variable) */
|
|
588
|
+
#define nk_unused_(x) ((void)(x))
|
|
589
|
+
|
|
590
|
+
/**
|
|
591
|
+
* @brief C99 static array parameter annotation for minimum array size.
|
|
592
|
+
*
|
|
593
|
+
* In C, expands to `static n` enabling compiler bounds checking.
|
|
594
|
+
* In C++, expands to nothing as this syntax is not supported.
|
|
595
|
+
* @see https://lwn.net/Articles/1046840/
|
|
596
|
+
*
|
|
597
|
+
* Example usage:
|
|
598
|
+
* @code{.c}
|
|
599
|
+
* void hash_digest(uint8_t digest[nk_at_least_(32)]);
|
|
600
|
+
* void lookup(uint8_t const lut[nk_at_least_(256)]);
|
|
601
|
+
* @endcode
|
|
602
|
+
*/
|
|
603
|
+
#if defined(__cplusplus) || defined(_MSC_VER)
|
|
604
|
+
#define nk_at_least_(n)
|
|
605
|
+
#else
|
|
606
|
+
#define nk_at_least_(n) static n
|
|
607
|
+
#endif
|
|
608
|
+
|
|
609
|
+
#ifdef __cplusplus
|
|
610
|
+
extern "C" {
|
|
611
|
+
#endif
|
|
612
|
+
|
|
613
|
+
/** @brief Packed 8-bit bit-vector (8 booleans in one byte), LSB = dimension 0.
|
|
614
|
+
* Used for Hamming distance and Jaccard similarity via popcount.
|
|
615
|
+
* Dimension count must be a multiple of 8; unused bits in the final byte must be zeroed. */
|
|
616
|
+
typedef unsigned char nk_u1x8_t;
|
|
617
|
+
/** @brief Packed 4-bit signed integer pair (2 × i4 in one byte), [high nibble : low nibble].
|
|
618
|
+
* Range per element: [−8, +7]. Elements sign-extended to i8 for arithmetic.
|
|
619
|
+
* Dimension count must be a multiple of 2; unused nibbles in the final byte must be zeroed. */
|
|
620
|
+
typedef unsigned char nk_i4x2_t;
|
|
621
|
+
/** @brief Packed 4-bit unsigned integer pair (2 × u4 in one byte), [high nibble : low nibble].
|
|
622
|
+
* Range per element: [0, 15]. Elements zero-extended to u8 for arithmetic.
|
|
623
|
+
* Dimension count must be a multiple of 2; unused nibbles in the final byte must be zeroed. */
|
|
624
|
+
typedef unsigned char nk_u4x2_t;
|
|
625
|
+
|
|
626
|
+
/** @brief 8-bit E4M3 float (OCP FP8): sign(1) + exponent(4) + mantissa(3), bias=7.
|
|
627
|
+
* Range: ±448, no infinities (all-ones exponent → NaN at 0x7F/0xFF).
|
|
628
|
+
* 114 of 254 finite values (44.9%) fall in [−1, +1]. */
|
|
629
|
+
typedef unsigned char nk_e4m3_t;
|
|
630
|
+
/** @brief 8-bit E5M2 float (OCP FP8): sign(1) + exponent(5) + mantissa(2), bias=15.
|
|
631
|
+
* Range: ±57 344, supports infinities at 0x7C/0xFC.
|
|
632
|
+
* 122 of 248 finite values (49.2%) fall in [−1, +1]. */
|
|
633
|
+
typedef unsigned char nk_e5m2_t;
|
|
634
|
+
/** @brief 6-bit E2M3 micro-float (OCP MX v1.0): sign(1) + exponent(2) + mantissa(3), bias=1.
|
|
635
|
+
* Range: ±7.5, no infinities or NaN. Only 64 total codes; 18 (28.1%) fall in [−1, +1]. */
|
|
636
|
+
typedef unsigned char nk_e2m3_t;
|
|
637
|
+
/** @brief 6-bit E3M2 micro-float (OCP MX v1.0): sign(1) + exponent(3) + mantissa(2), bias=3.
|
|
638
|
+
* Range: ±28, supports infinities. Only 64 total codes; 26 (40.6%) fall in [−1, +1]. */
|
|
639
|
+
typedef unsigned char nk_e3m2_t;
|
|
640
|
+
|
|
641
|
+
/** @brief Signed 8-bit integer. Range: [−128, +127]. */
|
|
642
|
+
typedef signed char nk_i8_t;
|
|
643
|
+
/** @brief Unsigned 8-bit integer. Range: [0, 255]. */
|
|
644
|
+
typedef unsigned char nk_u8_t;
|
|
645
|
+
/** @brief Signed 16-bit integer. Range: [−32 768, +32 767]. */
|
|
646
|
+
typedef signed short nk_i16_t;
|
|
647
|
+
/** @brief Unsigned 16-bit integer. Range: [0, 65 535]. */
|
|
648
|
+
typedef unsigned short nk_u16_t;
|
|
649
|
+
/** @brief Signed 32-bit integer. Range: [−2³¹, +2³¹−1]. */
|
|
650
|
+
typedef signed int nk_i32_t;
|
|
651
|
+
/** @brief Unsigned 32-bit integer. Range: [0, 2³²−1]. */
|
|
652
|
+
typedef unsigned int nk_u32_t;
|
|
653
|
+
/* On LP64 targets (Linux ARM64, RISC-V 64), `long` and `long long` are both 64-bit but distinct types.
|
|
654
|
+
* NEON/RVV intrinsics on Linux expect `long*`, while Apple's NEON intrinsics expect `long long*`.
|
|
655
|
+
* Windows uses LLP64 where `long` is 32-bit, so it must use `long long` for 64-bit types. */
|
|
656
|
+
#if ((NK_TARGET_ARM_ && !defined(NK_DEFINED_APPLE_)) || NK_TARGET_RISCV_) && !defined(NK_DEFINED_WINDOWS_)
|
|
657
|
+
/** @brief Signed 64-bit integer. Range: [−2⁶³, +2⁶³−1]. */
|
|
658
|
+
typedef signed long nk_i64_t;
|
|
659
|
+
/** @brief Unsigned 64-bit integer. Range: [0, 2⁶⁴−1]. */
|
|
660
|
+
typedef unsigned long nk_u64_t;
|
|
661
|
+
#else
|
|
662
|
+
/** @brief Signed 64-bit integer. Range: [−2⁶³, +2⁶³−1]. */
|
|
663
|
+
typedef signed long long nk_i64_t;
|
|
664
|
+
/** @brief Unsigned 64-bit integer. Range: [0, 2⁶⁴−1]. */
|
|
665
|
+
typedef unsigned long long nk_u64_t;
|
|
666
|
+
#endif
|
|
667
|
+
|
|
668
|
+
/** @brief Single-precision (32-bit) IEEE 754 float. sign(1) + exponent(8) + mantissa(23), bias=127. */
|
|
669
|
+
typedef float nk_f32_t;
|
|
670
|
+
/** @brief Double-precision (64-bit) IEEE 754 float. sign(1) + exponent(11) + mantissa(52), bias=1023. */
|
|
671
|
+
typedef double nk_f64_t;
|
|
672
|
+
|
|
673
|
+
#if NK_TARGET_X86_ || NK_TARGET_ARM_ || NK_TARGET_RISCV_
|
|
674
|
+
#define NK_IS_64BIT_ 1
|
|
675
|
+
#else
|
|
676
|
+
#define NK_IS_64BIT_ 0
|
|
677
|
+
#endif
|
|
678
|
+
|
|
679
|
+
#if NK_IS_64BIT_
|
|
680
|
+
typedef nk_u64_t nk_size_t;
|
|
681
|
+
typedef nk_i64_t nk_ssize_t;
|
|
682
|
+
#else
|
|
683
|
+
typedef nk_u32_t nk_size_t;
|
|
684
|
+
typedef nk_i32_t nk_ssize_t;
|
|
685
|
+
#endif
|
|
686
|
+
typedef nk_f64_t nk_fmax_t;
|
|
687
|
+
|
|
688
|
+
#define NK_SIZE_MAX ((nk_size_t) - 1)
|
|
689
|
+
|
|
690
|
+
#define NK_F64_MAX 1.7976931348623157e+308
|
|
691
|
+
#define NK_F64_MIN (-1.7976931348623157e+308)
|
|
692
|
+
#define NK_F32_MAX 3.402823466e+38f
|
|
693
|
+
#define NK_F32_MIN (-3.402823466e+38f)
|
|
694
|
+
|
|
695
|
+
#define NK_I64_MAX 9223372036854775807LL
|
|
696
|
+
#define NK_I64_MIN (-9223372036854775807LL - 1LL)
|
|
697
|
+
#define NK_U64_MAX 18446744073709551615ULL
|
|
698
|
+
#define NK_U64_MIN 0x0ULL
|
|
699
|
+
|
|
700
|
+
#define NK_I32_MAX 2147483647
|
|
701
|
+
#define NK_I32_MIN (-2147483647 - 1)
|
|
702
|
+
#define NK_U32_MAX 4294967295U
|
|
703
|
+
#define NK_U32_MIN 0x0U
|
|
704
|
+
|
|
705
|
+
#define NK_I16_MAX 32767
|
|
706
|
+
#define NK_I16_MIN (-32767 - 1)
|
|
707
|
+
#define NK_U16_MAX 65535U
|
|
708
|
+
#define NK_U16_MIN 0x0U
|
|
709
|
+
|
|
710
|
+
#define NK_I8_MAX 127
|
|
711
|
+
#define NK_I8_MIN (-127 - 1)
|
|
712
|
+
#define NK_U8_MAX 255U
|
|
713
|
+
#define NK_U8_MIN 0x0U
|
|
714
|
+
|
|
715
|
+
#define NK_F16_MAX 0x7BFF // IEEE 754 binary16: +65504.0
|
|
716
|
+
#define NK_F16_MIN 0xFBFF // IEEE 754 binary16: -65504.0
|
|
717
|
+
|
|
718
|
+
#define NK_BF16_MAX 0x7F7F // BFloat16: ~+3.39e38
|
|
719
|
+
#define NK_BF16_MIN 0xFF7F // BFloat16: ~-3.39e38
|
|
720
|
+
|
|
721
|
+
#define NK_E4M3_MAX 0x7E // FP8 E4M3: +448.0
|
|
722
|
+
#define NK_E4M3_MIN 0xFE // FP8 E4M3: -448.0
|
|
723
|
+
|
|
724
|
+
#define NK_E5M2_MAX 0x7B // FP8 E5M2: +57344.0
|
|
725
|
+
#define NK_E5M2_MIN 0xFB // FP8 E5M2: -57344.0
|
|
726
|
+
|
|
727
|
+
#define NK_E2M3_MAX 0x1F // FP6 E2M3: +7.5
|
|
728
|
+
#define NK_E2M3_MIN 0x3F // FP6 E2M3: -7.5
|
|
729
|
+
|
|
730
|
+
#define NK_E3M2_MAX 0x1F // FP6 E3M2: +28.0
|
|
731
|
+
#define NK_E3M2_MIN 0x3F // FP6 E3M2: -28.0
|
|
732
|
+
|
|
733
|
+
#define NK_BITS_PER_BYTE 8
|
|
734
|
+
|
|
735
|
+
/**
|
|
736
|
+
* @brief Enumeration of supported scalar data types.
|
|
737
|
+
*
|
|
738
|
+
* Includes complex type descriptors which in C code would use the real counterparts,
|
|
739
|
+
* but the independent flags contain metadata to be passed between programming language
|
|
740
|
+
* interfaces.
|
|
741
|
+
*/
|
|
742
|
+
typedef enum {
|
|
743
|
+
nk_dtype_unknown_k = 0, ///< Unknown data type
|
|
744
|
+
nk_u1_k = 1 << 1, ///< Single-bit values packed into 8-bit words
|
|
745
|
+
|
|
746
|
+
nk_i8_k = 1 << 2, ///< 8-bit signed integer
|
|
747
|
+
nk_i16_k = 1 << 3, ///< 16-bit signed integer
|
|
748
|
+
nk_i32_k = 1 << 4, ///< 32-bit signed integer
|
|
749
|
+
nk_i64_k = 1 << 5, ///< 64-bit signed integer
|
|
750
|
+
|
|
751
|
+
nk_u8_k = 1 << 6, ///< 8-bit unsigned integer
|
|
752
|
+
nk_u16_k = 1 << 7, ///< 16-bit unsigned integer
|
|
753
|
+
nk_u32_k = 1 << 8, ///< 32-bit unsigned integer
|
|
754
|
+
nk_u64_k = 1 << 9, ///< 64-bit unsigned integer
|
|
755
|
+
|
|
756
|
+
nk_f64_k = 1 << 10, ///< Double precision floating point
|
|
757
|
+
nk_f32_k = 1 << 11, ///< Single precision floating point
|
|
758
|
+
nk_f16_k = 1 << 12, ///< Half precision floating point
|
|
759
|
+
nk_bf16_k = 1 << 13, ///< Brain floating point
|
|
760
|
+
|
|
761
|
+
nk_e4m3_k = 1 << 14, ///< FP8 E4M3 floating point
|
|
762
|
+
nk_e5m2_k = 1 << 15, ///< FP8 E5M2 floating point
|
|
763
|
+
nk_i4_k = 1 << 16, ///< 4-bit signed integers packed into 8-bit words
|
|
764
|
+
nk_u4_k = 1 << 17, ///< 4-bit unsigned integers packed into 8-bit words
|
|
765
|
+
nk_e2m3_k = 1 << 18, ///< FP6 E2M3 floating point
|
|
766
|
+
nk_e3m2_k = 1 << 19, ///< FP6 E3M2 floating point
|
|
767
|
+
|
|
768
|
+
nk_f64c_k = 1 << 20, ///< Complex double precision floating point
|
|
769
|
+
nk_f32c_k = 1 << 21, ///< Complex single precision floating point
|
|
770
|
+
nk_f16c_k = 1 << 22, ///< Complex half precision floating point
|
|
771
|
+
nk_bf16c_k = 1 << 23, ///< Complex brain floating point
|
|
772
|
+
} nk_dtype_t;
|
|
773
|
+
|
|
774
|
+
typedef enum {
|
|
775
|
+
nk_dtype_family_unknown_k = 0,
|
|
776
|
+
nk_dtype_family_float_k,
|
|
777
|
+
nk_dtype_family_complex_float_k,
|
|
778
|
+
nk_dtype_family_int_k,
|
|
779
|
+
nk_dtype_family_uint_k,
|
|
780
|
+
} nk_dtype_family_t;
|
|
781
|
+
|
|
782
|
+
/** @brief Classifies the family of the dtype. */
|
|
783
|
+
NK_PUBLIC nk_dtype_family_t nk_dtype_family(nk_dtype_t dtype) {
|
|
784
|
+
switch (dtype) {
|
|
785
|
+
case nk_f64_k: return nk_dtype_family_float_k;
|
|
786
|
+
case nk_f32_k: return nk_dtype_family_float_k;
|
|
787
|
+
case nk_f16_k: return nk_dtype_family_float_k;
|
|
788
|
+
case nk_bf16_k: return nk_dtype_family_float_k;
|
|
789
|
+
case nk_e4m3_k: return nk_dtype_family_float_k;
|
|
790
|
+
case nk_e5m2_k: return nk_dtype_family_float_k;
|
|
791
|
+
case nk_e2m3_k: return nk_dtype_family_float_k;
|
|
792
|
+
case nk_e3m2_k: return nk_dtype_family_float_k;
|
|
793
|
+
case nk_f64c_k: return nk_dtype_family_complex_float_k;
|
|
794
|
+
case nk_f32c_k: return nk_dtype_family_complex_float_k;
|
|
795
|
+
case nk_f16c_k: return nk_dtype_family_complex_float_k;
|
|
796
|
+
case nk_bf16c_k: return nk_dtype_family_complex_float_k;
|
|
797
|
+
case nk_u1_k: return nk_dtype_family_uint_k;
|
|
798
|
+
case nk_u4_k: return nk_dtype_family_uint_k;
|
|
799
|
+
case nk_u8_k: return nk_dtype_family_uint_k;
|
|
800
|
+
case nk_u16_k: return nk_dtype_family_uint_k;
|
|
801
|
+
case nk_u32_k: return nk_dtype_family_uint_k;
|
|
802
|
+
case nk_u64_k: return nk_dtype_family_uint_k;
|
|
803
|
+
case nk_i4_k: return nk_dtype_family_int_k;
|
|
804
|
+
case nk_i8_k: return nk_dtype_family_int_k;
|
|
805
|
+
case nk_i16_k: return nk_dtype_family_int_k;
|
|
806
|
+
case nk_i32_k: return nk_dtype_family_int_k;
|
|
807
|
+
case nk_i64_k: return nk_dtype_family_int_k;
|
|
808
|
+
default: return nk_dtype_family_unknown_k;
|
|
809
|
+
}
|
|
810
|
+
}
|
|
811
|
+
|
|
812
|
+
/** @brief Returns the number of bits in a single scalar of a given type. */
|
|
813
|
+
NK_PUBLIC nk_size_t nk_dtype_bits(nk_dtype_t dtype) {
|
|
814
|
+
switch (dtype) {
|
|
815
|
+
case nk_f64_k: return 64;
|
|
816
|
+
case nk_f32_k: return 32;
|
|
817
|
+
case nk_f16_k: return 16;
|
|
818
|
+
case nk_bf16_k: return 16;
|
|
819
|
+
case nk_e4m3_k: return 8;
|
|
820
|
+
case nk_e5m2_k: return 8;
|
|
821
|
+
case nk_e2m3_k: return 8;
|
|
822
|
+
case nk_e3m2_k: return 8;
|
|
823
|
+
case nk_f64c_k: return 128;
|
|
824
|
+
case nk_f32c_k: return 64;
|
|
825
|
+
case nk_f16c_k: return 32;
|
|
826
|
+
case nk_bf16c_k: return 32;
|
|
827
|
+
case nk_u1_k: return 1;
|
|
828
|
+
case nk_u4_k: return 4;
|
|
829
|
+
case nk_u8_k: return 8;
|
|
830
|
+
case nk_u16_k: return 16;
|
|
831
|
+
case nk_u32_k: return 32;
|
|
832
|
+
case nk_u64_k: return 64;
|
|
833
|
+
case nk_i4_k: return 4;
|
|
834
|
+
case nk_i8_k: return 8;
|
|
835
|
+
case nk_i16_k: return 16;
|
|
836
|
+
case nk_i32_k: return 32;
|
|
837
|
+
case nk_i64_k: return 64;
|
|
838
|
+
default: return 0;
|
|
839
|
+
}
|
|
840
|
+
}
|
|
841
|
+
|
|
842
|
+
/** @brief Returns how many logical dimensions are packed into one storage value.
|
|
843
|
+
* For sub-byte types multiple dimensions share a single byte container.
|
|
844
|
+
* For byte-or-larger types this is always 1. */
|
|
845
|
+
NK_PUBLIC nk_size_t nk_dtype_dimensions_per_value(nk_dtype_t dtype) {
|
|
846
|
+
switch (dtype) {
|
|
847
|
+
case nk_u1_k: return 8;
|
|
848
|
+
case nk_i4_k: return 2;
|
|
849
|
+
case nk_u4_k: return 2;
|
|
850
|
+
default: return 1;
|
|
851
|
+
}
|
|
852
|
+
}
|
|
853
|
+
|
|
854
|
+
/** @brief Half-precision (16-bit) IEEE 754 float.
|
|
855
|
+
*
|
|
856
|
+
* Layout: sign(1) + exponent(5) + mantissa(10), bias=15.
|
|
857
|
+
* Range: ±65 504, epsilon at 1.0 ≈ 9.77×10⁻⁴. 30 722 of 63 488 finite values (48.4%) in [−1, +1].
|
|
858
|
+
*
|
|
859
|
+
* - GCC or Clang on 64-bit Arm: `__fp16`, may require `-mfp16-format` option.
|
|
860
|
+
* - GCC or Clang on 64-bit x86: `_Float16`.
|
|
861
|
+
* - Default: `unsigned short`.
|
|
862
|
+
*/
|
|
863
|
+
#if !defined(NK_NATIVE_F16) || NK_NATIVE_F16
|
|
864
|
+
#if (defined(__GNUC__) || defined(__clang__)) && (defined(__ARM_ARCH) || defined(__aarch64__)) && \
|
|
865
|
+
(defined(__ARM_FP16_FORMAT_IEEE))
|
|
866
|
+
#undef NK_NATIVE_F16
|
|
867
|
+
#define NK_NATIVE_F16 1
|
|
868
|
+
typedef __fp16 nk_f16_t;
|
|
869
|
+
#elif ((defined(__GNUC__) || defined(__clang__)) && (defined(__x86_64__) || defined(__i386__)) && \
|
|
870
|
+
(defined(__AVX512FP16__)))
|
|
871
|
+
typedef _Float16 nk_f16_t;
|
|
872
|
+
#undef NK_NATIVE_F16
|
|
873
|
+
#define NK_NATIVE_F16 1
|
|
874
|
+
#else // Unknown compiler or architecture
|
|
875
|
+
#undef NK_NATIVE_F16
|
|
876
|
+
#define NK_NATIVE_F16 0
|
|
877
|
+
#endif // Unknown compiler or architecture
|
|
878
|
+
#endif // !NK_NATIVE_F16
|
|
879
|
+
|
|
880
|
+
#if !NK_NATIVE_F16
|
|
881
|
+
typedef unsigned short nk_f16_t;
|
|
882
|
+
#endif
|
|
883
|
+
|
|
884
|
+
#if !defined(NK_NATIVE_BF16) || NK_NATIVE_BF16
|
|
885
|
+
/** @brief BFloat16 (16-bit) float — truncated IEEE 754 single-precision.
|
|
886
|
+
*
|
|
887
|
+
* Layout: sign(1) + exponent(8) + mantissa(7), bias=127.
|
|
888
|
+
* Same dynamic range as f32, epsilon ≈ 7.81×10⁻³.
|
|
889
|
+
* 32 514 of 65 280 finite values (49.8%) in [−1, +1]. Wider range than f16 but lower precision.
|
|
890
|
+
*
|
|
891
|
+
* - GCC or Clang: `__bf16`
|
|
892
|
+
* - Default: `unsigned short`.
|
|
893
|
+
*
|
|
894
|
+
* The compilers have added `__bf16` support in compliance with the x86-64 psABI spec.
|
|
895
|
+
* The motivation for this new special type is summed up as:
|
|
896
|
+
*
|
|
897
|
+
* Currently `__bfloat16` is a typedef of short, which creates a problem where the
|
|
898
|
+
* compiler does not raise any alarms if it is used to add, subtract, multiply or
|
|
899
|
+
* divide, but the result of the calculation is actually meaningless.
|
|
900
|
+
* To solve this problem, a real scalar type `__Bfloat16` needs to be introduced.
|
|
901
|
+
* It is mainly used for intrinsics, not available for C standard operators.
|
|
902
|
+
* `__Bfloat16` will also be used for movement like passing parameter, load and store,
|
|
903
|
+
* vector initialization, vector shuffle, and etc. It creates a need for a
|
|
904
|
+
* corresponding psABI.
|
|
905
|
+
*
|
|
906
|
+
* @warning Apple Clang has hard time with bf16.
|
|
907
|
+
* https://developer.apple.com/documentation/xcode/writing-arm64-code-for-apple-platforms
|
|
908
|
+
* https://forums.developer.apple.com/forums/thread/726201
|
|
909
|
+
* https://www.phoronix.com/news/GCC-LLVM-bf16-BFloat16-Type
|
|
910
|
+
*/
|
|
911
|
+
#if (defined(__GNUC__) || defined(__clang__)) && ((defined(__ARM_BF16_FORMAT_ALTERNATIVE)) || (defined(__AVX512BF16__)))
|
|
912
|
+
#undef NK_NATIVE_BF16
|
|
913
|
+
#define NK_NATIVE_BF16 1
|
|
914
|
+
typedef __bf16 nk_bf16_t;
|
|
915
|
+
#else // Unknown compiler or architecture
|
|
916
|
+
#undef NK_NATIVE_BF16
|
|
917
|
+
#define NK_NATIVE_BF16 0
|
|
918
|
+
#endif // Unknown compiler or architecture
|
|
919
|
+
#endif // !NK_NATIVE_BF16
|
|
920
|
+
|
|
921
|
+
#if !NK_NATIVE_BF16
|
|
922
|
+
typedef unsigned short nk_bf16_t;
|
|
923
|
+
#endif
|
|
924
|
+
|
|
925
|
+
/**
|
|
926
|
+
* @brief Alias for the half-precision floating-point type on Arm.
|
|
927
|
+
*
|
|
928
|
+
* Clang and GCC bring the `float16_t` symbol when you compile for Aarch64.
|
|
929
|
+
* MSVC lacks it, and it's `vld1_f16`-like intrinsics are in reality macros,
|
|
930
|
+
* that cast to 16-bit integers internally, instead of using floats.
|
|
931
|
+
* Some of those are defined as aliases, so we use `#define` preprocessor
|
|
932
|
+
* directives instead of `typedef` to avoid errors.
|
|
933
|
+
*/
|
|
934
|
+
#if NK_TARGET_ARM_
|
|
935
|
+
#if defined(_MSC_VER)
|
|
936
|
+
#define nk_f16_for_arm_simd_t nk_f16_t
|
|
937
|
+
#define nk_bf16_for_arm_simd_t nk_bf16_t
|
|
938
|
+
#else
|
|
939
|
+
#define nk_f16_for_arm_simd_t float16_t
|
|
940
|
+
#define nk_bf16_for_arm_simd_t bfloat16_t
|
|
941
|
+
#endif
|
|
942
|
+
#endif
|
|
943
|
+
|
|
944
|
+
/**
|
|
945
|
+
* RISC-V Vector (RVV) intrinsics use `_Float16` for half-precision floats.
|
|
946
|
+
* This is the standard C23 type, also available in GCC/Clang with RVV extensions.
|
|
947
|
+
*/
|
|
948
|
+
#if NK_TARGET_RISCV_
|
|
949
|
+
#define nk_f16_for_rvv_intrinsics_t _Float16
|
|
950
|
+
#endif
|
|
951
|
+
|
|
952
|
+
/*
|
|
953
|
+
* Let's make sure the sizes of the types are as expected.
|
|
954
|
+
* In C the `_Static_assert` is only available with C11 and later.
|
|
955
|
+
*/
|
|
956
|
+
#define NK_STATIC_ASSERT(cond, msg) typedef char static_assertion_##msg[(cond) ? 1 : -1]
|
|
957
|
+
NK_STATIC_ASSERT(sizeof(nk_u1x8_t) == 1, nk_u1x8_t_must_be_1_byte);
|
|
958
|
+
NK_STATIC_ASSERT(sizeof(nk_i4x2_t) == 1, nk_i4_t_must_be_1_byte);
|
|
959
|
+
NK_STATIC_ASSERT(sizeof(nk_u4x2_t) == 1, nk_u4_t_must_be_1_byte);
|
|
960
|
+
NK_STATIC_ASSERT(sizeof(nk_e4m3_t) == 1, nk_e4m3_t_must_be_1_byte);
|
|
961
|
+
NK_STATIC_ASSERT(sizeof(nk_e5m2_t) == 1, nk_e5m2_t_must_be_1_byte);
|
|
962
|
+
NK_STATIC_ASSERT(sizeof(nk_i8_t) == 1, nk_i8_t_must_be_1_byte);
|
|
963
|
+
NK_STATIC_ASSERT(sizeof(nk_u8_t) == 1, nk_u8_t_must_be_1_byte);
|
|
964
|
+
NK_STATIC_ASSERT(sizeof(nk_i16_t) == 2, nk_i16_t_must_be_2_bytes);
|
|
965
|
+
NK_STATIC_ASSERT(sizeof(nk_u16_t) == 2, nk_u16_t_must_be_2_bytes);
|
|
966
|
+
NK_STATIC_ASSERT(sizeof(nk_i32_t) == 4, nk_i32_t_must_be_4_bytes);
|
|
967
|
+
NK_STATIC_ASSERT(sizeof(nk_u32_t) == 4, nk_u32_t_must_be_4_bytes);
|
|
968
|
+
NK_STATIC_ASSERT(sizeof(nk_i64_t) == 8, nk_i64_t_must_be_8_bytes);
|
|
969
|
+
NK_STATIC_ASSERT(sizeof(nk_u64_t) == 8, nk_u64_t_must_be_8_bytes);
|
|
970
|
+
NK_STATIC_ASSERT(sizeof(nk_f32_t) == 4, nk_f32_t_must_be_4_bytes);
|
|
971
|
+
NK_STATIC_ASSERT(sizeof(nk_f64_t) == 8, nk_f64_t_must_be_8_bytes);
|
|
972
|
+
NK_STATIC_ASSERT(sizeof(nk_f16_t) == 2, nk_f16_t_must_be_2_bytes);
|
|
973
|
+
NK_STATIC_ASSERT(sizeof(nk_bf16_t) == 2, nk_bf16_t_must_be_2_bytes);
|
|
974
|
+
|
|
975
|
+
#define nk_assign_from_to_(src, dest) (*(dest) = *(src))
|
|
976
|
+
|
|
977
|
+
/** @brief 16-bit union for f16/bf16/u16/i16 bit manipulation. */
|
|
978
|
+
typedef union {
|
|
979
|
+
nk_u16_t u;
|
|
980
|
+
nk_i16_t i;
|
|
981
|
+
nk_f16_t f;
|
|
982
|
+
nk_bf16_t bf;
|
|
983
|
+
} nk_fui16_t;
|
|
984
|
+
|
|
985
|
+
/** @brief 32-bit union for f32/u32/i32 bit manipulation. */
|
|
986
|
+
typedef union {
|
|
987
|
+
nk_u32_t u;
|
|
988
|
+
nk_i32_t i;
|
|
989
|
+
nk_f32_t f;
|
|
990
|
+
} nk_fui32_t;
|
|
991
|
+
|
|
992
|
+
/** @brief 64-bit union for f64/u64/i64 bit manipulation. */
|
|
993
|
+
typedef union {
|
|
994
|
+
nk_u64_t u;
|
|
995
|
+
nk_i64_t i;
|
|
996
|
+
nk_f64_t f;
|
|
997
|
+
} nk_fui64_t;
|
|
998
|
+
|
|
999
|
+
/** @brief Half-precision (32-bit) complex number — {real: f16, imag: f16}. Kernel outputs widened to f32c. */
|
|
1000
|
+
typedef struct {
|
|
1001
|
+
nk_f16_t real;
|
|
1002
|
+
nk_f16_t imag;
|
|
1003
|
+
} nk_f16c_t;
|
|
1004
|
+
|
|
1005
|
+
/** @brief BFloat16 (32-bit) complex number — {real: bf16, imag: bf16}. Kernel outputs widened to f32c. */
|
|
1006
|
+
typedef struct {
|
|
1007
|
+
nk_bf16_t real;
|
|
1008
|
+
nk_bf16_t imag;
|
|
1009
|
+
} nk_bf16c_t;
|
|
1010
|
+
|
|
1011
|
+
/** @brief Single-precision (64-bit) complex number — {real: f32, imag: f32}. */
|
|
1012
|
+
typedef struct {
|
|
1013
|
+
nk_f32_t real;
|
|
1014
|
+
nk_f32_t imag;
|
|
1015
|
+
} nk_f32c_t;
|
|
1016
|
+
|
|
1017
|
+
/** @brief Double-precision (128-bit) complex number — {real: f64, imag: f64}. */
|
|
1018
|
+
typedef struct {
|
|
1019
|
+
nk_f64_t real;
|
|
1020
|
+
nk_f64_t imag;
|
|
1021
|
+
} nk_f64c_t;
|
|
1022
|
+
|
|
1023
|
+
/** @brief Small 4-byte memory slice viewable as different types. */
|
|
1024
|
+
typedef union nk_b32_vec_t {
|
|
1025
|
+
nk_u32_t u32;
|
|
1026
|
+
nk_i32_t i32;
|
|
1027
|
+
nk_f32_t f32;
|
|
1028
|
+
nk_u8_t u8s[4];
|
|
1029
|
+
nk_i8_t i8s[4];
|
|
1030
|
+
nk_u16_t u16s[2];
|
|
1031
|
+
nk_i16_t i16s[2];
|
|
1032
|
+
nk_e4m3_t e4m3s[4];
|
|
1033
|
+
nk_e5m2_t e5m2s[4];
|
|
1034
|
+
} nk_b32_vec_t;
|
|
1035
|
+
|
|
1036
|
+
/** @brief Small 8-byte memory slice viewable as different types. */
|
|
1037
|
+
typedef union nk_b64_vec_t {
|
|
1038
|
+
#if NK_TARGET_NEON
|
|
1039
|
+
uint8x8_t u8x8;
|
|
1040
|
+
uint16x4_t u16x4;
|
|
1041
|
+
uint32x2_t u32x2;
|
|
1042
|
+
int8x8_t i8x8;
|
|
1043
|
+
int16x4_t i16x4;
|
|
1044
|
+
int32x2_t i32x2;
|
|
1045
|
+
float32x2_t f32x2;
|
|
1046
|
+
#endif
|
|
1047
|
+
#if NK_TARGET_NEONHALF
|
|
1048
|
+
float16x4_t f16x4;
|
|
1049
|
+
#endif
|
|
1050
|
+
nk_u8_t u8s[8];
|
|
1051
|
+
nk_u16_t u16s[4];
|
|
1052
|
+
nk_u32_t u32s[2];
|
|
1053
|
+
nk_u64_t u64;
|
|
1054
|
+
nk_i8_t i8s[8];
|
|
1055
|
+
nk_i16_t i16s[4];
|
|
1056
|
+
nk_i32_t i32s[2];
|
|
1057
|
+
nk_i64_t i64;
|
|
1058
|
+
nk_f16_t f16s[4];
|
|
1059
|
+
nk_bf16_t bf16s[4];
|
|
1060
|
+
nk_f32_t f32s[2];
|
|
1061
|
+
} nk_b64_vec_t;
|
|
1062
|
+
|
|
1063
|
+
/** @brief Small 16-byte memory slice viewable as different types. */
|
|
1064
|
+
typedef union nk_b128_vec_t {
|
|
1065
|
+
#if NK_TARGET_HASWELL
|
|
1066
|
+
__m128i xmm;
|
|
1067
|
+
__m128d xmm_pd;
|
|
1068
|
+
__m128 xmm_ps;
|
|
1069
|
+
#endif
|
|
1070
|
+
#if NK_TARGET_V128RELAXED
|
|
1071
|
+
v128_t v128;
|
|
1072
|
+
#endif
|
|
1073
|
+
#if NK_TARGET_NEON
|
|
1074
|
+
uint8x16_t u8x16;
|
|
1075
|
+
uint16x8_t u16x8;
|
|
1076
|
+
uint32x4_t u32x4;
|
|
1077
|
+
uint64x2_t u64x2;
|
|
1078
|
+
int8x16_t i8x16;
|
|
1079
|
+
int16x8_t i16x8;
|
|
1080
|
+
int32x4_t i32x4;
|
|
1081
|
+
int64x2_t i64x2;
|
|
1082
|
+
float32x4_t f32x4;
|
|
1083
|
+
float64x2_t f64x2;
|
|
1084
|
+
#endif
|
|
1085
|
+
nk_u8_t u8s[16];
|
|
1086
|
+
nk_u16_t u16s[8];
|
|
1087
|
+
nk_u32_t u32s[4];
|
|
1088
|
+
nk_u64_t u64s[2];
|
|
1089
|
+
nk_i8_t i8s[16];
|
|
1090
|
+
nk_i16_t i16s[8];
|
|
1091
|
+
nk_i32_t i32s[4];
|
|
1092
|
+
nk_i64_t i64s[2];
|
|
1093
|
+
nk_f16_t f16s[8];
|
|
1094
|
+
nk_bf16_t bf16s[8];
|
|
1095
|
+
nk_e4m3_t e4m3s[16];
|
|
1096
|
+
nk_e5m2_t e5m2s[16];
|
|
1097
|
+
nk_e2m3_t e2m3s[16];
|
|
1098
|
+
nk_e3m2_t e3m2s[16];
|
|
1099
|
+
nk_f32_t f32s[4];
|
|
1100
|
+
nk_f64_t f64s[2];
|
|
1101
|
+
} nk_b128_vec_t;
|
|
1102
|
+
|
|
1103
|
+
/** @brief Small 32-byte memory slice viewable as different types. */
|
|
1104
|
+
typedef union nk_b256_vec_t {
|
|
1105
|
+
#if NK_TARGET_HASWELL
|
|
1106
|
+
__m256i ymm;
|
|
1107
|
+
__m256d ymm_pd;
|
|
1108
|
+
__m256 ymm_ps;
|
|
1109
|
+
__m128i xmms[2];
|
|
1110
|
+
#endif
|
|
1111
|
+
#if NK_TARGET_V128RELAXED
|
|
1112
|
+
v128_t v128s[2];
|
|
1113
|
+
#endif
|
|
1114
|
+
#if NK_TARGET_NEON
|
|
1115
|
+
uint8x16_t u8x16s[2];
|
|
1116
|
+
uint16x8_t u16x8s[2];
|
|
1117
|
+
uint32x4_t u32x4s[2];
|
|
1118
|
+
uint64x2_t u64x2s[2];
|
|
1119
|
+
int8x16_t i8x16s[2];
|
|
1120
|
+
int16x8_t i16x8s[2];
|
|
1121
|
+
int32x4_t i32x4s[2];
|
|
1122
|
+
int64x2_t i64x2s[2];
|
|
1123
|
+
float32x4_t f32x4s[2];
|
|
1124
|
+
float64x2_t f64x2s[2];
|
|
1125
|
+
#endif
|
|
1126
|
+
nk_u8_t u8s[32];
|
|
1127
|
+
nk_u16_t u16s[16];
|
|
1128
|
+
nk_u32_t u32s[8];
|
|
1129
|
+
nk_u64_t u64s[4];
|
|
1130
|
+
nk_i8_t i8s[32];
|
|
1131
|
+
nk_i16_t i16s[16];
|
|
1132
|
+
nk_i32_t i32s[8];
|
|
1133
|
+
nk_i64_t i64s[4];
|
|
1134
|
+
nk_f16_t f16s[16];
|
|
1135
|
+
nk_bf16_t bf16s[16];
|
|
1136
|
+
nk_e4m3_t e4m3s[32];
|
|
1137
|
+
nk_e5m2_t e5m2s[32];
|
|
1138
|
+
nk_e2m3_t e2m3s[32];
|
|
1139
|
+
nk_e3m2_t e3m2s[32];
|
|
1140
|
+
nk_f32_t f32s[8];
|
|
1141
|
+
nk_f64_t f64s[4];
|
|
1142
|
+
} nk_b256_vec_t;
|
|
1143
|
+
|
|
1144
|
+
/** @brief Small 64-byte memory slice viewable as different types.
|
|
1145
|
+
*
|
|
1146
|
+
* TODO: On GCC and Clang we use `__transparent_union__` attribute to allow implicit conversions
|
|
1147
|
+
* between the different vector types when passing them as function arguments. The most important side-effect
|
|
1148
|
+
* of this is that the argument of such type is passed to functions using the calling convention of the first
|
|
1149
|
+
* member of the union, which in our case is a register-based calling convention for SIMD types.
|
|
1150
|
+
*/
|
|
1151
|
+
typedef union nk_b512_vec_t {
|
|
1152
|
+
#if NK_TARGET_SKYLAKE
|
|
1153
|
+
__m512i zmm;
|
|
1154
|
+
__m512d zmm_pd;
|
|
1155
|
+
__m512 zmm_ps;
|
|
1156
|
+
#endif
|
|
1157
|
+
#if NK_TARGET_HASWELL
|
|
1158
|
+
__m256i ymms[2];
|
|
1159
|
+
__m256d ymms_pd[2];
|
|
1160
|
+
__m256 ymms_ps[2];
|
|
1161
|
+
__m128i xmms[4];
|
|
1162
|
+
__m128d xmms_pd[4];
|
|
1163
|
+
__m128 xmms_ps[4];
|
|
1164
|
+
#endif
|
|
1165
|
+
#if NK_TARGET_NEON
|
|
1166
|
+
uint8x16_t u8x16s[4];
|
|
1167
|
+
uint16x8_t u16x8s[4];
|
|
1168
|
+
uint32x4_t u32x4s[4];
|
|
1169
|
+
uint64x2_t u64x2s[4];
|
|
1170
|
+
#endif
|
|
1171
|
+
nk_u8_t u8s[64];
|
|
1172
|
+
nk_u16_t u16s[32];
|
|
1173
|
+
nk_u32_t u32s[16];
|
|
1174
|
+
nk_u64_t u64s[8];
|
|
1175
|
+
nk_i8_t i8s[64];
|
|
1176
|
+
nk_i16_t i16s[32];
|
|
1177
|
+
nk_i32_t i32s[16];
|
|
1178
|
+
nk_i64_t i64s[8];
|
|
1179
|
+
nk_f16_t f16s[32];
|
|
1180
|
+
nk_bf16_t bf16s[32];
|
|
1181
|
+
nk_f32_t f32s[16];
|
|
1182
|
+
nk_f64_t f64s[8];
|
|
1183
|
+
nk_e4m3_t e4m3s[64];
|
|
1184
|
+
nk_e5m2_t e5m2s[64];
|
|
1185
|
+
nk_e2m3_t e2m3s[64];
|
|
1186
|
+
nk_e3m2_t e3m2s[64];
|
|
1187
|
+
} nk_b512_vec_t;
|
|
1188
|
+
|
|
1189
|
+
/**
|
|
1190
|
+
* @brief Advances the Multi-Dimensional iterator to the next set of indicies.
|
|
1191
|
+
* @param[in] extents The extents of the tensor, defined by an array of at least `rank` scalars.
|
|
1192
|
+
* @param[in] strides The @b signed strides of the tensor in bytes, defined by an array of at least `rank` scalars.
|
|
1193
|
+
* @param[in] rank The number of dimensions in the tensor (its rank).
|
|
1194
|
+
* @param[inout] coordinates The array of offsets along each of `rank` dimensions, which will be updated.
|
|
1195
|
+
* @param[inout] byte_offset The @b signed byte offset of the current element, which will be advanced.
|
|
1196
|
+
* @return 1 if the iterator was successfully advanced, 0 if the end of iteration was reached.
|
|
1197
|
+
*
|
|
1198
|
+
* For flexibility, the API is decoupled from from the `nk_tensor_position_t` structure, and
|
|
1199
|
+
* can be used on any-rank tensors, independent of the `NK_TENSOR_MAX_RANK` constant.
|
|
1200
|
+
*/
|
|
1201
|
+
NK_PUBLIC int nk_tensor_position_next( //
|
|
1202
|
+
nk_size_t const *extents, nk_ssize_t const *strides, nk_size_t rank, //
|
|
1203
|
+
nk_size_t *coordinates, nk_ssize_t *byte_offset) {
|
|
1204
|
+
// Start from last dimension and move backward
|
|
1205
|
+
for (nk_size_t i = rank; i-- > 0;) {
|
|
1206
|
+
coordinates[i]++;
|
|
1207
|
+
*byte_offset += strides[i];
|
|
1208
|
+
if (coordinates[i] < extents[i]) return 1; // Successfully moved to the next index
|
|
1209
|
+
coordinates[i] = 0; // Reset this dimension counter
|
|
1210
|
+
*byte_offset -= strides[i] * extents[i]; // Discard the running progress along this dimension
|
|
1211
|
+
}
|
|
1212
|
+
// If we reach here, we've iterated over all elements
|
|
1213
|
+
return 0; // End of iteration
|
|
1214
|
+
}
|
|
1215
|
+
|
|
1216
|
+
/**
|
|
1217
|
+
* @brief Advances the Multi-Dimensional iterator to the provided coordinates, updating the byte offset.
|
|
1218
|
+
* @param[in] extents The extents of the tensor, defined by an array of at least `rank` scalars.
|
|
1219
|
+
* @param[in] strides The @b signed strides of the tensor in bytes, defined by an array of at least `rank` scalars.
|
|
1220
|
+
* @param[in] rank The number of dimensions in the tensor (its rank).
|
|
1221
|
+
* @param[in] coordinates The array of offsets along each of `rank` dimensions, which will be updated.
|
|
1222
|
+
* @param[out] byte_offset The byte offset of the current element, which will be advanced.
|
|
1223
|
+
* @return 1 if the offset was successfully advanced, 0 if the end of iteration was reached.
|
|
1224
|
+
*/
|
|
1225
|
+
NK_PUBLIC int nk_tensor_position_linearize( //
|
|
1226
|
+
nk_size_t const *extents, nk_ssize_t const *strides, nk_size_t rank, //
|
|
1227
|
+
nk_size_t const *coordinates, nk_ssize_t *byte_offset) {
|
|
1228
|
+
|
|
1229
|
+
nk_ssize_t result = 0;
|
|
1230
|
+
for (nk_size_t i = 0; i < rank; i++) {
|
|
1231
|
+
// Ensure the coordinates is within bounds for the given dimension
|
|
1232
|
+
if (coordinates[i] >= extents[i]) return 0; // Invalid coordinates, out of bounds
|
|
1233
|
+
// Update the byte offset by multiplying the coordinates by the stride
|
|
1234
|
+
result += coordinates[i] * strides[i];
|
|
1235
|
+
}
|
|
1236
|
+
*byte_offset = result;
|
|
1237
|
+
return 1; // Successfully calculated global and byte offsets
|
|
1238
|
+
}
|
|
1239
|
+
|
|
1240
|
+
/**
|
|
1241
|
+
* @brief A @b beefy structure to iterate through Multi-Dimensional arrays.
|
|
1242
|
+
* Occupies 512 + 8 = 520 bytes on a 64-bit machine, or @b 9 cache-lines, by default.
|
|
1243
|
+
*
|
|
1244
|
+
* When advancing through a structure, its overall size and strides should be stored somewhere else.
|
|
1245
|
+
* The `byte_offset` starts at zero and grow monotonically during iteration, if the strides are positive.
|
|
1246
|
+
*/
|
|
1247
|
+
typedef struct nk_tensor_position_t {
|
|
1248
|
+
nk_size_t coordinates[NK_TENSOR_MAX_RANK]; // Coordinate offsets along each dimension
|
|
1249
|
+
nk_ssize_t byte_offset; // Byte offset of the current element
|
|
1250
|
+
} nk_tensor_position_t;
|
|
1251
|
+
|
|
1252
|
+
NK_PUBLIC void nk_tensor_position_init(nk_tensor_position_t *tensor_position) {
|
|
1253
|
+
for (nk_size_t i = 0; i < NK_TENSOR_MAX_RANK; i++) tensor_position->coordinates[i] = 0;
|
|
1254
|
+
tensor_position->byte_offset = 0;
|
|
1255
|
+
}
|
|
1256
|
+
|
|
1257
|
+
/**
|
|
1258
|
+
* @brief A @b beefy structure describing the shape and memory layout of a Multi-Dimensional array.
|
|
1259
|
+
* Similar to `md::span` in C++20 and `numpy.ndarray` in Python, but with a focus on compatibility.
|
|
1260
|
+
* Occupies 512 + 512 + 8 = 2052 bytes on a 64-bit machine, or @b 17 cache-lines, by default.
|
|
1261
|
+
*
|
|
1262
|
+
* Unlike NumPy and the CPython "Buffer Protocol", we don't use `suboffsets` for pointer indirection.
|
|
1263
|
+
* The logic is that such layouts aren't friendly to conventional SIMD operations and dense matrix algorithms.
|
|
1264
|
+
* If the tensor is sparse, consider using a different data structure or a different memory layout.
|
|
1265
|
+
*
|
|
1266
|
+
* Most NumKong algorithms don't work with the entire structure, but expect the fields to be passed separately.
|
|
1267
|
+
* It would also require storing the @b start-pointer and the @b dtype/item-size separately, as it's not
|
|
1268
|
+
* stored inside the structure.
|
|
1269
|
+
*/
|
|
1270
|
+
typedef struct nk_tensor_shape_t {
|
|
1271
|
+
nk_size_t extents[NK_TENSOR_MAX_RANK]; /// Number of elements along each dimension
|
|
1272
|
+
nk_ssize_t strides[NK_TENSOR_MAX_RANK]; /// Strides of the tensor in bytes
|
|
1273
|
+
nk_size_t rank; /// Number of dimensions in the tensor
|
|
1274
|
+
} nk_tensor_shape_t;
|
|
1275
|
+
|
|
1276
|
+
NK_PUBLIC void nk_tensor_shape_init(nk_tensor_shape_t *tensor_shape) {
|
|
1277
|
+
for (nk_size_t i = 0; i < NK_TENSOR_MAX_RANK; i++) tensor_shape->extents[i] = 0, tensor_shape->strides[i] = 0;
|
|
1278
|
+
tensor_shape->rank = 0;
|
|
1279
|
+
}
|
|
1280
|
+
|
|
1281
|
+
NK_INTERNAL nk_u32_t nk_u32_rol(nk_u32_t x, int n) { return (x << n) | (x >> (32 - n)); }
|
|
1282
|
+
NK_INTERNAL nk_u16_t nk_u16_rol(nk_u16_t x, int n) { return (x << n) | (x >> (16 - n)); }
|
|
1283
|
+
NK_INTERNAL nk_u8_t nk_u8_rol(nk_u8_t x, int n) { return (x << n) | (x >> (8 - n)); }
|
|
1284
|
+
NK_INTERNAL nk_u32_t nk_u32_ror(nk_u32_t x, int n) { return (x >> n) | (x << (32 - n)); }
|
|
1285
|
+
NK_INTERNAL nk_u16_t nk_u16_ror(nk_u16_t x, int n) { return (x >> n) | (x << (16 - n)); }
|
|
1286
|
+
NK_INTERNAL nk_u8_t nk_u8_ror(nk_u8_t x, int n) { return (x >> n) | (x << (8 - n)); }
|
|
1287
|
+
|
|
1288
|
+
/**
|
|
1289
|
+
* @brief SWAR population count for 64-bit integers.
|
|
1290
|
+
*
|
|
1291
|
+
* Classic algorithm from Hacker's Delight using parallel bit summation:
|
|
1292
|
+
* - Step 1: Count bits in pairs (2-bit sums)
|
|
1293
|
+
* - Step 2: Count bits in nibbles (4-bit sums)
|
|
1294
|
+
* - Step 3: Count bits in bytes (8-bit sums)
|
|
1295
|
+
* - Step 4: Horizontal sum via multiply - each byte contributes to bits 56-63
|
|
1296
|
+
*
|
|
1297
|
+
* Cost: ~12 ALU ops, zero memory access (vs 8 table lookups for byte-wise).
|
|
1298
|
+
*/
|
|
1299
|
+
NK_INTERNAL nk_u64_t nk_u64_popcount_(nk_u64_t x) {
|
|
1300
|
+
x = x - ((x >> 1) & 0x5555555555555555ull);
|
|
1301
|
+
x = (x & 0x3333333333333333ull) + ((x >> 2) & 0x3333333333333333ull);
|
|
1302
|
+
x = (x + (x >> 4)) & 0x0F0F0F0F0F0F0F0Full;
|
|
1303
|
+
return (x * 0x0101010101010101ull) >> 56;
|
|
1304
|
+
}
|
|
1305
|
+
|
|
1306
|
+
NK_INTERNAL unsigned char nk_u1x8_popcount_(nk_u1x8_t x) {
|
|
1307
|
+
static unsigned char lookup_table[256] = {
|
|
1308
|
+
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, //
|
|
1309
|
+
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
|
|
1310
|
+
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
|
|
1311
|
+
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
|
|
1312
|
+
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
|
|
1313
|
+
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
|
|
1314
|
+
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
|
|
1315
|
+
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8};
|
|
1316
|
+
return lookup_table[x];
|
|
1317
|
+
}
|
|
1318
|
+
|
|
1319
|
+
/** @brief Divides the number rounding up to the next multiple of the given divisor. */
|
|
1320
|
+
NK_PUBLIC nk_size_t nk_size_divide_round_up_(nk_size_t number, nk_size_t divisor) NK_STREAMING_COMPATIBLE_ {
|
|
1321
|
+
return (number + divisor - 1) / divisor;
|
|
1322
|
+
}
|
|
1323
|
+
|
|
1324
|
+
/** @brief Rounds up the number to the next multiple of the given divisor. */
|
|
1325
|
+
NK_PUBLIC nk_size_t nk_size_round_up_to_multiple_(nk_size_t number, nk_size_t divisor) NK_STREAMING_COMPATIBLE_ {
|
|
1326
|
+
return nk_size_divide_round_up_(number, divisor) * divisor;
|
|
1327
|
+
}
|
|
1328
|
+
|
|
1329
|
+
NK_INTERNAL nk_f32_t nk_f32_abs_(nk_f32_t x) { return x < 0 ? -x : x; }
|
|
1330
|
+
NK_INTERNAL nk_f64_t nk_f64_abs_(nk_f64_t x) { return x < 0 ? -x : x; }
|
|
1331
|
+
NK_INTERNAL nk_i64_t nk_i64_abs_(nk_i64_t x) { return x < 0 ? -x : x; }
|
|
1332
|
+
NK_INTERNAL nk_u64_t nk_u64_abs_(nk_u64_t x) { return x; }
|
|
1333
|
+
NK_INTERNAL nk_i64_t nk_i32_abs_(nk_i32_t x) { return x < 0 ? -x : x; }
|
|
1334
|
+
NK_INTERNAL nk_u32_t nk_u32_abs_(nk_u32_t x) { return x; }
|
|
1335
|
+
|
|
1336
|
+
/** @brief Extract low (bits 0-3) unsigned nibble from packed u4x2 byte. */
|
|
1337
|
+
NK_INTERNAL nk_u8_t nk_u4x2_low_(nk_u4x2_t byte_val) { return byte_val & 0x0F; }
|
|
1338
|
+
/** @brief Extract high (bits 4-7) unsigned nibble from packed u4x2 byte. */
|
|
1339
|
+
NK_INTERNAL nk_u8_t nk_u4x2_high_(nk_u4x2_t byte_val) { return (byte_val >> 4) & 0x0F; }
|
|
1340
|
+
|
|
1341
|
+
/** @brief Extract low (bits 0-3) signed nibble from packed i4x2 byte as i8. */
|
|
1342
|
+
NK_INTERNAL nk_i8_t nk_i4x2_low_(nk_i4x2_t byte_val) { return (nk_i8_t)(((byte_val & 0x0F) ^ 8) - 8); }
|
|
1343
|
+
/** @brief Extract high (bits 4-7) signed nibble from packed i4x2 byte as i8. */
|
|
1344
|
+
NK_INTERNAL nk_i8_t nk_i4x2_high_(nk_i4x2_t byte_val) { return (nk_i8_t)((((byte_val >> 4) & 0x0F) ^ 8) - 8); }
|
|
1345
|
+
|
|
1346
|
+
/** @brief Extract n-th nibble (n=0: low, n=1: high) — branchless. */
|
|
1347
|
+
NK_INTERNAL nk_u8_t nk_u4x2_get_(nk_u4x2_t byte_val, int n) { return (byte_val >> ((n & 1) * 4)) & 0x0F; }
|
|
1348
|
+
NK_INTERNAL nk_i8_t nk_i4x2_get_(nk_i4x2_t byte_val, int n) {
|
|
1349
|
+
nk_u8_t nibble = (byte_val >> ((n & 1) * 4)) & 0x0F;
|
|
1350
|
+
return (nk_i8_t)((nibble ^ 8) - 8);
|
|
1351
|
+
}
|
|
1352
|
+
|
|
1353
|
+
/** @brief Extract bit at position n (0-7) from packed u1x8 byte. */
|
|
1354
|
+
NK_INTERNAL nk_u8_t nk_u1x8_get_(nk_u1x8_t byte_val, int n) { return (byte_val >> (n & 7)) & 1; }
|
|
1355
|
+
|
|
1356
|
+
NK_INTERNAL nk_f16_t nk_f16_from_u16_(nk_u16_t bits) {
|
|
1357
|
+
nk_fui16_t c;
|
|
1358
|
+
c.u = bits;
|
|
1359
|
+
return c.f;
|
|
1360
|
+
}
|
|
1361
|
+
NK_INTERNAL nk_bf16_t nk_bf16_from_u16_(nk_u16_t bits) {
|
|
1362
|
+
nk_fui16_t c;
|
|
1363
|
+
c.u = bits;
|
|
1364
|
+
return c.bf;
|
|
1365
|
+
}
|
|
1366
|
+
|
|
1367
|
+
/** @brief E4M3: NaN when (raw & 0x7F) == 0x7F (two NaN values: 0x7F, 0xFF). */
|
|
1368
|
+
NK_INTERNAL int nk_e4m3_is_nan_(nk_e4m3_t x) { return (x & 0x7F) == 0x7F; }
|
|
1369
|
+
|
|
1370
|
+
/** @brief E5M2: NaN when exponent=31 and mantissa!=0, i.e. (raw & 0x7F) > 0x7C.
|
|
1371
|
+
* Values: 0x7D-0x7F (positive), 0xFD-0xFF (negative). Infinity = 0x7C/0xFC is NOT NaN. */
|
|
1372
|
+
NK_INTERNAL int nk_e5m2_is_nan_(nk_e5m2_t x) { return (x & 0x7F) > 0x7C; }
|
|
1373
|
+
|
|
1374
|
+
/** @brief F16: NaN when (raw & 0x7FFF) > 0x7C00. */
|
|
1375
|
+
NK_INTERNAL int nk_f16_is_nan_(nk_u16_t x) { return (x & 0x7FFF) > 0x7C00; }
|
|
1376
|
+
|
|
1377
|
+
/** @brief BF16: NaN when (raw & 0x7FFF) > 0x7F80. */
|
|
1378
|
+
NK_INTERNAL int nk_bf16_is_nan_(nk_u16_t x) { return (x & 0x7FFF) > 0x7F80; }
|
|
1379
|
+
|
|
1380
|
+
#ifdef __cplusplus
|
|
1381
|
+
} // extern "C"
|
|
1382
|
+
#endif
|
|
1383
|
+
|
|
1384
|
+
#endif // NK_TYPES_H
|