numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -8,34 +8,34 @@
8
8
  *
9
9
  * @section neon_cast_instructions ARM NEON Conversion Instructions
10
10
  *
11
- * Float ↔ integer conversions (Cortex-A76 class):
11
+ * Float ↔ integer conversions:
12
12
  *
13
- * Intrinsic Instruction Latency Throughput
14
- * vcvtq_f32_s32 SCVTF (V.4S, V.4S) 3cy 2/cy
15
- * vcvtq_f32_u32 UCVTF (V.4S, V.4S) 3cy 2/cy
16
- * vcvtq_s32_f32 FCVTZS (V.4S, V.4S) 3cy 2/cy
17
- * vcvtq_u32_f32 FCVTZU (V.4S, V.4S) 3cy 2/cy
13
+ * Intrinsic Instruction A76 M5
14
+ * vcvtq_f32_s32 SCVTF (V.4S, V.4S) 3cy @ 2p 3cy @ 4p
15
+ * vcvtq_f32_u32 UCVTF (V.4S, V.4S) 3cy @ 2p 3cy @ 4p
16
+ * vcvtq_s32_f32 FCVTZS (V.4S, V.4S) 3cy @ 2p 3cy @ 4p
17
+ * vcvtq_u32_f32 FCVTZU (V.4S, V.4S) 3cy @ 2p 3cy @ 4p
18
18
  *
19
19
  * Float precision conversions:
20
20
  *
21
- * Intrinsic Instruction Latency Throughput
22
- * vcvt_f32_f16 FCVTL (V.4S, V.4H) 3cy 2/cy
23
- * vcvt_f16_f32 FCVTN (V.4H, V.4S) 3cy 2/cy
24
- * vcvt_f64_f32 FCVTL (V.2D, V.2S) 3cy 2/cy
25
- * vcvt_f32_f64 FCVTN (V.2S, V.2D) 3cy 2/cy
21
+ * Intrinsic Instruction A76 M5
22
+ * vcvt_f32_f16 FCVTL (V.4S, V.4H) 3cy @ 2p 3cy @ 4p
23
+ * vcvt_f16_f32 FCVTN (V.4H, V.4S) 3cy @ 2p 3cy @ 4p
24
+ * vcvt_f64_f32 FCVTL (V.2D, V.2S) 3cy @ 2p 3cy @ 4p
25
+ * vcvt_f32_f64 FCVTN (V.2S, V.2D) 3cy @ 2p 3cy @ 4p
26
26
  *
27
27
  * Integer narrowing with saturation:
28
28
  *
29
- * Intrinsic Instruction Latency Throughput
30
- * vqmovn_s32 SQXTN (V.4H, V.4S) 3cy 2/cy
31
- * vqmovn_u32 UQXTN (V.4H, V.4S) 3cy 2/cy
32
- * vqmovun_s32 SQXTUN (V.4H, V.4S) 3cy 2/cy
29
+ * Intrinsic Instruction A76 M5
30
+ * vqmovn_s32 SQXTN (V.4H, V.4S) 3cy @ 2p 3cy @ 4p
31
+ * vqmovn_u32 UQXTN (V.4H, V.4S) 3cy @ 2p 3cy @ 4p
32
+ * vqmovun_s32 SQXTUN (V.4H, V.4S) 3cy @ 2p 3cy @ 4p
33
33
  *
34
34
  * BF16 support (ARMv8.6-A+):
35
35
  *
36
- * Intrinsic Instruction Latency Throughput
37
- * vcvtq_low_bf16_f32 BFCVTN (V.4H, V.4S) 3cy 1/cy
38
- * vcvtq_high_bf16_f32 BFCVTN2 (V.8H, V.4S) 3cy 1/cy
36
+ * Intrinsic Instruction A76 M5
37
+ * vcvtq_low_bf16_f32 BFCVTN (V.4H, V.4S) 3cy @ 2p 3cy @ 4p
38
+ * vcvtq_high_bf16_f32 BFCVTN2 (V.8H, V.4S) 3cy @ 2p 3cy @ 4p
39
39
  *
40
40
  * BF16 conversions on baseline NEON (emulated via bit shifts):
41
41
  * - bf16 → f32: vmovl_u16 + vshlq_n_u32 by 16
