numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,748 @@
1
+ /**
2
+ * @brief SIMD capability detection and thread configuration.
3
+ * @file include/numkong/capabilities.h
4
+ * @author Ash Vardanian
5
+ * @date February 6, 2026
6
+ *
7
+ * @section x86_targets Choosing x86 Target Generations
8
+ *
9
+ * It's important to provide fine-grained controls over AVX512 families, as they are very fragmented:
10
+ *
11
+ * - Intel Skylake servers: F, CD, VL, DQ, BW
12
+ * - Intel Cascade Lake workstations: F, CD, VL, DQ, BW, VNNI
13
+ * > In other words, it extends Skylake with VNNI support
14
+ * - Intel Sunny Cove (Ice Lake) servers:
15
+ * F, CD, VL, DQ, BW, VNNI, VPOPCNTDQ, IFMA, VBMI, VAES, GFNI, VBMI2, BITALG, VPCLMULQDQ
16
+ * - AMD Zen4 (Genoa):
17
+ * F, CD, VL, DQ, BW, VNNI, VPOPCNTDQ, IFMA, VBMI, VAES, GFNI, VBMI2, BITALG, VPCLMULQDQ, BF16
18
+ * > In other words, it extends Sunny Cove with BF16 support
19
+ * - Intel Golden Cove (Sapphire Rapids): extends Zen4 and Sunny Cove with FP16 support
20
+ * - AMD Zen5 (Turin): makes VP2INTERSECT cool again
21
+ *
22
+ * Intel Palm Cove was an irrelevant intermediate release extending Skylake with IFMA and VBMI.
23
+ * Intel Willow Cove was an irrelevant intermediate release extending Sunny Cove with VP2INTERSECT,
24
+ * which are not supported by other CPUs to date and are only available in Tiger Lake laptops.
25
+ * Intel Cooper Lake was the only intermediary platform, that supported BF16, but not FP16.
26
+ * It's mostly used in 4-socket and 8-socket high-memory configurations.
27
+ *
28
+ * For us, it makes sense to differentiate only these AVX512 generations:
29
+ *
30
+ * 1. Intel Skylake (pre 2019): supports single-precision dot-products.
31
+ * 2. Intel Ice Lake (2019-2021): advanced integer algorithms.
32
+ * 3. AMD Genoa (2023+): brain-floating point support.
33
+ * 4. Intel Sapphire Rapids (2023+): advanced mixed-precision float processing.
34
+ * 5. AMD Turin (2024+): advanced sparse algorithms.
35
+ *
36
+ * Beyond those, we support AVX2 for old Haswell generation CPUs, AVX2+VNNI for Alder Lake (12th/13th gen),
37
+ * and AVX2+VNNI-INT8 for Sierra Forest (adds native signed×signed and unsigned×unsigned 8-bit dot products).
38
+ *
39
+ * To list all available macros for x86, take a recent compiler, like GCC 12 and run:
40
+ * gcc-12 -march=sapphirerapids -dM -E - < /dev/null | egrep "SSE|AVX" | sort
41
+ * On Arm machines you may want to check for other flags:
42
+ * gcc-12 -march=native -dM -E - < /dev/null | egrep "NEON|SVE|FP16|FMA" | sort
43
+ *
44
+ * @section arm_targets Choosing Arm Target Generations
45
+ *
46
+ * Arm CPUs share design IP, but are produced by different vendors, potentially making the platform
47
+ * even more fragmented than x86. There are 2 important families of SIMD extensions - NEON and SVE.
48
+ *
49
+ * - Armv8-A: +fp, +simd
50
+ * - Armv8.1-A: armv8-a, +crc, +lse, +rdma
51
+ * - Armv8.2-A: armv8.1-a
52
+ * - Armv8.3-A: armv8.2-a, +pauth
53
+ * - Armv8.4-A: armv8.3-a, +flagm, +fp16fml, +dotprod
54
+ * - Armv8.5-A: armv8.4-a, +sb, +ssbs, +predres
55
+ * - Armv8.6-A: armv8.5-a, +bf16, +i8mm
56
+ * - Armv8.7-A: armv8.6-a, +ls64
57
+ * - Armv8.8-A: armv8.7-a, +mops
58
+ * - Armv8.9-A: armv8.8-a
59
+ * - Armv9-A: armv8.5-a, +sve, +sve2
60
+ * - Armv9.1-A: armv9-a, +bf16, +i8mm
61
+ * - Armv9.2-A: armv9.1-a, +ls64
62
+ * - Armv9.3-A: armv9.2-a, +mops
63
+ * - Armv9.4-A: armv9.3-a
64
+ *
65
+ * SVE has been optional since Armv8.2-A, but it's a requirement for Armv9.0-A.
66
+ * A 512-bit SVE variant has already been implemented on the Fugaku supercomputer.
67
+ * A more flexible version, 2x256 SVE, was implemented by the AWS Graviton3 ARM processor.
68
+ * Here are the most important recent families of CPU cores designed by Arm:
69
+ *
70
+ * - Neoverse N1: armv8.2-a, extended with Armv8.4 "dotprod" instructions.
71
+ * Used in AWS @b Graviton2 and Ampere @b Altra.
72
+ * https://developer.arm.com/Processors/Neoverse%20N1
73
+ * - Neoverse V1: armv8.4-a, extended with Armv8.6 bfloat/int8 "matmul" instructions.
74
+ * Used in AWS @b Graviton3, which also enables `sve`, `svebf16`, and `svei8mm`.
75
+ * https://developer.arm.com/Processors/Neoverse%20V1
76
+ * - Neoverse V2: armv9.0 with SVE2 and SVE bit-permutes
77
+ * Used in AWS @b Graviton4, NVIDIA @b Grace, Google @b Axion.
78
+ * https://developer.arm.com/Processors/Neoverse%20V2
79
+ * The N2 core is very similar to V2 and is used by Microsoft @b Cobalt.
80
+ * https://developer.arm.com/Processors/Neoverse%20N2
81
+ *
82
+ * On the consumer side, Apple is the biggest player with mobile @b A chips and desktop @b M chips.
83
+ * The M1 implements Armv8.5-A, both M2 and M3 implement Armv8.6-A, and M4 is expected to have Armv9.1-A.
84
+ *
85
+ * @section references References
86
+ *
87
+ * - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide
88
+ * - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics
89
+ * - Detecting target CPU features at compile time: https://stackoverflow.com/a/28939692/2766161
90
+ */
91
+
92
+ #ifndef NK_CAPABILITIES_H
93
+ #define NK_CAPABILITIES_H
94
+
95
+ #include "numkong/types.h" // `nk_u64_t`, `NK_DEFINED_LINUX_`
96
+
97
+ #define NK_VERSION_MAJOR 7
98
+ #define NK_VERSION_MINOR 0
99
+ #define NK_VERSION_PATCH 0
100
+
101
+ /**
102
+ * @brief Removes compile-time dispatching, and replaces it with runtime dispatching.
103
+ * So the `nk_dot_f32` function will invoke the most advanced backend supported by the CPU,
104
+ * that runs the program, rather than the most advanced backend supported by the CPU
105
+ * used to compile the library or the downstream application.
106
+ */
107
+ #if !defined(NK_DYNAMIC_DISPATCH)
108
+ #define NK_DYNAMIC_DISPATCH (0) // true or false
109
+ #endif
110
+
111
+ // On Apple Silicon, `mrs` is not allowed in user-space, so we need to use the `sysctl` API.
112
+ #if defined(NK_DEFINED_APPLE_)
113
+ #include <fenv.h> // `fesetenv` - part of C 99 standard
114
+ #include <sys/sysctl.h> // `sysctlbyname`
115
+ #endif
116
+
117
+ // Detect POSIX extensions availability for signal handling.
118
+ // POSIX extensions provide `sigaction`, `sigjmp_buf`, and `sigsetjmp` for safe signal handling.
119
+ // These are needed on Linux ARM for safely testing `mrs` instruction availability.
120
+ #if (defined(NK_DEFINED_LINUX_) || defined(NK_DEFINED_FREEBSD_)) && defined(_POSIX_VERSION)
121
+ #include <setjmp.h> // `sigjmp_buf`, `sigsetjmp`, `siglongjmp`
122
+ #include <signal.h> // `sigaction`, `SIGILL`
123
+ #define NK_HAS_POSIX_EXTENSIONS_ 1
124
+ #else
125
+ #define NK_HAS_POSIX_EXTENSIONS_ 0
126
+ #endif
127
+
128
+ // On Linux x86/RISC-V, we need `syscall()` for AMX permission and hwprobe.
129
+ // With `-std=c11` glibc hides `syscall()` behind `_GNU_SOURCE`, but if any
130
+ // system header was included before us, `<features.h>` is already locked.
131
+ // Forward-declare `syscall` directly — it always exists in glibc.
132
+ #if defined(NK_DEFINED_LINUX_) && (NK_TARGET_X86_ || NK_TARGET_RISCV_)
133
+ #include <sys/syscall.h> // `SYS_arch_prctl`, `SYS_riscv_hwprobe`
134
+ #ifdef __cplusplus
135
+ extern "C" long syscall(long, ...) noexcept;
136
+ #else
137
+ extern long syscall(long, ...);
138
+ #endif
139
+ #if NK_TARGET_RISCV_
140
+ #include <sys/auxv.h> // `getauxval`, `AT_HWCAP`
141
+ #endif
142
+ #endif
143
+
144
+ // On FreeBSD RISC-V, we use elf_aux_info for capability detection
145
+ #if defined(NK_DEFINED_FREEBSD_) && NK_TARGET_RISCV_
146
+ #include <sys/auxv.h> // `elf_aux_info`, `AT_HWCAP`
147
+ #endif
148
+
149
+ // On Windows ARM, we use IsProcessorFeaturePresent API for capability detection
150
+ #if defined(NK_DEFINED_WINDOWS_) && NK_TARGET_ARM_
151
+ #include <processthreadsapi.h> // `IsProcessorFeaturePresent`
152
+ #endif
153
+
154
+ // On WASM with Emscripten, we use EM_JS for runtime capability detection
155
+ #if NK_TARGET_WASM_ && defined(__EMSCRIPTEN__)
156
+ #include <emscripten.h> // `EM_JS`
157
+ #endif
158
+
159
+ #ifdef __cplusplus
160
+ extern "C" {
161
+ #endif
162
+
163
+ /**
164
+ * @brief Enumeration of supported metric kinds.
165
+ * Some have aliases for convenience/discoverability.
166
+ */
167
+ typedef enum {
168
+ nk_kernel_unknown_k = 0, ///< Unknown kernel kind
169
+
170
+ // Classics:
171
+ nk_kernel_dot_k = 'i', ///< Inner product
172
+ nk_kernel_vdot_k = 'v', ///< Complex inner product
173
+ nk_kernel_angular_k = 'a', ///< Angular (cosine) distance
174
+ nk_kernel_euclidean_k = 'e', ///< Euclidean distance
175
+ nk_kernel_sqeuclidean_k = '2', ///< Squared Euclidean distance
176
+
177
+ // Binary:
178
+ nk_kernel_hamming_k = 'h', ///< Hamming (or Manhattan) distance
179
+ nk_kernel_jaccard_k = 'j', ///< Jaccard (or Tanimoto) coefficient
180
+
181
+ // Curved Spaces:
182
+ nk_kernel_bilinear_k = 'b', ///< Bilinear form
183
+ nk_kernel_mahalanobis_k = 'm', ///< Mahalanobis distance
184
+
185
+ // Geospatial:
186
+ nk_kernel_haversine_k = 'o', ///< Haversine distance
187
+ nk_kernel_vincenty_k = 'O', ///< Vincenty distance (ellipsoidal geodesic)
188
+
189
+ // Probability:
190
+ nk_kernel_kld_k = 'k', ///< Kullback-Leibler divergence
191
+ nk_kernel_jsd_k = 's', ///< Jensen-Shannon divergence
192
+
193
+ // Mesh superposition:
194
+ nk_kernel_rmsd_k = 'r', ///< RMSD without optimal superposition
195
+ nk_kernel_kabsch_k = 'K', ///< Kabsch RMSD with optimal rotation
196
+ nk_kernel_umeyama_k = 'U', ///< Umeyama RMSD with optimal rotation and scale
197
+
198
+ // Sparse Sets:
199
+ nk_kernel_sparse_dot_k = 'd', ///< Sparse dot product with weighted indices
200
+ nk_kernel_sparse_intersect_k = 'x', ///< Equivalent to unnormalized Jaccard
201
+
202
+ // BLAS-like operations:
203
+ nk_kernel_each_scale_k = '*', ///< Element-wise Scale
204
+ nk_kernel_each_sum_k = '+', ///< Element-wise Sum
205
+ nk_kernel_each_blend_k = 'w', ///< Element-wise Weighted Sum
206
+ nk_kernel_each_fma_k = 'f', ///< Element-wise Fused Multiply-Add
207
+
208
+ // Trigonometric functions:
209
+ nk_kernel_each_sin_k = 'S', ///< Element-wise sine
210
+ nk_kernel_each_cos_k = 'C', ///< Element-wise cosine
211
+ nk_kernel_each_atan_k = 'A', ///< Element-wise arctangent
212
+
213
+ // Horizontal reductions:
214
+ nk_kernel_reduce_moments_k = 'R', ///< Horizontal moments reduction (sum + sum-of-squares)
215
+ nk_kernel_reduce_minmax_k = 'X', ///< Horizontal minmax reduction (min + argmin + max + argmax)
216
+
217
+ // GEMM-like batched dot products:
218
+ nk_kernel_dots_packed_size_k = 'P', ///< GEMM packed buffer size
219
+ nk_kernel_dots_pack_k = 'Q', ///< GEMM B matrix packing
220
+ nk_kernel_dots_packed_k = 'G', ///< GEMM computation
221
+ nk_kernel_dots_symmetric_k = 'y', ///< Symmetric Gram matrix (A x At)
222
+
223
+ // GEMM-like batched set similarity functions:
224
+ nk_kernel_hammings_packed_k = 'M', ///< Hamming distance computation
225
+ nk_kernel_hammings_symmetric_k = 'Y', ///< Symmetric Hamming distance matrix (A x At)
226
+ nk_kernel_jaccards_packed_k = 'p', ///< Jaccard distance computation
227
+ nk_kernel_jaccards_symmetric_k = 'Z', ///< Symmetric Jaccard distance matrix
228
+
229
+ // GEMM-like batched spatial distances functions:
230
+ nk_kernel_angulars_packed_k = 'N', ///< Batched angular distances (packed B)
231
+ nk_kernel_angulars_symmetric_k = 'n', ///< Symmetric angular distance matrix
232
+ nk_kernel_euclideans_packed_k = 'E', ///< Batched euclidean distances (packed B)
233
+ nk_kernel_euclideans_symmetric_k = 'D', ///< Symmetric euclidean distance matrix
234
+
235
+ // MaxSim late-interaction functions:
236
+ nk_kernel_maxsim_packed_size_k = 'L', ///< MaxSim packed buffer size
237
+ nk_kernel_maxsim_pack_k = 'l', ///< MaxSim vector packing
238
+ nk_kernel_maxsim_packed_k = 'T', ///< MaxSim late-interaction computation
239
+
240
+ nk_kernel_cast_k = '-', ///< Type casting from one type to another
241
+
242
+ } nk_kernel_kind_t;
243
+
244
+ /**
245
+ * @brief 64-bit bitmask representing SIMD capabilities of the target architecture.
246
+ */
247
+ typedef nk_u64_t nk_capability_t;
248
+
249
+ /** @brief Serial (non-SIMD) fallback capability. Always available. */
250
+ #define nk_cap_serial_k ((nk_capability_t)1)
251
+
252
+ /** @brief Mask representing any capability. */
253
+ #define nk_cap_any_k ((nk_capability_t)NK_U64_MAX)
254
+
255
+ #define nk_cap_neon_k ((nk_capability_t)1 << 1)
256
+ #define nk_cap_haswell_k ((nk_capability_t)1 << 2)
257
+ #define nk_cap_skylake_k ((nk_capability_t)1 << 3)
258
+ #define nk_cap_neonhalf_k ((nk_capability_t)1 << 4)
259
+ #define nk_cap_neonsdot_k ((nk_capability_t)1 << 5)
260
+ #define nk_cap_neonfhm_k ((nk_capability_t)1 << 6)
261
+ #define nk_cap_icelake_k ((nk_capability_t)1 << 7)
262
+ #define nk_cap_genoa_k ((nk_capability_t)1 << 8)
263
+ #define nk_cap_neonbfdot_k ((nk_capability_t)1 << 9)
264
+ #define nk_cap_sve_k ((nk_capability_t)1 << 10)
265
+ #define nk_cap_svehalf_k ((nk_capability_t)1 << 11)
266
+ #define nk_cap_svesdot_k ((nk_capability_t)1 << 12)
267
+ #define nk_cap_alder_k ((nk_capability_t)1 << 13)
268
+ #define nk_cap_svebfdot_k ((nk_capability_t)1 << 14)
269
+ #define nk_cap_sve2_k ((nk_capability_t)1 << 15)
270
+ #define nk_cap_v128relaxed_k ((nk_capability_t)1 << 16)
271
+ #define nk_cap_sapphire_k ((nk_capability_t)1 << 17)
272
+ #define nk_cap_sapphireamx_k ((nk_capability_t)1 << 18)
273
+ #define nk_cap_rvv_k ((nk_capability_t)1 << 19)
274
+ #define nk_cap_rvvhalf_k ((nk_capability_t)1 << 20)
275
+ #define nk_cap_rvvbf16_k ((nk_capability_t)1 << 21)
276
+ #define nk_cap_graniteamx_k ((nk_capability_t)1 << 22)
277
+ #define nk_cap_turin_k ((nk_capability_t)1 << 23)
278
+ #define nk_cap_sme_k ((nk_capability_t)1 << 24)
279
+ #define nk_cap_sme2_k ((nk_capability_t)1 << 25)
280
+ #define nk_cap_smef64_k ((nk_capability_t)1 << 26)
281
+ #define nk_cap_smefa64_k ((nk_capability_t)1 << 27)
282
+ #define nk_cap_sve2p1_k ((nk_capability_t)1 << 28)
283
+ #define nk_cap_sme2p1_k ((nk_capability_t)1 << 29)
284
+ #define nk_cap_smehalf_k ((nk_capability_t)1 << 30)
285
+ #define nk_cap_smebf16_k ((nk_capability_t)1 << 31)
286
+ #define nk_cap_smelut2_k ((nk_capability_t)1 << 32)
287
+ #define nk_cap_rvvbb_k ((nk_capability_t)1 << 33)
288
+ #define nk_cap_sierra_k ((nk_capability_t)1 << 34)
289
+
290
+ typedef void (*nk_metric_dense_punned_t)(void const *a, void const *b, nk_size_t n, void *d);
291
+
292
+ typedef void (*nk_sparse_intersect_punned_t)(void const *a, void const *b, nk_size_t a_length, nk_size_t b_length,
293
+ void *result, nk_size_t *count);
294
+
295
+ typedef void (*nk_sparse_dot_punned_t)(void const *a, void const *b, void const *a_weights, void const *b_weights,
296
+ nk_size_t a_length, nk_size_t b_length, void *product);
297
+
298
+ typedef void (*nk_metric_curved_punned_t)(void const *a, void const *b, void const *c, nk_size_t n, void *d);
299
+
300
+ typedef void (*nk_metric_geospatial_punned_t)(void const *a_lats, void const *a_lons, void const *b_lats,
301
+ void const *b_lons, nk_size_t n, void *results);
302
+
303
+ typedef void (*nk_each_scale_punned_t)(void const *a, nk_size_t n, void const *alpha, void const *beta, void *y);
304
+
305
+ typedef void (*nk_each_sum_punned_t)(void const *a, void const *b, nk_size_t n, void *y);
306
+
307
+ typedef void (*nk_each_blend_punned_t)(void const *a, void const *b, nk_size_t n, void const *alpha, void const *beta,
308
+ void *y);
309
+
310
+ typedef void (*nk_each_fma_punned_t)(void const *a, void const *b, void const *c, nk_size_t n, void const *alpha,
311
+ void const *beta, void *y);
312
+
313
+ typedef void (*nk_kernel_trigonometry_punned_t)(void const *x, nk_size_t n, void *y);
314
+
315
+ typedef void (*nk_metric_mesh_punned_t)(void const *a, void const *b, nk_size_t n, void *a_centroid, void *b_centroid,
316
+ void *rotation, void *scale, void *d);
317
+
318
+ typedef void (*nk_kernel_reduce_moments_punned_t)(void const *data, nk_size_t count, nk_size_t stride_bytes,
319
+ void *sum_ptr, void *sumsq_ptr);
320
+
321
+ typedef void (*nk_kernel_reduce_minmax_punned_t)(void const *data, nk_size_t count, nk_size_t stride_bytes,
322
+ void *min_value, nk_size_t *min_index, void *max_value,
323
+ nk_size_t *max_index);
324
+
325
+ typedef nk_size_t (*nk_dots_packed_size_punned_t)(nk_size_t width, nk_size_t depth);
326
+
327
+ typedef void (*nk_dots_pack_punned_t)(void const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
328
+ void *b_packed);
329
+
330
+ typedef void (*nk_dots_packed_punned_t)(void const *a, void const *b_packed, void *c, nk_size_t height, nk_size_t width,
331
+ nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
332
+
333
+ typedef void (*nk_dots_symmetric_punned_t)(void const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
334
+ void *result, nk_size_t result_stride, nk_size_t row_start,
335
+ nk_size_t row_count);
336
+
337
+ typedef void (*nk_hammings_packed_punned_t)(void const *a, void const *b_packed, void *c, nk_size_t height,
338
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
339
+
340
+ typedef void (*nk_hammings_symmetric_punned_t)(void const *vectors, nk_size_t n_vectors, nk_size_t depth,
341
+ nk_size_t stride, void *result, nk_size_t result_stride,
342
+ nk_size_t row_start, nk_size_t row_count);
343
+
344
+ typedef void (*nk_jaccards_packed_punned_t)(void const *a, void const *b_packed, void *c, nk_size_t height,
345
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
346
+
347
+ typedef void (*nk_jaccards_symmetric_punned_t)(void const *vectors, nk_size_t n_vectors, nk_size_t depth,
348
+ nk_size_t stride, void *result, nk_size_t result_stride,
349
+ nk_size_t row_start, nk_size_t row_count);
350
+
351
+ typedef void (*nk_angulars_packed_punned_t)(void const *a, void const *b_packed, void *c, nk_size_t height,
352
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
353
+ typedef void (*nk_angulars_symmetric_punned_t)(void const *vectors, nk_size_t n_vectors, nk_size_t depth,
354
+ nk_size_t stride, void *result, nk_size_t result_stride,
355
+ nk_size_t row_start, nk_size_t row_count);
356
+ typedef void (*nk_euclideans_packed_punned_t)(void const *a, void const *b_packed, void *c, nk_size_t height,
357
+ nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
358
+ typedef void (*nk_euclideans_symmetric_punned_t)(void const *vectors, nk_size_t n_vectors, nk_size_t depth,
359
+ nk_size_t stride, void *result, nk_size_t result_stride,
360
+ nk_size_t row_start, nk_size_t row_count);
361
+
362
+ typedef void (*nk_maxsim_packed_punned_t)(void const *q_packed, void const *d_packed, nk_size_t query_count,
363
+ nk_size_t document_count, nk_size_t depth, void *result);
364
+
365
+ typedef void (*nk_kernel_cast_punned_t)(void const *from, nk_dtype_t from_type, nk_size_t n, void *to,
366
+ nk_dtype_t to_type);
367
+
368
+ typedef void (*nk_kernel_punned_t)(void *);
369
+
370
+ #if NK_TARGET_X86_
371
+
372
+ NK_PUBLIC int nk_configure_thread_x86_(nk_capability_t capabilities) {
373
+ #if defined(_MSC_VER)
374
+ unsigned int mxcsr = _mm_getcsr();
375
+ mxcsr |= 1 << 15;
376
+ mxcsr |= 1 << 6;
377
+ _mm_setcsr(mxcsr);
378
+ #else
379
+ unsigned int mxcsr;
380
+ __asm__ __volatile__("stmxcsr %0" : "=m"(mxcsr));
381
+ mxcsr |= 1 << 15;
382
+ mxcsr |= 1 << 6;
383
+ __asm__ __volatile__("ldmxcsr %0" : : "m"(mxcsr));
384
+ #endif
385
+
386
+ #if NK_TARGET_SAPPHIREAMX
387
+ if (capabilities & nk_cap_sapphireamx_k) {
388
+ #if defined(NK_DEFINED_LINUX_)
389
+ // Linux requires explicit permission for AMX tile state via arch_prctl syscall
390
+ int const ARCH_REQ_XCOMP_PERM = 0x1023;
391
+ unsigned long const XFEATURE_XTILEDATA = 18;
392
+ syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA);
393
+ #endif
394
+ // On Windows, AMX tile state is automatically enabled by the OS if hardware supports it.
395
+ // On FreeBSD, no explicit request is needed either.
396
+ nk_unused_(capabilities);
397
+ }
398
+ #else
399
+ nk_unused_(capabilities);
400
+ #endif
401
+ return 1;
402
+ }
403
+
404
+ NK_PUBLIC nk_capability_t nk_capabilities_x86_(void) {
405
+ union four_registers_t {
406
+ int array[4];
407
+ struct separate_t {
408
+ unsigned eax, ebx, ecx, edx;
409
+ } named;
410
+ } info1, info7, info7sub1;
411
+
412
+ #if defined(_MSC_VER)
413
+ __cpuidex(info1.array, 1, 0);
414
+ __cpuidex(info7.array, 7, 0);
415
+ __cpuidex(info7sub1.array, 7, 1);
416
+ #else
417
+ __asm__ __volatile__("cpuid"
418
+ : "=a"(info1.named.eax), "=b"(info1.named.ebx), "=c"(info1.named.ecx), "=d"(info1.named.edx)
419
+ : "a"(1), "c"(0));
420
+ __asm__ __volatile__("cpuid"
421
+ : "=a"(info7.named.eax), "=b"(info7.named.ebx), "=c"(info7.named.ecx), "=d"(info7.named.edx)
422
+ : "a"(7), "c"(0));
423
+ __asm__ __volatile__("cpuid"
424
+ : "=a"(info7sub1.named.eax), "=b"(info7sub1.named.ebx), "=c"(info7sub1.named.ecx),
425
+ "=d"(info7sub1.named.edx)
426
+ : "a"(7), "c"(1));
427
+ #endif
428
+
429
+ unsigned supports_avx2 = (info7.named.ebx & 0x00000020) != 0;
430
+ unsigned supports_f16c = (info1.named.ecx & 0x20000000) != 0;
431
+ unsigned supports_fma = (info1.named.ecx & 0x00001000) != 0;
432
+ unsigned supports_avx512f = (info7.named.ebx & 0x00010000) != 0;
433
+ unsigned supports_avx512fp16 = (info7.named.edx & 0x00800000) != 0;
434
+ unsigned supports_avx512vnni = (info7.named.ecx & 0x00000800) != 0;
435
+ unsigned supports_avx512ifma = (info7.named.ebx & 0x00200000) != 0;
436
+ unsigned supports_avx512bitalg = (info7.named.ecx & 0x00001000) != 0;
437
+ unsigned supports_avx512vbmi = (info7.named.ecx & 0x00000002) != 0;
438
+ unsigned supports_avx512vbmi2 = (info7.named.ecx & 0x00000040) != 0;
439
+ unsigned supports_avx512vpopcntdq = (info7.named.ecx & 0x00004000) != 0;
440
+ unsigned supports_avx512bf16 = (info7sub1.named.eax & 0x00000020) != 0;
441
+ unsigned supports_avx512vp2intersect = (info7.named.edx & 0x00000100) != 0;
442
+ unsigned supports_amx_tile = (info7.named.edx & 0x01000000) != 0;
443
+ unsigned supports_amx_bf16 = (info7.named.edx & 0x00400000) != 0;
444
+ unsigned supports_amx_int8 = (info7.named.edx & 0x02000000) != 0;
445
+ unsigned supports_amx_fp16 = (info7sub1.named.eax & 0x00200000) != 0;
446
+ unsigned supports_avxvnni = (info7sub1.named.eax & 0x00000010) != 0;
447
+ unsigned supports_avxvnniint8 = (info7sub1.named.edx & 0x00000010) != 0;
448
+
449
+ unsigned supports_haswell = supports_avx2 && supports_f16c && supports_fma;
450
+ unsigned supports_skylake = supports_avx512f;
451
+ unsigned supports_icelake = supports_avx512vnni && supports_avx512ifma && supports_avx512bitalg &&
452
+ supports_avx512vbmi && supports_avx512vbmi2 && supports_avx512vpopcntdq;
453
+ unsigned supports_genoa = supports_avx512bf16;
454
+ unsigned supports_sapphire = supports_avx512fp16;
455
+ unsigned supports_turin = supports_avx512vp2intersect && supports_avx512bf16;
456
+ unsigned supports_sierra = supports_haswell && supports_avxvnniint8;
457
+ unsigned supports_alder = supports_haswell && supports_avxvnni;
458
+ unsigned supports_sapphireamx = supports_amx_tile && supports_amx_bf16 && supports_amx_int8;
459
+ unsigned supports_graniteamx = supports_sapphireamx && supports_amx_fp16;
460
+
461
+ return (nk_capability_t)((nk_cap_haswell_k * supports_haswell) | (nk_cap_skylake_k * supports_skylake) |
462
+ (nk_cap_icelake_k * supports_icelake) | (nk_cap_genoa_k * supports_genoa) |
463
+ (nk_cap_sapphire_k * supports_sapphire) | (nk_cap_turin_k * supports_turin) |
464
+ (nk_cap_sierra_k * supports_sierra) | (nk_cap_alder_k * supports_alder) |
465
+ (nk_cap_sapphireamx_k * supports_sapphireamx) |
466
+ (nk_cap_graniteamx_k * supports_graniteamx) | (nk_cap_serial_k));
467
+ }
468
+
469
+ #endif // NK_TARGET_X86_
470
+
471
+ #if NK_TARGET_ARM_
472
+
473
+ #if defined(__clang__)
474
+ #pragma clang attribute push(__attribute__((target("arch=armv8.5-a+sve"))), apply_to = function)
475
+ #elif defined(__GNUC__)
476
+ #pragma GCC push_options
477
+ #pragma GCC target("arch=armv8.5-a+sve")
478
+ #endif
479
+
480
+ #if NK_HAS_POSIX_EXTENSIONS_
481
+ static sigjmp_buf nk_mrs_test_jump_buffer_;
482
+ static void nk_mrs_test_sigill_handler_(int sig) {
483
+ nk_unused_(sig);
484
+ siglongjmp(nk_mrs_test_jump_buffer_, 1);
485
+ }
486
+ #endif
487
+
488
+ NK_PUBLIC int nk_configure_thread_arm_(nk_capability_t capabilities) {
489
+ nk_unused_(capabilities);
490
+ #if defined(NK_DEFINED_APPLE_)
491
+ int is_success = fesetenv(FE_DFL_DISABLE_DENORMS_ENV) == 0;
492
+ return is_success;
493
+ #elif defined(NK_DEFINED_LINUX_) || defined(NK_DEFINED_FREEBSD_)
494
+ uint64_t fpcr;
495
+ __asm__ volatile("mrs %0, fpcr" : "=r"(fpcr));
496
+ fpcr |= (1 << 19);
497
+ fpcr |= (1 << 24);
498
+ fpcr |= (1 << 25);
499
+ __asm__ volatile("msr fpcr, %0" : : "r"(fpcr));
500
+ return 1;
501
+ #else
502
+ return 0;
503
+ #endif
504
+ }
505
+
506
+ NK_PUBLIC nk_capability_t nk_capabilities_arm_(void) {
507
+ #if defined(NK_DEFINED_APPLE_)
508
+ size_t size = sizeof(unsigned);
509
+ unsigned supports_neon = 0, supports_fp16 = 0, supports_fhm = 0, supports_bf16 = 0, supports_i8mm = 0;
510
+ unsigned supports_sme = 0, supports_sme2 = 0, supports_smef64 = 0, supports_smehalf = 0, supports_sme2p1 = 0;
511
+ if (sysctlbyname("hw.optional.neon", &supports_neon, &size, NULL, 0) != 0) supports_neon = 0;
512
+ if (sysctlbyname("hw.optional.arm.FEAT_FP16", &supports_fp16, &size, NULL, 0) != 0) supports_fp16 = 0;
513
+ if (sysctlbyname("hw.optional.arm.FEAT_FHM", &supports_fhm, &size, NULL, 0) != 0) supports_fhm = 0;
514
+ if (sysctlbyname("hw.optional.arm.FEAT_BF16", &supports_bf16, &size, NULL, 0) != 0) supports_bf16 = 0;
515
+ if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &supports_i8mm, &size, NULL, 0) != 0) supports_i8mm = 0;
516
+ if (sysctlbyname("hw.optional.arm.FEAT_SME", &supports_sme, &size, NULL, 0) != 0) supports_sme = 0;
517
+ if (sysctlbyname("hw.optional.arm.FEAT_SME2", &supports_sme2, &size, NULL, 0) != 0) supports_sme2 = 0;
518
+ if (sysctlbyname("hw.optional.arm.FEAT_SME_F64F64", &supports_smef64, &size, NULL, 0) != 0) supports_smef64 = 0;
519
+ if (sysctlbyname("hw.optional.arm.FEAT_SME_F16F16", &supports_smehalf, &size, NULL, 0) != 0) supports_smehalf = 0;
520
+ if (sysctlbyname("hw.optional.arm.FEAT_SME2p1", &supports_sme2p1, &size, NULL, 0) != 0) supports_sme2p1 = 0;
521
+
522
+ return (nk_capability_t)((nk_cap_neon_k * (supports_neon)) |
523
+ (nk_cap_neonhalf_k * (supports_neon && supports_fp16)) |
524
+ (nk_cap_neonfhm_k * (supports_neon && supports_fhm)) |
525
+ (nk_cap_neonbfdot_k * (supports_neon && supports_bf16)) |
526
+ (nk_cap_neonsdot_k * (supports_neon && supports_i8mm)) | (nk_cap_sme_k * (supports_sme)) |
527
+ (nk_cap_sme2_k * (supports_sme2)) | (nk_cap_sme2p1_k * (supports_sme2p1)) |
528
+ (nk_cap_smef64_k * (supports_smef64)) | (nk_cap_smehalf_k * (supports_smehalf)) |
529
+ (nk_cap_smebf16_k * (supports_sme)) | (nk_cap_serial_k));
530
+
531
+ #elif defined(NK_DEFINED_LINUX_) || defined(NK_DEFINED_FREEBSD_)
532
+
533
+ #if NK_HAS_POSIX_EXTENSIONS_
534
+ struct sigaction action_new, action_old;
535
+ action_new.sa_handler = nk_mrs_test_sigill_handler_;
536
+ sigemptyset(&action_new.sa_mask);
537
+ action_new.sa_flags = 0;
538
+
539
+ int mrs_works = 0;
540
+ if (sigaction(SIGILL, &action_new, &action_old) == 0) {
541
+ if (sigsetjmp(nk_mrs_test_jump_buffer_, 1) == 0) {
542
+ unsigned long midr_value;
543
+ __asm__ __volatile__("mrs %0, MIDR_EL1" : "=r"(midr_value));
544
+ mrs_works = 1;
545
+ }
546
+ sigaction(SIGILL, &action_old, NULL);
547
+ }
548
+
549
+ if (!mrs_works) return (nk_capability_t)(nk_cap_neon_k | nk_cap_serial_k);
550
+ #else
551
+ return (nk_capability_t)(nk_cap_neon_k | nk_cap_serial_k);
552
+ #endif
553
+
554
+ unsigned long id_aa64isar0_el1 = 0, id_aa64isar1_el1 = 0, id_aa64pfr0_el1 = 0, id_aa64zfr0_el1 = 0;
555
+
556
+ __asm__ __volatile__("mrs %0, ID_AA64ISAR0_EL1" : "=r"(id_aa64isar0_el1));
557
+ unsigned supports_integer_dot_products = ((id_aa64isar0_el1 >> 44) & 0xF) >= 1;
558
+ unsigned supports_fhm = ((id_aa64isar0_el1 >> 48) & 0xF) >= 1;
559
+ __asm__ __volatile__("mrs %0, ID_AA64ISAR1_EL1" : "=r"(id_aa64isar1_el1));
560
+ unsigned supports_i8mm = ((id_aa64isar1_el1 >> 52) & 0xF) >= 1;
561
+ unsigned supports_bf16 = ((id_aa64isar1_el1 >> 44) & 0xF) >= 1;
562
+
563
+ __asm__ __volatile__("mrs %0, ID_AA64PFR0_EL1" : "=r"(id_aa64pfr0_el1));
564
+ unsigned supports_sve = ((id_aa64pfr0_el1 >> 32) & 0xF) >= 1;
565
+ unsigned supports_fp16 = ((id_aa64pfr0_el1 >> 20) & 0xF) == 0x1;
566
+ unsigned supports_neon = ((id_aa64pfr0_el1 >> 20) & 0xF) != 0xF;
567
+
568
+ if (supports_sve) __asm__ __volatile__("mrs %0, ID_AA64ZFR0_EL1" : "=r"(id_aa64zfr0_el1));
569
+ unsigned supports_svesdotmm = ((id_aa64zfr0_el1 >> 44) & 0xF) >= 1;
570
+ unsigned supports_svebfdot = ((id_aa64zfr0_el1 >> 20) & 0xF) >= 1;
571
+ unsigned supports_sve2 = ((id_aa64zfr0_el1) & 0xF) >= 1;
572
+ unsigned supports_sve2p1 = ((id_aa64zfr0_el1) & 0xF) >= 2;
573
+
574
+ unsigned long id_aa64pfr1_el1 = 0, id_aa64smfr0_el1 = 0;
575
+ __asm__ __volatile__("mrs %0, ID_AA64PFR1_EL1" : "=r"(id_aa64pfr1_el1));
576
+ unsigned supports_sme = ((id_aa64pfr1_el1 >> 24) & 0xF) >= 1;
577
+
578
+ unsigned supports_sme2 = 0, supports_sme2p1 = 0;
579
+ unsigned supports_smef64 = 0, supports_smehalf = 0, supports_smebf16 = 0;
580
+ unsigned supports_smelut2 = 0, supports_smefa64 = 0;
581
+ if (supports_sme) {
582
+ __asm__ __volatile__("mrs %0, ID_AA64SMFR0_EL1" : "=r"(id_aa64smfr0_el1));
583
+ unsigned sme_version = (id_aa64smfr0_el1 >> 56) & 0xF;
584
+ supports_sme2 = sme_version >= 1;
585
+ supports_sme2p1 = sme_version >= 2;
586
+ supports_smef64 = (id_aa64smfr0_el1 >> 48) & 0x1;
587
+ supports_smehalf = (id_aa64smfr0_el1 >> 42) & 0x1;
588
+ supports_smebf16 = (id_aa64smfr0_el1 >> 44) & 0x1;
589
+ supports_smefa64 = (id_aa64smfr0_el1 >> 63) & 0x1;
590
+ }
591
+
592
+ return (nk_capability_t)((nk_cap_neon_k * (supports_neon)) |
593
+ (nk_cap_neonhalf_k * (supports_neon && supports_fp16)) |
594
+ (nk_cap_neonfhm_k * (supports_neon && supports_fhm)) |
595
+ (nk_cap_neonbfdot_k * (supports_neon && supports_bf16)) |
596
+ (nk_cap_neonsdot_k * (supports_neon && supports_i8mm && supports_integer_dot_products)) |
597
+ (nk_cap_sve_k * (supports_sve)) | (nk_cap_svehalf_k * (supports_sve && supports_fp16)) |
598
+ (nk_cap_svebfdot_k * (supports_sve && supports_svebfdot)) |
599
+ (nk_cap_svesdot_k * (supports_sve && supports_svesdotmm)) |
600
+ (nk_cap_sve2_k * (supports_sve2)) | (nk_cap_sve2p1_k * (supports_sve2p1)) |
601
+ (nk_cap_sme_k * (supports_sme)) | (nk_cap_sme2_k * (supports_sme2)) |
602
+ (nk_cap_sme2p1_k * (supports_sme2p1)) | (nk_cap_smef64_k * (supports_smef64)) |
603
+ (nk_cap_smehalf_k * (supports_smehalf)) | (nk_cap_smebf16_k * (supports_smebf16)) |
604
+ (nk_cap_smefa64_k * (supports_smefa64)) | (nk_cap_serial_k));
605
+ #elif defined(NK_DEFINED_WINDOWS_)
606
+
607
+ unsigned supports_neon = 0, supports_dp = 0;
608
+
609
+ #if defined(PF_ARM_V8_INSTRUCTIONS_AVAILABLE)
610
+ supports_neon = IsProcessorFeaturePresent(PF_ARM_V8_INSTRUCTIONS_AVAILABLE);
611
+ #endif
612
+ #if defined(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE)
613
+ supports_dp = IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE);
614
+ #endif
615
+
616
+ return (nk_capability_t)((nk_cap_neon_k * (supports_neon)) | (nk_cap_neonsdot_k * (supports_neon && supports_dp)) |
617
+ (nk_cap_serial_k));
618
+
619
+ #else
620
+ return (nk_capability_t)(nk_cap_neon_k | nk_cap_serial_k);
621
+ #endif
622
+ }
623
+
624
+ #if defined(__clang__)
625
+ #pragma clang attribute pop
626
+ #elif defined(__GNUC__)
627
+ #pragma GCC pop_options
628
+ #endif
629
+
630
+ #endif // NK_TARGET_ARM_
631
+
632
+ #if NK_TARGET_RISCV_
633
+
634
+ NK_PUBLIC nk_capability_t nk_capabilities_riscv_(void) {
635
+ #if defined(NK_DEFINED_LINUX_)
636
+ unsigned long hwcap = getauxval(AT_HWCAP);
637
+ nk_capability_t caps = nk_cap_serial_k;
638
+ if (hwcap & (1UL << 21)) {
639
+ caps |= nk_cap_rvv_k;
640
+ struct {
641
+ long key;
642
+ unsigned long value;
643
+ } pairs[1] = {{4, 0}};
644
+ if (syscall(258, pairs, 1, 0, (void *)0, 0) == 0) {
645
+ if (pairs[0].value & (1ULL << 30)) caps |= nk_cap_rvvhalf_k;
646
+ if (pairs[0].value & (1ULL << 54)) caps |= nk_cap_rvvbf16_k;
647
+ if (pairs[0].value & (1ULL << 48)) caps |= nk_cap_rvvbb_k; // Zvbb
648
+ }
649
+ }
650
+ return caps;
651
+ #elif defined(NK_DEFINED_FREEBSD_)
652
+ unsigned long hwcap = 0;
653
+ elf_aux_info(AT_HWCAP, &hwcap, sizeof(hwcap));
654
+ nk_capability_t caps = nk_cap_serial_k;
655
+ if (hwcap & (1UL << 21)) {
656
+ caps |= nk_cap_rvv_k;
657
+ // FreeBSD lacks the Linux hwprobe syscall (258),
658
+ // so zvfh/zvfbfwma/zvbb are compile-time only.
659
+ }
660
+ return caps;
661
+ #else
662
+ return nk_cap_serial_k;
663
+ #endif
664
+ }
665
+
666
+ #endif // NK_TARGET_RISCV_
667
+
668
+ #if NK_TARGET_WASM_
669
+
670
+ #if defined(__EMSCRIPTEN__) && NK_DYNAMIC_DISPATCH && !defined(NK_PYODIDE_SIDE_MODULE)
671
+ // Standalone Emscripten dynamic dispatch: EM_JS probes defined in c/numkong.c.
672
+ extern int nk_has_v128(void);
673
+ extern int nk_has_relaxed(void);
674
+ #elif defined(__wasi__) && NK_DEFINED_WASI_
675
+ // WASI hosted (NK_WASI_HOSTED=ON): the host provides capability probes via imports.
676
+ __attribute__((__import_module__("env"), __import_name__("nk_has_v128"))) extern int nk_has_v128(void);
677
+ __attribute__((__import_module__("env"), __import_name__("nk_has_relaxed"))) extern int nk_has_relaxed(void);
678
+ #endif
679
+
680
+ NK_PUBLIC nk_capability_t nk_capabilities_v128relaxed_(void) {
681
+ #if ((defined(__EMSCRIPTEN__) && NK_DYNAMIC_DISPATCH) || (defined(__wasi__) && NK_DEFINED_WASI_)) && \
682
+ !defined(NK_PYODIDE_SIDE_MODULE)
683
+ // Hosted environment (Emscripten or WASI with NK_WASI_HOSTED): the host provides
684
+ // runtime probes. Compile-time flags only mean the *compiler* emitted relaxed-SIMD
685
+ // opcodes, not that the current runtime can execute them.
686
+ int has_relaxed = nk_has_relaxed();
687
+ return has_relaxed ? (nk_cap_serial_k | nk_cap_v128relaxed_k) : nk_cap_serial_k;
688
+ #elif defined(__wasm_relaxed_simd__) || defined(__wasm_simd128__)
689
+ // Static WASM or Pyodide side module: if the compiler targeted relaxed SIMD,
690
+ // the runtime must support it (modules with relaxed opcodes fail validation otherwise).
691
+ return nk_cap_serial_k | nk_cap_v128relaxed_k;
692
+ #else
693
+ return nk_cap_serial_k;
694
+ #endif
695
+ }
696
+
697
+ #endif // NK_TARGET_WASM_
698
+
699
+ NK_PUBLIC int nk_configure_thread_(nk_capability_t capabilities) {
700
+ #if NK_TARGET_X86_
701
+ return nk_configure_thread_x86_(capabilities);
702
+ #endif
703
+ #if NK_TARGET_ARM_
704
+ return nk_configure_thread_arm_(capabilities);
705
+ #endif
706
+ nk_unused_(capabilities);
707
+ return 1; // success — no platform-specific thread configuration needed
708
+ }
709
+
710
+ NK_PUBLIC nk_capability_t nk_capabilities_(void) {
711
+ #if NK_TARGET_X86_
712
+ return nk_capabilities_x86_();
713
+ #endif
714
+ #if NK_TARGET_ARM_
715
+ return nk_capabilities_arm_();
716
+ #endif
717
+ #if NK_TARGET_RISCV_
718
+ return nk_capabilities_riscv_();
719
+ #endif
720
+ #if NK_TARGET_WASM_
721
+ return nk_capabilities_v128relaxed_();
722
+ #endif
723
+ return nk_cap_serial_k;
724
+ }
725
+
726
+ #if NK_DYNAMIC_DISPATCH
727
+
728
+ NK_DYNAMIC nk_capability_t nk_capabilities(void);
729
+ NK_DYNAMIC int nk_configure_thread(nk_capability_t);
730
+ NK_DYNAMIC int nk_uses_dynamic_dispatch(void);
731
+ NK_DYNAMIC void nk_dispatch_table_update(nk_capability_t);
732
+ NK_DYNAMIC void nk_find_kernel_punned(nk_kernel_kind_t kind, nk_dtype_t dtype, nk_capability_t viable,
733
+ nk_kernel_punned_t *kernel_output, nk_capability_t *capability_output);
734
+
735
+ #else
736
+
737
+ NK_PUBLIC int nk_uses_dynamic_dispatch(void) { return 0; }
738
+ NK_PUBLIC int nk_configure_thread(nk_capability_t c) { return nk_configure_thread_(c); }
739
+ NK_PUBLIC nk_capability_t nk_capabilities(void) { return nk_capabilities_(); }
740
+ NK_PUBLIC void nk_dispatch_table_update(nk_capability_t caps) { nk_unused_(caps); }
741
+
742
+ #endif
743
+
744
+ #ifdef __cplusplus
745
+ }
746
+
747
+ #endif
748
+ #endif // NK_CAPABILITIES_H