numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,1384 @@
1
+ /**
2
+ * @brief Shared definitions for the NumKong library.
3
+ * @file include/numkong/types.h
4
+ * @author Ash Vardanian
5
+ * @date October 2, 2023
6
+ *
7
+ * Defines:
8
+ *
9
+ * - Sized aliases for numeric types, like: `nk_i32_t` and `nk_f64_t`.
10
+ * - Macros for internal compiler/hardware checks, like: `NK_TARGET_ARM_`.
11
+ * - Macros for feature controls, like: `NK_TARGET_NEON`
12
+ *
13
+ * @section fp8_types FP8 Numeric Types
14
+ *
15
+ * There are several variants of 8-bit floating point types supported by different industry memebers
16
+ * with different hardware support. None are part of the IEEE 754 standard, but some are part of the
17
+ * Open Compute Project (OCP) 8-bit Floating Point Specification (OFP8):
18
+ *
19
+ * Format Bias Sign Exp Mant Range Infinity NaN Standard
20
+ * E4M3FN 7 1 4 3 ±448 ❌ No Only 0x7F/0xFF OCP, NVIDIA, ONNX
21
+ * E5M2 15 1 5 2 ±57344 ✅ Yes (0x7C/0xFC) 0x7D-7F, 0xFD-FF OCP, IEEE-like
22
+ * E4M3FNUZ 8 1 4 3 ±240 ❌ No 0x80 only GraphCore, ONNX
23
+ * E5M2FNUZ 16 1 5 2 ±57344 ❌ No 0x80 only GraphCore, ONNX
24
+ *
25
+ * In currently available and soon incoming harware, only two series of models prioritze FNUZ over OCP:
26
+ *
27
+ * - GraphCore IPUs were the original platform proposing FNUZ
28
+ * - AMD MI300 series based on CDNA3 implements FNUZ, but not OCP
29
+ * - AMD MI350+ series based on CDNA4 switch to OCP and remove FNUZ
30
+ * - NVIDIA Hopper and Blackwell only support E4M3FN, E5M2
31
+ * - Intel AVX10.2 defines HF8 (E4M3FN) and BF8 (E5M2) - OCP-aligned
32
+ * - Arm implements E4M3 (meaning E4M3FN) and E5M2 with a shared `__mfp8` type and a `FPMR` format selector
33
+ *
34
+ * For brevety, across NumKong, "E4M3" implies "E4M3FN".
35
+ *
36
+ * @see https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1
37
+ * @see FP8 Formats for Deep Learning: https://arxiv.org/pdf/2209.05433
38
+ * @see ONNX Float8 Types: https://onnx.ai/onnx/technical/float8.html
39
+ */
40
+ #ifndef NK_TYPES_H
41
+ #define NK_TYPES_H
42
+
43
+ // On Linux, `_GNU_SOURCE` must be defined before any system headers
44
+ // to expose `syscall` and other GNU extensions when C extensions are disabled.
45
+ #if defined(__linux__) && !defined(_GNU_SOURCE)
46
+ #define _GNU_SOURCE
47
+ #endif
48
+
49
+ // Inferring target OS: Windows, macOS, Linux, or FreeBSD
50
+ #if defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(__NT__)
51
+ #define NK_DEFINED_WINDOWS_ 1
52
+ #elif defined(__APPLE__) && defined(__MACH__)
53
+ #define NK_DEFINED_APPLE_ 1
54
+ #elif defined(__linux__)
55
+ #define NK_DEFINED_LINUX_ 1
56
+ #elif defined(__FreeBSD__)
57
+ #define NK_DEFINED_FREEBSD_ 1
58
+ #endif
59
+
60
+ // Annotation for the public API symbols:
61
+ //
62
+ // - `NK_PUBLIC` is used for functions that are part of the public API.
63
+ // - `NK_INTERNAL` is used for internal helper functions with unstable APIs.
64
+ // - `NK_DYNAMIC` is used for functions that are part of the public API, but are dispatched at runtime.
65
+ //
66
+ // On GCC we mark the functions as `nonnull` informing that none of the arguments can be `NULL`.
67
+ // Marking with `pure` and `const` isn't possible as outputting to a pointer is a "side effect".
68
+ #if defined(__GNUC__) || defined(__clang__)
69
+ #define NK_PUBLIC __attribute__((unused)) inline static
70
+ #define NK_INTERNAL __attribute__((always_inline)) inline static
71
+ #else
72
+ #define NK_PUBLIC inline static
73
+ #define NK_INTERNAL inline static
74
+ #endif // defined(__GNUC__) || defined(__clang__)
75
+
76
+ #if NK_DYNAMIC_DISPATCH
77
+ #if defined(_WIN32) || defined(__CYGWIN__)
78
+ #define NK_DYNAMIC __declspec(dllexport)
79
+ #elif defined(__GNUC__) || defined(__clang__)
80
+ #define NK_DYNAMIC __attribute__((visibility("default")))
81
+ #else
82
+ #define NK_DYNAMIC NK_PUBLIC
83
+ #endif
84
+ #else
85
+ #define NK_DYNAMIC NK_PUBLIC
86
+ #endif // NK_DYNAMIC_DISPATCH
87
+
88
+ // Allow SIMD kernels to redirect small inputs to serial implementations.
89
+ // Enabled by default for production use. Tests and benchmarks may disable
90
+ // this to isolate SIMD path behavior on small inputs.
91
+ #if !defined(NK_ALLOW_ISA_REDIRECT)
92
+ #define NK_ALLOW_ISA_REDIRECT 1
93
+ #endif
94
+
95
+ // Compiling for Arm: NK_TARGET_ARM_
96
+ #if !defined(NK_TARGET_ARM_)
97
+ #if defined(__aarch64__) || defined(_M_ARM64)
98
+ #define NK_TARGET_ARM_ 1
99
+ #else
100
+ #define NK_TARGET_ARM_ 0
101
+ #endif // defined(__aarch64__) || defined(_M_ARM64)
102
+ #endif // !defined(NK_TARGET_ARM_)
103
+
104
+ // Compiling for x86: NK_TARGET_X86_
105
+ #if !defined(NK_TARGET_X86_)
106
+ #if defined(__x86_64__) || defined(_M_X64)
107
+ #define NK_TARGET_X86_ 1
108
+ #else
109
+ #define NK_TARGET_X86_ 0
110
+ #endif // defined(__x86_64__) || defined(_M_X64)
111
+ #endif // !defined(NK_TARGET_X86_)
112
+
113
+ // Compiling for RISC-V: NK_TARGET_RISCV_
114
+ #if !defined(NK_TARGET_RISCV_)
115
+ #if defined(__riscv) && (__riscv_xlen == 64)
116
+ #define NK_TARGET_RISCV_ 1
117
+ #else
118
+ #define NK_TARGET_RISCV_ 0
119
+ #endif // defined(__riscv) && (__riscv_xlen == 64)
120
+ #endif // !defined(NK_TARGET_RISCV_)
121
+
122
+ // Compiling for WASM: NK_TARGET_WASM_
123
+ #if !defined(NK_TARGET_WASM_)
124
+ #if defined(__wasm__) || defined(__EMSCRIPTEN__)
125
+ #define NK_TARGET_WASM_ 1
126
+ #else
127
+ #define NK_TARGET_WASM_ 0
128
+ #endif
129
+ #endif // !defined(NK_TARGET_WASM_)
130
+
131
+ // WASI hosted mode: NK_DEFINED_WASI_
132
+ // When NK_WASI_HOSTED=ON in CMake, this is predefined to 1 so the library
133
+ // imports capability probes (nk_has_v128, nk_has_relaxed) from the host.
134
+ // Standalone runtimes (Wasmer, Wasmtime CLI) cannot supply those imports,
135
+ // so the default for plain __wasi__ builds is 0 (compile-time detection).
136
+ #if !defined(NK_DEFINED_WASI_)
137
+ #define NK_DEFINED_WASI_ 0
138
+ #endif // !defined(NK_DEFINED_WASI_)
139
+
140
+ // Compiling for WASM with Relaxed SIMD: NK_TARGET_V128RELAXED
141
+ // Requires -mrelaxed-simd for FMA instructions (f32x4.relaxed_madd, f64x2.relaxed_madd)
142
+ #if !defined(NK_TARGET_V128RELAXED) || (NK_TARGET_V128RELAXED && !NK_TARGET_WASM_)
143
+ #if defined(__wasm_relaxed_simd__)
144
+ #define NK_TARGET_V128RELAXED 1
145
+ #else
146
+ #undef NK_TARGET_V128RELAXED
147
+ #define NK_TARGET_V128RELAXED 0
148
+ #endif
149
+ #endif // !defined(NK_TARGET_V128RELAXED) || ...
150
+
151
+ // Compiling for RISC-V Vector: NK_TARGET_RVV
152
+ #if !defined(NK_TARGET_RVV) || (NK_TARGET_RVV && !NK_TARGET_RISCV_)
153
+ #if defined(__riscv_v) && (__riscv_v >= 1000000)
154
+ #define NK_TARGET_RVV 1
155
+ #else
156
+ #undef NK_TARGET_RVV
157
+ #define NK_TARGET_RVV 0
158
+ #endif // defined(__riscv_v) && (__riscv_v >= 1000000)
159
+ #endif // !defined(NK_TARGET_RVV) || ...
160
+
161
+ // Compiling for RISC-V Vector with Zvfh (f16): NK_TARGET_RVVHALF
162
+ // Requires GCC 14+ or Clang 18+ for full intrinsic support
163
+ #if !defined(NK_TARGET_RVVHALF) || (NK_TARGET_RVVHALF && !NK_TARGET_RVV)
164
+ #if defined(__riscv_zvfh) && (__riscv_zvfh > 0)
165
+ #define NK_TARGET_RVVHALF 1
166
+ #else
167
+ #undef NK_TARGET_RVVHALF
168
+ #define NK_TARGET_RVVHALF 0
169
+ #endif // defined(__riscv_zvfh) && (__riscv_zvfh > 0)
170
+ #endif // !defined(NK_TARGET_RVVHALF) || ...
171
+
172
+ // Compiling for RISC-V Vector with Zvfbfwma (bf16 widening FMA): NK_TARGET_RVVBF16
173
+ // Requires GCC 14+ or Clang 18+ for full intrinsic support
174
+ #if !defined(NK_TARGET_RVVBF16) || (NK_TARGET_RVVBF16 && !NK_TARGET_RVV)
175
+ #if defined(__riscv_zvfbfwma) && (__riscv_zvfbfwma > 0)
176
+ #define NK_TARGET_RVVBF16 1
177
+ #else
178
+ #undef NK_TARGET_RVVBF16
179
+ #define NK_TARGET_RVVBF16 0
180
+ #endif // defined(__riscv_zvfbfwma) && (__riscv_zvfbfwma > 0)
181
+ #endif // !defined(NK_TARGET_RVVBF16) || ...
182
+
183
+ // Compiling for RISC-V Vector with Zvbb (basic bit-manipulation): NK_TARGET_RVVBB
184
+ // Provides vcpop.v (per-element popcount), vclz.v, vctz.v, vbrev.v, vrol.v, vror.v
185
+ #if !defined(NK_TARGET_RVVBB) || (NK_TARGET_RVVBB && !NK_TARGET_RVV)
186
+ #if defined(__riscv_zvbb) && (__riscv_zvbb > 0)
187
+ #define NK_TARGET_RVVBB 1
188
+ #else
189
+ #undef NK_TARGET_RVVBB
190
+ #define NK_TARGET_RVVBB 0
191
+ #endif // defined(__riscv_zvbb) && (__riscv_zvbb > 0)
192
+ #endif // !defined(NK_TARGET_RVVBB) || ...
193
+
194
+ // Compiling for Arm: NK_TARGET_NEON
195
+ #if !defined(NK_TARGET_NEON) || (NK_TARGET_NEON && !NK_TARGET_ARM_)
196
+ #if defined(__ARM_NEON)
197
+ #define NK_TARGET_NEON 1
198
+ #else
199
+ #undef NK_TARGET_NEON
200
+ #define NK_TARGET_NEON 0
201
+ #endif // defined(__ARM_NEON)
202
+ #endif // !defined(NK_TARGET_NEON) || ...
203
+
204
+ // Compiling for Arm: NK_TARGET_NEONSDOT
205
+ #if !defined(NK_TARGET_NEONSDOT) || (NK_TARGET_NEONSDOT && !NK_TARGET_ARM_)
206
+ #if defined(__ARM_NEON)
207
+ #define NK_TARGET_NEONSDOT 1
208
+ #else
209
+ #undef NK_TARGET_NEONSDOT
210
+ #define NK_TARGET_NEONSDOT 0
211
+ #endif // defined(__ARM_NEON)
212
+ #endif // !defined(NK_TARGET_NEONSDOT) || ...
213
+
214
+ // Compiling for Arm: NK_TARGET_NEONHALF
215
+ #if !defined(NK_TARGET_NEONHALF) || (NK_TARGET_NEONHALF && !NK_TARGET_ARM_)
216
+ #if defined(__ARM_NEON)
217
+ #define NK_TARGET_NEONHALF 1
218
+ #else
219
+ #undef NK_TARGET_NEONHALF
220
+ #define NK_TARGET_NEONHALF 0
221
+ #endif // defined(__ARM_NEON)
222
+ #endif // !defined(NK_TARGET_NEONHALF) || ...
223
+
224
+ // Compiling for Arm: NK_TARGET_NEONFHM (FEAT_FHM - FMLAL/FMLSL widening ops)
225
+ #if !defined(NK_TARGET_NEONFHM) || (NK_TARGET_NEONFHM && !NK_TARGET_ARM_)
226
+ #if defined(__ARM_NEON)
227
+ #define NK_TARGET_NEONFHM 1
228
+ #else
229
+ #undef NK_TARGET_NEONFHM
230
+ #define NK_TARGET_NEONFHM 0
231
+ #endif // defined(__ARM_NEON)
232
+ #endif // !defined(NK_TARGET_NEONFHM) || ...
233
+
234
+ // Compiling for Arm: NK_TARGET_NEONBFDOT
235
+ #if !defined(NK_TARGET_NEONBFDOT) || (NK_TARGET_NEONBFDOT && !NK_TARGET_ARM_)
236
+ #if defined(__ARM_NEON)
237
+ #define NK_TARGET_NEONBFDOT 1
238
+ #else
239
+ #undef NK_TARGET_NEONBFDOT
240
+ #define NK_TARGET_NEONBFDOT 0
241
+ #endif // defined(__ARM_NEON)
242
+ #endif // !defined(NK_TARGET_NEONBFDOT) || ...
243
+
244
+ // Compiling for Arm: NK_TARGET_SVE
245
+ #if !defined(NK_TARGET_SVE) || (NK_TARGET_SVE && !NK_TARGET_ARM_)
246
+ #if defined(__ARM_FEATURE_SVE)
247
+ #define NK_TARGET_SVE 1
248
+ #else
249
+ #undef NK_TARGET_SVE
250
+ #define NK_TARGET_SVE 0
251
+ #endif // defined(__ARM_FEATURE_SVE)
252
+ #endif // !defined(NK_TARGET_SVE) || ...
253
+
254
+ // Compiling for Arm: NK_TARGET_SVESDOT
255
+ #if !defined(NK_TARGET_SVESDOT) || (NK_TARGET_SVESDOT && !NK_TARGET_ARM_)
256
+ #if defined(__ARM_FEATURE_SVE)
257
+ #define NK_TARGET_SVESDOT 1
258
+ #else
259
+ #undef NK_TARGET_SVESDOT
260
+ #define NK_TARGET_SVESDOT 0
261
+ #endif // defined(__ARM_FEATURE_SVE)
262
+ #endif // !defined(NK_TARGET_SVESDOT) || ...
263
+
264
+ // Compiling for Arm: NK_TARGET_SVEHALF
265
+ #if !defined(NK_TARGET_SVEHALF) || (NK_TARGET_SVEHALF && !NK_TARGET_ARM_)
266
+ #if defined(__ARM_FEATURE_SVE)
267
+ #define NK_TARGET_SVEHALF 1
268
+ #else
269
+ #undef NK_TARGET_SVEHALF
270
+ #define NK_TARGET_SVEHALF 0
271
+ #endif // defined(__ARM_FEATURE_SVE)
272
+ #endif // !defined(NK_TARGET_SVEHALF) || ...
273
+
274
+ // Compiling for Arm: NK_TARGET_SVEBFDOT
275
+ #if !defined(NK_TARGET_SVEBFDOT) || (NK_TARGET_SVEBFDOT && !NK_TARGET_ARM_)
276
+ #if defined(__ARM_FEATURE_SVE)
277
+ #define NK_TARGET_SVEBFDOT 1
278
+ #else
279
+ #undef NK_TARGET_SVEBFDOT
280
+ #define NK_TARGET_SVEBFDOT 0
281
+ #endif // defined(__ARM_FEATURE_SVE)
282
+ #endif // !defined(NK_TARGET_SVEBFDOT) || ...
283
+
284
+ // Compiling for Arm: NK_TARGET_SVE2
285
+ #if !defined(NK_TARGET_SVE2) || (NK_TARGET_SVE2 && !NK_TARGET_ARM_)
286
+ #if defined(__ARM_FEATURE_SVE2)
287
+ #define NK_TARGET_SVE2 1
288
+ #else
289
+ #undef NK_TARGET_SVE2
290
+ #define NK_TARGET_SVE2 0
291
+ #endif // defined(__ARM_FEATURE_SVE2)
292
+ #endif // !defined(NK_TARGET_SVE2) || ...
293
+
294
+ // Compiling for Arm: NK_TARGET_SVE2P1
295
+ #if !defined(NK_TARGET_SVE2P1) || (NK_TARGET_SVE2P1 && !NK_TARGET_ARM_)
296
+ #undef NK_TARGET_SVE2P1
297
+ #define NK_TARGET_SVE2P1 0
298
+ #endif // !defined(NK_TARGET_SVE2P1) || ...
299
+
300
+ // Compiling for Arm: NK_TARGET_SME (Scalable Matrix Extension)
301
+ #if !defined(NK_TARGET_SME) || (NK_TARGET_SME && !NK_TARGET_ARM_)
302
+ #if defined(__ARM_FEATURE_SME)
303
+ #define NK_TARGET_SME 1
304
+ #else
305
+ #undef NK_TARGET_SME
306
+ #define NK_TARGET_SME 0
307
+ #endif // defined(__ARM_FEATURE_SME)
308
+ #endif // !defined(NK_TARGET_SME) || ...
309
+
310
+ #if !defined(NK_TARGET_SME2) || (NK_TARGET_SME2 && !NK_TARGET_ARM_)
311
+ #if defined(__ARM_FEATURE_SME2)
312
+ #define NK_TARGET_SME2 1
313
+ #else
314
+ #undef NK_TARGET_SME2
315
+ #define NK_TARGET_SME2 0
316
+ #endif // defined(__ARM_FEATURE_SME2)
317
+ #endif // !defined(NK_TARGET_SME2) || ...
318
+
319
+ #if !defined(NK_TARGET_SME2P1) || (NK_TARGET_SME2P1 && !NK_TARGET_ARM_)
320
+ #undef NK_TARGET_SME2P1
321
+ #define NK_TARGET_SME2P1 0
322
+ #endif
323
+
324
+ // AppleClang 17 exposes SME sub-features through `arm_sme.h` builtin aliases,
325
+ // not dedicated `__ARM_FEATURE_*` predefines for every matrix subtype.
326
+ #if !defined(NK_TARGET_SMEF64) || (NK_TARGET_SMEF64 && !NK_TARGET_ARM_)
327
+ #if defined(__has_builtin) && __has_builtin(__builtin_sme_svmopa_za64_f64_m)
328
+ #define NK_TARGET_SMEF64 1
329
+ #else
330
+ #undef NK_TARGET_SMEF64
331
+ #define NK_TARGET_SMEF64 0
332
+ #endif // defined(__has_builtin) && __has_builtin(__builtin_sme_svmopa_za64_f64_m)
333
+ #endif // !defined(NK_TARGET_SMEF64) || ...
334
+
335
+ #if !defined(NK_TARGET_SMEBI32) || (NK_TARGET_SMEBI32 && !NK_TARGET_ARM_)
336
+ #if defined(__has_builtin) && __has_builtin(__builtin_sme_svbmopa_za32_u32_m)
337
+ #define NK_TARGET_SMEBI32 1
338
+ #else
339
+ #undef NK_TARGET_SMEBI32
340
+ #define NK_TARGET_SMEBI32 0
341
+ #endif // defined(__has_builtin) && __has_builtin(__builtin_sme_svbmopa_za32_u32_m)
342
+ #endif // !defined(NK_TARGET_SMEBI32) || ...
343
+
344
+ #if !defined(NK_TARGET_SMEHALF) || (NK_TARGET_SMEHALF && !NK_TARGET_ARM_)
345
+ #if defined(__has_builtin) && __has_builtin(__builtin_sme_svmopa_za32_f16_m)
346
+ #define NK_TARGET_SMEHALF 1
347
+ #else
348
+ #undef NK_TARGET_SMEHALF
349
+ #define NK_TARGET_SMEHALF 0
350
+ #endif // defined(__has_builtin) && __has_builtin(__builtin_sme_svmopa_za32_f16_m)
351
+ #endif // !defined(NK_TARGET_SMEHALF) || ...
352
+
353
+ #if !defined(NK_TARGET_SMEBF16) || (NK_TARGET_SMEBF16 && !NK_TARGET_ARM_)
354
+ #if defined(__has_builtin) && __has_builtin(__builtin_sme_svmopa_za32_bf16_m)
355
+ #define NK_TARGET_SMEBF16 1
356
+ #else
357
+ #undef NK_TARGET_SMEBF16
358
+ #define NK_TARGET_SMEBF16 0
359
+ #endif // defined(__has_builtin) && __has_builtin(__builtin_sme_svmopa_za32_bf16_m)
360
+ #endif // !defined(NK_TARGET_SMEBF16) || ...
361
+
362
+ #if !defined(NK_TARGET_SMELUT2) || (NK_TARGET_SMELUT2 && !NK_TARGET_ARM_)
363
+ #if defined(__has_builtin) && __has_builtin(__builtin_sme_svluti2_lane_zt_u8)
364
+ #define NK_TARGET_SMELUT2 1
365
+ #else
366
+ #undef NK_TARGET_SMELUT2
367
+ #define NK_TARGET_SMELUT2 0
368
+ #endif // defined(__has_builtin) && __has_builtin(__builtin_sme_svluti2_lane_zt_u8)
369
+ #endif // !defined(NK_TARGET_SMELUT2) || ...
370
+
371
+ #if !defined(NK_TARGET_SMEFA64) || (NK_TARGET_SMEFA64 && !NK_TARGET_ARM_)
372
+ #undef NK_TARGET_SMEFA64
373
+ #define NK_TARGET_SMEFA64 0
374
+ #endif
375
+
376
+ // Compiling for x86: NK_TARGET_HASWELL
377
+ //
378
+ // Starting with Ivy Bridge, Intel supports the `F16C` extensions for fast half-precision
379
+ // to single-precision floating-point conversions. On AMD those instructions
380
+ // are supported on all CPUs starting with Jaguar 2009.
381
+ // Starting with Sandy Bridge, Intel adds basic AVX support in their CPUs and in 2013
382
+ // extends it with AVX2 in the Haswell generation. Moreover, Haswell adds FMA support.
383
+ //
384
+ // On MSVC, most GCC-style ISA macros are unavailable. MSVC defines __AVX__, __AVX2__,
385
+ // __AVX512F/BW/CD/DQ/VL__, and __AVX10_VER__, but NOT __AVXVNNI__, __AVX512VNNI__,
386
+ // __AVX512BF16__, __AVX512FP16__, __AMX_*__, etc.
387
+ // Instead, MSVC makes all intrinsics available once the toolset version supports them,
388
+ // without requiring `/arch:AVX512`. We gate on _MSC_VER to auto-enable targets:
389
+ // - _MSC_VER >= 1900 (VS 2015+): AVX2/FMA/F16C (Haswell)
390
+ // - _MSC_VER >= 1920 (VS 2019+): AVX-512 base (Skylake, Icelake), AVX-VNNI (Alder)
391
+ // - _MSC_VER >= 1944 (VS 2022 17.14+): BF16, FP16, VP2INTERSECT, VNNI-INT8 (Sierra), AMX
392
+ #if !defined(NK_TARGET_HASWELL) || (NK_TARGET_HASWELL && !NK_TARGET_X86_)
393
+ #if (defined(__AVX2__) && defined(__FMA__) && defined(__F16C__)) || (defined(_MSC_VER) && _MSC_VER >= 1900)
394
+ #define NK_TARGET_HASWELL 1
395
+ #else
396
+ #undef NK_TARGET_HASWELL
397
+ #define NK_TARGET_HASWELL 0
398
+ #endif // defined(__AVX2__)
399
+ #endif // !defined(NK_TARGET_HASWELL) || ...
400
+
401
+ // Compiling for x86: NK_TARGET_SKYLAKE, NK_TARGET_ICELAKE, NK_TARGET_GENOA,
402
+ // NK_TARGET_SAPPHIRE, NK_TARGET_TURIN, NK_TARGET_SIERRA
403
+ //
404
+ // To list all available macros for x86, take a recent compiler, like GCC 12 and run:
405
+ // gcc-12 -march=sapphirerapids -dM -E - < /dev/null | egrep "SSE|AVX" | sort
406
+ // On Arm machines you may want to check for other flags:
407
+ // gcc-12 -march=native -dM -E - < /dev/null | egrep "NEON|SVE|FP16|FMA" | sort
408
+ #if !defined(NK_TARGET_SKYLAKE) || (NK_TARGET_SKYLAKE && !NK_TARGET_X86_)
409
+ #if (defined(__AVX512F__) && defined(__AVX512CD__) && defined(__AVX512VL__) && defined(__AVX512DQ__) && \
410
+ defined(__AVX512BW__)) || \
411
+ (defined(_MSC_VER) && _MSC_VER >= 1920)
412
+ #define NK_TARGET_SKYLAKE 1
413
+ #else
414
+ #undef NK_TARGET_SKYLAKE
415
+ #define NK_TARGET_SKYLAKE 0
416
+ #endif
417
+ #endif // !defined(NK_TARGET_SKYLAKE) || ...
418
+
419
+ #if !defined(NK_TARGET_ICELAKE) || (NK_TARGET_ICELAKE && !NK_TARGET_X86_)
420
+ #if (defined(__AVX512VNNI__) && defined(__AVX512IFMA__) && defined(__AVX512BITALG__) && defined(__AVX512VBMI__) && \
421
+ defined(__AVX512VBMI2__) && defined(__AVX512VPOPCNTDQ__)) || \
422
+ (defined(_MSC_VER) && _MSC_VER >= 1920)
423
+ #define NK_TARGET_ICELAKE 1
424
+ #else
425
+ #undef NK_TARGET_ICELAKE
426
+ #define NK_TARGET_ICELAKE 0
427
+ #endif
428
+ #endif // !defined(NK_TARGET_ICELAKE) || ...
429
+
430
+ #if !defined(NK_TARGET_GENOA) || (NK_TARGET_GENOA && !NK_TARGET_X86_)
431
+ #if defined(__AVX512BF16__) || (defined(_MSC_VER) && _MSC_VER >= 1944)
432
+ #define NK_TARGET_GENOA 1
433
+ #else
434
+ #undef NK_TARGET_GENOA
435
+ #define NK_TARGET_GENOA 0
436
+ #endif
437
+ #endif // !defined(NK_TARGET_GENOA) || ...
438
+
439
+ #if !defined(NK_TARGET_SAPPHIRE) || (NK_TARGET_SAPPHIRE && !NK_TARGET_X86_)
440
+ #if defined(__AVX512FP16__) || (defined(_MSC_VER) && _MSC_VER >= 1944)
441
+ #define NK_TARGET_SAPPHIRE 1
442
+ #else
443
+ #undef NK_TARGET_SAPPHIRE
444
+ #define NK_TARGET_SAPPHIRE 0
445
+ #endif
446
+ #endif // !defined(NK_TARGET_SAPPHIRE) || ...
447
+
448
+ #if !defined(NK_TARGET_SAPPHIREAMX) || (NK_TARGET_SAPPHIREAMX && !NK_TARGET_X86_)
449
+ #if (defined(__AMX_TILE__) && defined(__AMX_BF16__) && defined(__AMX_INT8__)) || (defined(_MSC_VER) && _MSC_VER >= 1944)
450
+ #define NK_TARGET_SAPPHIREAMX 1
451
+ #else
452
+ #undef NK_TARGET_SAPPHIREAMX
453
+ #define NK_TARGET_SAPPHIREAMX 0
454
+ #endif
455
+ #endif // !defined(NK_TARGET_SAPPHIREAMX) || ...
456
+
457
+ #if !defined(NK_TARGET_GRANITEAMX) || (NK_TARGET_GRANITEAMX && !NK_TARGET_X86_)
458
+ #if (defined(__AMX_TILE__) && defined(__AMX_FP16__)) || (defined(_MSC_VER) && _MSC_VER >= 1944)
459
+ #define NK_TARGET_GRANITEAMX 1
460
+ #else
461
+ #undef NK_TARGET_GRANITEAMX
462
+ #define NK_TARGET_GRANITEAMX 0
463
+ #endif
464
+ #endif // !defined(NK_TARGET_GRANITEAMX) || ...
465
+
466
+ #if !defined(NK_TARGET_TURIN) || (NK_TARGET_TURIN && !NK_TARGET_X86_)
467
+ #if defined(__AVX512VP2INTERSECT__) || (defined(_MSC_VER) && _MSC_VER >= 1944)
468
+ #define NK_TARGET_TURIN 1
469
+ #else
470
+ #undef NK_TARGET_TURIN
471
+ #define NK_TARGET_TURIN 0
472
+ #endif
473
+ #endif // !defined(NK_TARGET_TURIN) || ...
474
+
475
+ #if !defined(NK_TARGET_ALDER) || (NK_TARGET_ALDER && !NK_TARGET_X86_)
476
+ #if defined(__AVXVNNI__) || (defined(_MSC_VER) && _MSC_VER >= 1920)
477
+ #define NK_TARGET_ALDER 1
478
+ #else
479
+ #undef NK_TARGET_ALDER
480
+ #define NK_TARGET_ALDER 0
481
+ #endif
482
+ #endif // !defined(NK_TARGET_ALDER) || ...
483
+
484
+ #if !defined(NK_TARGET_SIERRA) || (NK_TARGET_SIERRA && !NK_TARGET_X86_)
485
+ #if defined(__AVXVNNIINT8__) || (defined(_MSC_VER) && _MSC_VER >= 1944)
486
+ #define NK_TARGET_SIERRA 1
487
+ #else
488
+ #undef NK_TARGET_SIERRA
489
+ #define NK_TARGET_SIERRA 0
490
+ #endif
491
+ #endif // !defined(NK_TARGET_SIERRA) || ...
492
+
493
+ // Include the relevant intrinsics file - different for different OSes and ISAs
494
+ #if defined(_MSC_VER)
495
+ #include <intrin.h>
496
+ #elif NK_TARGET_ARM_
497
+ #if NK_TARGET_NEON
498
+ #include <arm_neon.h>
499
+ #endif
500
+ #if NK_TARGET_SVE || NK_TARGET_SVE2
501
+ #include <arm_sve.h>
502
+ #endif
503
+ #if NK_TARGET_SME || NK_TARGET_SME2 || NK_TARGET_SMEBI32
504
+ #include <arm_sme.h>
505
+ #endif
506
+ #elif NK_TARGET_HASWELL || NK_TARGET_SKYLAKE
507
+ #include <immintrin.h>
508
+ #elif NK_TARGET_RVV
509
+ #include <riscv_vector.h>
510
+ #elif NK_TARGET_V128RELAXED
511
+ #include <wasm_simd128.h>
512
+ #endif
513
+
514
+ #if !defined(NK_F64_DIVISION_EPSILON)
515
+ #define NK_F64_DIVISION_EPSILON (1e-15)
516
+ #endif
517
+
518
+ #if !defined(NK_F32_DIVISION_EPSILON)
519
+ #define NK_F32_DIVISION_EPSILON (1e-7)
520
+ #endif
521
+
522
+ #if !defined(NK_F16_DIVISION_EPSILON)
523
+ #define NK_F16_DIVISION_EPSILON (1e-3)
524
+ #endif
525
+
526
+ /**
527
+ * @brief The compile-time constant defining the capacity of `nk_tensor_position_t`.
528
+ * Matches `PyBUF_MAX_NDIM` by default.
529
+ */
530
+ #if !defined(NK_TENSOR_MAX_RANK)
531
+ #define NK_TENSOR_MAX_RANK (64)
532
+ #endif
533
+
534
+ /**
535
+ * @brief Aligns a variable to a 64-byte boundary using compiler extensions for
536
+ * compatibility with C 99, as `alignas(64)` is only available in C 11 or C++.
537
+ * Used internally and recommended for external users.
538
+ */
539
+ #if defined(_MSC_VER)
540
+ #define NK_ALIGN64 __declspec(align(64))
541
+ #elif defined(__GNUC__) || defined(__clang__)
542
+ #define NK_ALIGN64 __attribute__((aligned(64)))
543
+ #endif
544
+
545
+ /**
546
+ * ARM Streaming attributes (require SME-capable compiler: GCC 14+, Clang 16+).
547
+ * NK_STREAMING_ marks functions that require streaming SVE mode (e.g. FCVTLT).
548
+ * NK_STREAMING_COMPATIBLE_ marks helpers callable from both streaming and non-streaming mode.
549
+ */
550
+ #if NK_TARGET_ARM_ && NK_TARGET_SME
551
+ #define NK_STREAMING_ __arm_streaming
552
+ #define NK_STREAMING_COMPATIBLE_ __arm_streaming_compatible
553
+ #else
554
+ #define NK_STREAMING_
555
+ #define NK_STREAMING_COMPATIBLE_
556
+ #endif
557
+
558
+ /**
559
+ * @brief Portable casts between SIMD vector types.
560
+ * MSVC typedefs `__m512bh`, `__m512h`, `__m256bh` as aliases for `__m512i`/`__m256i`,
561
+ * but rejects C-style casts between them. GCC/Clang define them as distinct types.
562
+ */
563
+ #if NK_TARGET_X86_
564
+ #if defined(_MSC_VER)
565
+ #define nk_m512bh_from_m512i_(x) (x)
566
+ #define nk_m512h_from_m512i_(x) (x)
567
+ #define nk_m512i_from_m512h_(x) (x)
568
+ #define nk_m256bh_from_m256i_(x) (x)
569
+ #define nk_m256i_from_m256bh_(x) (x)
570
+ #else
571
+ #define nk_m512bh_from_m512i_(x) ((__m512bh)(x))
572
+ #define nk_m512h_from_m512i_(x) ((__m512h)(x))
573
+ #define nk_m512i_from_m512h_(x) ((__m512i)(x))
574
+ #define nk_m256bh_from_m256i_(x) ((__m256bh)(x))
575
+ #define nk_m256i_from_m256bh_(x) ((__m256i)(x))
576
+ #endif
577
+ #endif
578
+
579
+ /** Copy 16 bits (2 bytes) from source to destination */
580
+ #if defined(__GNUC__) || defined(__clang__)
581
+ #define nk_copy_bytes_(destination_ptr, source_ptr, count) __builtin_memcpy((destination_ptr), (source_ptr), count)
582
+ #else
583
+ #include <string.h> // `memcpy`
584
+ #define nk_copy_bytes_(destination_ptr, source_ptr, count) memcpy((destination_ptr), (source_ptr), count)
585
+ #endif
586
+
587
+ /** Macro to mark unused parameters (cleaner than (void)variable) */
588
+ #define nk_unused_(x) ((void)(x))
589
+
590
+ /**
591
+ * @brief C99 static array parameter annotation for minimum array size.
592
+ *
593
+ * In C, expands to `static n` enabling compiler bounds checking.
594
+ * In C++, expands to nothing as this syntax is not supported.
595
+ * @see https://lwn.net/Articles/1046840/
596
+ *
597
+ * Example usage:
598
+ * @code{.c}
599
+ * void hash_digest(uint8_t digest[nk_at_least_(32)]);
600
+ * void lookup(uint8_t const lut[nk_at_least_(256)]);
601
+ * @endcode
602
+ */
603
+ #if defined(__cplusplus) || defined(_MSC_VER)
604
+ #define nk_at_least_(n)
605
+ #else
606
+ #define nk_at_least_(n) static n
607
+ #endif
608
+
609
+ #ifdef __cplusplus
610
+ extern "C" {
611
+ #endif
612
+
613
+ /** @brief Packed 8-bit bit-vector (8 booleans in one byte), LSB = dimension 0.
614
+ * Used for Hamming distance and Jaccard similarity via popcount.
615
+ * Dimension count must be a multiple of 8; unused bits in the final byte must be zeroed. */
616
+ typedef unsigned char nk_u1x8_t;
617
+ /** @brief Packed 4-bit signed integer pair (2 × i4 in one byte), [high nibble : low nibble].
618
+ * Range per element: [−8, +7]. Elements sign-extended to i8 for arithmetic.
619
+ * Dimension count must be a multiple of 2; unused nibbles in the final byte must be zeroed. */
620
+ typedef unsigned char nk_i4x2_t;
621
+ /** @brief Packed 4-bit unsigned integer pair (2 × u4 in one byte), [high nibble : low nibble].
622
+ * Range per element: [0, 15]. Elements zero-extended to u8 for arithmetic.
623
+ * Dimension count must be a multiple of 2; unused nibbles in the final byte must be zeroed. */
624
+ typedef unsigned char nk_u4x2_t;
625
+
626
+ /** @brief 8-bit E4M3 float (OCP FP8): sign(1) + exponent(4) + mantissa(3), bias=7.
627
+ * Range: ±448, no infinities (all-ones exponent → NaN at 0x7F/0xFF).
628
+ * 114 of 254 finite values (44.9%) fall in [−1, +1]. */
629
+ typedef unsigned char nk_e4m3_t;
630
+ /** @brief 8-bit E5M2 float (OCP FP8): sign(1) + exponent(5) + mantissa(2), bias=15.
631
+ * Range: ±57 344, supports infinities at 0x7C/0xFC.
632
+ * 122 of 248 finite values (49.2%) fall in [−1, +1]. */
633
+ typedef unsigned char nk_e5m2_t;
634
+ /** @brief 6-bit E2M3 micro-float (OCP MX v1.0): sign(1) + exponent(2) + mantissa(3), bias=1.
635
+ * Range: ±7.5, no infinities or NaN. Only 64 total codes; 18 (28.1%) fall in [−1, +1]. */
636
+ typedef unsigned char nk_e2m3_t;
637
+ /** @brief 6-bit E3M2 micro-float (OCP MX v1.0): sign(1) + exponent(3) + mantissa(2), bias=3.
638
+ * Range: ±28, supports infinities. Only 64 total codes; 26 (40.6%) fall in [−1, +1]. */
639
+ typedef unsigned char nk_e3m2_t;
640
+
641
+ /** @brief Signed 8-bit integer. Range: [−128, +127]. */
642
+ typedef signed char nk_i8_t;
643
+ /** @brief Unsigned 8-bit integer. Range: [0, 255]. */
644
+ typedef unsigned char nk_u8_t;
645
+ /** @brief Signed 16-bit integer. Range: [−32 768, +32 767]. */
646
+ typedef signed short nk_i16_t;
647
+ /** @brief Unsigned 16-bit integer. Range: [0, 65 535]. */
648
+ typedef unsigned short nk_u16_t;
649
+ /** @brief Signed 32-bit integer. Range: [−2³¹, +2³¹−1]. */
650
+ typedef signed int nk_i32_t;
651
+ /** @brief Unsigned 32-bit integer. Range: [0, 2³²−1]. */
652
+ typedef unsigned int nk_u32_t;
653
+ /* On LP64 targets (Linux ARM64, RISC-V 64), `long` and `long long` are both 64-bit but distinct types.
654
+ * NEON/RVV intrinsics on Linux expect `long*`, while Apple's NEON intrinsics expect `long long*`.
655
+ * Windows uses LLP64 where `long` is 32-bit, so it must use `long long` for 64-bit types. */
656
+ #if ((NK_TARGET_ARM_ && !defined(NK_DEFINED_APPLE_)) || NK_TARGET_RISCV_) && !defined(NK_DEFINED_WINDOWS_)
657
+ /** @brief Signed 64-bit integer. Range: [−2⁶³, +2⁶³−1]. */
658
+ typedef signed long nk_i64_t;
659
+ /** @brief Unsigned 64-bit integer. Range: [0, 2⁶⁴−1]. */
660
+ typedef unsigned long nk_u64_t;
661
+ #else
662
+ /** @brief Signed 64-bit integer. Range: [−2⁶³, +2⁶³−1]. */
663
+ typedef signed long long nk_i64_t;
664
+ /** @brief Unsigned 64-bit integer. Range: [0, 2⁶⁴−1]. */
665
+ typedef unsigned long long nk_u64_t;
666
+ #endif
667
+
668
+ /** @brief Single-precision (32-bit) IEEE 754 float. sign(1) + exponent(8) + mantissa(23), bias=127. */
669
+ typedef float nk_f32_t;
670
+ /** @brief Double-precision (64-bit) IEEE 754 float. sign(1) + exponent(11) + mantissa(52), bias=1023. */
671
+ typedef double nk_f64_t;
672
+
673
+ #if NK_TARGET_X86_ || NK_TARGET_ARM_ || NK_TARGET_RISCV_
674
+ #define NK_IS_64BIT_ 1
675
+ #else
676
+ #define NK_IS_64BIT_ 0
677
+ #endif
678
+
679
+ #if NK_IS_64BIT_
680
+ typedef nk_u64_t nk_size_t;
681
+ typedef nk_i64_t nk_ssize_t;
682
+ #else
683
+ typedef nk_u32_t nk_size_t;
684
+ typedef nk_i32_t nk_ssize_t;
685
+ #endif
686
+ typedef nk_f64_t nk_fmax_t;
687
+
688
+ #define NK_SIZE_MAX ((nk_size_t) - 1)
689
+
690
+ #define NK_F64_MAX 1.7976931348623157e+308
691
+ #define NK_F64_MIN (-1.7976931348623157e+308)
692
+ #define NK_F32_MAX 3.402823466e+38f
693
+ #define NK_F32_MIN (-3.402823466e+38f)
694
+
695
+ #define NK_I64_MAX 9223372036854775807LL
696
+ #define NK_I64_MIN (-9223372036854775807LL - 1LL)
697
+ #define NK_U64_MAX 18446744073709551615ULL
698
+ #define NK_U64_MIN 0x0ULL
699
+
700
+ #define NK_I32_MAX 2147483647
701
+ #define NK_I32_MIN (-2147483647 - 1)
702
+ #define NK_U32_MAX 4294967295U
703
+ #define NK_U32_MIN 0x0U
704
+
705
+ #define NK_I16_MAX 32767
706
+ #define NK_I16_MIN (-32767 - 1)
707
+ #define NK_U16_MAX 65535U
708
+ #define NK_U16_MIN 0x0U
709
+
710
+ #define NK_I8_MAX 127
711
+ #define NK_I8_MIN (-127 - 1)
712
+ #define NK_U8_MAX 255U
713
+ #define NK_U8_MIN 0x0U
714
+
715
+ #define NK_F16_MAX 0x7BFF // IEEE 754 binary16: +65504.0
716
+ #define NK_F16_MIN 0xFBFF // IEEE 754 binary16: -65504.0
717
+
718
+ #define NK_BF16_MAX 0x7F7F // BFloat16: ~+3.39e38
719
+ #define NK_BF16_MIN 0xFF7F // BFloat16: ~-3.39e38
720
+
721
+ #define NK_E4M3_MAX 0x7E // FP8 E4M3: +448.0
722
+ #define NK_E4M3_MIN 0xFE // FP8 E4M3: -448.0
723
+
724
+ #define NK_E5M2_MAX 0x7B // FP8 E5M2: +57344.0
725
+ #define NK_E5M2_MIN 0xFB // FP8 E5M2: -57344.0
726
+
727
+ #define NK_E2M3_MAX 0x1F // FP6 E2M3: +7.5
728
+ #define NK_E2M3_MIN 0x3F // FP6 E2M3: -7.5
729
+
730
+ #define NK_E3M2_MAX 0x1F // FP6 E3M2: +28.0
731
+ #define NK_E3M2_MIN 0x3F // FP6 E3M2: -28.0
732
+
733
+ #define NK_BITS_PER_BYTE 8
734
+
735
+ /**
736
+ * @brief Enumeration of supported scalar data types.
737
+ *
738
+ * Includes complex type descriptors which in C code would use the real counterparts,
739
+ * but the independent flags contain metadata to be passed between programming language
740
+ * interfaces.
741
+ */
742
+ typedef enum {
743
+ nk_dtype_unknown_k = 0, ///< Unknown data type
744
+ nk_u1_k = 1 << 1, ///< Single-bit values packed into 8-bit words
745
+
746
+ nk_i8_k = 1 << 2, ///< 8-bit signed integer
747
+ nk_i16_k = 1 << 3, ///< 16-bit signed integer
748
+ nk_i32_k = 1 << 4, ///< 32-bit signed integer
749
+ nk_i64_k = 1 << 5, ///< 64-bit signed integer
750
+
751
+ nk_u8_k = 1 << 6, ///< 8-bit unsigned integer
752
+ nk_u16_k = 1 << 7, ///< 16-bit unsigned integer
753
+ nk_u32_k = 1 << 8, ///< 32-bit unsigned integer
754
+ nk_u64_k = 1 << 9, ///< 64-bit unsigned integer
755
+
756
+ nk_f64_k = 1 << 10, ///< Double precision floating point
757
+ nk_f32_k = 1 << 11, ///< Single precision floating point
758
+ nk_f16_k = 1 << 12, ///< Half precision floating point
759
+ nk_bf16_k = 1 << 13, ///< Brain floating point
760
+
761
+ nk_e4m3_k = 1 << 14, ///< FP8 E4M3 floating point
762
+ nk_e5m2_k = 1 << 15, ///< FP8 E5M2 floating point
763
+ nk_i4_k = 1 << 16, ///< 4-bit signed integers packed into 8-bit words
764
+ nk_u4_k = 1 << 17, ///< 4-bit unsigned integers packed into 8-bit words
765
+ nk_e2m3_k = 1 << 18, ///< FP6 E2M3 floating point
766
+ nk_e3m2_k = 1 << 19, ///< FP6 E3M2 floating point
767
+
768
+ nk_f64c_k = 1 << 20, ///< Complex double precision floating point
769
+ nk_f32c_k = 1 << 21, ///< Complex single precision floating point
770
+ nk_f16c_k = 1 << 22, ///< Complex half precision floating point
771
+ nk_bf16c_k = 1 << 23, ///< Complex brain floating point
772
+ } nk_dtype_t;
773
+
774
+ typedef enum {
775
+ nk_dtype_family_unknown_k = 0,
776
+ nk_dtype_family_float_k,
777
+ nk_dtype_family_complex_float_k,
778
+ nk_dtype_family_int_k,
779
+ nk_dtype_family_uint_k,
780
+ } nk_dtype_family_t;
781
+
782
+ /** @brief Classifies the family of the dtype. */
783
+ NK_PUBLIC nk_dtype_family_t nk_dtype_family(nk_dtype_t dtype) {
784
+ switch (dtype) {
785
+ case nk_f64_k: return nk_dtype_family_float_k;
786
+ case nk_f32_k: return nk_dtype_family_float_k;
787
+ case nk_f16_k: return nk_dtype_family_float_k;
788
+ case nk_bf16_k: return nk_dtype_family_float_k;
789
+ case nk_e4m3_k: return nk_dtype_family_float_k;
790
+ case nk_e5m2_k: return nk_dtype_family_float_k;
791
+ case nk_e2m3_k: return nk_dtype_family_float_k;
792
+ case nk_e3m2_k: return nk_dtype_family_float_k;
793
+ case nk_f64c_k: return nk_dtype_family_complex_float_k;
794
+ case nk_f32c_k: return nk_dtype_family_complex_float_k;
795
+ case nk_f16c_k: return nk_dtype_family_complex_float_k;
796
+ case nk_bf16c_k: return nk_dtype_family_complex_float_k;
797
+ case nk_u1_k: return nk_dtype_family_uint_k;
798
+ case nk_u4_k: return nk_dtype_family_uint_k;
799
+ case nk_u8_k: return nk_dtype_family_uint_k;
800
+ case nk_u16_k: return nk_dtype_family_uint_k;
801
+ case nk_u32_k: return nk_dtype_family_uint_k;
802
+ case nk_u64_k: return nk_dtype_family_uint_k;
803
+ case nk_i4_k: return nk_dtype_family_int_k;
804
+ case nk_i8_k: return nk_dtype_family_int_k;
805
+ case nk_i16_k: return nk_dtype_family_int_k;
806
+ case nk_i32_k: return nk_dtype_family_int_k;
807
+ case nk_i64_k: return nk_dtype_family_int_k;
808
+ default: return nk_dtype_family_unknown_k;
809
+ }
810
+ }
811
+
812
+ /** @brief Returns the number of bits in a single scalar of a given type. */
813
+ NK_PUBLIC nk_size_t nk_dtype_bits(nk_dtype_t dtype) {
814
+ switch (dtype) {
815
+ case nk_f64_k: return 64;
816
+ case nk_f32_k: return 32;
817
+ case nk_f16_k: return 16;
818
+ case nk_bf16_k: return 16;
819
+ case nk_e4m3_k: return 8;
820
+ case nk_e5m2_k: return 8;
821
+ case nk_e2m3_k: return 8;
822
+ case nk_e3m2_k: return 8;
823
+ case nk_f64c_k: return 128;
824
+ case nk_f32c_k: return 64;
825
+ case nk_f16c_k: return 32;
826
+ case nk_bf16c_k: return 32;
827
+ case nk_u1_k: return 1;
828
+ case nk_u4_k: return 4;
829
+ case nk_u8_k: return 8;
830
+ case nk_u16_k: return 16;
831
+ case nk_u32_k: return 32;
832
+ case nk_u64_k: return 64;
833
+ case nk_i4_k: return 4;
834
+ case nk_i8_k: return 8;
835
+ case nk_i16_k: return 16;
836
+ case nk_i32_k: return 32;
837
+ case nk_i64_k: return 64;
838
+ default: return 0;
839
+ }
840
+ }
841
+
842
+ /** @brief Returns how many logical dimensions are packed into one storage value.
843
+ * For sub-byte types multiple dimensions share a single byte container.
844
+ * For byte-or-larger types this is always 1. */
845
+ NK_PUBLIC nk_size_t nk_dtype_dimensions_per_value(nk_dtype_t dtype) {
846
+ switch (dtype) {
847
+ case nk_u1_k: return 8;
848
+ case nk_i4_k: return 2;
849
+ case nk_u4_k: return 2;
850
+ default: return 1;
851
+ }
852
+ }
853
+
854
+ /** @brief Half-precision (16-bit) IEEE 754 float.
855
+ *
856
+ * Layout: sign(1) + exponent(5) + mantissa(10), bias=15.
857
+ * Range: ±65 504, epsilon at 1.0 ≈ 9.77×10⁻⁴. 30 722 of 63 488 finite values (48.4%) in [−1, +1].
858
+ *
859
+ * - GCC or Clang on 64-bit Arm: `__fp16`, may require `-mfp16-format` option.
860
+ * - GCC or Clang on 64-bit x86: `_Float16`.
861
+ * - Default: `unsigned short`.
862
+ */
863
+ #if !defined(NK_NATIVE_F16) || NK_NATIVE_F16
864
+ #if (defined(__GNUC__) || defined(__clang__)) && (defined(__ARM_ARCH) || defined(__aarch64__)) && \
865
+ (defined(__ARM_FP16_FORMAT_IEEE))
866
+ #undef NK_NATIVE_F16
867
+ #define NK_NATIVE_F16 1
868
+ typedef __fp16 nk_f16_t;
869
+ #elif ((defined(__GNUC__) || defined(__clang__)) && (defined(__x86_64__) || defined(__i386__)) && \
870
+ (defined(__AVX512FP16__)))
871
+ typedef _Float16 nk_f16_t;
872
+ #undef NK_NATIVE_F16
873
+ #define NK_NATIVE_F16 1
874
+ #else // Unknown compiler or architecture
875
+ #undef NK_NATIVE_F16
876
+ #define NK_NATIVE_F16 0
877
+ #endif // Unknown compiler or architecture
878
+ #endif // !NK_NATIVE_F16
879
+
880
+ #if !NK_NATIVE_F16
881
+ typedef unsigned short nk_f16_t;
882
+ #endif
883
+
884
+ #if !defined(NK_NATIVE_BF16) || NK_NATIVE_BF16
885
+ /** @brief BFloat16 (16-bit) float — truncated IEEE 754 single-precision.
886
+ *
887
+ * Layout: sign(1) + exponent(8) + mantissa(7), bias=127.
888
+ * Same dynamic range as f32, epsilon ≈ 7.81×10⁻³.
889
+ * 32 514 of 65 280 finite values (49.8%) in [−1, +1]. Wider range than f16 but lower precision.
890
+ *
891
+ * - GCC or Clang: `__bf16`
892
+ * - Default: `unsigned short`.
893
+ *
894
+ * The compilers have added `__bf16` support in compliance with the x86-64 psABI spec.
895
+ * The motivation for this new special type is summed up as:
896
+ *
897
+ * Currently `__bfloat16` is a typedef of short, which creates a problem where the
898
+ * compiler does not raise any alarms if it is used to add, subtract, multiply or
899
+ * divide, but the result of the calculation is actually meaningless.
900
+ * To solve this problem, a real scalar type `__Bfloat16` needs to be introduced.
901
+ * It is mainly used for intrinsics, not available for C standard operators.
902
+ * `__Bfloat16` will also be used for movement like passing parameter, load and store,
903
+ * vector initialization, vector shuffle, and etc. It creates a need for a
904
+ * corresponding psABI.
905
+ *
906
+ * @warning Apple Clang has hard time with bf16.
907
+ * https://developer.apple.com/documentation/xcode/writing-arm64-code-for-apple-platforms
908
+ * https://forums.developer.apple.com/forums/thread/726201
909
+ * https://www.phoronix.com/news/GCC-LLVM-bf16-BFloat16-Type
910
+ */
911
+ #if (defined(__GNUC__) || defined(__clang__)) && ((defined(__ARM_BF16_FORMAT_ALTERNATIVE)) || (defined(__AVX512BF16__)))
912
+ #undef NK_NATIVE_BF16
913
+ #define NK_NATIVE_BF16 1
914
+ typedef __bf16 nk_bf16_t;
915
+ #else // Unknown compiler or architecture
916
+ #undef NK_NATIVE_BF16
917
+ #define NK_NATIVE_BF16 0
918
+ #endif // Unknown compiler or architecture
919
+ #endif // !NK_NATIVE_BF16
920
+
921
+ #if !NK_NATIVE_BF16
922
+ typedef unsigned short nk_bf16_t;
923
+ #endif
924
+
925
+ /**
926
+ * @brief Alias for the half-precision floating-point type on Arm.
927
+ *
928
+ * Clang and GCC bring the `float16_t` symbol when you compile for Aarch64.
929
+ * MSVC lacks it, and it's `vld1_f16`-like intrinsics are in reality macros,
930
+ * that cast to 16-bit integers internally, instead of using floats.
931
+ * Some of those are defined as aliases, so we use `#define` preprocessor
932
+ * directives instead of `typedef` to avoid errors.
933
+ */
934
+ #if NK_TARGET_ARM_
935
+ #if defined(_MSC_VER)
936
+ #define nk_f16_for_arm_simd_t nk_f16_t
937
+ #define nk_bf16_for_arm_simd_t nk_bf16_t
938
+ #else
939
+ #define nk_f16_for_arm_simd_t float16_t
940
+ #define nk_bf16_for_arm_simd_t bfloat16_t
941
+ #endif
942
+ #endif
943
+
944
+ /**
945
+ * RISC-V Vector (RVV) intrinsics use `_Float16` for half-precision floats.
946
+ * This is the standard C23 type, also available in GCC/Clang with RVV extensions.
947
+ */
948
+ #if NK_TARGET_RISCV_
949
+ #define nk_f16_for_rvv_intrinsics_t _Float16
950
+ #endif
951
+
952
+ /*
953
+ * Let's make sure the sizes of the types are as expected.
954
+ * In C the `_Static_assert` is only available with C11 and later.
955
+ */
956
+ #define NK_STATIC_ASSERT(cond, msg) typedef char static_assertion_##msg[(cond) ? 1 : -1]
957
+ NK_STATIC_ASSERT(sizeof(nk_u1x8_t) == 1, nk_u1x8_t_must_be_1_byte);
958
+ NK_STATIC_ASSERT(sizeof(nk_i4x2_t) == 1, nk_i4_t_must_be_1_byte);
959
+ NK_STATIC_ASSERT(sizeof(nk_u4x2_t) == 1, nk_u4_t_must_be_1_byte);
960
+ NK_STATIC_ASSERT(sizeof(nk_e4m3_t) == 1, nk_e4m3_t_must_be_1_byte);
961
+ NK_STATIC_ASSERT(sizeof(nk_e5m2_t) == 1, nk_e5m2_t_must_be_1_byte);
962
+ NK_STATIC_ASSERT(sizeof(nk_i8_t) == 1, nk_i8_t_must_be_1_byte);
963
+ NK_STATIC_ASSERT(sizeof(nk_u8_t) == 1, nk_u8_t_must_be_1_byte);
964
+ NK_STATIC_ASSERT(sizeof(nk_i16_t) == 2, nk_i16_t_must_be_2_bytes);
965
+ NK_STATIC_ASSERT(sizeof(nk_u16_t) == 2, nk_u16_t_must_be_2_bytes);
966
+ NK_STATIC_ASSERT(sizeof(nk_i32_t) == 4, nk_i32_t_must_be_4_bytes);
967
+ NK_STATIC_ASSERT(sizeof(nk_u32_t) == 4, nk_u32_t_must_be_4_bytes);
968
+ NK_STATIC_ASSERT(sizeof(nk_i64_t) == 8, nk_i64_t_must_be_8_bytes);
969
+ NK_STATIC_ASSERT(sizeof(nk_u64_t) == 8, nk_u64_t_must_be_8_bytes);
970
+ NK_STATIC_ASSERT(sizeof(nk_f32_t) == 4, nk_f32_t_must_be_4_bytes);
971
+ NK_STATIC_ASSERT(sizeof(nk_f64_t) == 8, nk_f64_t_must_be_8_bytes);
972
+ NK_STATIC_ASSERT(sizeof(nk_f16_t) == 2, nk_f16_t_must_be_2_bytes);
973
+ NK_STATIC_ASSERT(sizeof(nk_bf16_t) == 2, nk_bf16_t_must_be_2_bytes);
974
+
975
+ #define nk_assign_from_to_(src, dest) (*(dest) = *(src))
976
+
977
+ /** @brief 16-bit union for f16/bf16/u16/i16 bit manipulation. */
978
+ typedef union {
979
+ nk_u16_t u;
980
+ nk_i16_t i;
981
+ nk_f16_t f;
982
+ nk_bf16_t bf;
983
+ } nk_fui16_t;
984
+
985
+ /** @brief 32-bit union for f32/u32/i32 bit manipulation. */
986
+ typedef union {
987
+ nk_u32_t u;
988
+ nk_i32_t i;
989
+ nk_f32_t f;
990
+ } nk_fui32_t;
991
+
992
+ /** @brief 64-bit union for f64/u64/i64 bit manipulation. */
993
+ typedef union {
994
+ nk_u64_t u;
995
+ nk_i64_t i;
996
+ nk_f64_t f;
997
+ } nk_fui64_t;
998
+
999
+ /** @brief Half-precision (32-bit) complex number — {real: f16, imag: f16}. Kernel outputs widened to f32c. */
1000
+ typedef struct {
1001
+ nk_f16_t real;
1002
+ nk_f16_t imag;
1003
+ } nk_f16c_t;
1004
+
1005
+ /** @brief BFloat16 (32-bit) complex number — {real: bf16, imag: bf16}. Kernel outputs widened to f32c. */
1006
+ typedef struct {
1007
+ nk_bf16_t real;
1008
+ nk_bf16_t imag;
1009
+ } nk_bf16c_t;
1010
+
1011
+ /** @brief Single-precision (64-bit) complex number — {real: f32, imag: f32}. */
1012
+ typedef struct {
1013
+ nk_f32_t real;
1014
+ nk_f32_t imag;
1015
+ } nk_f32c_t;
1016
+
1017
+ /** @brief Double-precision (128-bit) complex number — {real: f64, imag: f64}. */
1018
+ typedef struct {
1019
+ nk_f64_t real;
1020
+ nk_f64_t imag;
1021
+ } nk_f64c_t;
1022
+
1023
+ /** @brief Small 4-byte memory slice viewable as different types. */
1024
+ typedef union nk_b32_vec_t {
1025
+ nk_u32_t u32;
1026
+ nk_i32_t i32;
1027
+ nk_f32_t f32;
1028
+ nk_u8_t u8s[4];
1029
+ nk_i8_t i8s[4];
1030
+ nk_u16_t u16s[2];
1031
+ nk_i16_t i16s[2];
1032
+ nk_e4m3_t e4m3s[4];
1033
+ nk_e5m2_t e5m2s[4];
1034
+ } nk_b32_vec_t;
1035
+
1036
+ /** @brief Small 8-byte memory slice viewable as different types. */
1037
+ typedef union nk_b64_vec_t {
1038
+ #if NK_TARGET_NEON
1039
+ uint8x8_t u8x8;
1040
+ uint16x4_t u16x4;
1041
+ uint32x2_t u32x2;
1042
+ int8x8_t i8x8;
1043
+ int16x4_t i16x4;
1044
+ int32x2_t i32x2;
1045
+ float32x2_t f32x2;
1046
+ #endif
1047
+ #if NK_TARGET_NEONHALF
1048
+ float16x4_t f16x4;
1049
+ #endif
1050
+ nk_u8_t u8s[8];
1051
+ nk_u16_t u16s[4];
1052
+ nk_u32_t u32s[2];
1053
+ nk_u64_t u64;
1054
+ nk_i8_t i8s[8];
1055
+ nk_i16_t i16s[4];
1056
+ nk_i32_t i32s[2];
1057
+ nk_i64_t i64;
1058
+ nk_f16_t f16s[4];
1059
+ nk_bf16_t bf16s[4];
1060
+ nk_f32_t f32s[2];
1061
+ } nk_b64_vec_t;
1062
+
1063
+ /** @brief Small 16-byte memory slice viewable as different types. */
1064
+ typedef union nk_b128_vec_t {
1065
+ #if NK_TARGET_HASWELL
1066
+ __m128i xmm;
1067
+ __m128d xmm_pd;
1068
+ __m128 xmm_ps;
1069
+ #endif
1070
+ #if NK_TARGET_V128RELAXED
1071
+ v128_t v128;
1072
+ #endif
1073
+ #if NK_TARGET_NEON
1074
+ uint8x16_t u8x16;
1075
+ uint16x8_t u16x8;
1076
+ uint32x4_t u32x4;
1077
+ uint64x2_t u64x2;
1078
+ int8x16_t i8x16;
1079
+ int16x8_t i16x8;
1080
+ int32x4_t i32x4;
1081
+ int64x2_t i64x2;
1082
+ float32x4_t f32x4;
1083
+ float64x2_t f64x2;
1084
+ #endif
1085
+ nk_u8_t u8s[16];
1086
+ nk_u16_t u16s[8];
1087
+ nk_u32_t u32s[4];
1088
+ nk_u64_t u64s[2];
1089
+ nk_i8_t i8s[16];
1090
+ nk_i16_t i16s[8];
1091
+ nk_i32_t i32s[4];
1092
+ nk_i64_t i64s[2];
1093
+ nk_f16_t f16s[8];
1094
+ nk_bf16_t bf16s[8];
1095
+ nk_e4m3_t e4m3s[16];
1096
+ nk_e5m2_t e5m2s[16];
1097
+ nk_e2m3_t e2m3s[16];
1098
+ nk_e3m2_t e3m2s[16];
1099
+ nk_f32_t f32s[4];
1100
+ nk_f64_t f64s[2];
1101
+ } nk_b128_vec_t;
1102
+
1103
+ /** @brief Small 32-byte memory slice viewable as different types. */
1104
+ typedef union nk_b256_vec_t {
1105
+ #if NK_TARGET_HASWELL
1106
+ __m256i ymm;
1107
+ __m256d ymm_pd;
1108
+ __m256 ymm_ps;
1109
+ __m128i xmms[2];
1110
+ #endif
1111
+ #if NK_TARGET_V128RELAXED
1112
+ v128_t v128s[2];
1113
+ #endif
1114
+ #if NK_TARGET_NEON
1115
+ uint8x16_t u8x16s[2];
1116
+ uint16x8_t u16x8s[2];
1117
+ uint32x4_t u32x4s[2];
1118
+ uint64x2_t u64x2s[2];
1119
+ int8x16_t i8x16s[2];
1120
+ int16x8_t i16x8s[2];
1121
+ int32x4_t i32x4s[2];
1122
+ int64x2_t i64x2s[2];
1123
+ float32x4_t f32x4s[2];
1124
+ float64x2_t f64x2s[2];
1125
+ #endif
1126
+ nk_u8_t u8s[32];
1127
+ nk_u16_t u16s[16];
1128
+ nk_u32_t u32s[8];
1129
+ nk_u64_t u64s[4];
1130
+ nk_i8_t i8s[32];
1131
+ nk_i16_t i16s[16];
1132
+ nk_i32_t i32s[8];
1133
+ nk_i64_t i64s[4];
1134
+ nk_f16_t f16s[16];
1135
+ nk_bf16_t bf16s[16];
1136
+ nk_e4m3_t e4m3s[32];
1137
+ nk_e5m2_t e5m2s[32];
1138
+ nk_e2m3_t e2m3s[32];
1139
+ nk_e3m2_t e3m2s[32];
1140
+ nk_f32_t f32s[8];
1141
+ nk_f64_t f64s[4];
1142
+ } nk_b256_vec_t;
1143
+
1144
+ /** @brief Small 64-byte memory slice viewable as different types.
1145
+ *
1146
+ * TODO: On GCC and Clang we use `__transparent_union__` attribute to allow implicit conversions
1147
+ * between the different vector types when passing them as function arguments. The most important side-effect
1148
+ * of this is that the argument of such type is passed to functions using the calling convention of the first
1149
+ * member of the union, which in our case is a register-based calling convention for SIMD types.
1150
+ */
1151
+ typedef union nk_b512_vec_t {
1152
+ #if NK_TARGET_SKYLAKE
1153
+ __m512i zmm;
1154
+ __m512d zmm_pd;
1155
+ __m512 zmm_ps;
1156
+ #endif
1157
+ #if NK_TARGET_HASWELL
1158
+ __m256i ymms[2];
1159
+ __m256d ymms_pd[2];
1160
+ __m256 ymms_ps[2];
1161
+ __m128i xmms[4];
1162
+ __m128d xmms_pd[4];
1163
+ __m128 xmms_ps[4];
1164
+ #endif
1165
+ #if NK_TARGET_NEON
1166
+ uint8x16_t u8x16s[4];
1167
+ uint16x8_t u16x8s[4];
1168
+ uint32x4_t u32x4s[4];
1169
+ uint64x2_t u64x2s[4];
1170
+ #endif
1171
+ nk_u8_t u8s[64];
1172
+ nk_u16_t u16s[32];
1173
+ nk_u32_t u32s[16];
1174
+ nk_u64_t u64s[8];
1175
+ nk_i8_t i8s[64];
1176
+ nk_i16_t i16s[32];
1177
+ nk_i32_t i32s[16];
1178
+ nk_i64_t i64s[8];
1179
+ nk_f16_t f16s[32];
1180
+ nk_bf16_t bf16s[32];
1181
+ nk_f32_t f32s[16];
1182
+ nk_f64_t f64s[8];
1183
+ nk_e4m3_t e4m3s[64];
1184
+ nk_e5m2_t e5m2s[64];
1185
+ nk_e2m3_t e2m3s[64];
1186
+ nk_e3m2_t e3m2s[64];
1187
+ } nk_b512_vec_t;
1188
+
1189
+ /**
1190
+ * @brief Advances the Multi-Dimensional iterator to the next set of indicies.
1191
+ * @param[in] extents The extents of the tensor, defined by an array of at least `rank` scalars.
1192
+ * @param[in] strides The @b signed strides of the tensor in bytes, defined by an array of at least `rank` scalars.
1193
+ * @param[in] rank The number of dimensions in the tensor (its rank).
1194
+ * @param[inout] coordinates The array of offsets along each of `rank` dimensions, which will be updated.
1195
+ * @param[inout] byte_offset The @b signed byte offset of the current element, which will be advanced.
1196
+ * @return 1 if the iterator was successfully advanced, 0 if the end of iteration was reached.
1197
+ *
1198
+ * For flexibility, the API is decoupled from from the `nk_tensor_position_t` structure, and
1199
+ * can be used on any-rank tensors, independent of the `NK_TENSOR_MAX_RANK` constant.
1200
+ */
1201
+ NK_PUBLIC int nk_tensor_position_next( //
1202
+ nk_size_t const *extents, nk_ssize_t const *strides, nk_size_t rank, //
1203
+ nk_size_t *coordinates, nk_ssize_t *byte_offset) {
1204
+ // Start from last dimension and move backward
1205
+ for (nk_size_t i = rank; i-- > 0;) {
1206
+ coordinates[i]++;
1207
+ *byte_offset += strides[i];
1208
+ if (coordinates[i] < extents[i]) return 1; // Successfully moved to the next index
1209
+ coordinates[i] = 0; // Reset this dimension counter
1210
+ *byte_offset -= strides[i] * extents[i]; // Discard the running progress along this dimension
1211
+ }
1212
+ // If we reach here, we've iterated over all elements
1213
+ return 0; // End of iteration
1214
+ }
1215
+
1216
+ /**
1217
+ * @brief Advances the Multi-Dimensional iterator to the provided coordinates, updating the byte offset.
1218
+ * @param[in] extents The extents of the tensor, defined by an array of at least `rank` scalars.
1219
+ * @param[in] strides The @b signed strides of the tensor in bytes, defined by an array of at least `rank` scalars.
1220
+ * @param[in] rank The number of dimensions in the tensor (its rank).
1221
+ * @param[in] coordinates The array of offsets along each of `rank` dimensions, which will be updated.
1222
+ * @param[out] byte_offset The byte offset of the current element, which will be advanced.
1223
+ * @return 1 if the offset was successfully advanced, 0 if the end of iteration was reached.
1224
+ */
1225
+ NK_PUBLIC int nk_tensor_position_linearize( //
1226
+ nk_size_t const *extents, nk_ssize_t const *strides, nk_size_t rank, //
1227
+ nk_size_t const *coordinates, nk_ssize_t *byte_offset) {
1228
+
1229
+ nk_ssize_t result = 0;
1230
+ for (nk_size_t i = 0; i < rank; i++) {
1231
+ // Ensure the coordinates is within bounds for the given dimension
1232
+ if (coordinates[i] >= extents[i]) return 0; // Invalid coordinates, out of bounds
1233
+ // Update the byte offset by multiplying the coordinates by the stride
1234
+ result += coordinates[i] * strides[i];
1235
+ }
1236
+ *byte_offset = result;
1237
+ return 1; // Successfully calculated global and byte offsets
1238
+ }
1239
+
1240
+ /**
1241
+ * @brief A @b beefy structure to iterate through Multi-Dimensional arrays.
1242
+ * Occupies 512 + 8 = 520 bytes on a 64-bit machine, or @b 9 cache-lines, by default.
1243
+ *
1244
+ * When advancing through a structure, its overall size and strides should be stored somewhere else.
1245
+ * The `byte_offset` starts at zero and grow monotonically during iteration, if the strides are positive.
1246
+ */
1247
+ typedef struct nk_tensor_position_t {
1248
+ nk_size_t coordinates[NK_TENSOR_MAX_RANK]; // Coordinate offsets along each dimension
1249
+ nk_ssize_t byte_offset; // Byte offset of the current element
1250
+ } nk_tensor_position_t;
1251
+
1252
+ NK_PUBLIC void nk_tensor_position_init(nk_tensor_position_t *tensor_position) {
1253
+ for (nk_size_t i = 0; i < NK_TENSOR_MAX_RANK; i++) tensor_position->coordinates[i] = 0;
1254
+ tensor_position->byte_offset = 0;
1255
+ }
1256
+
1257
+ /**
1258
+ * @brief A @b beefy structure describing the shape and memory layout of a Multi-Dimensional array.
1259
+ * Similar to `md::span` in C++20 and `numpy.ndarray` in Python, but with a focus on compatibility.
1260
+ * Occupies 512 + 512 + 8 = 2052 bytes on a 64-bit machine, or @b 17 cache-lines, by default.
1261
+ *
1262
+ * Unlike NumPy and the CPython "Buffer Protocol", we don't use `suboffsets` for pointer indirection.
1263
+ * The logic is that such layouts aren't friendly to conventional SIMD operations and dense matrix algorithms.
1264
+ * If the tensor is sparse, consider using a different data structure or a different memory layout.
1265
+ *
1266
+ * Most NumKong algorithms don't work with the entire structure, but expect the fields to be passed separately.
1267
+ * It would also require storing the @b start-pointer and the @b dtype/item-size separately, as it's not
1268
+ * stored inside the structure.
1269
+ */
1270
+ typedef struct nk_tensor_shape_t {
1271
+ nk_size_t extents[NK_TENSOR_MAX_RANK]; /// Number of elements along each dimension
1272
+ nk_ssize_t strides[NK_TENSOR_MAX_RANK]; /// Strides of the tensor in bytes
1273
+ nk_size_t rank; /// Number of dimensions in the tensor
1274
+ } nk_tensor_shape_t;
1275
+
1276
+ NK_PUBLIC void nk_tensor_shape_init(nk_tensor_shape_t *tensor_shape) {
1277
+ for (nk_size_t i = 0; i < NK_TENSOR_MAX_RANK; i++) tensor_shape->extents[i] = 0, tensor_shape->strides[i] = 0;
1278
+ tensor_shape->rank = 0;
1279
+ }
1280
+
1281
+ NK_INTERNAL nk_u32_t nk_u32_rol(nk_u32_t x, int n) { return (x << n) | (x >> (32 - n)); }
1282
+ NK_INTERNAL nk_u16_t nk_u16_rol(nk_u16_t x, int n) { return (x << n) | (x >> (16 - n)); }
1283
+ NK_INTERNAL nk_u8_t nk_u8_rol(nk_u8_t x, int n) { return (x << n) | (x >> (8 - n)); }
1284
+ NK_INTERNAL nk_u32_t nk_u32_ror(nk_u32_t x, int n) { return (x >> n) | (x << (32 - n)); }
1285
+ NK_INTERNAL nk_u16_t nk_u16_ror(nk_u16_t x, int n) { return (x >> n) | (x << (16 - n)); }
1286
+ NK_INTERNAL nk_u8_t nk_u8_ror(nk_u8_t x, int n) { return (x >> n) | (x << (8 - n)); }
1287
+
1288
+ /**
1289
+ * @brief SWAR population count for 64-bit integers.
1290
+ *
1291
+ * Classic algorithm from Hacker's Delight using parallel bit summation:
1292
+ * - Step 1: Count bits in pairs (2-bit sums)
1293
+ * - Step 2: Count bits in nibbles (4-bit sums)
1294
+ * - Step 3: Count bits in bytes (8-bit sums)
1295
+ * - Step 4: Horizontal sum via multiply - each byte contributes to bits 56-63
1296
+ *
1297
+ * Cost: ~12 ALU ops, zero memory access (vs 8 table lookups for byte-wise).
1298
+ */
1299
+ NK_INTERNAL nk_u64_t nk_u64_popcount_(nk_u64_t x) {
1300
+ x = x - ((x >> 1) & 0x5555555555555555ull);
1301
+ x = (x & 0x3333333333333333ull) + ((x >> 2) & 0x3333333333333333ull);
1302
+ x = (x + (x >> 4)) & 0x0F0F0F0F0F0F0F0Full;
1303
+ return (x * 0x0101010101010101ull) >> 56;
1304
+ }
1305
+
1306
+ NK_INTERNAL unsigned char nk_u1x8_popcount_(nk_u1x8_t x) {
1307
+ static unsigned char lookup_table[256] = {
1308
+ 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, //
1309
+ 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
1310
+ 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
1311
+ 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
1312
+ 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
1313
+ 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
1314
+ 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
1315
+ 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8};
1316
+ return lookup_table[x];
1317
+ }
1318
+
1319
+ /** @brief Divides the number rounding up to the next multiple of the given divisor. */
1320
+ NK_PUBLIC nk_size_t nk_size_divide_round_up_(nk_size_t number, nk_size_t divisor) NK_STREAMING_COMPATIBLE_ {
1321
+ return (number + divisor - 1) / divisor;
1322
+ }
1323
+
1324
+ /** @brief Rounds up the number to the next multiple of the given divisor. */
1325
+ NK_PUBLIC nk_size_t nk_size_round_up_to_multiple_(nk_size_t number, nk_size_t divisor) NK_STREAMING_COMPATIBLE_ {
1326
+ return nk_size_divide_round_up_(number, divisor) * divisor;
1327
+ }
1328
+
1329
+ NK_INTERNAL nk_f32_t nk_f32_abs_(nk_f32_t x) { return x < 0 ? -x : x; }
1330
+ NK_INTERNAL nk_f64_t nk_f64_abs_(nk_f64_t x) { return x < 0 ? -x : x; }
1331
+ NK_INTERNAL nk_i64_t nk_i64_abs_(nk_i64_t x) { return x < 0 ? -x : x; }
1332
+ NK_INTERNAL nk_u64_t nk_u64_abs_(nk_u64_t x) { return x; }
1333
+ NK_INTERNAL nk_i64_t nk_i32_abs_(nk_i32_t x) { return x < 0 ? -x : x; }
1334
+ NK_INTERNAL nk_u32_t nk_u32_abs_(nk_u32_t x) { return x; }
1335
+
1336
+ /** @brief Extract low (bits 0-3) unsigned nibble from packed u4x2 byte. */
1337
+ NK_INTERNAL nk_u8_t nk_u4x2_low_(nk_u4x2_t byte_val) { return byte_val & 0x0F; }
1338
+ /** @brief Extract high (bits 4-7) unsigned nibble from packed u4x2 byte. */
1339
+ NK_INTERNAL nk_u8_t nk_u4x2_high_(nk_u4x2_t byte_val) { return (byte_val >> 4) & 0x0F; }
1340
+
1341
+ /** @brief Extract low (bits 0-3) signed nibble from packed i4x2 byte as i8. */
1342
+ NK_INTERNAL nk_i8_t nk_i4x2_low_(nk_i4x2_t byte_val) { return (nk_i8_t)(((byte_val & 0x0F) ^ 8) - 8); }
1343
+ /** @brief Extract high (bits 4-7) signed nibble from packed i4x2 byte as i8. */
1344
+ NK_INTERNAL nk_i8_t nk_i4x2_high_(nk_i4x2_t byte_val) { return (nk_i8_t)((((byte_val >> 4) & 0x0F) ^ 8) - 8); }
1345
+
1346
+ /** @brief Extract n-th nibble (n=0: low, n=1: high) — branchless. */
1347
+ NK_INTERNAL nk_u8_t nk_u4x2_get_(nk_u4x2_t byte_val, int n) { return (byte_val >> ((n & 1) * 4)) & 0x0F; }
1348
+ NK_INTERNAL nk_i8_t nk_i4x2_get_(nk_i4x2_t byte_val, int n) {
1349
+ nk_u8_t nibble = (byte_val >> ((n & 1) * 4)) & 0x0F;
1350
+ return (nk_i8_t)((nibble ^ 8) - 8);
1351
+ }
1352
+
1353
+ /** @brief Extract bit at position n (0-7) from packed u1x8 byte. */
1354
+ NK_INTERNAL nk_u8_t nk_u1x8_get_(nk_u1x8_t byte_val, int n) { return (byte_val >> (n & 7)) & 1; }
1355
+
1356
+ NK_INTERNAL nk_f16_t nk_f16_from_u16_(nk_u16_t bits) {
1357
+ nk_fui16_t c;
1358
+ c.u = bits;
1359
+ return c.f;
1360
+ }
1361
+ NK_INTERNAL nk_bf16_t nk_bf16_from_u16_(nk_u16_t bits) {
1362
+ nk_fui16_t c;
1363
+ c.u = bits;
1364
+ return c.bf;
1365
+ }
1366
+
1367
+ /** @brief E4M3: NaN when (raw & 0x7F) == 0x7F (two NaN values: 0x7F, 0xFF). */
1368
+ NK_INTERNAL int nk_e4m3_is_nan_(nk_e4m3_t x) { return (x & 0x7F) == 0x7F; }
1369
+
1370
+ /** @brief E5M2: NaN when exponent=31 and mantissa!=0, i.e. (raw & 0x7F) > 0x7C.
1371
+ * Values: 0x7D-0x7F (positive), 0xFD-0xFF (negative). Infinity = 0x7C/0xFC is NOT NaN. */
1372
+ NK_INTERNAL int nk_e5m2_is_nan_(nk_e5m2_t x) { return (x & 0x7F) > 0x7C; }
1373
+
1374
+ /** @brief F16: NaN when (raw & 0x7FFF) > 0x7C00. */
1375
+ NK_INTERNAL int nk_f16_is_nan_(nk_u16_t x) { return (x & 0x7FFF) > 0x7C00; }
1376
+
1377
+ /** @brief BF16: NaN when (raw & 0x7FFF) > 0x7F80. */
1378
+ NK_INTERNAL int nk_bf16_is_nan_(nk_u16_t x) { return (x & 0x7FFF) > 0x7F80; }
1379
+
1380
+ #ifdef __cplusplus
1381
+ } // extern "C"
1382
+ #endif
1383
+
1384
+ #endif // NK_TYPES_H