@@ -68,18 +68,18 @@ extern "C" {
68
68
  #endif
69
69
 
70
70
  NK_PUBLIC void nk_f16_to_f32_neon(nk_f16_t const *src, nk_f32_t *dest) {
71
- float16x4_t f16vec = vreinterpret_f16_u16(vld1_dup_u16((nk_u16_t const *)src));
72
- float32x4_t f32vec = vcvt_f32_f16(f16vec);
73
- *dest = vgetq_lane_f32(f32vec, 0);
71
+ float16x4_t f16_f16x4 = vreinterpret_f16_u16(vld1_dup_u16((nk_u16_t const *)src));
72
+ float32x4_t f32_f32x4 = vcvt_f32_f16(f16_f16x4);
73
+ *dest = vgetq_lane_f32(f32_f32x4, 0);
74
74
  }
75
75
 
76
76
  NK_PUBLIC void nk_f32_to_f16_neon(nk_f32_t const *src, nk_f16_t *dest) {
77
- float32x4_t f32vec = vdupq_n_f32(*src);
78
- float16x4_t f16vec = vcvt_f16_f32(f32vec);
79
- vst1_lane_u16((nk_u16_t *)dest, vreinterpret_u16_f16(f16vec), 0);
77
+ float32x4_t f32_f32x4 = vdupq_n_f32(*src);
78
+ float16x4_t f16_f16x4 = vcvt_f16_f32(f32_f32x4);
79
+ vst1_lane_u16((nk_u16_t *)dest, vreinterpret_u16_f16(f16_f16x4), 0);
80
80
  }
81
81
 
82
- #pragma region - Type Punned Loads and Stores
82
+ #pragma region Type Punned Loads and Stores
83
83
 
84
84
  /** @brief Type-agnostic 128-bit full load (NEON). */
85
85
  NK_INTERNAL void nk_load_b128_neon_(void const *src, nk_b128_vec_t *dst) {
@@ -104,73 +104,64 @@ NK_INTERNAL void nk_store_b256_neon_(nk_b256_vec_t const *src, void *dst) {
104
104
  /** @brief Type-agnostic 64-bit full load (NEON). */
105
105
  NK_INTERNAL void nk_load_b64_neon_(void const *src, nk_b64_vec_t *dst) { dst->u8x8 = vld1_u8((nk_u8_t const *)src); }
106
106
 
107
- #pragma endregion - Type Punned Loads and Stores
107
+ #pragma endregion Type Punned Loads and Stores
108
108
 
109
- #pragma region - Vectorized Conversions
109
+ #pragma region Vectorized Conversions
110
110
 
111
- /** @brief Convert 4x e4m3 → f32x4 via bit manipulation (NEON).
112
- * E4M3FN format: S EEEE MMM (bias=7). No ∞ representation.
113
- * Only exp=15, mant=7 (0x7F) is NaN; exp=15, mant [0,6] are valid normals (max=448). */
111
+ /** @brief Convert 4x e4m3 → f32x4 via Giesen magic-multiply (NEON).
112
+ * Reinterprets magnitude bits as a tiny f32, then multiplies by 2^(127-bias) to rebias.
113
+ * Handles zero, subnormals, and normals in a single VMUL. NaN fixup for magnitude 0x7F.
114
+ * https://fgiesen.wordpress.com/2012/03/28/half-to-float-done-quic/ */
114
115
  NK_INTERNAL float32x4_t nk_e4m3x4_to_f32x4_neon_(nk_b32_vec_t src) {
115
116
  uint8x8_t e4m3_u8x8 = vcreate_u8(src.u32);
116
117
  uint16x8_t e4m3_u16x8 = vmovl_u8(e4m3_u8x8);
117
118
  uint32x4_t e4m3_u32x4 = vmovl_u16(vget_low_u16(e4m3_u16x8));
118
- uint32x4_t sign_u32x4 = vshlq_n_u32(vandq_u32(e4m3_u32x4, vdupq_n_u32(0x80)), 24);
119
- uint32x4_t exp_u32x4 = vandq_u32(vshrq_n_u32(e4m3_u32x4, 3), vdupq_n_u32(0x0F));
120
- uint32x4_t mant_u32x4 = vandq_u32(e4m3_u32x4, vdupq_n_u32(0x07));
121
119
 
122
- // Normal path: f32 = sign | ((exp+120)<<23) | (mant<<20)
123
- uint32x4_t f32_exp_u32x4 = vshlq_n_u32(vaddq_u32(exp_u32x4, vdupq_n_u32(120)), 23);
124
- uint32x4_t f32_mant_u32x4 = vshlq_n_u32(mant_u32x4, 20);
125
- uint32x4_t normal_u32x4 = vorrq_u32(sign_u32x4, vorrq_u32(f32_exp_u32x4, f32_mant_u32x4));
120
+ // Extract sign: (raw & 0x80) << 24 → f32 sign bit
121
+ uint32x4_t sign_u32x4 = vshlq_n_u32(vandq_u32(e4m3_u32x4, vdupq_n_u32(0x80)), 24);
122
+ // Strip sign to get 7-bit magnitude, shift left by 20 so E4M3 exponent overlaps f32 exponent
123
+ uint32x4_t nonsign_u32x4 = vandq_u32(e4m3_u32x4, vdupq_n_u32(0x7F));
124
+ uint32x4_t shifted_u32x4 = vshlq_n_u32(nonsign_u32x4, 20);
126
125
 
127
- // Subnormal path (exp=0, mant 0): value = ±mantissa × 2⁻⁹
128
- float32x4_t subnormal_f32x4 = vmulq_n_f32(vcvtq_f32_u32(mant_u32x4), 1.0f / 512.0f);
129
- uint32x4_t subnormal_u32x4 = vorrq_u32(vreinterpretq_u32_f32(subnormal_f32x4), sign_u32x4);
126
+ // Magic multiply: reinterpret as f32 × 2^120 rebiases from E4M3 (bias=7) to f32 (bias=127).
127
+ float32x4_t result_f32x4 = vmulq_f32(vreinterpretq_f32_u32(shifted_u32x4),
128
+ vreinterpretq_f32_u32(vdupq_n_u32(0x7B800000))); // 2^120
130
129
 
131
- // NaN path: E4M3FN only has NaN when exp=15 AND mant=7 (0x7F or 0xFF)
130
+ // NaN fixup: E4M3FN NaN only at magnitude 0x7F force to f32 quiet NaN
131
+ uint32x4_t is_nan_mask_u32x4 = vceqq_u32(nonsign_u32x4, vdupq_n_u32(0x7F));
132
132
  uint32x4_t nan_u32x4 = vorrq_u32(sign_u32x4, vdupq_n_u32(0x7FC00000));
133
- uint32x4_t is_nan_mask = vandq_u32(vceqq_u32(exp_u32x4, vdupq_n_u32(15)), vceqq_u32(mant_u32x4, vdupq_n_u32(7)));
133
+ uint32x4_t result_u32x4 = vbslq_u32(is_nan_mask_u32x4, nan_u32x4, vreinterpretq_u32_f32(result_f32x4));
134
134
 
135
- // Blend paths: subnormal when exp=0, NaN when exp=15 && mant=7, else normal
136
- uint32x4_t exp_zero_mask = vceqq_u32(exp_u32x4, vdupq_n_u32(0));
137
- uint32x4_t result_u32x4 = vbslq_u32(exp_zero_mask, subnormal_u32x4, normal_u32x4);
138
- result_u32x4 = vbslq_u32(is_nan_mask, nan_u32x4, result_u32x4);
139
- return vreinterpretq_f32_u32(result_u32x4);
135
+ // Restore sign
136
+ return vreinterpretq_f32_u32(vorrq_u32(result_u32x4, sign_u32x4));
140
137
  }
141
138
 
142
- /** @brief Convert 4x e5m2 → f32x4 via bit manipulation (NEON).
143
- * E5M2 format: S EEEEE MM (bias=15). F32: sign<<31, (exp+112)<<23, mant<<21.
144
- * Handles subnormals (exp=0, mant 0), inf (exp=31, mant=0), and nan (exp=31, mant ≠ 0). */
139
+ /** @brief Convert 4x e5m2 → f32x4 via Giesen magic-multiply (NEON).
140
+ * Reinterprets magnitude bits as a tiny f32, then multiplies by 2^(127-bias) to rebias.
141
+ * Handles zero, subnormals, and normals in a single VMUL. Inf/NaN fixup for exp=31.
142
+ * https://fgiesen.wordpress.com/2012/03/28/half-to-float-done-quic/ */
145
143
  NK_INTERNAL float32x4_t nk_e5m2x4_to_f32x4_neon_(nk_b32_vec_t src) {
146
144
  uint8x8_t e5m2_u8x8 = vcreate_u8(src.u32);
147
145
  uint16x8_t e5m2_u16x8 = vmovl_u8(e5m2_u8x8);
148
146
  uint32x4_t e5m2_u32x4 = vmovl_u16(vget_low_u16(e5m2_u16x8));
149
- uint32x4_t sign_u32x4 = vshlq_n_u32(vandq_u32(e5m2_u32x4, vdupq_n_u32(0x80)), 24);
150
- uint32x4_t exp_u32x4 = vandq_u32(vshrq_n_u32(e5m2_u32x4, 2), vdupq_n_u32(0x1F));
151
- uint32x4_t mant_u32x4 = vandq_u32(e5m2_u32x4, vdupq_n_u32(0x03));
152
147
 
153
- // Normal path: f32 = sign | ((exp+112)<<23) | (mant<<21)
154
- uint32x4_t f32_exp_u32x4 = vshlq_n_u32(vaddq_u32(exp_u32x4, vdupq_n_u32(112)), 23);
155
- uint32x4_t f32_mant_u32x4 = vshlq_n_u32(mant_u32x4, 21);
156
- uint32x4_t normal_u32x4 = vorrq_u32(sign_u32x4, vorrq_u32(f32_exp_u32x4, f32_mant_u32x4));
148
+ // Extract sign: (raw & 0x80) << 24 → f32 sign bit
149
+ uint32x4_t sign_u32x4 = vshlq_n_u32(vandq_u32(e5m2_u32x4, vdupq_n_u32(0x80)), 24);
150
+ // Strip sign to get 7-bit magnitude, shift left by 21 so E5M2 exponent overlaps f32 exponent
151
+ uint32x4_t nonsign_u32x4 = vandq_u32(e5m2_u32x4, vdupq_n_u32(0x7F));
152
+ uint32x4_t shifted_u32x4 = vshlq_n_u32(nonsign_u32x4, 21);
157
153
 
158
- // Subnormal path (exp=0, mant 0): value = ±mantissa × 2⁻¹⁶
159
- float32x4_t subnormal_f32x4 = vmulq_n_f32(vcvtq_f32_u32(mant_u32x4), 1.0f / 65536.0f);
160
- uint32x4_t subnormal_u32x4 = vorrq_u32(vreinterpretq_u32_f32(subnormal_f32x4), sign_u32x4);
154
+ // Magic multiply: reinterpret as f32 × 2^112 rebiases from E5M2 (bias=15) to f32 (bias=127).
155
+ float32x4_t result_f32x4 = vmulq_f32(vreinterpretq_f32_u32(shifted_u32x4),
156
+ vreinterpretq_f32_u32(vdupq_n_u32(0x77800000))); // 2^112
161
157
 
162
- // Special path (exp=31): inf (mant=0) or nan (mant≠0)
163
- uint32x4_t infinity_u32x4 = vorrq_u32(sign_u32x4, vdupq_n_u32(0x7F800000));
164
- uint32x4_t nan_u32x4 = vorrq_u32(sign_u32x4, vdupq_n_u32(0x7FC00000));
165
- uint32x4_t mant_zero_mask = vceqq_u32(mant_u32x4, vdupq_n_u32(0));
166
- uint32x4_t special_u32x4 = vbslq_u32(mant_zero_mask, infinity_u32x4, nan_u32x4);
158
+ // Inf/NaN fixup: nonsign > 123 means exp=31 force f32 exponent to 255
159
+ uint32x4_t is_infnan_u32x4 = vcgtq_u32(nonsign_u32x4, vdupq_n_u32(123));
160
+ uint32x4_t result_u32x4 = vorrq_u32(vreinterpretq_u32_f32(result_f32x4),
161
+ vandq_u32(is_infnan_u32x4, vdupq_n_u32(0x7F800000)));
167
162
 
168
- // Blend paths based on exponent value
169
- uint32x4_t exp_zero_mask = vceqq_u32(exp_u32x4, vdupq_n_u32(0));
170
- uint32x4_t exp_max_mask = vceqq_u32(exp_u32x4, vdupq_n_u32(31));
171
- uint32x4_t result_u32x4 = vbslq_u32(exp_zero_mask, subnormal_u32x4, normal_u32x4);
172
- result_u32x4 = vbslq_u32(exp_max_mask, special_u32x4, result_u32x4);
173
- return vreinterpretq_f32_u32(result_u32x4);
163
+ // Restore sign
164
+ return vreinterpretq_f32_u32(vorrq_u32(result_u32x4, sign_u32x4));
174
165
  }
175
166
 
176
167
  /** @brief Convert 8x e4m3 → f16x8 via bit manipulation (NEON).
@@ -190,19 +181,20 @@ NK_INTERNAL float16x8_t nk_e4m3x8_to_f16x8_neon_(uint8x8_t e4m3_u8x8) {
190
181
  // Subnormal path (exp=0, mant ≠ 0): E4M3 subnormal value = mant × 2⁻⁹ = mant ÷ 512
191
182
  // Compute arithmetically: mant → f32 → multiply → f16
192
183
  float32x4_t subnormal_low_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(mant_u16x8))), 1.0f / 512.0f);
193
- float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(mant_u16x8))), 1.0f / 512.0f);
184
+ float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_high_u16(mant_u16x8)), 1.0f / 512.0f);
194
185
  uint16x8_t subnormal_abs_u16x8 = vreinterpretq_u16_f16(
195
186
  vcombine_f16(vcvt_f16_f32(subnormal_low_f32x4), vcvt_f16_f32(subnormal_high_f32x4)));
196
187
  uint16x8_t subnormal_u16x8 = vorrq_u16(subnormal_abs_u16x8, sign_u16x8);
197
188
 
198
189
  // NaN path: E4M3FN only has NaN when exp=15 AND mant=7 (0x7F or 0xFF)
199
190
  uint16x8_t nan_u16x8 = vorrq_u16(sign_u16x8, vdupq_n_u16(0x7E00)); // F16 quiet NaN
200
- uint16x8_t is_nan_mask = vandq_u16(vceqq_u16(exp_u16x8, vdupq_n_u16(15)), vceqq_u16(mant_u16x8, vdupq_n_u16(7)));
191
+ uint16x8_t is_nan_mask_u16x8 = vandq_u16(vceqq_u16(exp_u16x8, vdupq_n_u16(15)),
192
+ vceqq_u16(mant_u16x8, vdupq_n_u16(7)));
201
193
 
202
194
  // Blend paths: subnormal when exp=0, NaN when exp=15 && mant=7, else normal
203
- uint16x8_t exp_zero_mask = vceqq_u16(exp_u16x8, vdupq_n_u16(0));
204
- uint16x8_t result_u16x8 = vbslq_u16(exp_zero_mask, subnormal_u16x8, normal_u16x8);
205
- result_u16x8 = vbslq_u16(is_nan_mask, nan_u16x8, result_u16x8);
195
+ uint16x8_t exp_zero_mask_u16x8 = vceqq_u16(exp_u16x8, vdupq_n_u16(0));
196
+ uint16x8_t result_u16x8 = vbslq_u16(exp_zero_mask_u16x8, subnormal_u16x8, normal_u16x8);
197
+ result_u16x8 = vbslq_u16(is_nan_mask_u16x8, nan_u16x8, result_u16x8);
206
198
  return vreinterpretq_f16_u16(result_u16x8);
207
199
  }
208
200
 
@@ -232,8 +224,8 @@ NK_INTERNAL void nk_e4m3x16_to_f16x8x2_neon_(uint8x16_t input_u8x16, float16x8_t
232
224
  0x58, 0x58, 0x59, 0x59, 0x5A, 0x5A, 0x5B, 0x5B, 0x5C, 0x5C, 0x5D, 0x5D, 0x5E, 0x5E, 0x5F, 0x7E,
233
225
  };
234
226
 
235
- uint8x16x4_t lut_q0 = vld1q_u8_x4(table_q0_u8x64);
236
- uint8x16x4_t lut_q1 = vld1q_u8_x4(table_q1_u8x64);
227
+ uint8x16x4_t lut_q0_u8x16x4 = vld1q_u8_x4(table_q0_u8x64);
228
+ uint8x16x4_t lut_q1_u8x16x4 = vld1q_u8_x4(table_q1_u8x64);
237
229
 
238
230
  // Strip sign bit, work with 7-bit absolute value
239
231
  uint8x16_t sign_u8x16 = vandq_u8(input_u8x16, vdupq_n_u8(0x80));
@@ -241,9 +233,9 @@ NK_INTERNAL void nk_e4m3x16_to_f16x8x2_neon_(uint8x16_t input_u8x16, float16x8_t
241
233
 
242
234
  // High byte via 2× VQTBL4 on unsigned index, then OR sign back.
243
235
  // VQTBL4 returns 0 for out-of-range indices (>= 64), so results OR together cleanly.
244
- uint8x16_t high_q0_u8x16 = vqtbl4q_u8(lut_q0, abs_u8x16);
236
+ uint8x16_t high_q0_u8x16 = vqtbl4q_u8(lut_q0_u8x16x4, abs_u8x16);
245
237
  uint8x16_t offset_q1_u8x16 = vsubq_u8(abs_u8x16, vdupq_n_u8(64));
246
- uint8x16_t high_q1_u8x16 = vqtbl4q_u8(lut_q1, offset_q1_u8x16);
238
+ uint8x16_t high_q1_u8x16 = vqtbl4q_u8(lut_q1_u8x16x4, offset_q1_u8x16);
247
239
  uint8x16_t high_bytes_u8x16 = vorrq_u8(vorrq_u8(high_q0_u8x16, high_q1_u8x16), sign_u8x16);
248
240
 
249
241
  // Low byte: (lsb << 7), masked to 0 for subnormals (exp=0) and NaN (exp=15, mant=7)
@@ -290,14 +282,14 @@ NK_INTERNAL float16x8_t nk_e2m3x8_to_f16x8_neon_(uint8x8_t e2m3_u8x8) {
290
282
  // Subnormal path (exp=0): E2M3 subnormal = mant / 8
291
283
  // Compute via f32: mant → f32 → multiply → f16
292
284
  float32x4_t subnormal_low_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(mant_u16x8))), 1.0f / 8.0f);
293
- float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(mant_u16x8))), 1.0f / 8.0f);
285
+ float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_high_u16(mant_u16x8)), 1.0f / 8.0f);
294
286
  uint16x8_t subnormal_abs_u16x8 = vreinterpretq_u16_f16(
295
287
  vcombine_f16(vcvt_f16_f32(subnormal_low_f32x4), vcvt_f16_f32(subnormal_high_f32x4)));
296
288
  uint16x8_t subnormal_u16x8 = vorrq_u16(subnormal_abs_u16x8, sign_u16x8);
297
289
 
298
290
  // Blend: use subnormal result when exp=0, else normal
299
- uint16x8_t exp_zero_mask = vceqq_u16(exp_u16x8, vdupq_n_u16(0));
300
- uint16x8_t result_u16x8 = vbslq_u16(exp_zero_mask, subnormal_u16x8, normal_u16x8);
291
+ uint16x8_t exp_zero_mask_u16x8 = vceqq_u16(exp_u16x8, vdupq_n_u16(0));
292
+ uint16x8_t result_u16x8 = vbslq_u16(exp_zero_mask_u16x8, subnormal_u16x8, normal_u16x8);
301
293
 
302
294
  return vreinterpretq_f16_u16(result_u16x8);
303
295
  }
@@ -323,14 +315,14 @@ NK_INTERNAL float16x8_t nk_e3m2x8_to_f16x8_neon_(uint8x8_t e3m2_u8x8) {
323
315
  // Subnormal path (exp=0): E3M2 subnormal = mant × 2^(-2) × (1/4) = mant / 16
324
316
  // Compute via f32: mant → f32 → multiply → f16
325
317
  float32x4_t subnormal_low_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(mant_u16x8))), 1.0f / 16.0f);
326
- float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(mant_u16x8))), 1.0f / 16.0f);
318
+ float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_high_u16(mant_u16x8)), 1.0f / 16.0f);
327
319
  uint16x8_t subnormal_abs_u16x8 = vreinterpretq_u16_f16(
328
320
  vcombine_f16(vcvt_f16_f32(subnormal_low_f32x4), vcvt_f16_f32(subnormal_high_f32x4)));
329
321
  uint16x8_t subnormal_u16x8 = vorrq_u16(subnormal_abs_u16x8, sign_u16x8);
330
322
 
331
323
  // Blend: use subnormal result when exp=0, else normal
332
- uint16x8_t exp_zero_mask = vceqq_u16(exp_u16x8, vdupq_n_u16(0));
333
- uint16x8_t result_u16x8 = vbslq_u16(exp_zero_mask, subnormal_u16x8, normal_u16x8);
324
+ uint16x8_t exp_zero_mask_u16x8 = vceqq_u16(exp_u16x8, vdupq_n_u16(0));
325
+ uint16x8_t result_u16x8 = vbslq_u16(exp_zero_mask_u16x8, subnormal_u16x8, normal_u16x8);
334
326
 
335
327
  return vreinterpretq_f16_u16(result_u16x8);
336
328
  }
@@ -442,43 +434,43 @@ NK_INTERNAL uint8x8_t nk_f16x8_to_e4m3x8_neon_(float16x8_t f16x8) {
442
434
  uint16x8_t f16_mant_u16x8 = vandq_u16(bits_u16x8, vdupq_n_u16(0x03FF));
443
435
 
444
436
  // Rebias exponent: F16 bias=15 → E4M3 bias=7, subtract 8
445
- int16x8_t e4m3_exp_s16x8 = vsubq_s16(vreinterpretq_s16_u16(f16_exp_u16x8), vdupq_n_s16(8));
437
+ int16x8_t e4m3_exp_i16x8 = vsubq_s16(vreinterpretq_s16_u16(f16_exp_u16x8), vdupq_n_s16(8));
446
438
 
447
439
  // Detect special cases
448
- uint16x8_t is_f16_zero = vceqq_u16(vandq_u16(bits_u16x8, vdupq_n_u16(0x7FFF)), vdupq_n_u16(0));
449
- uint16x8_t is_f16_special = vceqq_u16(f16_exp_u16x8, vdupq_n_u16(31)); // inf or nan
450
- uint16x8_t is_f16_nan = vandq_u16(is_f16_special, vcgtq_u16(f16_mant_u16x8, vdupq_n_u16(0)));
451
- uint16x8_t is_underflow = vcltq_s16(e4m3_exp_s16x8, vdupq_n_s16(1)); // exp < 1 → subnormal/zero
452
- uint16x8_t is_overflow = vcgtq_s16(e4m3_exp_s16x8, vdupq_n_s16(15)); // exp > 15 → overflow
440
+ uint16x8_t is_f16_zero_u16x8 = vceqq_u16(vandq_u16(bits_u16x8, vdupq_n_u16(0x7FFF)), vdupq_n_u16(0));
441
+ uint16x8_t is_f16_special_u16x8 = vceqq_u16(f16_exp_u16x8, vdupq_n_u16(31)); // inf or nan
442
+ uint16x8_t is_f16_nan_u16x8 = vandq_u16(is_f16_special_u16x8, vcgtq_u16(f16_mant_u16x8, vdupq_n_u16(0)));
443
+ uint16x8_t is_underflow_u16x8 = vcltq_s16(e4m3_exp_i16x8, vdupq_n_s16(1)); // exp < 1 → subnormal/zero
444
+ uint16x8_t is_overflow_u16x8 = vcgtq_s16(e4m3_exp_i16x8, vdupq_n_s16(15)); // exp > 15 → overflow
453
445
 
454
446
  // Normal path with RNE rounding: round mantissa from 10 to 3 bits
455
447
  // RNE: add (0x3F + lsb) where lsb = bit 7 of mantissa
456
448
  uint16x8_t lsb_u16x8 = vandq_u16(vshrq_n_u16(f16_mant_u16x8, 7), vdupq_n_u16(1));
457
449
  uint16x8_t rounded_mant_u16x8 = vaddq_u16(f16_mant_u16x8, vaddq_u16(vdupq_n_u16(0x3F), lsb_u16x8));
458
450
  uint16x8_t carry_u16x8 = vshrq_n_u16(rounded_mant_u16x8, 10); // Mantissa overflow → carry to exponent
459
- e4m3_exp_s16x8 = vaddq_s16(e4m3_exp_s16x8, vreinterpretq_s16_u16(carry_u16x8));
451
+ e4m3_exp_i16x8 = vaddq_s16(e4m3_exp_i16x8, vreinterpretq_s16_u16(carry_u16x8));
460
452
  uint16x8_t e4m3_mant_u16x8 = vandq_u16(vshrq_n_u16(rounded_mant_u16x8, 7), vdupq_n_u16(0x07));
461
453
  e4m3_mant_u16x8 = vbicq_u16(e4m3_mant_u16x8, vceqq_u16(carry_u16x8, vdupq_n_u16(1))); // Clear mant if carry
462
454
 
463
455
  // Recheck overflow after rounding (carry might have pushed us over)
464
- is_overflow = vorrq_u16(is_overflow, vcgtq_s16(e4m3_exp_s16x8, vdupq_n_s16(15)));
456
+ is_overflow_u16x8 = vorrq_u16(is_overflow_u16x8, vcgtq_s16(e4m3_exp_i16x8, vdupq_n_s16(15)));
465
457
 
466
458
  // Clamp exponent to [1, 15] for normal values
467
- int16x8_t clamped_exp_s16x8 = vmaxq_s16(e4m3_exp_s16x8, vdupq_n_s16(1));
468
- clamped_exp_s16x8 = vminq_s16(clamped_exp_s16x8, vdupq_n_s16(15));
459
+ int16x8_t clamped_exp_i16x8 = vmaxq_s16(e4m3_exp_i16x8, vdupq_n_s16(1));
460
+ clamped_exp_i16x8 = vminq_s16(clamped_exp_i16x8, vdupq_n_s16(15));
469
461
 
470
462
  // E4M3FN quirk: exp=15, mant=7 is NaN, so clamp mantissa to 6 when exp=15
471
- uint16x8_t is_max_exp = vceqq_s16(clamped_exp_s16x8, vdupq_n_s16(15));
472
- e4m3_mant_u16x8 = vbslq_u16(is_max_exp, vminq_u16(e4m3_mant_u16x8, vdupq_n_u16(6)), e4m3_mant_u16x8);
463
+ uint16x8_t is_max_exp_u16x8 = vceqq_s16(clamped_exp_i16x8, vdupq_n_s16(15));
464
+ e4m3_mant_u16x8 = vbslq_u16(is_max_exp_u16x8, vminq_u16(e4m3_mant_u16x8, vdupq_n_u16(6)), e4m3_mant_u16x8);
473
465
 
474
466
  // Assemble normal result
475
467
  uint16x8_t normal_result_u16x8 = vorrq_u16(
476
- sign_byte_u16x8, vorrq_u16(vshlq_n_u16(vreinterpretq_u16_s16(clamped_exp_s16x8), 3), e4m3_mant_u16x8));
468
+ sign_byte_u16x8, vorrq_u16(vshlq_n_u16(vreinterpretq_u16_s16(clamped_exp_i16x8), 3), e4m3_mant_u16x8));
477
469
 
478
470
  // Subnormal path: E4M3 subnormal = mant × 2⁻⁹
479
471
  // Use float conversion for correctness: abs(f16) × 512, round to int, clamp to [0,7]
480
472
  float32x4_t abs_low_f32x4 = vabsq_f32(vcvt_f32_f16(vget_low_f16(f16x8)));
481
- float32x4_t abs_high_f32x4 = vabsq_f32(vcvt_f32_f16(vget_high_f16(f16x8)));
473
+ float32x4_t abs_high_f32x4 = vabsq_f32(vcvt_high_f32_f16(f16x8));
482
474
  float32x4_t scaled_low_f32x4 = vmulq_n_f32(abs_low_f32x4, 512.0f);
483
475
  float32x4_t scaled_high_f32x4 = vmulq_n_f32(abs_high_f32x4, 512.0f);
484
476
  int32x4_t subnormal_mantissa_low_i32x4 = vcvtnq_s32_f32(scaled_low_f32x4); // Round to nearest even
@@ -492,17 +484,18 @@ NK_INTERNAL uint8x8_t nk_f16x8_to_e4m3x8_neon_(float16x8_t f16x8) {
492
484
  uint16x8_t subnormal_result_u16x8 = vorrq_u16(sign_byte_u16x8, subnormal_mant_u16x8);
493
485
 
494
486
  // Special values: E4M3FN has no ∞, max normal = 0x7E (exp=15, mant=6 = 448)
495
- uint16x8_t e4m3_max = vorrq_u16(sign_byte_u16x8, vdupq_n_u16(0x7E)); // ±448 (exp=15, mant=6)
496
- uint16x8_t e4m3_nan = vorrq_u16(sign_byte_u16x8, vdupq_n_u16(0x7F)); // ±NaN (exp=15, mant=7)
497
- uint16x8_t e4m3_zero = sign_byte_u16x8; // ±0
487
+ uint16x8_t e4m3_max_u16x8 = vorrq_u16(sign_byte_u16x8, vdupq_n_u16(0x7E)); // ±448 (exp=15, mant=6)
488
+ uint16x8_t e4m3_nan_u16x8 = vorrq_u16(sign_byte_u16x8, vdupq_n_u16(0x7F)); // ±NaN (exp=15, mant=7)
489
+ uint16x8_t e4m3_zero_u16x8 = sign_byte_u16x8; // ±0
498
490
 
499
491
  // Blend results (order matters: later conditions override earlier)
500
492
  uint16x8_t result_u16x8 = normal_result_u16x8;
501
- result_u16x8 = vbslq_u16(is_underflow, subnormal_result_u16x8, result_u16x8);
502
- result_u16x8 = vbslq_u16(is_overflow, e4m3_max, result_u16x8);
503
- result_u16x8 = vbslq_u16(is_f16_special, e4m3_max, result_u16x8); // F16 inf → E4M3 max (no inf in E4M3FN)
504
- result_u16x8 = vbslq_u16(is_f16_nan, e4m3_nan, result_u16x8); // F16 nan → E4M3 nan
505
- result_u16x8 = vbslq_u16(is_f16_zero, e4m3_zero, result_u16x8); // Preserve ±0
493
+ result_u16x8 = vbslq_u16(is_underflow_u16x8, subnormal_result_u16x8, result_u16x8);
494
+ result_u16x8 = vbslq_u16(is_overflow_u16x8, e4m3_max_u16x8, result_u16x8);
495
+ result_u16x8 = vbslq_u16(is_f16_special_u16x8, e4m3_max_u16x8,
496
+ result_u16x8); // F16 inf → E4M3 max (no inf in E4M3FN)
497
+ result_u16x8 = vbslq_u16(is_f16_nan_u16x8, e4m3_nan_u16x8, result_u16x8); // F16 nan → E4M3 nan
498
+ result_u16x8 = vbslq_u16(is_f16_zero_u16x8, e4m3_zero_u16x8, result_u16x8); // Preserve ±0
506
499
 
507
500
  return vmovn_u16(result_u16x8);
508
501
  }
@@ -515,7 +508,7 @@ NK_INTERNAL uint8x8_t nk_f16x8_to_e5m2x8_neon_(float16x8_t f16x8) {
515
508
 
516
509
  // Detect inf/nan (exp=31) - these should not be rounded, just truncated
517
510
  uint16x8_t exp_u16x8 = vandq_u16(vshrq_n_u16(bits_u16x8, 10), vdupq_n_u16(0x1F));
518
- uint16x8_t is_special_mask = vceqq_u16(exp_u16x8, vdupq_n_u16(31));
511
+ uint16x8_t is_special_mask_u16x8 = vceqq_u16(exp_u16x8, vdupq_n_u16(31));
519
512
 
520
513
  // RNE rounding: add (0x7F + lsb) where lsb = bit 8 of F16
521
514
  // This rounds the lower 8 bits correctly and may carry into exponent
@@ -524,7 +517,7 @@ NK_INTERNAL uint8x8_t nk_f16x8_to_e5m2x8_neon_(float16x8_t f16x8) {
524
517
  uint16x8_t rounded_bits_u16x8 = vaddq_u16(bits_u16x8, rounding_bias_u16x8);
525
518
 
526
519
  // For special values (inf/nan), use original bits without rounding
527
- uint16x8_t final_bits_u16x8 = vbslq_u16(is_special_mask, bits_u16x8, rounded_bits_u16x8);
520
+ uint16x8_t final_bits_u16x8 = vbslq_u16(is_special_mask_u16x8, bits_u16x8, rounded_bits_u16x8);
528
521
 
529
522
  // Shift right by 8 to get E5M2 format
530
523
  uint16x8_t e5m2_u16x8 = vshrq_n_u16(final_bits_u16x8, 8);
@@ -539,32 +532,6 @@ NK_INTERNAL float32x4_t nk_bf16x4_to_f32x4_neon_(uint16x4_t bf16_u16x4) {
539
532
  return vreinterpretq_f32_u32(bits_u32x4);
540
533
  }
541
534
 
542
- /** @brief Convert 4x f16 (as u16 bits) → f32x4 via integer bit manipulation (NEON).
543
- * F16 format: S EEEEE MMMMMMMMMM (bias=15, 5-bit exponent, 10-bit mantissa).
544
- * Works on ARMv8.0 without the FP16 arithmetic extension. Treats denormals as zero. */
545
- NK_INTERNAL float32x4_t nk_f16x4_to_f32x4_neon_(uint16x4_t half_u16x4) {
546
- // Widen u16 to u32
547
- uint32x4_t bits_u32x4 = vmovl_u16(half_u16x4);
548
- // Extract sign, exponent, mantissa
549
- uint32x4_t sign_u32x4 = vshlq_n_u32(vandq_u32(bits_u32x4, vdupq_n_u32(0x8000)), 16);
550
- uint32x4_t exponent_u32x4 = vandq_u32(bits_u32x4, vdupq_n_u32(0x7C00));
551
- uint32x4_t mantissa_u32x4 = vandq_u32(bits_u32x4, vdupq_n_u32(0x03FF));
552
- // Normal path: ((exponent + mantissa) << 13) + rebias(112 << 23 = 0x38000000)
553
- uint32x4_t exponent_mantissa_u32x4 = vandq_u32(bits_u32x4, vdupq_n_u32(0x7FFF));
554
- uint32x4_t normal_u32x4 = vaddq_u32(vshlq_n_u32(exponent_mantissa_u32x4, 13), vdupq_n_u32(0x38000000));
555
- // Inf/NaN path (exponent == 0x7C00): 0x7F800000 | (mantissa << 13)
556
- uint32x4_t inf_nan_u32x4 = vorrq_u32(vdupq_n_u32(0x7F800000), vshlq_n_u32(mantissa_u32x4, 13));
557
- // Select inf/NaN where exponent == 31 (0x7C00)
558
- uint32x4_t is_inf_nan_u32x4 = vceqq_u32(exponent_u32x4, vdupq_n_u32(0x7C00));
559
- uint32x4_t result_u32x4 = vbslq_u32(is_inf_nan_u32x4, inf_nan_u32x4, normal_u32x4);
560
- // Zero path (exponent == 0): treat denormals as zero for simplicity
561
- uint32x4_t is_zero_u32x4 = vceqq_u32(exponent_u32x4, vdupq_n_u32(0));
562
- result_u32x4 = vbslq_u32(is_zero_u32x4, vdupq_n_u32(0), result_u32x4);
563
- // OR sign back
564
- result_u32x4 = vorrq_u32(result_u32x4, sign_u32x4);
565
- return vreinterpretq_f32_u32(result_u32x4);
566
- }
567
-
568
535
  /** @brief Convert f32x4 → 4x bf16 with RNE rounding (NEON).
569
536
  * Round-to-nearest-even: add (0x7FFF + lsb) before truncation. */
570
537
  NK_INTERNAL uint16x4_t nk_f32x4_to_bf16x4_neon_(float32x4_t f32x4) {
@@ -592,19 +559,20 @@ NK_INTERNAL uint16x8_t nk_e4m3x8_to_bf16x8_neon_(uint8x8_t e4m3_u8x8) {
592
559
  // Subnormal path (exp=0): E4M3 subnormal = mant × 2⁻⁹ = mant ÷ 512 → BF16
593
560
  // Compute via f32: mant → f32 → multiply → truncate to bf16
594
561
  float32x4_t subnormal_low_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(mant_u16x8))), 1.0f / 512.0f);
595
- float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(mant_u16x8))), 1.0f / 512.0f);
562
+ float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_high_u16(mant_u16x8)), 1.0f / 512.0f);
596
563
  uint16x8_t subnormal_abs_u16x8 = vcombine_u16(nk_f32x4_to_bf16x4_neon_(subnormal_low_f32x4),
597
564
  nk_f32x4_to_bf16x4_neon_(subnormal_high_f32x4));
