numkong 7.0.0 → 7.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +239 -122
  2. package/binding.gyp +25 -491
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -84,14 +84,14 @@
84
84
  * rounding (notably 3x faster on Genoa than Ice Lake). VFPCLASS detects NaN/Inf inputs for special
85
85
  * case handling. Division appears in tangent's final step but isn't on the critical path.
86
86
  *
87
- * Intrinsic Instruction Ice Genoa
88
- * _mm512_roundscale_ps VRNDSCALEPS (ZMM, ZMM, I8) 8c @ p0 3c @ p23
89
- * _mm512_roundscale_pd VRNDSCALEPD (ZMM, ZMM, I8) 8c @ p0 3c @ p23
90
- * _mm512_fpclass_ps_mask VFPCLASSPS (K, ZMM, I8) 3c @ p5 5c @ p01
91
- * _mm512_fmadd_ps VFMADD231PS (ZMM, ZMM, ZMM) 4c @ p0 4c @ p01
92
- * _mm256_fmadd_ps VFMADD231PS (YMM, YMM, YMM) 4c @ p01 4c @ p01
93
- * _mm256_div_ps VDIVPS (YMM, YMM, YMM) ~14c @ p0 ~11c @ p01
94
- * _mm256_div_pd VDIVPD (YMM, YMM, YMM) ~23c @ p0 ~13c @ p01
87
+ * Intrinsic Instruction Icelake Genoa
88
+ * _mm512_roundscale_ps VRNDSCALEPS (ZMM, ZMM, I8) 8cy @ p0+p0 3cy @ p23
89
+ * _mm512_roundscale_pd VRNDSCALEPD (ZMM, ZMM, I8) 8cy @ p0+p0 3cy @ p23
90
+ * _mm512_fpclass_ps_mask VFPCLASSPS (K, ZMM, I8) 3cy @ p5 5cy @ p01
91
+ * _mm512_fmadd_ps VFMADD231PS (ZMM, ZMM, ZMM) 4cy @ p0 4cy @ p01
92
+ * _mm256_fmadd_ps VFMADD231PS (YMM, YMM, YMM) 4cy @ p01 4cy @ p01
93
+ * _mm256_div_ps VDIVPS (YMM, YMM, YMM) ~11cy @ p0 ~11cy @ p01
94
+ * _mm256_div_pd VDIVPD (YMM, YMM, YMM) ~13cy @ p0 ~13cy @ p01
95
95
  *
96
96
  * @section arm_instructions Relevant ARM NEON/SVE Instructions
97
97
  *
@@ -99,14 +99,14 @@
99
99
  * fast rounding for range reduction. The 4-cycle FMA latency with 4 inst/cycle throughput allows
100
100
  * excellent pipelining when processing multiple elements.
101
101
  *
102
- * Intrinsic Instruction M1 Firestorm Graviton 3 Graviton 4
103
- * vfmaq_f32 FMLA.S (vec) 4c @ V0123 4c @ V0123 4c @ V0123
104
- * vfmaq_f64 FMLA.D (vec) 4c @ V0123 4c @ V0123 4c @ V0123
105
- * vrndaq_f32 FRINTA.S 2c @ V0123 2c @ V01 2c @ V01
102
+ * Intrinsic Instruction M1 Firestorm Graviton 3 Graviton 4
103
+ * vfmaq_f32 FMLA.S (vec) 4cy @ V0123 4cy @ V0123 4cy @ V0123
104
+ * vfmaq_f64 FMLA.D (vec) 4cy @ V0123 4cy @ V0123 4cy @ V0123
105
+ * vrndaq_f32 FRINTA.S 2cy @ V0123 2cy @ V01 2cy @ V01
106
106
  *
107
107
  * @section references References
108
108
  *
109
- * - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
109
+ * - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html
110
110
  * - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
111
111
  *
112
112
  */
