numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -28,7 +28,7 @@
28
28
  *
29
29
  * @section references References
30
30
  *
31
- * - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
31
+ * - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html
32
32
  * - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
33
33
  *
34
34
  */
@@ -84,7 +84,7 @@
84
84
  *
85
85
  * @section references References
86
86
  *
87
- * - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide
87
+ * - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html
88
88
  * - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics
89
89
  * - Detecting target CPU features at compile time: https://stackoverflow.com/a/28939692/2766161
90
90
  */
@@ -95,8 +95,8 @@
95
95
  #include "numkong/types.h" // `nk_u64_t`, `NK_DEFINED_LINUX_`
96
96
 
97
97
  #define NK_VERSION_MAJOR 7
98
- #define NK_VERSION_MINOR 0
99
- #define NK_VERSION_PATCH 0
98
+ #define NK_VERSION_MINOR 4
99
+ #define NK_VERSION_PATCH 2
100
100
 
101
101
  /**
102
102
  * @brief Removes compile-time dispatching, and replaces it with runtime dispatching.
@@ -117,6 +117,9 @@
117
117
  // Detect POSIX extensions availability for signal handling.
118
118
  // POSIX extensions provide `sigaction`, `sigjmp_buf`, and `sigsetjmp` for safe signal handling.
119
119
  // These are needed on Linux ARM for safely testing `mrs` instruction availability.
120
+ #if defined(NK_DEFINED_LINUX_) || defined(NK_DEFINED_FREEBSD_)
121
+ #include <unistd.h> // `_POSIX_VERSION`
122
+ #endif
120
123
  #if (defined(NK_DEFINED_LINUX_) || defined(NK_DEFINED_FREEBSD_)) && defined(_POSIX_VERSION)
121
124
  #include <setjmp.h> // `sigjmp_buf`, `sigsetjmp`, `siglongjmp`
122
125
  #include <signal.h> // `sigaction`, `SIGILL`
@@ -141,6 +144,14 @@ extern long syscall(long, ...);
141
144
  #endif
142
145
  #endif
143
146
 
147
+ #if defined(NK_DEFINED_LINUX_) && NK_TARGET_LOONGARCH_
148
+ #include <sys/auxv.h> // `getauxval`, `AT_HWCAP`
149
+ #endif
150
+
151
+ #if defined(NK_DEFINED_LINUX_) && NK_TARGET_POWER_
152
+ #include <sys/auxv.h> // `getauxval`, `AT_HWCAP`
153
+ #endif
154
+
144
155
  // On FreeBSD RISC-V, we use elf_aux_info for capability detection
145
156
  #if defined(NK_DEFINED_FREEBSD_) && NK_TARGET_RISCV_
146
157
  #include <sys/auxv.h> // `elf_aux_info`, `AT_HWCAP`
@@ -286,8 +297,13 @@ typedef nk_u64_t nk_capability_t;
286
297
  #define nk_cap_smelut2_k ((nk_capability_t)1 << 32)
287
298
  #define nk_cap_rvvbb_k ((nk_capability_t)1 << 33)
288
299
  #define nk_cap_sierra_k ((nk_capability_t)1 << 34)
300
+ #define nk_cap_smebi32_k ((nk_capability_t)1 << 35)
301
+ #define nk_cap_loongsonasx_k ((nk_capability_t)1 << 36)
302
+ #define nk_cap_powervsx_k ((nk_capability_t)1 << 37)
303
+ #define nk_cap_diamond_k ((nk_capability_t)1 << 38)
304
+ #define nk_cap_neonfp8_k ((nk_capability_t)1 << 39)
289
305
 
290
- typedef void (*nk_metric_dense_punned_t)(void const *a, void const *b, nk_size_t n, void *d);
306
+ typedef void (*nk_metric_dense_punned_t)(void const *a, void const *b, nk_size_t dimensions, void *result);
291
307
 
292
308
  typedef void (*nk_sparse_intersect_punned_t)(void const *a, void const *b, nk_size_t a_length, nk_size_t b_length,
293
309
  void *result, nk_size_t *count);
@@ -295,25 +311,26 @@ typedef void (*nk_sparse_intersect_punned_t)(void const *a, void const *b, nk_si
295
311
  typedef void (*nk_sparse_dot_punned_t)(void const *a, void const *b, void const *a_weights, void const *b_weights,
296
312
  nk_size_t a_length, nk_size_t b_length, void *product);
297
313
 
298
- typedef void (*nk_metric_curved_punned_t)(void const *a, void const *b, void const *c, nk_size_t n, void *d);
314
+ typedef void (*nk_metric_curved_punned_t)(void const *a, void const *b, void const *c, nk_size_t dimensions,
315
+ void *result);
299
316
 
300
317
  typedef void (*nk_metric_geospatial_punned_t)(void const *a_lats, void const *a_lons, void const *b_lats,
301
- void const *b_lons, nk_size_t n, void *results);
318
+ void const *b_lons, nk_size_t count, void *results);
302
319
 
303
- typedef void (*nk_each_scale_punned_t)(void const *a, nk_size_t n, void const *alpha, void const *beta, void *y);
320
+ typedef void (*nk_each_scale_punned_t)(void const *a, nk_size_t count, void const *alpha, void const *beta, void *y);
304
321
 
305
- typedef void (*nk_each_sum_punned_t)(void const *a, void const *b, nk_size_t n, void *y);
322
+ typedef void (*nk_each_sum_punned_t)(void const *a, void const *b, nk_size_t count, void *y);
306
323
 
307
- typedef void (*nk_each_blend_punned_t)(void const *a, void const *b, nk_size_t n, void const *alpha, void const *beta,
308
- void *y);
324
+ typedef void (*nk_each_blend_punned_t)(void const *a, void const *b, nk_size_t count, void const *alpha,
325
+ void const *beta, void *y);
309
326
 
310
- typedef void (*nk_each_fma_punned_t)(void const *a, void const *b, void const *c, nk_size_t n, void const *alpha,
327
+ typedef void (*nk_each_fma_punned_t)(void const *a, void const *b, void const *c, nk_size_t count, void const *alpha,
311
328
  void const *beta, void *y);
312
329
 
313
- typedef void (*nk_kernel_trigonometry_punned_t)(void const *x, nk_size_t n, void *y);
330
+ typedef void (*nk_kernel_trigonometry_punned_t)(void const *x, nk_size_t count, void *y);
314
331
 
315
- typedef void (*nk_metric_mesh_punned_t)(void const *a, void const *b, nk_size_t n, void *a_centroid, void *b_centroid,
316
- void *rotation, void *scale, void *d);
332
+ typedef void (*nk_metric_mesh_punned_t)(void const *a, void const *b, nk_size_t points_count, void *a_centroid,
333
+ void *b_centroid, void *rotation, void *scale, void *result);
317
334
 
318
335
  typedef void (*nk_kernel_reduce_moments_punned_t)(void const *data, nk_size_t count, nk_size_t stride_bytes,
319
336
  void *sum_ptr, void *sumsq_ptr);
@@ -322,47 +339,51 @@ typedef void (*nk_kernel_reduce_minmax_punned_t)(void const *data, nk_size_t cou
322
339
  void *min_value, nk_size_t *min_index, void *max_value,
323
340
  nk_size_t *max_index);
324
341
 
325
- typedef nk_size_t (*nk_dots_packed_size_punned_t)(nk_size_t width, nk_size_t depth);
342
+ typedef nk_size_t (*nk_dots_packed_size_punned_t)(nk_size_t columns, nk_size_t depth);
326
343
 
327
- typedef void (*nk_dots_pack_punned_t)(void const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
344
+ typedef void (*nk_dots_pack_punned_t)(void const *b, nk_size_t columns, nk_size_t depth, nk_size_t b_stride_bytes,
328
345
  void *b_packed);
329
346
 
330
- typedef void (*nk_dots_packed_punned_t)(void const *a, void const *b_packed, void *c, nk_size_t height, nk_size_t width,
331
- nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
347
+ typedef void (*nk_dots_packed_punned_t)(void const *a, void const *b_packed, void *c, nk_size_t rows, nk_size_t columns,
348
+ nk_size_t depth, nk_size_t a_stride_bytes, nk_size_t c_stride_bytes);
332
349
 
333
- typedef void (*nk_dots_symmetric_punned_t)(void const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
334
- void *result, nk_size_t result_stride, nk_size_t row_start,
335
- nk_size_t row_count);
350
+ typedef void (*nk_dots_symmetric_punned_t)(void const *vectors, nk_size_t vectors_count, nk_size_t depth,
351
+ nk_size_t stride_bytes, void *result, nk_size_t result_stride_bytes,
352
+ nk_size_t row_start, nk_size_t row_count);
336
353
 
337
- typedef void (*nk_hammings_packed_punned_t)(void const *a, void const *b_packed, void *c, nk_size_t height,
338
- nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
354
+ typedef void (*nk_hammings_packed_punned_t)(void const *a, void const *b_packed, void *c, nk_size_t rows,
355
+ nk_size_t columns, nk_size_t depth, nk_size_t a_stride_bytes,
356
+ nk_size_t c_stride_bytes);
339
357
 
340
- typedef void (*nk_hammings_symmetric_punned_t)(void const *vectors, nk_size_t n_vectors, nk_size_t depth,
341
- nk_size_t stride, void *result, nk_size_t result_stride,
358
+ typedef void (*nk_hammings_symmetric_punned_t)(void const *vectors, nk_size_t vectors_count, nk_size_t depth,
359
+ nk_size_t stride_bytes, void *result, nk_size_t result_stride_bytes,
342
360
  nk_size_t row_start, nk_size_t row_count);
343
361
 
344
- typedef void (*nk_jaccards_packed_punned_t)(void const *a, void const *b_packed, void *c, nk_size_t height,
345
- nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
362
+ typedef void (*nk_jaccards_packed_punned_t)(void const *a, void const *b_packed, void *c, nk_size_t rows,
363
+ nk_size_t columns, nk_size_t depth, nk_size_t a_stride_bytes,
364
+ nk_size_t c_stride_bytes);
346
365
 
347
- typedef void (*nk_jaccards_symmetric_punned_t)(void const *vectors, nk_size_t n_vectors, nk_size_t depth,
348
- nk_size_t stride, void *result, nk_size_t result_stride,
366
+ typedef void (*nk_jaccards_symmetric_punned_t)(void const *vectors, nk_size_t vectors_count, nk_size_t depth,
367
+ nk_size_t stride_bytes, void *result, nk_size_t result_stride_bytes,
349
368
  nk_size_t row_start, nk_size_t row_count);
350
369
 
351
- typedef void (*nk_angulars_packed_punned_t)(void const *a, void const *b_packed, void *c, nk_size_t height,
352
- nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
353
- typedef void (*nk_angulars_symmetric_punned_t)(void const *vectors, nk_size_t n_vectors, nk_size_t depth,
354
- nk_size_t stride, void *result, nk_size_t result_stride,
370
+ typedef void (*nk_angulars_packed_punned_t)(void const *a, void const *b_packed, void *c, nk_size_t rows,
371
+ nk_size_t columns, nk_size_t depth, nk_size_t a_stride_bytes,
372
+ nk_size_t c_stride_bytes);
373
+ typedef void (*nk_angulars_symmetric_punned_t)(void const *vectors, nk_size_t vectors_count, nk_size_t depth,
374
+ nk_size_t stride_bytes, void *result, nk_size_t result_stride_bytes,
355
375
  nk_size_t row_start, nk_size_t row_count);
356
- typedef void (*nk_euclideans_packed_punned_t)(void const *a, void const *b_packed, void *c, nk_size_t height,
357
- nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
358
- typedef void (*nk_euclideans_symmetric_punned_t)(void const *vectors, nk_size_t n_vectors, nk_size_t depth,
359
- nk_size_t stride, void *result, nk_size_t result_stride,
376
+ typedef void (*nk_euclideans_packed_punned_t)(void const *a, void const *b_packed, void *c, nk_size_t rows,
377
+ nk_size_t columns, nk_size_t depth, nk_size_t a_stride_bytes,
378
+ nk_size_t c_stride_bytes);
379
+ typedef void (*nk_euclideans_symmetric_punned_t)(void const *vectors, nk_size_t vectors_count, nk_size_t depth,
380
+ nk_size_t stride_bytes, void *result, nk_size_t result_stride_bytes,
360
381
  nk_size_t row_start, nk_size_t row_count);
361
382
 
362
383
  typedef void (*nk_maxsim_packed_punned_t)(void const *q_packed, void const *d_packed, nk_size_t query_count,
363
384
  nk_size_t document_count, nk_size_t depth, void *result);
364
385
 
365
- typedef void (*nk_kernel_cast_punned_t)(void const *from, nk_dtype_t from_type, nk_size_t n, void *to,
386
+ typedef void (*nk_kernel_cast_punned_t)(void const *from, nk_dtype_t from_type, nk_size_t count, void *to,
366
387
  nk_dtype_t to_type);
367
388
 
368
389
  typedef void (*nk_kernel_punned_t)(void *);
@@ -370,19 +391,6 @@ typedef void (*nk_kernel_punned_t)(void *);
370
391
  #if NK_TARGET_X86_
371
392
 
372
393
  NK_PUBLIC int nk_configure_thread_x86_(nk_capability_t capabilities) {
373
- #if defined(_MSC_VER)
374
- unsigned int mxcsr = _mm_getcsr();
375
- mxcsr |= 1 << 15;
376
- mxcsr |= 1 << 6;
377
- _mm_setcsr(mxcsr);
378
- #else
379
- unsigned int mxcsr;
380
- __asm__ __volatile__("stmxcsr %0" : "=m"(mxcsr));
381
- mxcsr |= 1 << 15;
382
- mxcsr |= 1 << 6;
383
- __asm__ __volatile__("ldmxcsr %0" : : "m"(mxcsr));
384
- #endif
385
-
386
394
  #if NK_TARGET_SAPPHIREAMX
387
395
  if (capabilities & nk_cap_sapphireamx_k) {
388
396
  #if defined(NK_DEFINED_LINUX_)
@@ -407,13 +415,17 @@ NK_PUBLIC nk_capability_t nk_capabilities_x86_(void) {
407
415
  struct separate_t {
408
416
  unsigned eax, ebx, ecx, edx;
409
417
  } named;
410
- } info1, info7, info7sub1;
418
+ } info0, info1, info7, info7sub1;
411
419
 
412
420
  #if defined(_MSC_VER)
421
+ __cpuidex(info0.array, 0, 0);
413
422
  __cpuidex(info1.array, 1, 0);
414
423
  __cpuidex(info7.array, 7, 0);
415
424
  __cpuidex(info7sub1.array, 7, 1);
416
425
  #else
426
+ __asm__ __volatile__("cpuid"
427
+ : "=a"(info0.named.eax), "=b"(info0.named.ebx), "=c"(info0.named.ecx), "=d"(info0.named.edx)
428
+ : "a"(0), "c"(0));
417
429
  __asm__ __volatile__("cpuid"
418
430
  : "=a"(info1.named.eax), "=b"(info1.named.ebx), "=c"(info1.named.ecx), "=d"(info1.named.edx)
419
431
  : "a"(1), "c"(0));
@@ -426,6 +438,7 @@ NK_PUBLIC nk_capability_t nk_capabilities_x86_(void) {
426
438
  : "a"(7), "c"(1));
427
439
  #endif
428
440
 
441
+ unsigned max_leaf = info0.named.eax;
429
442
  unsigned supports_avx2 = (info7.named.ebx & 0x00000020) != 0;
430
443
  unsigned supports_f16c = (info1.named.ecx & 0x20000000) != 0;
431
444
  unsigned supports_fma = (info1.named.ecx & 0x00001000) != 0;
@@ -446,12 +459,29 @@ NK_PUBLIC nk_capability_t nk_capabilities_x86_(void) {
446
459
  unsigned supports_avxvnni = (info7sub1.named.eax & 0x00000010) != 0;
447
460
  unsigned supports_avxvnniint8 = (info7sub1.named.edx & 0x00000010) != 0;
448
461
 
462
+ // AVX10.2 detection via CPUID leaf 0x24, subleaf 0.
463
+ // EBX[7:0] contains the AVX10 convergent vector ISA version number.
464
+ unsigned supports_avx10v2 = 0;
465
+ if (max_leaf >= 0x24) {
466
+ union four_registers_t info24;
467
+ #if defined(_MSC_VER)
468
+ __cpuidex(info24.array, 0x24, 0);
469
+ #else
470
+ __asm__ __volatile__("cpuid"
471
+ : "=a"(info24.named.eax), "=b"(info24.named.ebx), "=c"(info24.named.ecx),
472
+ "=d"(info24.named.edx)
473
+ : "a"(0x24), "c"(0));
474
+ #endif
475
+ supports_avx10v2 = (info24.named.ebx & 0xFF) >= 2;
476
+ }
477
+
449
478
  unsigned supports_haswell = supports_avx2 && supports_f16c && supports_fma;
450
479
  unsigned supports_skylake = supports_avx512f;
451
480
  unsigned supports_icelake = supports_avx512vnni && supports_avx512ifma && supports_avx512bitalg &&
452
481
  supports_avx512vbmi && supports_avx512vbmi2 && supports_avx512vpopcntdq;
453
482
  unsigned supports_genoa = supports_avx512bf16;
454
483
  unsigned supports_sapphire = supports_avx512fp16;
484
+ unsigned supports_diamond = supports_avx10v2 && supports_sapphire;
455
485
  unsigned supports_turin = supports_avx512vp2intersect && supports_avx512bf16;
456
486
  unsigned supports_sierra = supports_haswell && supports_avxvnniint8;
457
487
  unsigned supports_alder = supports_haswell && supports_avxvnni;
@@ -460,9 +490,9 @@ NK_PUBLIC nk_capability_t nk_capabilities_x86_(void) {
460
490
 
461
491
  return (nk_capability_t)((nk_cap_haswell_k * supports_haswell) | (nk_cap_skylake_k * supports_skylake) |
462
492
  (nk_cap_icelake_k * supports_icelake) | (nk_cap_genoa_k * supports_genoa) |
463
- (nk_cap_sapphire_k * supports_sapphire) | (nk_cap_turin_k * supports_turin) |
464
- (nk_cap_sierra_k * supports_sierra) | (nk_cap_alder_k * supports_alder) |
465
- (nk_cap_sapphireamx_k * supports_sapphireamx) |
493
+ (nk_cap_diamond_k * supports_diamond) | (nk_cap_sapphire_k * supports_sapphire) |
494
+ (nk_cap_turin_k * supports_turin) | (nk_cap_sierra_k * supports_sierra) |
495
+ (nk_cap_alder_k * supports_alder) | (nk_cap_sapphireamx_k * supports_sapphireamx) |
466
496
  (nk_cap_graniteamx_k * supports_graniteamx) | (nk_cap_serial_k));
467
497
  }
468
498
 
@@ -486,28 +516,66 @@ static void nk_mrs_test_sigill_handler_(int sig) {
486
516
  #endif
487
517
 
488
518
  NK_PUBLIC int nk_configure_thread_arm_(nk_capability_t capabilities) {
519
+ #if defined(_MSC_VER)
489
520
  nk_unused_(capabilities);
521
+ return 1;
522
+ #else
523
+ // FPCR.EBF (bit 13) — requires FEAT_EBF16:
524
+ // Enables fused BF16 dot-product semantics for BFDOT/BFMOPA/BFMMLA.
525
+ // Without it, each bf16×bf16 product is individually rounded (Round-to-Odd) before
526
+ // summation (3-way rounding). With EBF=1, intermediates are summed before rounding,
527
+ // matching x86 VDPBF16PS (Genoa/Sapphire Rapids) precision.
528
+ //
529
+ // FPCR.AH (bit 1) — requires FEAT_AFP — is intentionally NOT set:
530
+ // It enables alternate floating-point behavior (FEAT_RPRES: 12-bit FRECPE/FRSQRTE),
531
+ // but also changes sign-bit handling in ways that break CPython's `decimal` module
532
+ // (`Decimal.from_float()` drops the sign of negative values). The FEAT_RPRES benefit
533
+ // (saving one Newton-Raphson iteration) is not worth the process-wide side effects.
534
+ //
535
+ // EBF defaults to 0 on process creation (kernel zeroes FPCR). Setting it is
536
+ // ABI-legal per AAPCS64. Writing this bit on hardware without FEAT_EBF16
537
+ // is unsafe (it is RES0), so we gate on feature detection.
538
+ unsigned long fpcr_desired = 0;
539
+
490
540
  #if defined(NK_DEFINED_APPLE_)
491
- int is_success = fesetenv(FE_DFL_DISABLE_DENORMS_ENV) == 0;
492
- return is_success;
541
+ nk_unused_(capabilities);
542
+ size_t sysctl_size = sizeof(unsigned);
543
+ unsigned has_ebf16 = 0;
544
+ if (sysctlbyname("hw.optional.arm.FEAT_EBF16", &has_ebf16, &sysctl_size, NULL, 0) != 0) has_ebf16 = 0;
545
+ if (has_ebf16) fpcr_desired |= (1UL << 13);
546
+
493
547
  #elif defined(NK_DEFINED_LINUX_) || defined(NK_DEFINED_FREEBSD_)
494
- uint64_t fpcr;
495
- __asm__ volatile("mrs %0, fpcr" : "=r"(fpcr));
496
- fpcr |= (1 << 19);
497
- fpcr |= (1 << 24);
498
- fpcr |= (1 << 25);
499
- __asm__ volatile("msr fpcr, %0" : : "r"(fpcr));
500
- return 1;
548
+ // Read ID registers via MRS. Only safe if MRS is known to work — indicated by
549
+ // capabilities beyond basic NEON (nk_capabilities_arm_ validated MRS via sigaction probe).
550
+ if (capabilities & ~(nk_cap_neon_k | nk_cap_serial_k)) {
551
+ // FEAT_EBF16: ID_AA64ISAR1_EL1.BF16 bits [47:44] >= 0b0010
552
+ register unsigned long isar1_val __asm__("x0");
553
+ __asm__ __volatile__(".inst 0xD5380620" : "=r"(isar1_val)); // MRS x0, ID_AA64ISAR1_EL1
554
+ if (((isar1_val >> 44) & 0xF) >= 2) fpcr_desired |= (1UL << 13);
555
+ }
556
+ else { nk_unused_(capabilities); }
501
557
  #else
502
- return 0;
558
+ nk_unused_(capabilities);
503
559
  #endif
560
+
561
+ if (fpcr_desired) {
562
+ unsigned long fpcr_val;
563
+ __asm__ __volatile__("mrs %0, fpcr" : "=r"(fpcr_val));
564
+ if ((fpcr_val & fpcr_desired) != fpcr_desired) {
565
+ fpcr_val |= fpcr_desired;
566
+ __asm__ __volatile__("msr fpcr, %0" : : "r"(fpcr_val));
567
+ }
568
+ }
569
+ return 1;
570
+ #endif // _MSC_VER
504
571
  }
505
572
 
506
573
  NK_PUBLIC nk_capability_t nk_capabilities_arm_(void) {
507
574
  #if defined(NK_DEFINED_APPLE_)
508
575
  size_t size = sizeof(unsigned);
509
576
  unsigned supports_neon = 0, supports_fp16 = 0, supports_fhm = 0, supports_bf16 = 0, supports_i8mm = 0;
510
- unsigned supports_sme = 0, supports_sme2 = 0, supports_smef64 = 0, supports_smehalf = 0, supports_sme2p1 = 0;
577
+ unsigned supports_sme = 0, supports_sme2 = 0, supports_smef64 = 0, supports_smehalf = 0, supports_sme2p1 = 0,
578
+ supports_smebi32 = 0;
511
579
  if (sysctlbyname("hw.optional.neon", &supports_neon, &size, NULL, 0) != 0) supports_neon = 0;
512
580
  if (sysctlbyname("hw.optional.arm.FEAT_FP16", &supports_fp16, &size, NULL, 0) != 0) supports_fp16 = 0;
513
581
  if (sysctlbyname("hw.optional.arm.FEAT_FHM", &supports_fhm, &size, NULL, 0) != 0) supports_fhm = 0;
@@ -518,6 +586,7 @@ NK_PUBLIC nk_capability_t nk_capabilities_arm_(void) {
518
586
  if (sysctlbyname("hw.optional.arm.FEAT_SME_F64F64", &supports_smef64, &size, NULL, 0) != 0) supports_smef64 = 0;
519
587
  if (sysctlbyname("hw.optional.arm.FEAT_SME_F16F16", &supports_smehalf, &size, NULL, 0) != 0) supports_smehalf = 0;
520
588
  if (sysctlbyname("hw.optional.arm.FEAT_SME2p1", &supports_sme2p1, &size, NULL, 0) != 0) supports_sme2p1 = 0;
589
+ if (sysctlbyname("hw.optional.arm.SME_BI32I32", &supports_smebi32, &size, NULL, 0) != 0) supports_smebi32 = 0;
521
590
 
522
591
  return (nk_capability_t)((nk_cap_neon_k * (supports_neon)) |
523
592
  (nk_cap_neonhalf_k * (supports_neon && supports_fp16)) |
@@ -526,7 +595,8 @@ NK_PUBLIC nk_capability_t nk_capabilities_arm_(void) {
526
595
  (nk_cap_neonsdot_k * (supports_neon && supports_i8mm)) | (nk_cap_sme_k * (supports_sme)) |
527
596
  (nk_cap_sme2_k * (supports_sme2)) | (nk_cap_sme2p1_k * (supports_sme2p1)) |
528
597
  (nk_cap_smef64_k * (supports_smef64)) | (nk_cap_smehalf_k * (supports_smehalf)) |
529
- (nk_cap_smebf16_k * (supports_sme)) | (nk_cap_serial_k));
598
+ (nk_cap_smebf16_k * (supports_sme)) | (nk_cap_smebi32_k * (supports_smebi32)) |
599
+ (nk_cap_serial_k));
530
600
 
531
601
  #elif defined(NK_DEFINED_LINUX_) || defined(NK_DEFINED_FREEBSD_)
532
602
 
@@ -539,8 +609,8 @@ NK_PUBLIC nk_capability_t nk_capabilities_arm_(void) {
539
609
  int mrs_works = 0;
540
610
  if (sigaction(SIGILL, &action_new, &action_old) == 0) {
541
611
  if (sigsetjmp(nk_mrs_test_jump_buffer_, 1) == 0) {
542
- unsigned long midr_value;
543
- __asm__ __volatile__("mrs %0, MIDR_EL1" : "=r"(midr_value));
612
+ register unsigned long midr_value __asm__("x0");
613
+ __asm__ __volatile__(".inst 0xD5380000" : "=r"(midr_value)); // MRS x0, MIDR_EL1
544
614
  mrs_works = 1;
545
615
  }
546
616
  sigaction(SIGILL, &action_old, NULL);
@@ -551,57 +621,82 @@ NK_PUBLIC nk_capability_t nk_capabilities_arm_(void) {
551
621
  return (nk_capability_t)(nk_cap_neon_k | nk_cap_serial_k);
552
622
  #endif
553
623
 
554
- unsigned long id_aa64isar0_el1 = 0, id_aa64isar1_el1 = 0, id_aa64pfr0_el1 = 0, id_aa64zfr0_el1 = 0;
624
+ unsigned long id_aa64isar0_el1 = 0, id_aa64isar1_el1 = 0, id_aa64pfr0_el1 = 0, id_aa64zfr0_el1 = 0,
625
+ id_aa64fpfr0_el1 = 0;
555
626
 
556
- __asm__ __volatile__("mrs %0, ID_AA64ISAR0_EL1" : "=r"(id_aa64isar0_el1));
627
+ register unsigned long __isar0 __asm__("x0");
628
+ __asm__ __volatile__(".inst 0xD5380600" : "=r"(__isar0)); // MRS x0, ID_AA64ISAR0_EL1
629
+ id_aa64isar0_el1 = __isar0;
557
630
  unsigned supports_integer_dot_products = ((id_aa64isar0_el1 >> 44) & 0xF) >= 1;
558
631
  unsigned supports_fhm = ((id_aa64isar0_el1 >> 48) & 0xF) >= 1;
559
- __asm__ __volatile__("mrs %0, ID_AA64ISAR1_EL1" : "=r"(id_aa64isar1_el1));
632
+ register unsigned long __isar1 __asm__("x0");
633
+ __asm__ __volatile__(".inst 0xD5380620" : "=r"(__isar1)); // MRS x0, ID_AA64ISAR1_EL1
634
+ id_aa64isar1_el1 = __isar1;
560
635
  unsigned supports_i8mm = ((id_aa64isar1_el1 >> 52) & 0xF) >= 1;
561
636
  unsigned supports_bf16 = ((id_aa64isar1_el1 >> 44) & 0xF) >= 1;
562
637
 
563
- __asm__ __volatile__("mrs %0, ID_AA64PFR0_EL1" : "=r"(id_aa64pfr0_el1));
638
+ register unsigned long __pfr0 __asm__("x0");
639
+ __asm__ __volatile__(".inst 0xD5380400" : "=r"(__pfr0)); // MRS x0, ID_AA64PFR0_EL1
640
+ id_aa64pfr0_el1 = __pfr0;
564
641
  unsigned supports_sve = ((id_aa64pfr0_el1 >> 32) & 0xF) >= 1;
565
642
  unsigned supports_fp16 = ((id_aa64pfr0_el1 >> 20) & 0xF) == 0x1;
566
643
  unsigned supports_neon = ((id_aa64pfr0_el1 >> 20) & 0xF) != 0xF;
567
644
 
568
- if (supports_sve) __asm__ __volatile__("mrs %0, ID_AA64ZFR0_EL1" : "=r"(id_aa64zfr0_el1));
645
+ if (supports_sve) {
646
+ register unsigned long __zfr0 __asm__("x0");
647
+ __asm__ __volatile__(".inst 0xD5380480" : "=r"(__zfr0)); // MRS x0, ID_AA64ZFR0_EL1
648
+ id_aa64zfr0_el1 = __zfr0;
649
+ }
569
650
  unsigned supports_svesdotmm = ((id_aa64zfr0_el1 >> 44) & 0xF) >= 1;
570
651
  unsigned supports_svebfdot = ((id_aa64zfr0_el1 >> 20) & 0xF) >= 1;
571
652
  unsigned supports_sve2 = ((id_aa64zfr0_el1) & 0xF) >= 1;
572
653
  unsigned supports_sve2p1 = ((id_aa64zfr0_el1) & 0xF) >= 2;
573
654
 
655
+ register unsigned long __fpfr0 __asm__("x0");
656
+ __asm__ __volatile__(".inst 0xD53804E0" : "=r"(__fpfr0)); // MRS x0, ID_AA64FPFR0_EL1
657
+ id_aa64fpfr0_el1 = __fpfr0;
658
+ unsigned supports_fp8dot4 = ((id_aa64fpfr0_el1 >> 29) & 0x1) >= 1;
659
+
574
660
  unsigned long id_aa64pfr1_el1 = 0, id_aa64smfr0_el1 = 0;
575
- __asm__ __volatile__("mrs %0, ID_AA64PFR1_EL1" : "=r"(id_aa64pfr1_el1));
661
+ register unsigned long __pfr1 __asm__("x0");
662
+ __asm__ __volatile__(".inst 0xD5380420" : "=r"(__pfr1)); // MRS x0, ID_AA64PFR1_EL1
663
+ id_aa64pfr1_el1 = __pfr1;
576
664
  unsigned supports_sme = ((id_aa64pfr1_el1 >> 24) & 0xF) >= 1;
577
665
 
578
666
  unsigned supports_sme2 = 0, supports_sme2p1 = 0;
579
667
  unsigned supports_smef64 = 0, supports_smehalf = 0, supports_smebf16 = 0;
580
- unsigned supports_smelut2 = 0, supports_smefa64 = 0;
668
+ unsigned supports_smebi32 = 0, supports_smelut2 = 0, supports_smefa64 = 0;
581
669
  if (supports_sme) {
582
- __asm__ __volatile__("mrs %0, ID_AA64SMFR0_EL1" : "=r"(id_aa64smfr0_el1));
670
+ // MRS x0, ID_AA64SMFR0_EL1 (S3_0_C0_C4_5) — encoded as raw .inst because some
671
+ // assemblers (Clang 21 in Android NDK r29) reject the symbolic register name.
672
+ // Encoding: 0xD53804A0 = MRS x0, op0=3, op1=0, CRn=0, CRm=4, op2=5, Rt=0.
673
+ register unsigned long __smfr0 __asm__("x0");
674
+ __asm__ __volatile__(".inst 0xD53804A0" : "=r"(__smfr0));
675
+ id_aa64smfr0_el1 = __smfr0;
583
676
  unsigned sme_version = (id_aa64smfr0_el1 >> 56) & 0xF;
584
677
  supports_sme2 = sme_version >= 1;
585
678
  supports_sme2p1 = sme_version >= 2;
586
679
  supports_smef64 = (id_aa64smfr0_el1 >> 48) & 0x1;
587
680
  supports_smehalf = (id_aa64smfr0_el1 >> 42) & 0x1;
588
681
  supports_smebf16 = (id_aa64smfr0_el1 >> 44) & 0x1;
682
+ supports_smebi32 = (id_aa64smfr0_el1 >> 33) & 0x1;
589
683
  supports_smefa64 = (id_aa64smfr0_el1 >> 63) & 0x1;
590
684
  }
591
685
 
592
- return (nk_capability_t)((nk_cap_neon_k * (supports_neon)) |
593
- (nk_cap_neonhalf_k * (supports_neon && supports_fp16)) |
594
- (nk_cap_neonfhm_k * (supports_neon && supports_fhm)) |
595
- (nk_cap_neonbfdot_k * (supports_neon && supports_bf16)) |
596
- (nk_cap_neonsdot_k * (supports_neon && supports_i8mm && supports_integer_dot_products)) |
597
- (nk_cap_sve_k * (supports_sve)) | (nk_cap_svehalf_k * (supports_sve && supports_fp16)) |
598
- (nk_cap_svebfdot_k * (supports_sve && supports_svebfdot)) |
599
- (nk_cap_svesdot_k * (supports_sve && supports_svesdotmm)) |
600
- (nk_cap_sve2_k * (supports_sve2)) | (nk_cap_sve2p1_k * (supports_sve2p1)) |
601
- (nk_cap_sme_k * (supports_sme)) | (nk_cap_sme2_k * (supports_sme2)) |
602
- (nk_cap_sme2p1_k * (supports_sme2p1)) | (nk_cap_smef64_k * (supports_smef64)) |
603
- (nk_cap_smehalf_k * (supports_smehalf)) | (nk_cap_smebf16_k * (supports_smebf16)) |
604
- (nk_cap_smefa64_k * (supports_smefa64)) | (nk_cap_serial_k));
686
+ return (
687
+ nk_capability_t)((nk_cap_neon_k * (supports_neon)) | (nk_cap_neonhalf_k * (supports_neon && supports_fp16)) |
688
+ (nk_cap_neonfhm_k * (supports_neon && supports_fhm)) |
689
+ (nk_cap_neonbfdot_k * (supports_neon && supports_bf16)) |
690
+ (nk_cap_neonsdot_k * (supports_neon && supports_i8mm && supports_integer_dot_products)) |
691
+ (nk_cap_neonfp8_k * (supports_neon && supports_fp8dot4)) | //
692
+ (nk_cap_sve_k * (supports_sve)) | (nk_cap_svehalf_k * (supports_sve && supports_fp16)) |
693
+ (nk_cap_svebfdot_k * (supports_sve && supports_svebfdot)) |
694
+ (nk_cap_svesdot_k * (supports_sve && supports_svesdotmm)) | (nk_cap_sve2_k * (supports_sve2)) |
695
+ (nk_cap_sve2p1_k * (supports_sve2p1)) | (nk_cap_sme_k * (supports_sme)) |
696
+ (nk_cap_sme2_k * (supports_sme2)) | (nk_cap_sme2p1_k * (supports_sme2p1)) |
697
+ (nk_cap_smef64_k * (supports_smef64)) | (nk_cap_smehalf_k * (supports_smehalf)) |
698
+ (nk_cap_smebf16_k * (supports_smebf16)) | (nk_cap_smebi32_k * (supports_smebi32)) |
699
+ (nk_cap_smefa64_k * (supports_smefa64)) | (nk_cap_serial_k));
605
700
  #elif defined(NK_DEFINED_WINDOWS_)
606
701
 
607
702
  unsigned supports_neon = 0, supports_dp = 0;
@@ -665,6 +760,40 @@ NK_PUBLIC nk_capability_t nk_capabilities_riscv_(void) {
665
760
 
666
761
  #endif // NK_TARGET_RISCV_
667
762
 
763
+ #if NK_TARGET_LOONGARCH_
764
+
765
+ NK_PUBLIC nk_capability_t nk_capabilities_loongarch_(void) {
766
+ #if defined(NK_DEFINED_LINUX_)
767
+ unsigned long hwcap = getauxval(AT_HWCAP);
768
+ nk_capability_t caps = nk_cap_serial_k;
769
+ // LoongArch HWCAP bit 5 = LASX (256-bit SIMD)
770
+ if (hwcap & (1UL << 5)) caps |= nk_cap_loongsonasx_k;
771
+ return caps;
772
+ #else
773
+ return nk_cap_serial_k;
774
+ #endif
775
+ }
776
+
777
+ #endif // NK_TARGET_LOONGARCH_
778
+
779
+ #if NK_TARGET_POWER_
780
+
781
+ NK_PUBLIC nk_capability_t nk_capabilities_power_(void) {
782
+ #if defined(NK_DEFINED_LINUX_)
783
+ unsigned long hwcap = getauxval(AT_HWCAP);
784
+ unsigned long hwcap2 = getauxval(AT_HWCAP2);
785
+ nk_capability_t caps = nk_cap_serial_k;
786
+ nk_unused_(hwcap2);
787
+ // PPC_FEATURE_HAS_VSX = 0x00000080
788
+ if (hwcap & 0x00000080) caps |= nk_cap_powervsx_k;
789
+ return caps;
790
+ #else
791
+ return nk_cap_serial_k;
792
+ #endif
793
+ }
794
+
795
+ #endif // NK_TARGET_POWER_
796
+
668
797
  #if NK_TARGET_WASM_
669
798
 
670
799
  #if defined(__EMSCRIPTEN__) && NK_DYNAMIC_DISPATCH && !defined(NK_PYODIDE_SIDE_MODULE)
@@ -710,22 +839,86 @@ NK_PUBLIC int nk_configure_thread_(nk_capability_t capabilities) {
710
839
  NK_PUBLIC nk_capability_t nk_capabilities_(void) {
711
840
  #if NK_TARGET_X86_
712
841
  return nk_capabilities_x86_();
842
+ #elif NK_TARGET_ARM_
843
+ return nk_capabilities_arm_();
844
+ #elif NK_TARGET_RISCV_
845
+ return nk_capabilities_riscv_();
846
+ #elif NK_TARGET_LOONGARCH_
847
+ return nk_capabilities_loongarch_();
848
+ #elif NK_TARGET_POWER_
849
+ return nk_capabilities_power_();
850
+ #elif NK_TARGET_WASM_
851
+ return nk_capabilities_v128relaxed_();
852
+ #else
853
+ return nk_cap_serial_k;
854
+ #endif
855
+ }
856
+
857
+ /**
858
+ * @brief Returns a bitmask of all capabilities the library was compiled with,
859
+ * regardless of whether the current CPU supports them at runtime.
860
+ */
861
+ NK_PUBLIC nk_capability_t nk_capabilities_compiled_(void) {
862
+ nk_capability_t caps = nk_cap_serial_k;
863
+ #if NK_TARGET_X86_
864
+ caps |= nk_cap_haswell_k * NK_TARGET_HASWELL;
865
+ caps |= nk_cap_skylake_k * NK_TARGET_SKYLAKE;
866
+ caps |= nk_cap_icelake_k * NK_TARGET_ICELAKE;
867
+ caps |= nk_cap_genoa_k * NK_TARGET_GENOA;
868
+ caps |= nk_cap_sapphire_k * NK_TARGET_SAPPHIRE;
869
+ caps |= nk_cap_sapphireamx_k * NK_TARGET_SAPPHIREAMX;
870
+ caps |= nk_cap_graniteamx_k * NK_TARGET_GRANITEAMX;
871
+ caps |= nk_cap_diamond_k * NK_TARGET_DIAMOND;
872
+ caps |= nk_cap_turin_k * NK_TARGET_TURIN;
873
+ caps |= nk_cap_alder_k * NK_TARGET_ALDER;
874
+ caps |= nk_cap_sierra_k * NK_TARGET_SIERRA;
713
875
  #endif
714
876
  #if NK_TARGET_ARM_
715
- return nk_capabilities_arm_();
877
+ caps |= nk_cap_neon_k * NK_TARGET_NEON;
878
+ caps |= nk_cap_neonhalf_k * NK_TARGET_NEONHALF;
879
+ caps |= nk_cap_neonsdot_k * NK_TARGET_NEONSDOT;
880
+ caps |= nk_cap_neonbfdot_k * NK_TARGET_NEONBFDOT;
881
+ caps |= nk_cap_neonfhm_k * NK_TARGET_NEONFHM;
882
+ caps |= nk_cap_neonfp8_k * NK_TARGET_NEONFP8;
883
+ caps |= nk_cap_sve_k * NK_TARGET_SVE;
884
+ caps |= nk_cap_svehalf_k * NK_TARGET_SVEHALF;
885
+ caps |= nk_cap_svesdot_k * NK_TARGET_SVESDOT;
886
+ caps |= nk_cap_svebfdot_k * NK_TARGET_SVEBFDOT;
887
+ caps |= nk_cap_sve2_k * NK_TARGET_SVE2;
888
+ caps |= nk_cap_sve2p1_k * NK_TARGET_SVE2P1;
889
+ caps |= nk_cap_sme_k * NK_TARGET_SME;
890
+ caps |= nk_cap_sme2_k * NK_TARGET_SME2;
891
+ caps |= nk_cap_sme2p1_k * NK_TARGET_SME2P1;
892
+ caps |= nk_cap_smef64_k * NK_TARGET_SMEF64;
893
+ caps |= nk_cap_smehalf_k * NK_TARGET_SMEHALF;
894
+ caps |= nk_cap_smebf16_k * NK_TARGET_SMEBF16;
895
+ caps |= nk_cap_smebi32_k * NK_TARGET_SMEBI32;
896
+ caps |= nk_cap_smelut2_k * NK_TARGET_SMELUT2;
897
+ caps |= nk_cap_smefa64_k * NK_TARGET_SMEFA64;
716
898
  #endif
717
899
  #if NK_TARGET_RISCV_
718
- return nk_capabilities_riscv_();
900
+ caps |= nk_cap_rvv_k * NK_TARGET_RVV;
901
+ caps |= nk_cap_rvvhalf_k * NK_TARGET_RVVHALF;
902
+ caps |= nk_cap_rvvbf16_k * NK_TARGET_RVVBF16;
903
+ caps |= nk_cap_rvvbb_k * NK_TARGET_RVVBB;
904
+ #endif
905
+ #if NK_TARGET_LOONGARCH_
906
+ caps |= nk_cap_loongsonasx_k * NK_TARGET_LOONGSONASX;
907
+ #endif
908
+ #if NK_TARGET_POWER_
909
+ caps |= nk_cap_powervsx_k * NK_TARGET_POWERVSX;
719
910
  #endif
720
911
  #if NK_TARGET_WASM_
721
- return nk_capabilities_v128relaxed_();
912
+ caps |= nk_cap_v128relaxed_k * NK_TARGET_V128RELAXED;
722
913
  #endif
723
- return nk_cap_serial_k;
914
+ return caps;
724
915
  }