598
565
  uint16x8_t subnormal_u16x8 = vorrq_u16(subnormal_abs_u16x8, sign_u16x8);
599
566
 
600
567
  // NaN path: E4M3FN only has NaN when exp=15 AND mant=7 (0x7F or 0xFF)
601
568
  uint16x8_t nan_u16x8 = vorrq_u16(sign_u16x8, vdupq_n_u16(0x7FC0)); // BF16 quiet NaN
602
- uint16x8_t is_nan_mask = vandq_u16(vceqq_u16(exp_u16x8, vdupq_n_u16(15)), vceqq_u16(mant_u16x8, vdupq_n_u16(7)));
569
+ uint16x8_t is_nan_mask_u16x8 = vandq_u16(vceqq_u16(exp_u16x8, vdupq_n_u16(15)),
570
+ vceqq_u16(mant_u16x8, vdupq_n_u16(7)));
603
571
 
604
572
  // Blend paths: subnormal when exp=0, NaN when exp=15 && mant=7, else normal
605
- uint16x8_t exp_zero_mask = vceqq_u16(exp_u16x8, vdupq_n_u16(0));
606
- uint16x8_t result_u16x8 = vbslq_u16(exp_zero_mask, subnormal_u16x8, normal_u16x8);
607
- result_u16x8 = vbslq_u16(is_nan_mask, nan_u16x8, result_u16x8);
573
+ uint16x8_t exp_zero_mask_u16x8 = vceqq_u16(exp_u16x8, vdupq_n_u16(0));
574
+ uint16x8_t result_u16x8 = vbslq_u16(exp_zero_mask_u16x8, subnormal_u16x8, normal_u16x8);
575
+ result_u16x8 = vbslq_u16(is_nan_mask_u16x8, nan_u16x8, result_u16x8);
608
576
  return result_u16x8;