@@ -91,7 +91,7 @@ void atan(in_type_ const *in, std::size_t n, in_type_ *out) noexcept {
91
91
 
92
92
  namespace ashvardanian::numkong {
93
93
 
94
- #pragma region - Tensor Trigonometric
94
+ #pragma region Tensor Trigonometric
95
95
 
96
96
  /** @brief Elementwise sin into pre-allocated output. */
97
97
  template <numeric_dtype value_type_, std::size_t max_rank_ = 8>
@@ -159,7 +159,7 @@ tensor<value_type_, allocator_type_, max_rank_> try_atan(tensor_view<value_type_
159
159
  return result;
160
160
  }
161
161
 
162
- #pragma endregion - Tensor Trigonometric
162
+ #pragma endregion Tensor Trigonometric
163
163
 
164
164
  } // namespace ashvardanian::numkong
165
165
 
@@ -36,6 +36,29 @@
36
36
  * @see https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1
37
37
  * @see FP8 Formats for Deep Learning: https://arxiv.org/pdf/2209.05433
38
38
  * @see ONNX Float8 Types: https://onnx.ai/onnx/technical/float8.html
39
+ *
40
+ * @section fp6_types FP6 Numeric Types
41
+ *
42
+ * The OCP Microscaling (MX) v1.0 specification defines two 6-bit floating-point formats
43
+ * for block-scaled quantization. Both are "FN" (finite-numeric): all bit patterns map
44
+ * to real numbers with no Inf or NaN codes. Stored byte-aligned with 2 bits of padding.
45
+ *
46
+ * Format Bias Sign Exp Mant Range Subnormals Infinity NaN Standard
47
+ * E2M3 1 1 2 3 ±7.5 14 of 64 ❌ No ❌ OCP MX v1.0
48
+ * E3M2 3 1 3 2 ±28 6 of 64 ❌ No ❌ OCP MX v1.0
49
+ *
50
+ * E2M3 favors mantissa precision (3 bits) for narrow dynamic range — ideal for activations.
51
+ * E3M2 favors exponent range (3 bits) for wider dynamic range — suited for weights.
52
+ * Both follow IEEE 754 subnormal rules: when exp=0, the implicit leading bit is 0,
53
+ * giving value = (-1)^s × 0.mmm × 2^(1-bias). This provides gradual underflow to zero.
54
+ *
55
+ * No hardware directly computes on FP6. On Arm with FEAT_FP8DOT4, E2M3 values can be
56
+ * losslessly promoted to E4M3 (same mantissa width, rebias exponent by +6) and E3M2 to
57
+ * E5M2 (same mantissa width, rebias exponent by +12), then fed to FDOT instructions.
58
+ * Subnormal values (exp=0) require normalization during this promotion.
59
+ *
60
+ * @see https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
61
+ * @see https://arxiv.org/abs/2401.14112 (FP6-LLM paper)
39
62
  */
40
63
  #ifndef NK_TYPES_H
41
64
  #define NK_TYPES_H
@@ -68,6 +91,9 @@
68
91
  #if defined(__GNUC__) || defined(__clang__)
69
92
  #define NK_PUBLIC __attribute__((unused)) inline static
70
93
  #define NK_INTERNAL __attribute__((always_inline)) inline static
94
+ #elif defined(_MSC_VER)
95
+ #define NK_PUBLIC inline static
96
+ #define NK_INTERNAL __forceinline static
71
97
  #else
72
98
  #define NK_PUBLIC inline static
73
99
  #define NK_INTERNAL inline static
@@ -85,6 +111,14 @@
85
111
  #define NK_DYNAMIC NK_PUBLIC
86
112
  #endif // NK_DYNAMIC_DISPATCH
87
113
 
114
+ // Vector union types use type punning by design (write as f16, read as f32, etc.).
115
+ // Without this, GCC at -O2 assumes strict aliasing and may optimize away valid accesses.
116
+ #if defined(__GNUC__) || defined(__clang__)
117
+ #define NK_MAY_ALIAS_ __attribute__((may_alias))
118
+ #else
119
+ #define NK_MAY_ALIAS_
120
+ #endif
121
+
88
122
  // Allow SIMD kernels to redirect small inputs to serial implementations.
89
123
  // Enabled by default for production use. Tests and benchmarks may disable
90
124
  // this to isolate SIMD path behavior on small inputs.
@@ -93,6 +127,7 @@
93
127
  #endif
94
128
 
95
129
  // Compiling for Arm: NK_TARGET_ARM_
130
+ // https://arm-software.github.io/acle/main/acle.html
96
131
  #if !defined(NK_TARGET_ARM_)
97
132
  #if defined(__aarch64__) || defined(_M_ARM64)
98
133
  #define NK_TARGET_ARM_ 1
@@ -102,6 +137,7 @@
102
137
  #endif // !defined(NK_TARGET_ARM_)
103
138
 
104
139
  // Compiling for x86: NK_TARGET_X86_
140
+ // https://www.intel.com/content/www/us/en/docs/dpcpp-cpp-compiler/developer-guide-reference/2024-2/additional-predefined-macros.html
105
141
  #if !defined(NK_TARGET_X86_)
106
142
  #if defined(__x86_64__) || defined(_M_X64)
107
143
  #define NK_TARGET_X86_ 1
@@ -119,6 +155,24 @@
119
155
  #endif // defined(__riscv) && (__riscv_xlen == 64)
120
156
  #endif // !defined(NK_TARGET_RISCV_)
121
157
 
158
+ // Compiling for LoongArch: NK_TARGET_LOONGARCH_
159
+ #if !defined(NK_TARGET_LOONGARCH_)
160
+ #if defined(__loongarch__)
161
+ #define NK_TARGET_LOONGARCH_ 1
162
+ #else
163
+ #define NK_TARGET_LOONGARCH_ 0
164
+ #endif // defined(__loongarch__)
165
+ #endif // !defined(NK_TARGET_LOONGARCH_)
166
+
167
+ // Compiling for Power: NK_TARGET_POWER_
168
+ #if !defined(NK_TARGET_POWER_)
169
+ #if defined(__powerpc64__) || defined(__ppc64__) || defined(_ARCH_PPC64)
170
+ #define NK_TARGET_POWER_ 1
171
+ #else
172
+ #define NK_TARGET_POWER_ 0
173
+ #endif // defined(__powerpc64__) || defined(__ppc64__) || defined(_ARCH_PPC64)
174
+ #endif // !defined(NK_TARGET_POWER_)
175
+
122
176
  // Compiling for WASM: NK_TARGET_WASM_
123
177
  #if !defined(NK_TARGET_WASM_)
124
178
  #if defined(__wasm__) || defined(__EMSCRIPTEN__)
@@ -191,56 +245,93 @@
191
245
  #endif // defined(__riscv_zvbb) && (__riscv_zvbb > 0)
192
246
  #endif // !defined(NK_TARGET_RVVBB) || ...
193
247
 
248
+ // Compiling for LoongArch LASX (256-bit SIMD): NK_TARGET_LOONGSONASX
249
+ // LASX provides 32 × 256-bit vector registers, widening integer multiply-accumulate,
250
+ // and f32-to-f64 conversion (xvfcvtl_d_s / xvfcvth_d_s) but no widening FMA.
251
+ #if !defined(NK_TARGET_LOONGSONASX) || (NK_TARGET_LOONGSONASX && !NK_TARGET_LOONGARCH_)
252
+ #if defined(__loongarch_asx)
253
+ #define NK_TARGET_LOONGSONASX 1
254
+ #else
255
+ #undef NK_TARGET_LOONGSONASX
256
+ #define NK_TARGET_LOONGSONASX 0
257
+ #endif // defined(__loongarch_asx)
258
+ #endif // !defined(NK_TARGET_LOONGSONASX) || ...
259
+
260
+ // Compiling for Power VSX (128-bit SIMD, POWER9+ baseline): NK_TARGET_POWERVSX
261
+ // VSX provides 64 × 128-bit registers, FMA (vec_madd), vec_msum (multiply-sum), hardware f16
262
+ // conversion (vec_extract_fp32_from_shorth/l), length-limited loads (vec_xl_len), per-byte
263
+ // popcount (vec_popcnt), and vec_cmpne. Requires POWER9 (ISA 3.0) or newer.
264
+ #if !defined(NK_TARGET_POWERVSX) || (NK_TARGET_POWERVSX && !NK_TARGET_POWER_)
265
+ #if defined(__VSX__) && defined(__POWER9_VECTOR__)
266
+ #define NK_TARGET_POWERVSX 1
267
+ #else
268
+ #undef NK_TARGET_POWERVSX
269
+ #define NK_TARGET_POWERVSX 0
270
+ #endif // defined(__VSX__)
271
+ #endif // !defined(NK_TARGET_POWERVSX) || ...
272
+
194
273
  // Compiling for Arm: NK_TARGET_NEON
195
274
  #if !defined(NK_TARGET_NEON) || (NK_TARGET_NEON && !NK_TARGET_ARM_)
196
- #if defined(__ARM_NEON)
275
+ #if defined(__ARM_NEON) || (defined(_MSC_VER) && defined(_M_ARM64))
197
276
  #define NK_TARGET_NEON 1
198
277
  #else
199
278
  #undef NK_TARGET_NEON
200
279
  #define NK_TARGET_NEON 0
201
- #endif // defined(__ARM_NEON)
280
+ #endif // defined(__ARM_NEON) || ...
202
281
  #endif // !defined(NK_TARGET_NEON) || ...
203
282
 
204
- // Compiling for Arm: NK_TARGET_NEONSDOT
283
+ // Compiling for Arm: NK_TARGET_NEONSDOT (FEAT_DotProd, optional from ARMv8.1, mandatory at ARMv8.4 with AdvSIMD)
205
284
  #if !defined(NK_TARGET_NEONSDOT) || (NK_TARGET_NEONSDOT && !NK_TARGET_ARM_)
206
- #if defined(__ARM_NEON)
285
+ #if defined(__ARM_FEATURE_DOTPROD) || (defined(_MSC_VER) && defined(_M_ARM64) && __ARM_ARCH >= 804)
207
286
  #define NK_TARGET_NEONSDOT 1
208
287
  #else
209
288
  #undef NK_TARGET_NEONSDOT
210
289
  #define NK_TARGET_NEONSDOT 0
211
- #endif // defined(__ARM_NEON)
290
+ #endif
212
291
  #endif // !defined(NK_TARGET_NEONSDOT) || ...
213
292
 
214
- // Compiling for Arm: NK_TARGET_NEONHALF
293
+ // Compiling for Arm: NK_TARGET_NEONHALF (FEAT_FP16, optional from ARMv8.2, mandatory at ARMv9.0 with AdvSIMD)
215
294
  #if !defined(NK_TARGET_NEONHALF) || (NK_TARGET_NEONHALF && !NK_TARGET_ARM_)
216
- #if defined(__ARM_NEON)
295
+ #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || (defined(_MSC_VER) && defined(_M_ARM64) && __ARM_ARCH >= 802)
217
296
  #define NK_TARGET_NEONHALF 1
218
297
  #else
219
298
  #undef NK_TARGET_NEONHALF
220
299
  #define NK_TARGET_NEONHALF 0
221
- #endif // defined(__ARM_NEON)
300
+ #endif
222
301
  #endif // !defined(NK_TARGET_NEONHALF) || ...
223
302
 
224
- // Compiling for Arm: NK_TARGET_NEONFHM (FEAT_FHM - FMLAL/FMLSL widening ops)
303
+ // Compiling for Arm: NK_TARGET_NEONFHM (FEAT_FHM, optional from ARMv8.1, mandatory at ARMv8.4 with FP16)
225
304
  #if !defined(NK_TARGET_NEONFHM) || (NK_TARGET_NEONFHM && !NK_TARGET_ARM_)
226
- #if defined(__ARM_NEON)
305
+ #if defined(__ARM_FEATURE_FP16_FML) || (defined(_MSC_VER) && defined(_M_ARM64) && __ARM_ARCH >= 804)
227
306
  #define NK_TARGET_NEONFHM 1
228
307
  #else
229
308
  #undef NK_TARGET_NEONFHM
230
309
  #define NK_TARGET_NEONFHM 0
231
- #endif // defined(__ARM_NEON)
310
+ #endif
232
311
  #endif // !defined(NK_TARGET_NEONFHM) || ...
233
312
 
234
- // Compiling for Arm: NK_TARGET_NEONBFDOT
313
+ // Compiling for Arm: NK_TARGET_NEONBFDOT (FEAT_BF16, optional from ARMv8.2, mandatory at ARMv8.6 with FP)
235
314
  #if !defined(NK_TARGET_NEONBFDOT) || (NK_TARGET_NEONBFDOT && !NK_TARGET_ARM_)
236
- #if defined(__ARM_NEON)
315
+ #if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || (defined(_MSC_VER) && defined(_M_ARM64) && __ARM_ARCH >= 806)
237
316
  #define NK_TARGET_NEONBFDOT 1
238
317
  #else
239
318
  #undef NK_TARGET_NEONBFDOT
240
319
  #define NK_TARGET_NEONBFDOT 0
241
- #endif // defined(__ARM_NEON)
320
+ #endif
242
321
  #endif // !defined(NK_TARGET_NEONBFDOT) || ...
243
322
 
323
+ // Compiling for Arm: NK_TARGET_NEONFP8 (NEON FP8 extensions, FEAT_FP8DOT4)
324
+ // ACLE macro __ARM_FEATURE_FP8DOT4 defined by GCC 15+ and Clang 21+ when +fp8dot4 is enabled.
325
+ // Older compilers lack mfloat8x16_t and the fp8dot4 target attribute entirely.
326
+ #if !defined(NK_TARGET_NEONFP8) || (NK_TARGET_NEONFP8 && !NK_TARGET_ARM_)
327
+ #if defined(__ARM_FEATURE_FP8DOT4)
328
+ #define NK_TARGET_NEONFP8 1
329
+ #else
330
+ #undef NK_TARGET_NEONFP8
331
+ #define NK_TARGET_NEONFP8 0
332
+ #endif // defined(__ARM_FEATURE_FP8DOT4)
333
+ #endif // !defined(NK_TARGET_NEONFP8) || ...
334
+
244
335
  // Compiling for Arm: NK_TARGET_SVE
245
336
  #if !defined(NK_TARGET_SVE) || (NK_TARGET_SVE && !NK_TARGET_ARM_)
246
337
  #if defined(__ARM_FEATURE_SVE)
@@ -316,20 +407,26 @@
316
407
  #endif // defined(__ARM_FEATURE_SME2)
317
408
  #endif // !defined(NK_TARGET_SME2) || ...
318
409
 
410
+ // Compiling for Arm: NK_TARGET_SME2P1 (FEAT_SME2p1)
411
+ // ACLE macro: __ARM_FEATURE_SME2p1 (note lowercase 'p')
319
412
  #if !defined(NK_TARGET_SME2P1) || (NK_TARGET_SME2P1 && !NK_TARGET_ARM_)
413
+ #if defined(__ARM_FEATURE_SME2p1)
414
+ #define NK_TARGET_SME2P1 1
415
+ #else
320
416
  #undef NK_TARGET_SME2P1
321
417
  #define NK_TARGET_SME2P1 0
322
- #endif
418
+ #endif // defined(__ARM_FEATURE_SME2p1)
419
+ #endif // !defined(NK_TARGET_SME2P1) || ...
323
420
 
324
421
  // AppleClang 17 exposes SME sub-features through `arm_sme.h` builtin aliases,
325
422
  // not dedicated `__ARM_FEATURE_*` predefines for every matrix subtype.
326
423
  #if !defined(NK_TARGET_SMEF64) || (NK_TARGET_SMEF64 && !NK_TARGET_ARM_)
327
- #if defined(__has_builtin) && __has_builtin(__builtin_sme_svmopa_za64_f64_m)
424
+ #if defined(__ARM_FEATURE_SME_F64F64) || (defined(__has_builtin) && __has_builtin(__builtin_sme_svmopa_za64_f64_m))
328
425
  #define NK_TARGET_SMEF64 1
329
426
  #else
330
427
  #undef NK_TARGET_SMEF64
331
428
  #define NK_TARGET_SMEF64 0
332
- #endif // defined(__has_builtin) && __has_builtin(__builtin_sme_svmopa_za64_f64_m)
429
+ #endif // defined(__ARM_FEATURE_SME_F64F64) || ...
333
430
  #endif // !defined(NK_TARGET_SMEF64) || ...
334
431
 
335
432
  #if !defined(NK_TARGET_SMEBI32) || (NK_TARGET_SMEBI32 && !NK_TARGET_ARM_)
@@ -342,7 +439,7 @@
342
439
  #endif // !defined(NK_TARGET_SMEBI32) || ...
343
440
 
344
441
  #if !defined(NK_TARGET_SMEHALF) || (NK_TARGET_SMEHALF && !NK_TARGET_ARM_)
345
- #if defined(__has_builtin) && __has_builtin(__builtin_sme_svmopa_za32_f16_m)
442
+ #if defined(__ARM_FEATURE_SME_F16F16) || (defined(__has_builtin) && __has_builtin(__builtin_sme_svmopa_za32_f16_m))
346
443
  #define NK_TARGET_SMEHALF 1
347
444
  #else
348
445
  #undef NK_TARGET_SMEHALF
@@ -368,10 +465,15 @@
368
465
  #endif // defined(__has_builtin) && __has_builtin(__builtin_sme_svluti2_lane_zt_u8)
369
466
  #endif // !defined(NK_TARGET_SMELUT2) || ...
370
467
 
468
+ // Compiling for Arm: NK_TARGET_SMEFA64 (FEAT_SME_FA64, full SVE2 in streaming mode)
371
469
  #if !defined(NK_TARGET_SMEFA64) || (NK_TARGET_SMEFA64 && !NK_TARGET_ARM_)
470
+ #if defined(__ARM_FEATURE_SME_FA64)
471
+ #define NK_TARGET_SMEFA64 1
472
+ #else
372
473
  #undef NK_TARGET_SMEFA64
373
474
  #define NK_TARGET_SMEFA64 0
374
- #endif
475
+ #endif // defined(__ARM_FEATURE_SME_FA64)
476
+ #endif // !defined(NK_TARGET_SMEFA64) || ...
375
477
 
376
478
  // Compiling for x86: NK_TARGET_HASWELL
377
479
  //
@@ -433,9 +535,22 @@
433
535
  #else
434
536
  #undef NK_TARGET_GENOA
435
537
  #define NK_TARGET_GENOA 0
436
- #endif
538
+ #endif // defined(__AVX512BF16__) || ...
437
539
  #endif // !defined(NK_TARGET_GENOA) || ...
438
540
 
541
+ // Compiling for x86: NK_TARGET_DIAMOND (AVX10.2, Diamond Rapids)
542
+ // GCC 14+: defines __AVX10_2__ with -mavx10.2-512
543
+ // Clang 19+: defines __AVX10_2__ with -mavx10.2-512
544
+ // MSVC: defines __AVX10_VER__ >= 2 with /arch:AVX10.2 (VS 2026+, not yet released)
545
+ #if !defined(NK_TARGET_DIAMOND) || (NK_TARGET_DIAMOND && !NK_TARGET_X86_)
546
+ #if defined(__AVX10_2__) || (defined(__AVX10_VER__) && __AVX10_VER__ >= 2)
547
+ #define NK_TARGET_DIAMOND 1
548
+ #else
549
+ #undef NK_TARGET_DIAMOND
550
+ #define NK_TARGET_DIAMOND 0
551
+ #endif // defined(__AVX10_2__) || ...
552
+ #endif // !defined(NK_TARGET_DIAMOND) || ...
553
+
439
554
  #if !defined(NK_TARGET_SAPPHIRE) || (NK_TARGET_SAPPHIRE && !NK_TARGET_X86_)
440
555
  #if defined(__AVX512FP16__) || (defined(_MSC_VER) && _MSC_VER >= 1944)
441
556
  #define NK_TARGET_SAPPHIRE 1
@@ -490,10 +605,10 @@
490
605
  #endif
491
606
  #endif // !defined(NK_TARGET_SIERRA) || ...
492
607
 
493
- // Include the relevant intrinsics file - different for different OSes and ISAs
608
+ // Include the relevant intrinsics headers
494
609
  #if defined(_MSC_VER)
495
610
  #include <intrin.h>
496
- #elif NK_TARGET_ARM_
611
+ #endif
497
612
  #if NK_TARGET_NEON
498
613
  #include <arm_neon.h>
499
614
  #endif
@@ -503,11 +618,20 @@
503
618
  #if NK_TARGET_SME || NK_TARGET_SME2 || NK_TARGET_SMEBI32
504
619
  #include <arm_sme.h>
505
620
  #endif
506
- #elif NK_TARGET_HASWELL || NK_TARGET_SKYLAKE
621
+ #if NK_TARGET_HASWELL || NK_TARGET_SKYLAKE
507
622
  #include <immintrin.h>
508
- #elif NK_TARGET_RVV
623
+ #endif
624
+ #if NK_TARGET_RVV
509
625
  #include <riscv_vector.h>
510
- #elif NK_TARGET_V128RELAXED
626
+ #endif
627
+ #if NK_TARGET_LOONGSONASX
628
+ #include <lsxintrin.h> // `__m128i` for LSX SIMD
629
+ #include <lasxintrin.h> // `__m256i` for LASX SIMD
630
+ #endif
631
+ #if NK_TARGET_POWERVSX
632
+ #include <altivec.h>
633
+ #endif
634
+ #if NK_TARGET_V128RELAXED
511
635
  #include <wasm_simd128.h>
512
636
  #endif
513
637
 
@@ -516,11 +640,11 @@
516
640
  #endif
517
641
 
518
642
  #if !defined(NK_F32_DIVISION_EPSILON)
519
- #define NK_F32_DIVISION_EPSILON (1e-7)
643
+ #define NK_F32_DIVISION_EPSILON (1e-7f)
520
644
  #endif
521
645
 
522
646
  #if !defined(NK_F16_DIVISION_EPSILON)
523
- #define NK_F16_DIVISION_EPSILON (1e-3)
647
+ #define NK_F16_DIVISION_EPSILON (1e-3f)
524
648
  #endif
525
649
 
526
650
  /**
@@ -576,6 +700,27 @@
576
700
  #endif
577
701
  #endif
578
702
 
703
+ /* AltiVec defines `bool`, `vector`, and `pixel` as macros, which conflict with C++.
704
+ * We use `__vector` directly in our code, so undef the problematic macros.
705
+ */
706
+ #if NK_TARGET_POWERVSX
707
+ #ifdef __cplusplus
708
+ #undef bool
709
+ #undef vector
710
+ #undef pixel
711
+ #endif
712
+ typedef __vector unsigned char nk_vu8x16_t;
713
+ typedef __vector unsigned short nk_vu16x8_t;
714
+ typedef __vector unsigned int nk_vu32x4_t;
715
+ typedef __vector unsigned long long nk_vu64x2_t;
716
+ typedef __vector signed char nk_vi8x16_t;
717
+ typedef __vector signed short nk_vi16x8_t;
718
+ typedef __vector signed int nk_vi32x4_t;
719
+ typedef __vector signed long long nk_vi64x2_t;
720
+ typedef __vector float nk_vf32x4_t;
721
+ typedef __vector double nk_vf64x2_t;
722
+ #endif // NK_TARGET_POWERVSX
723
+
579
724
  /** Copy 16 bits (2 bytes) from source to destination */
580
725
  #if defined(__GNUC__) || defined(__clang__)
581
726
  #define nk_copy_bytes_(destination_ptr, source_ptr, count) __builtin_memcpy((destination_ptr), (source_ptr), count)
@@ -632,10 +777,16 @@ typedef unsigned char nk_e4m3_t;
632
777
  * 122 of 248 finite values (49.2%) fall in [−1, +1]. */
633
778
  typedef unsigned char nk_e5m2_t;
634
779
  /** @brief 6-bit E2M3 micro-float (OCP MX v1.0): sign(1) + exponent(2) + mantissa(3), bias=1.
635
- * Range: ±7.5, no infinities or NaN. Only 64 total codes; 18 (28.1%) fall in [−1, +1]. */
780
+ * Stored as 0b00SEEMMM with 2 bits of padding. Range: ±7.5, no infinities or NaN.
781
+ * 64 total codes: 48 normal, 14 subnormal (exp=0, mant≠0), 2 zeros (±0).
782
+ * 18 of 64 values (28.1%) fall in [−1, +1]. Subnormal values span [±0.125, ±0.875].
783
+ * Losslessly promotable to E4M3 by rebiasing exponent +6 (normals) or normalizing (subnormals). */
636
784
  typedef unsigned char nk_e2m3_t;
637
785
  /** @brief 6-bit E3M2 micro-float (OCP MX v1.0): sign(1) + exponent(3) + mantissa(2), bias=3.
638
- * Range: ±28, supports infinities. Only 64 total codes; 26 (40.6%) fall in [−1, +1]. */
786
+ * Stored as 0b00SEEEMM with 2 bits of padding. Range: ±28, no infinities or NaN.
787
+ * 64 total codes: 56 normal, 6 subnormal (exp=0, mant≠0), 2 zeros (±0).
788
+ * 26 of 64 values (40.6%) fall in [−1, +1]. Subnormal values span [±0.0625, ±0.1875].
789
+ * Losslessly promotable to E5M2 by rebiasing exponent +12 (normals) or normalizing (subnormals). */
639
790
  typedef unsigned char nk_e3m2_t;
640
791
 
641
792
  /** @brief Signed 8-bit integer. Range: [−128, +127]. */
@@ -670,7 +821,7 @@ typedef float nk_f32_t;
670
821
  /** @brief Double-precision (64-bit) IEEE 754 float. sign(1) + exponent(11) + mantissa(52), bias=1023. */
671
822
  typedef double nk_f64_t;
672
823
 
673
- #if NK_TARGET_X86_ || NK_TARGET_ARM_ || NK_TARGET_RISCV_
824
+ #if NK_TARGET_X86_ || NK_TARGET_ARM_ || NK_TARGET_RISCV_ || NK_TARGET_POWER_ || NK_TARGET_LOONGARCH_
674
825
  #define NK_IS_64BIT_ 1
675
826
  #else
676
827
  #define NK_IS_64BIT_ 0
@@ -712,11 +863,17 @@ typedef nk_f64_t nk_fmax_t;
712
863
  #define NK_U8_MAX 255U
713
864
  #define NK_U8_MIN 0x0U
714
865
 
715
- #define NK_F16_MAX 0x7BFF // IEEE 754 binary16: +65504.0
716
- #define NK_F16_MIN 0xFBFF // IEEE 754 binary16: -65504.0
866
+ #define NK_F16_MAX_AS_U16 0x7BFF // IEEE 754 binary16: +65504.0
867
+ #define NK_F16_MIN_AS_U16 0xFBFF // IEEE 754 binary16: -65504.0
717
868
 
718
- #define NK_BF16_MAX 0x7F7F // BFloat16: ~+3.39e38
719
- #define NK_BF16_MIN 0xFF7F // BFloat16: ~-3.39e38
869
+ #define NK_F16_MAX nk_u16_as_f16_(0x7BFF)
870
+ #define NK_F16_MIN nk_u16_as_f16_(0xFBFF)
871
+
872
+ #define NK_BF16_MAX_AS_U16 0x7F7F // BFloat16: ~+3.39e38
873
+ #define NK_BF16_MIN_AS_U16 0xFF7F // BFloat16: ~-3.39e38
874
+
875
+ #define NK_BF16_MAX nk_u16_as_bf16_(0x7F7F)
876
+ #define NK_BF16_MIN nk_u16_as_bf16_(0xFF7F)
720
877
 
721
878
  #define NK_E4M3_MAX 0x7E // FP8 E4M3: +448.0
722
879
  #define NK_E4M3_MIN 0xFE // FP8 E4M3: -448.0
@@ -842,7 +999,7 @@ NK_PUBLIC nk_size_t nk_dtype_bits(nk_dtype_t dtype) {
842
999
  /** @brief Returns how many logical dimensions are packed into one storage value.
843
1000
  * For sub-byte types multiple dimensions share a single byte container.
844
1001
  * For byte-or-larger types this is always 1. */
845
- NK_PUBLIC nk_size_t nk_dtype_dimensions_per_value(nk_dtype_t dtype) {
1002
+ NK_PUBLIC nk_size_t nk_dimensions_per_value(nk_dtype_t dtype) {
846
1003
  switch (dtype) {
847
1004
  case nk_u1_k: return 8;
848
1005
  case nk_i4_k: return 2;
@@ -975,7 +1132,7 @@ NK_STATIC_ASSERT(sizeof(nk_bf16_t) == 2, nk_bf16_t_must_be_2_bytes);
975
1132
  #define nk_assign_from_to_(src, dest) (*(dest) = *(src))
976
1133
 
977
1134
  /** @brief 16-bit union for f16/bf16/u16/i16 bit manipulation. */
978
- typedef union {
1135
+ typedef union NK_MAY_ALIAS_ {
979
1136
  nk_u16_t u;
980
1137
  nk_i16_t i;
981
1138
  nk_f16_t f;
@@ -983,14 +1140,14 @@ typedef union {
983
1140
  } nk_fui16_t;
984
1141
 
985
1142
  /** @brief 32-bit union for f32/u32/i32 bit manipulation. */
986
- typedef union {
1143
+ typedef union NK_MAY_ALIAS_ {
987
1144
  nk_u32_t u;
988
1145
  nk_i32_t i;
989
1146
  nk_f32_t f;
990
1147
  } nk_fui32_t;
991
1148
 
992
1149
  /** @brief 64-bit union for f64/u64/i64 bit manipulation. */
993
- typedef union {
1150
+ typedef union NK_MAY_ALIAS_ {
994
1151
  nk_u64_t u;
995
1152
  nk_i64_t i;
996
1153
  nk_f64_t f;
@@ -1021,7 +1178,7 @@ typedef struct {
1021
1178
  } nk_f64c_t;
1022
1179
 
1023
1180
  /** @brief Small 4-byte memory slice viewable as different types. */
1024
- typedef union nk_b32_vec_t {
1181
+ typedef union NK_MAY_ALIAS_ nk_b32_vec_t {
1025
1182
  nk_u32_t u32;
1026
1183
  nk_i32_t i32;
1027
1184
  nk_f32_t f32;
@@ -1034,7 +1191,7 @@ typedef union nk_b32_vec_t {
1034
1191
  } nk_b32_vec_t;
1035
1192
 
1036
1193
  /** @brief Small 8-byte memory slice viewable as different types. */
1037
- typedef union nk_b64_vec_t {
1194
+ typedef union NK_MAY_ALIAS_ nk_b64_vec_t {
1038
1195
  #if NK_TARGET_NEON
1039
1196
  uint8x8_t u8x8;
1040
1197
  uint16x4_t u16x4;
@@ -1061,8 +1218,8 @@ typedef union nk_b64_vec_t {
1061
1218
  } nk_b64_vec_t;
1062
1219
 
1063
1220
  /** @brief Small 16-byte memory slice viewable as different types. */
1064
- typedef union nk_b128_vec_t {
1065
- #if NK_TARGET_HASWELL
1221
+ typedef union NK_MAY_ALIAS_ nk_b128_vec_t {
1222
+ #if NK_TARGET_HASWELL || NK_TARGET_LOONGSONASX
1066
1223
  __m128i xmm;
1067
1224
  __m128d xmm_pd;
1068
1225
  __m128 xmm_ps;
@@ -1082,6 +1239,22 @@ typedef union nk_b128_vec_t {
1082
1239
  float32x4_t f32x4;
1083
1240
  float64x2_t f64x2;
1084
1241
  #endif
1242
+ #if NK_TARGET_NEONHALF
1243
+ float16x8_t f16x8;
1244
+ #endif
1245
+ #if NK_TARGET_POWERVSX
1246
+ nk_vu8x16_t vu8x16;
1247
+ nk_vu16x8_t vu16x8;
1248
+ nk_vu32x4_t vu32x4;
1249
+ nk_vu64x2_t vu64x2;
1250
+ nk_vi8x16_t vi8x16;
1251
+ nk_vi16x8_t vi16x8;
1252
+ nk_vi32x4_t vi32x4;
1253
+ nk_vi64x2_t vi64x2;
1254
+ nk_vf32x4_t vf32x4;
1255
+ nk_vf64x2_t vf64x2;
1256
+ #endif
1257
+
1085
1258
  nk_u8_t u8s[16];
1086
1259
  nk_u16_t u16s[8];
1087
1260
  nk_u32_t u32s[4];
@@ -1101,8 +1274,8 @@ typedef union nk_b128_vec_t {
1101
1274
  } nk_b128_vec_t;
1102
1275
 
1103
1276
  /** @brief Small 32-byte memory slice viewable as different types. */
1104
- typedef union nk_b256_vec_t {
1105
- #if NK_TARGET_HASWELL
1277
+ typedef union NK_MAY_ALIAS_ nk_b256_vec_t {
1278
+ #if NK_TARGET_HASWELL || NK_TARGET_LOONGSONASX
1106
1279
  __m256i ymm;
1107
1280
  __m256d ymm_pd;
1108
1281
  __m256 ymm_ps;
@@ -1123,6 +1296,19 @@ typedef union nk_b256_vec_t {
1123
1296
  float32x4_t f32x4s[2];
1124
1297
  float64x2_t f64x2s[2];
1125
1298
  #endif
1299
+ #if NK_TARGET_POWERVSX
1300
+ nk_vu8x16_t vu8x16s[2];
1301
+ nk_vu16x8_t vu16x8s[2];
1302
+ nk_vu32x4_t vu32x4s[2];
1303
+ nk_vu64x2_t vu64x2s[2];
1304
+ nk_vi8x16_t vi8x16s[2];
1305
+ nk_vi16x8_t vi16x8s[2];
1306
+ nk_vi32x4_t vi32x4s[2];
1307
+ nk_vi64x2_t vi64x2s[2];
1308
+ nk_vf32x4_t vf32x4s[2];
1309
+ nk_vf64x2_t vf64x2s[2];
1310
+ #endif
1311
+
1126
1312
  nk_u8_t u8s[32];
1127
1313
  nk_u16_t u16s[16];
1128
1314
  nk_u32_t u32s[8];
@@ -1148,7 +1334,7 @@ typedef union nk_b256_vec_t {
1148
1334
  * of this is that the argument of such type is passed to functions using the calling convention of the first
1149
1335
  * member of the union, which in our case is a register-based calling convention for SIMD types.
1150
1336
  */
1151
- typedef union nk_b512_vec_t {
1337
+ typedef union NK_MAY_ALIAS_ nk_b512_vec_t {
1152
1338
  #if NK_TARGET_SKYLAKE
1153
1339
  __m512i zmm;
1154
1340
  __m512d zmm_pd;
@@ -1353,17 +1539,28 @@ NK_INTERNAL nk_i8_t nk_i4x2_get_(nk_i4x2_t byte_val, int n) {
1353
1539
  /** @brief Extract bit at position n (0-7) from packed u1x8 byte. */
1354
1540
  NK_INTERNAL nk_u8_t nk_u1x8_get_(nk_u1x8_t byte_val, int n) { return (byte_val >> (n & 7)) & 1; }
1355
1541
 
1356
- NK_INTERNAL nk_f16_t nk_f16_from_u16_(nk_u16_t bits) {
1542
+ NK_INTERNAL nk_f16_t nk_u16_as_f16_(nk_u16_t bits) {
1357
1543
  nk_fui16_t c;
1358
1544
  c.u = bits;
1359
1545
  return c.f;
1360
1546
  }
1361
- NK_INTERNAL nk_bf16_t nk_bf16_from_u16_(nk_u16_t bits) {
1547
+ NK_INTERNAL nk_u16_t nk_f16_as_u16_(nk_f16_t x) {
1548
+ nk_fui16_t c;
1549
+ c.f = x;
1550
+ return c.u;
1551
+ }
1552
+ NK_INTERNAL nk_bf16_t nk_u16_as_bf16_(nk_u16_t bits) {
1362
1553
  nk_fui16_t c;
1363
1554
  c.u = bits;
1364
1555
  return c.bf;
1365
1556
  }
1366
1557
 
1558
+ NK_INTERNAL void nk_f64_from_i64_(nk_i64_t const *src, nk_f64_t *dest) { *dest = (nk_f64_t)*src; }
1559
+ NK_INTERNAL void nk_f64_from_u64_(nk_u64_t const *src, nk_f64_t *dest) { *dest = (nk_f64_t)*src; }
1560
+ NK_INTERNAL void nk_f32_from_i32_(nk_i32_t const *src, nk_f32_t *dest) { *dest = (nk_f32_t)*src; }
1561
+ NK_INTERNAL void nk_f32_from_u32_(nk_u32_t const *src, nk_f32_t *dest) { *dest = (nk_f32_t)*src; }
1562
+ NK_INTERNAL void nk_f32_from_f64_(nk_f64_t const *src, nk_f32_t *dest) { *dest = (nk_f32_t)*src; }
1563
+
1367
1564
  /** @brief E4M3: NaN when (raw & 0x7F) == 0x7F (two NaN values: 0x7F, 0xFF). */
1368
1565
  NK_INTERNAL int nk_e4m3_is_nan_(nk_e4m3_t x) { return (x & 0x7F) == 0x7F; }
1369
1566
 
@@ -1372,10 +1569,51 @@ NK_INTERNAL int nk_e4m3_is_nan_(nk_e4m3_t x) { return (x & 0x7F) == 0x7F; }
1372
1569
  NK_INTERNAL int nk_e5m2_is_nan_(nk_e5m2_t x) { return (x & 0x7F) > 0x7C; }
1373
1570
 
1374
1571
  /** @brief F16: NaN when (raw & 0x7FFF) > 0x7C00. */
1375
- NK_INTERNAL int nk_f16_is_nan_(nk_u16_t x) { return (x & 0x7FFF) > 0x7C00; }
1572
+ NK_INTERNAL int nk_f16_is_nan_(nk_f16_t x) {
1573
+ nk_fui16_t x_fui;
1574
+ x_fui.f = x;
1575
+ return (x_fui.u & 0x7FFF) > 0x7C00;
1576
+ }
1376
1577
 
1377
1578
  /** @brief BF16: NaN when (raw & 0x7FFF) > 0x7F80. */
1378
- NK_INTERNAL int nk_bf16_is_nan_(nk_u16_t x) { return (x & 0x7FFF) > 0x7F80; }
1579
+ NK_INTERNAL int nk_bf16_is_nan_(nk_bf16_t x) {
1580
+ nk_fui16_t x_fui;
1581
+ x_fui.bf = x;
1582
+ return (x_fui.u & 0x7FFF) > 0x7F80;
1583
+ }
1584
+
1585
+ /* Safe SVE vector-length queries usable from non-streaming context.
1586
+ * On Apple M4 (and other SME-only-SVE cores), SVE instructions like CNTW/CNTH/CNTB
1587
+ * trap with SIGILL outside streaming mode. These helpers bracket the query with
1588
+ * SMSTART SM / SMSTOP SM so the calling function's ABI is unchanged.
1589
+ * Inside `__arm_locally_streaming` functions the plain `svcntXX()` intrinsics are fine.
1590
+ */
1591
+ #if NK_TARGET_ARM_ && NK_TARGET_SME
1592
+ /** @brief Streaming SVL byte-element count (SVL/8) via SMSTART SM bracket. */
1593
+ NK_INTERNAL nk_size_t nk_sme_cntb_(void) {
1594
+ nk_u64_t r;
1595
+ __asm__ __volatile__("smstart sm\n\t" "cntb %0\n\t" "smstop sm" : "=r"(r));
1596
+ return (nk_size_t)r;
1597
+ }
1598
+ /** @brief Streaming SVL half-element count (SVL/16) via SMSTART SM bracket. */
1599
+ NK_INTERNAL nk_size_t nk_sme_cnth_(void) {
1600
+ nk_u64_t r;
1601
+ __asm__ __volatile__("smstart sm\n\t" "cnth %0\n\t" "smstop sm" : "=r"(r));
1602
+ return (nk_size_t)r;
1603
+ }
1604
+ /** @brief Streaming SVL word-element count (SVL/32) via SMSTART SM bracket. */
1605
+ NK_INTERNAL nk_size_t nk_sme_cntw_(void) {
1606
+ nk_u64_t r;
1607
+ __asm__ __volatile__("smstart sm\n\t" "cntw %0\n\t" "smstop sm" : "=r"(r));
1608
+ return (nk_size_t)r;
1609
+ }
1610
+ /** @brief Streaming SVL double-element count (SVL/64) via SMSTART SM bracket. */
1611
+ NK_INTERNAL nk_size_t nk_sme_cntd_(void) {
1612
+ nk_u64_t r;
1613
+ __asm__ __volatile__("smstart sm\n\t" "cntd %0\n\t" "smstop sm" : "=r"(r));
1614
+ return (nk_size_t)r;
1615
+ }
1616
+ #endif
1379
1617
 
1380
1618
  #ifdef __cplusplus
1381
1619
  } // extern "C"