725
916
 
726
917
  #if NK_DYNAMIC_DISPATCH
727
918
 
728
919
  NK_DYNAMIC nk_capability_t nk_capabilities(void);
920
+ NK_DYNAMIC nk_capability_t nk_capabilities_available(void);
921
+ NK_DYNAMIC nk_capability_t nk_capabilities_compiled(void);
729
922
  NK_DYNAMIC int nk_configure_thread(nk_capability_t);
730
923
  NK_DYNAMIC int nk_uses_dynamic_dispatch(void);
731
924
  NK_DYNAMIC void nk_dispatch_table_update(nk_capability_t);
@@ -737,6 +930,8 @@ NK_DYNAMIC void nk_find_kernel_punned(nk_kernel_kind_t kind, nk_dtype_t dtype, n
737
930
  NK_PUBLIC int nk_uses_dynamic_dispatch(void) { return 0; }
738
931
  NK_PUBLIC int nk_configure_thread(nk_capability_t c) { return nk_configure_thread_(c); }
739
932
  NK_PUBLIC nk_capability_t nk_capabilities(void) { return nk_capabilities_(); }
933
+ NK_PUBLIC nk_capability_t nk_capabilities_available(void) { return nk_capabilities_() & nk_capabilities_compiled_(); }
934
+ NK_PUBLIC nk_capability_t nk_capabilities_compiled(void) { return nk_capabilities_compiled_(); }
740
935
  NK_PUBLIC void nk_dispatch_table_update(nk_capability_t caps) { nk_unused_(caps); }
741
936
 
742
937
  #endif