609
577
  }
610
578
 
@@ -625,8 +593,7 @@ NK_INTERNAL uint16x8_t nk_e5m2x8_to_bf16x8_neon_(uint8x8_t e5m2_u8x8) {
625
593
  // Subnormal path (exp=0): E5M2 subnormal = mant × 2⁻¹⁶ = mant ÷ 65536 → BF16
626
594
  // Compute via f32: mant → f32 → multiply → truncate to bf16
627
595
  float32x4_t subnormal_low_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(mant_u16x8))), 1.0f / 65536.0f);
628
- float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(mant_u16x8))),
629
- 1.0f / 65536.0f);
596
+ float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_high_u16(mant_u16x8)), 1.0f / 65536.0f);
630
597
  uint16x8_t subnormal_abs_u16x8 = vcombine_u16(nk_f32x4_to_bf16x4_neon_(subnormal_low_f32x4),
631
598
  nk_f32x4_to_bf16x4_neon_(subnormal_high_f32x4));
632
599
  uint16x8_t subnormal_u16x8 = vorrq_u16(subnormal_abs_u16x8, sign_u16x8);
@@ -634,14 +601,14 @@ NK_INTERNAL uint16x8_t nk_e5m2x8_to_bf16x8_neon_(uint8x8_t e5m2_u8x8) {
634
601
  // Special path (exp=31): inf (mant=0) or nan (mant≠0)
635
602
  uint16x8_t infinity_u16x8 = vorrq_u16(sign_u16x8, vdupq_n_u16(0x7F80));
636
603
  uint16x8_t nan_u16x8 = vorrq_u16(sign_u16x8, vdupq_n_u16(0x7FC0));
637
- uint16x8_t mant_zero_mask = vceqq_u16(mant_u16x8, vdupq_n_u16(0));
638
- uint16x8_t special_u16x8 = vbslq_u16(mant_zero_mask, infinity_u16x8, nan_u16x8);
604
+ uint16x8_t mant_zero_mask_u16x8 = vceqq_u16(mant_u16x8, vdupq_n_u16(0));
605
+ uint16x8_t special_u16x8 = vbslq_u16(mant_zero_mask_u16x8, infinity_u16x8, nan_u16x8);
639
606
 
640
607
  // Blend paths based on exponent value
641
- uint16x8_t exp_zero_mask = vceqq_u16(exp_u16x8, vdupq_n_u16(0));
642
- uint16x8_t exp_max_mask = vceqq_u16(exp_u16x8, vdupq_n_u16(31));
643
- uint16x8_t result_u16x8 = vbslq_u16(exp_zero_mask, subnormal_u16x8, normal_u16x8);
644
- result_u16x8 = vbslq_u16(exp_max_mask, special_u16x8, result_u16x8);
608
+ uint16x8_t exp_zero_mask_u16x8 = vceqq_u16(exp_u16x8, vdupq_n_u16(0));
609
+ uint16x8_t exp_max_mask_u16x8 = vceqq_u16(exp_u16x8, vdupq_n_u16(31));
610
+ uint16x8_t result_u16x8 = vbslq_u16(exp_zero_mask_u16x8, subnormal_u16x8, normal_u16x8);
611
+ result_u16x8 = vbslq_u16(exp_max_mask_u16x8, special_u16x8, result_u16x8);
645
612
  return result_u16x8;
646
613
  }
647
614
 
@@ -678,21 +645,23 @@ NK_INTERNAL uint16x4_t nk_f32x4_to_u16x4_neon_(float32x4_t f32x4) {
678
645
  }
679
646
 
680
647
  /** @brief Convert f32x4 → 4x i8 with saturation (NEON). Convert to i32, narrow twice. */
681
- NK_INTERNAL void nk_f32x4_to_i8x4_neon_(float32x4_t f32x4, nk_i8_t *dst) {
648
+ NK_INTERNAL nk_b32_vec_t nk_f32x4_to_i8x4_neon_(float32x4_t f32x4) {
682
649
  int32x4_t i32x4 = vcvtnq_s32_f32(f32x4);
683
650
  int16x4_t i16x4 = vqmovn_s32(i32x4);
684
651
  int8x8_t i8x8 = vqmovn_s16(vcombine_s16(i16x4, i16x4));
685
- // Reinterpret as s32x2, store lane 0 (4 bytes in one instruction)
686
- vst1_lane_s32((int32_t *)dst, vreinterpret_s32_s8(i8x8), 0);
652
+ nk_b32_vec_t result_vec;
653
+ result_vec.u32 = vget_lane_u32(vreinterpret_u32_s8(i8x8), 0);
654
+ return result_vec;
687
655
  }
688
656
 
689
657
  /** @brief Convert f32x4 → 4x u8 with saturation (NEON). Convert to u32, narrow twice. */
690
- NK_INTERNAL void nk_f32x4_to_u8x4_neon_(float32x4_t f32x4, nk_u8_t *dst) {
658
+ NK_INTERNAL nk_b32_vec_t nk_f32x4_to_u8x4_neon_(float32x4_t f32x4) {
691
659
  uint32x4_t u32x4 = vcvtnq_u32_f32(f32x4);
692
660
  uint16x4_t u16x4 = vqmovn_u32(u32x4);
693
661
  uint8x8_t u8x8 = vqmovn_u16(vcombine_u16(u16x4, u16x4));
694
- // Reinterpret as u32x2, store lane 0 (4 bytes in one instruction)
695
- vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(u8x8), 0);
662
+ nk_b32_vec_t result_vec;
663
+ result_vec.u32 = vget_lane_u32(vreinterpret_u32_u8(u8x8), 0);
664
+ return result_vec;
696
665
  }
697
666
 
698
667
  /** @brief Convert f32x4 → 4x e4m3 via bit manipulation (NEON).
@@ -830,6 +799,8 @@ NK_INTERNAL float32x4_t nk_e2m3x4_to_f32x4_neon_(nk_b32_vec_t src) {
830
799
  uint8x8_t e2m3_u8x8 = vcreate_u8(src.u32);
831
800
  uint16x8_t e2m3_u16x8 = vmovl_u8(e2m3_u8x8);
832
801
  uint32x4_t e2m3_u32x4 = vmovl_u16(vget_low_u16(e2m3_u16x8));
802
+
803
+ // Extract sign: bit 5 → bit 31
833
804
  uint32x4_t sign_u32x4 = vshlq_n_u32(vandq_u32(e2m3_u32x4, vdupq_n_u32(0x20)), 26);
834
805
  uint32x4_t exp_u32x4 = vandq_u32(vshrq_n_u32(e2m3_u32x4, 3), vdupq_n_u32(0x03));
835
806
  uint32x4_t mant_u32x4 = vandq_u32(e2m3_u32x4, vdupq_n_u32(0x07));
@@ -844,8 +815,8 @@ NK_INTERNAL float32x4_t nk_e2m3x4_to_f32x4_neon_(nk_b32_vec_t src) {
844
815
  uint32x4_t subnormal_u32x4 = vorrq_u32(vreinterpretq_u32_f32(subnormal_f32x4), sign_u32x4);
845
816
 
846
817
  // Blend paths: subnormal when exp=0, else normal
847
- uint32x4_t exp_zero_mask = vceqq_u32(exp_u32x4, vdupq_n_u32(0));
848
- uint32x4_t result_u32x4 = vbslq_u32(exp_zero_mask, subnormal_u32x4, normal_u32x4);
818
+ uint32x4_t exp_zero_mask_u32x4 = vceqq_u32(exp_u32x4, vdupq_n_u32(0));
819
+ uint32x4_t result_u32x4 = vbslq_u32(exp_zero_mask_u32x4, subnormal_u32x4, normal_u32x4);
849
820
  return vreinterpretq_f32_u32(result_u32x4);
850
821
  }
851
822
 
@@ -856,6 +827,8 @@ NK_INTERNAL float32x4_t nk_e3m2x4_to_f32x4_neon_(nk_b32_vec_t src) {
856
827
  uint8x8_t e3m2_u8x8 = vcreate_u8(src.u32);
857
828
  uint16x8_t e3m2_u16x8 = vmovl_u8(e3m2_u8x8);
858
829
  uint32x4_t e3m2_u32x4 = vmovl_u16(vget_low_u16(e3m2_u16x8));
830
+
831
+ // Extract sign: bit 5 → bit 31
859
832
  uint32x4_t sign_u32x4 = vshlq_n_u32(vandq_u32(e3m2_u32x4, vdupq_n_u32(0x20)), 26);
860
833
  uint32x4_t exp_u32x4 = vandq_u32(vshrq_n_u32(e3m2_u32x4, 2), vdupq_n_u32(0x07));
861
834
  uint32x4_t mant_u32x4 = vandq_u32(e3m2_u32x4, vdupq_n_u32(0x03));
@@ -870,8 +843,8 @@ NK_INTERNAL float32x4_t nk_e3m2x4_to_f32x4_neon_(nk_b32_vec_t src) {
870
843
  uint32x4_t subnormal_u32x4 = vorrq_u32(vreinterpretq_u32_f32(subnormal_f32x4), sign_u32x4);
871
844
 
872
845
  // Blend paths: subnormal when exp=0, else normal
873
- uint32x4_t exp_zero_mask = vceqq_u32(exp_u32x4, vdupq_n_u32(0));
874
- uint32x4_t result_u32x4 = vbslq_u32(exp_zero_mask, subnormal_u32x4, normal_u32x4);
846
+ uint32x4_t exp_zero_mask_u32x4 = vceqq_u32(exp_u32x4, vdupq_n_u32(0));
847
+ uint32x4_t result_u32x4 = vbslq_u32(exp_zero_mask_u32x4, subnormal_u32x4, normal_u32x4);
875
848
  return vreinterpretq_f32_u32(result_u32x4);
876
849
  }
877
850
 
@@ -997,9 +970,9 @@ NK_INTERNAL nk_b32_vec_t nk_f32x4_to_e3m2x4_neon_(float32x4_t f32x4) {
997
970
  return result;
998
971
  }
999
972
 
1000
- #pragma endregion - Vectorized Conversions
973
+ #pragma endregion Vectorized Conversions
1001
974
 
1002
- #pragma region - Public API
975
+ #pragma region Public API
1003
976
 
1004
977
  NK_PUBLIC void nk_cast_neon(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type) {
1005
978
  // Same-type fast path
@@ -1044,38 +1017,37 @@ NK_PUBLIC void nk_cast_neon(void const *from, nk_dtype_t from_type, nk_size_t n,
1044
1017
  nk_u8_t *to_ptr = (nk_u8_t *)to;
1045
1018
 
1046
1019
  for (nk_size_t idx = 0; idx < batches; ++idx, from_ptr += from_step, to_ptr += to_step) {
1020
+ nk_b128_vec_t hub_vec;
1021
+
1047
1022
  // Upcast to f16x8 hub
1048
- float16x8_t hub_f16x8;
1049
1023
  switch (from_type) {
1050
- case nk_e4m3_k: hub_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(from_ptr)); break;
1051
- case nk_e5m2_k: hub_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(from_ptr)); break;
1052
- case nk_e2m3_k: hub_f16x8 = nk_e2m3x8_to_f16x8_neon_(vld1_u8(from_ptr)); break;
1053
- case nk_e3m2_k: hub_f16x8 = nk_e3m2x8_to_f16x8_neon_(vld1_u8(from_ptr)); break;
1054
- case nk_f16_k: hub_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)from_ptr)); break;
1024
+ case nk_e4m3_k: hub_vec.u16x8 = vreinterpretq_u16_f16(nk_e4m3x8_to_f16x8_neon_(vld1_u8(from_ptr))); break;
1025
+ case nk_e5m2_k: hub_vec.u16x8 = vreinterpretq_u16_f16(nk_e5m2x8_to_f16x8_neon_(vld1_u8(from_ptr))); break;
1026
+ case nk_e2m3_k: hub_vec.u16x8 = vreinterpretq_u16_f16(nk_e2m3x8_to_f16x8_neon_(vld1_u8(from_ptr))); break;
1027
+ case nk_e3m2_k: hub_vec.u16x8 = vreinterpretq_u16_f16(nk_e3m2x8_to_f16x8_neon_(vld1_u8(from_ptr))); break;
1028
+ case nk_f16_k: hub_vec.u16x8 = vld1q_u16((nk_u16_t const *)from_ptr); break;
1055
1029
  case nk_bf16_k: {
1056
- uint16x4_t brain_low_u16x4 = vld1_u16((nk_u16_t const *)from_ptr);
1057
- uint16x4_t brain_high_u16x4 = vld1_u16((nk_u16_t const *)(from_ptr + 8));
1058
- float32x4_t ieee_low_f32x4 = nk_bf16x4_to_f32x4_neon_(brain_low_u16x4);
1059
- float32x4_t ieee_high_f32x4 = nk_bf16x4_to_f32x4_neon_(brain_high_u16x4);
1060
- hub_f16x8 = vcombine_f16(vcvt_f16_f32(ieee_low_f32x4), vcvt_f16_f32(ieee_high_f32x4));
1030
+ float32x4_t low_f32x4 = nk_bf16x4_to_f32x4_neon_(vld1_u16((nk_u16_t const *)from_ptr));
1031
+ float32x4_t high_f32x4 = nk_bf16x4_to_f32x4_neon_(vld1_u16((nk_u16_t const *)(from_ptr + 8)));
1032
+ hub_vec.u16x8 = vreinterpretq_u16_f16(vcombine_f16(vcvt_f16_f32(low_f32x4), vcvt_f16_f32(high_f32x4)));
1061
1033
  } break;
1062
- default: hub_f16x8 = vreinterpretq_f16_u16(vdupq_n_u16(0)); break;
1034
+ default: hub_vec.u16x8 = vdupq_n_u16(0); break;
1063
1035
  }
1064
1036
 
1065
1037
  // Downcast from f16x8 hub
1066
1038
  switch (to_type) {
1067
- case nk_e4m3_k: vst1_u8(to_ptr, nk_f16x8_to_e4m3x8_neon_(hub_f16x8)); break;
1068
- case nk_e5m2_k: vst1_u8(to_ptr, nk_f16x8_to_e5m2x8_neon_(hub_f16x8)); break;
1069
- case nk_f16_k: vst1q_u16((nk_u16_t *)to_ptr, vreinterpretq_u16_f16(hub_f16x8)); break;
1039
+ case nk_e4m3_k: vst1_u8(to_ptr, nk_f16x8_to_e4m3x8_neon_(vreinterpretq_f16_u16(hub_vec.u16x8))); break;
1040
+ case nk_e5m2_k: vst1_u8(to_ptr, nk_f16x8_to_e5m2x8_neon_(vreinterpretq_f16_u16(hub_vec.u16x8))); break;
1041
+ case nk_f16_k: vst1q_u16((nk_u16_t *)to_ptr, hub_vec.u16x8); break;
1070
1042
  case nk_bf16_k: {
1071
- float32x4_t ieee_low_f32x4 = vcvt_f32_f16(vget_low_f16(hub_f16x8));
1072
- float32x4_t ieee_high_f32x4 = vcvt_f32_f16(vget_high_f16(hub_f16x8));
1073
- vst1_u16((nk_u16_t *)to_ptr, nk_f32x4_to_bf16x4_neon_(ieee_low_f32x4));
1074
- vst1_u16((nk_u16_t *)(to_ptr + 8), nk_f32x4_to_bf16x4_neon_(ieee_high_f32x4));
1043
+ float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(vreinterpretq_f16_u16(hub_vec.u16x8)));
1044
+ float32x4_t high_f32x4 = vcvt_high_f32_f16(vreinterpretq_f16_u16(hub_vec.u16x8));
1045
+ vst1_u16((nk_u16_t *)to_ptr, nk_f32x4_to_bf16x4_neon_(low_f32x4));
1046
+ vst1_u16((nk_u16_t *)(to_ptr + 8), nk_f32x4_to_bf16x4_neon_(high_f32x4));
1075
1047
  } break;
1076
1048
  case nk_f32_k: {
1077
- vst1q_f32((nk_f32_t *)to_ptr, vcvt_f32_f16(vget_low_f16(hub_f16x8)));
1078
- vst1q_f32((nk_f32_t *)(to_ptr + 16), vcvt_f32_f16(vget_high_f16(hub_f16x8)));
1049
+ vst1q_f32((nk_f32_t *)to_ptr, vcvt_f32_f16(vget_low_f16(vreinterpretq_f16_u16(hub_vec.u16x8))));
1050
+ vst1q_f32((nk_f32_t *)(to_ptr + 16), vcvt_high_f32_f16(vreinterpretq_f16_u16(hub_vec.u16x8)));
1079
1051
  } break;
1080
1052
  default: break;
1081
1053
  }
@@ -1097,76 +1069,71 @@ NK_PUBLIC void nk_cast_neon(void const *from, nk_dtype_t from_type, nk_size_t n,
1097
1069
  nk_u8_t *to_ptr = (nk_u8_t *)to;
1098
1070
 
1099
1071
  for (nk_size_t idx = 0; idx < batches; ++idx, from_ptr += from_step, to_ptr += to_step) {
1100
- // Load and upcast to f32x4
1101
- float32x4_t hub_f32x4;
1072
+ nk_b128_vec_t hub_vec;
1073
+
1074
+ // Upcast to f32x4 hub
1102
1075
  switch (from_type) {
1103
- case nk_f32_k: hub_f32x4 = vld1q_f32((nk_f32_t const *)from_ptr); break;
1104
- case nk_f16_k: hub_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16((nk_u16_t const *)from_ptr))); break;
1105
- case nk_bf16_k: hub_f32x4 = nk_bf16x4_to_f32x4_neon_(vld1_u16((nk_u16_t const *)from_ptr)); break;
1106
- case nk_e4m3_k: {
1107
- nk_b32_vec_t in_vec;
1108
- nk_load_b32_serial_(from_ptr, &in_vec);
1109
- hub_f32x4 = nk_e4m3x4_to_f32x4_neon_(in_vec);
1110
- } break;
1111
- case nk_e5m2_k: {
1112
- nk_b32_vec_t in_vec;
1113
- nk_load_b32_serial_(from_ptr, &in_vec);
1114
- hub_f32x4 = nk_e5m2x4_to_f32x4_neon_(in_vec);
1115
- } break;
1116
- case nk_e2m3_k: {
1117
- nk_b32_vec_t in_vec;
1118
- nk_load_b32_serial_(from_ptr, &in_vec);
1119
- hub_f32x4 = nk_e2m3x4_to_f32x4_neon_(in_vec);
1120
- } break;
1121
- case nk_e3m2_k: {
1122
- nk_b32_vec_t in_vec;
1123
- nk_load_b32_serial_(from_ptr, &in_vec);
1124
- hub_f32x4 = nk_e3m2x4_to_f32x4_neon_(in_vec);
1125
- } break;
1126
- case nk_i32_k: hub_f32x4 = vcvtq_f32_s32(vld1q_s32((nk_i32_t const *)from_ptr)); break;
1127
- case nk_u32_k: hub_f32x4 = vcvtq_f32_u32(vld1q_u32((nk_u32_t const *)from_ptr)); break;
1128
- case nk_i16_k: hub_f32x4 = nk_i16x4_to_f32x4_neon_(vld1_s16((nk_i16_t const *)from_ptr)); break;
1129
- case nk_u16_k: hub_f32x4 = nk_u16x4_to_f32x4_neon_(vld1_u16((nk_u16_t const *)from_ptr)); break;
1130
- case nk_i8_k: {
1131
- nk_b32_vec_t in_vec;
1132
- nk_load_b32_serial_(from_ptr, &in_vec);
1133
- hub_f32x4 = nk_i8x4_to_f32x4_neon_(in_vec);
1134
- } break;
1135
- case nk_u8_k: {
1136
- nk_b32_vec_t in_vec;
1137
- nk_load_b32_serial_(from_ptr, &in_vec);
1138
- hub_f32x4 = nk_u8x4_to_f32x4_neon_(in_vec);
1139
- } break;
1140
- default: hub_f32x4 = vdupq_n_f32(0); break;
1076
+ case nk_f32_k: hub_vec.f32x4 = vld1q_f32((nk_f32_t const *)from_ptr); break;
1077
+ case nk_f16_k: hub_vec.f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16((nk_u16_t const *)from_ptr))); break;
1078
+ case nk_bf16_k: hub_vec.f32x4 = nk_bf16x4_to_f32x4_neon_(vld1_u16((nk_u16_t const *)from_ptr)); break;
1079
+ case nk_e4m3_k:
1080
+ hub_vec.u32s[0] = *(nk_u32_t const *)from_ptr;
1081
+ hub_vec.f32x4 = nk_e4m3x4_to_f32x4_neon_(*(nk_b32_vec_t *)&hub_vec);
1082
+ break;
1083
+ case nk_e5m2_k:
1084
+ hub_vec.u32s[0] = *(nk_u32_t const *)from_ptr;
1085
+ hub_vec.f32x4 = nk_e5m2x4_to_f32x4_neon_(*(nk_b32_vec_t *)&hub_vec);
1086
+ break;
1087
+ case nk_e2m3_k:
1088
+ hub_vec.u32s[0] = *(nk_u32_t const *)from_ptr;
1089
+ hub_vec.f32x4 = nk_e2m3x4_to_f32x4_neon_(*(nk_b32_vec_t *)&hub_vec);
1090
+ break;
1091
+ case nk_e3m2_k:
1092
+ hub_vec.u32s[0] = *(nk_u32_t const *)from_ptr;
1093
+ hub_vec.f32x4 = nk_e3m2x4_to_f32x4_neon_(*(nk_b32_vec_t *)&hub_vec);
1094
+ break;
1095
+ case nk_i32_k: hub_vec.f32x4 = vcvtq_f32_s32(vld1q_s32((nk_i32_t const *)from_ptr)); break;
1096
+ case nk_u32_k: hub_vec.f32x4 = vcvtq_f32_u32(vld1q_u32((nk_u32_t const *)from_ptr)); break;
1097
+ case nk_i16_k: hub_vec.f32x4 = nk_i16x4_to_f32x4_neon_(vld1_s16((nk_i16_t const *)from_ptr)); break;
1098
+ case nk_u16_k: hub_vec.f32x4 = nk_u16x4_to_f32x4_neon_(vld1_u16((nk_u16_t const *)from_ptr)); break;
1099
+ case nk_i8_k:
1100
+ hub_vec.u32s[0] = *(nk_u32_t const *)from_ptr;
1101
+ hub_vec.f32x4 = nk_i8x4_to_f32x4_neon_(*(nk_b32_vec_t *)&hub_vec);
1102
+ break;
1103
+ case nk_u8_k:
1104
+ hub_vec.u32s[0] = *(nk_u32_t const *)from_ptr;
1105
+ hub_vec.f32x4 = nk_u8x4_to_f32x4_neon_(*(nk_b32_vec_t *)&hub_vec);
1106
+ break;
1107
+ default: hub_vec.f32x4 = vdupq_n_f32(0); break;
1141
1108
  }
1142
1109
 
1143
- // Downcast from f32x4 and store
1110
+ // Downcast from f32x4 hub and store
1144
1111
  switch (to_type) {
1145
- case nk_f32_k: vst1q_f32((nk_f32_t *)to_ptr, hub_f32x4); break;
1146
- case nk_f16_k: vst1_u16((nk_u16_t *)to_ptr, vreinterpret_u16_f16(vcvt_f16_f32(hub_f32x4))); break;
1147
- case nk_bf16_k: vst1_u16((nk_u16_t *)to_ptr, nk_f32x4_to_bf16x4_neon_(hub_f32x4)); break;
1148
- case nk_e4m3_k: {
1149
- nk_b32_vec_t out_vec = nk_f32x4_to_e4m3x4_neon_(hub_f32x4);
1150
- *(nk_u32_t *)to_ptr = out_vec.u32;
1151
- } break;
1152
- case nk_e5m2_k: {
1153
- nk_b32_vec_t out_vec = nk_f32x4_to_e5m2x4_neon_(hub_f32x4);
1154
- *(nk_u32_t *)to_ptr = out_vec.u32;
1155
- } break;
1156
- case nk_e2m3_k: {
1157
- nk_b32_vec_t out_vec = nk_f32x4_to_e2m3x4_neon_(hub_f32x4);
1158
- nk_copy_bytes_(to_ptr, &out_vec, sizeof(nk_b32_vec_t));
1159
- } break;
1160
- case nk_e3m2_k: {
1161
- nk_b32_vec_t out_vec = nk_f32x4_to_e3m2x4_neon_(hub_f32x4);
1162
- nk_copy_bytes_(to_ptr, &out_vec, sizeof(nk_b32_vec_t));
1163
- } break;
1164
- case nk_i32_k: vst1q_s32((nk_i32_t *)to_ptr, vcvtnq_s32_f32(hub_f32x4)); break;
1165
- case nk_u32_k: vst1q_u32((nk_u32_t *)to_ptr, vcvtnq_u32_f32(hub_f32x4)); break;
1166
- case nk_i16_k: vst1_s16((nk_i16_t *)to_ptr, nk_f32x4_to_i16x4_neon_(hub_f32x4)); break;
1167
- case nk_u16_k: vst1_u16((nk_u16_t *)to_ptr, nk_f32x4_to_u16x4_neon_(hub_f32x4)); break;
1168
- case nk_i8_k: nk_f32x4_to_i8x4_neon_(hub_f32x4, (nk_i8_t *)to_ptr); break;
1169
- case nk_u8_k: nk_f32x4_to_u8x4_neon_(hub_f32x4, (nk_u8_t *)to_ptr); break;
1112
+ case nk_f32_k: vst1q_f32((nk_f32_t *)to_ptr, hub_vec.f32x4); break;
1113
+ case nk_f16_k: vst1_u16((nk_u16_t *)to_ptr, vreinterpret_u16_f16(vcvt_f16_f32(hub_vec.f32x4))); break;
1114
+ case nk_bf16_k: vst1_u16((nk_u16_t *)to_ptr, nk_f32x4_to_bf16x4_neon_(hub_vec.f32x4)); break;
1115
+ case nk_e4m3_k:
1116
+ vst1_lane_u32((nk_u32_t *)to_ptr, vcreate_u32(nk_f32x4_to_e4m3x4_neon_(hub_vec.f32x4).u32), 0);
1117
+ break;
1118
+ case nk_e5m2_k:
1119
+ vst1_lane_u32((nk_u32_t *)to_ptr, vcreate_u32(nk_f32x4_to_e5m2x4_neon_(hub_vec.f32x4).u32), 0);
1120
+ break;
1121
+ case nk_e2m3_k:
1122
+ vst1_lane_u32((nk_u32_t *)to_ptr, vcreate_u32(nk_f32x4_to_e2m3x4_neon_(hub_vec.f32x4).u32), 0);
1123
+ break;
1124
+ case nk_e3m2_k:
1125
+ vst1_lane_u32((nk_u32_t *)to_ptr, vcreate_u32(nk_f32x4_to_e3m2x4_neon_(hub_vec.f32x4).u32), 0);
1126
+ break;
1127
+ case nk_i32_k: vst1q_s32((nk_i32_t *)to_ptr, vcvtnq_s32_f32(hub_vec.f32x4)); break;
1128
+ case nk_u32_k: vst1q_u32((nk_u32_t *)to_ptr, vcvtnq_u32_f32(hub_vec.f32x4)); break;
1129
+ case nk_i16_k: vst1_s16((nk_i16_t *)to_ptr, nk_f32x4_to_i16x4_neon_(hub_vec.f32x4)); break;
1130
+ case nk_u16_k: vst1_u16((nk_u16_t *)to_ptr, nk_f32x4_to_u16x4_neon_(hub_vec.f32x4)); break;
1131
+ case nk_i8_k:
1132
+ vst1_lane_u32((nk_u32_t *)to_ptr, vcreate_u32(nk_f32x4_to_i8x4_neon_(hub_vec.f32x4).u32), 0);
1133
+ break;
1134
+ case nk_u8_k:
1135
+ vst1_lane_u32((nk_u32_t *)to_ptr, vcreate_u32(nk_f32x4_to_u8x4_neon_(hub_vec.f32x4).u32), 0);
1136
+ break;
1170
1137
  default: break;
1171
1138
  }
1172
1139
  }
@@ -1175,7 +1142,7 @@ NK_PUBLIC void nk_cast_neon(void const *from, nk_dtype_t from_type, nk_size_t n,
1175
1142
  if (tail) nk_cast_serial(from_ptr, from_type, tail, to_ptr, to_type);
1176
1143
  }
1177
1144
 
1178
- #pragma endregion - Public API
1145
+ #pragma endregion Public API
1179
1146
 
1180
1147
  #if defined(__clang__)
1181
1148
  #pragma clang attribute pop