numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,1192 @@
1
+ /**
2
+ * @brief SIMD-accelerated Type Conversions for NEON.
3
+ * @file include/numkong/cast/neon.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/cast.h
8
+ *
9
+ * @section neon_cast_instructions ARM NEON Conversion Instructions
10
+ *
11
+ * Float ↔ integer conversions (Cortex-A76 class):
12
+ *
13
+ * Intrinsic Instruction Latency Throughput
14
+ * vcvtq_f32_s32 SCVTF (V.4S, V.4S) 3cy 2/cy
15
+ * vcvtq_f32_u32 UCVTF (V.4S, V.4S) 3cy 2/cy
16
+ * vcvtq_s32_f32 FCVTZS (V.4S, V.4S) 3cy 2/cy
17
+ * vcvtq_u32_f32 FCVTZU (V.4S, V.4S) 3cy 2/cy
18
+ *
19
+ * Float precision conversions:
20
+ *
21
+ * Intrinsic Instruction Latency Throughput
22
+ * vcvt_f32_f16 FCVTL (V.4S, V.4H) 3cy 2/cy
23
+ * vcvt_f16_f32 FCVTN (V.4H, V.4S) 3cy 2/cy
24
+ * vcvt_f64_f32 FCVTL (V.2D, V.2S) 3cy 2/cy
25
+ * vcvt_f32_f64 FCVTN (V.2S, V.2D) 3cy 2/cy
26
+ *
27
+ * Integer narrowing with saturation:
28
+ *
29
+ * Intrinsic Instruction Latency Throughput
30
+ * vqmovn_s32 SQXTN (V.4H, V.4S) 3cy 2/cy
31
+ * vqmovn_u32 UQXTN (V.4H, V.4S) 3cy 2/cy
32
+ * vqmovun_s32 SQXTUN (V.4H, V.4S) 3cy 2/cy
33
+ *
34
+ * BF16 support (ARMv8.6-A+):
35
+ *
36
+ * Intrinsic Instruction Latency Throughput
37
+ * vcvtq_low_bf16_f32 BFCVTN (V.4H, V.4S) 3cy 1/cy
38
+ * vcvtq_high_bf16_f32 BFCVTN2 (V.8H, V.4S) 3cy 1/cy
39
+ *
40
+ * BF16 conversions on baseline NEON (emulated via bit shifts):
41
+ * - bf16 → f32: vmovl_u16 + vshlq_n_u32 by 16
42
+ * - f32 → bf16: round-to-nearest + vshrn_n_u32 by 16
43
+ *
44
+ * FP8 (E4M3/E5M2) conversions use NEON bit manipulation:
45
+ * - Field extraction: vandq, vshrq, vshlq
46
+ * - Blending: vbslq for conditional selection
47
+ * - Subnormal handling: vmulq_n_f32 with scale factors (1/512, 1/65536)
48
+ */
49
+ #ifndef NK_CAST_NEON_H
50
+ #define NK_CAST_NEON_H
51
+
52
+ #if NK_TARGET_ARM_
53
+ #if NK_TARGET_NEON
54
+
55
+ #include "numkong/types.h"
56
+ #include "numkong/cast/serial.h" // `nk_cast_serial`, `nk_dtype_bits`
57
+ #include "numkong/reduce/serial.h" // `nk_reduce_moments_f32_serial`
58
+
59
+ #if defined(__cplusplus)
60
+ extern "C" {
61
+ #endif
62
+
63
+ #if defined(__clang__)
64
+ #pragma clang attribute push(__attribute__((target("arch=armv8-a+simd"))), apply_to = function)
65
+ #elif defined(__GNUC__)
66
+ #pragma GCC push_options
67
+ #pragma GCC target("arch=armv8-a+simd")
68
+ #endif
69
+
70
+ NK_PUBLIC void nk_f16_to_f32_neon(nk_f16_t const *src, nk_f32_t *dest) {
71
+ float16x4_t f16vec = vreinterpret_f16_u16(vld1_dup_u16((nk_u16_t const *)src));
72
+ float32x4_t f32vec = vcvt_f32_f16(f16vec);
73
+ *dest = vgetq_lane_f32(f32vec, 0);
74
+ }
75
+
76
+ NK_PUBLIC void nk_f32_to_f16_neon(nk_f32_t const *src, nk_f16_t *dest) {
77
+ float32x4_t f32vec = vdupq_n_f32(*src);
78
+ float16x4_t f16vec = vcvt_f16_f32(f32vec);
79
+ vst1_lane_u16((nk_u16_t *)dest, vreinterpret_u16_f16(f16vec), 0);
80
+ }
81
+
82
+ #pragma region - Type Punned Loads and Stores
83
+
84
+ /** @brief Type-agnostic 128-bit full load (NEON). */
85
+ NK_INTERNAL void nk_load_b128_neon_(void const *src, nk_b128_vec_t *dst) {
86
+ dst->u8x16 = vld1q_u8((nk_u8_t const *)src);
87
+ }
88
+
89
+ /** @brief Type-agnostic 256-bit full load (NEON). */
90
+ NK_INTERNAL void nk_load_b256_neon_(void const *src, nk_b256_vec_t *dst) {
91
+ dst->u8x16s[0] = vld1q_u8((nk_u8_t const *)src);
92
+ dst->u8x16s[1] = vld1q_u8((nk_u8_t const *)src + 16);
93
+ }
94
+
95
+ /** @brief Type-agnostic 128-bit full store (NEON). */
96
+ NK_INTERNAL void nk_store_b128_neon_(nk_b128_vec_t const *src, void *dst) { vst1q_u8((nk_u8_t *)dst, src->u8x16); }
97
+
98
+ /** @brief Type-agnostic 256-bit full store (NEON). */
99
+ NK_INTERNAL void nk_store_b256_neon_(nk_b256_vec_t const *src, void *dst) {
100
+ vst1q_u8((nk_u8_t *)dst, src->u8x16s[0]);
101
+ vst1q_u8((nk_u8_t *)dst + 16, src->u8x16s[1]);
102
+ }
103
+
104
+ /** @brief Type-agnostic 64-bit full load (NEON). */
105
+ NK_INTERNAL void nk_load_b64_neon_(void const *src, nk_b64_vec_t *dst) { dst->u8x8 = vld1_u8((nk_u8_t const *)src); }
106
+
107
+ #pragma endregion - Type Punned Loads and Stores
108
+
109
+ #pragma region - Vectorized Conversions
110
+
111
+ /** @brief Convert 4x e4m3 → f32x4 via bit manipulation (NEON).
112
+ * E4M3FN format: S EEEE MMM (bias=7). No ∞ representation.
113
+ * Only exp=15, mant=7 (0x7F) is NaN; exp=15, mant ∈ [0,6] are valid normals (max=448). */
114
+ NK_INTERNAL float32x4_t nk_e4m3x4_to_f32x4_neon_(nk_b32_vec_t src) {
115
+ uint8x8_t e4m3_u8x8 = vcreate_u8(src.u32);
116
+ uint16x8_t e4m3_u16x8 = vmovl_u8(e4m3_u8x8);
117
+ uint32x4_t e4m3_u32x4 = vmovl_u16(vget_low_u16(e4m3_u16x8));
118
+ uint32x4_t sign_u32x4 = vshlq_n_u32(vandq_u32(e4m3_u32x4, vdupq_n_u32(0x80)), 24);
119
+ uint32x4_t exp_u32x4 = vandq_u32(vshrq_n_u32(e4m3_u32x4, 3), vdupq_n_u32(0x0F));
120
+ uint32x4_t mant_u32x4 = vandq_u32(e4m3_u32x4, vdupq_n_u32(0x07));
121
+
122
+ // Normal path: f32 = sign | ((exp+120)<<23) | (mant<<20)
123
+ uint32x4_t f32_exp_u32x4 = vshlq_n_u32(vaddq_u32(exp_u32x4, vdupq_n_u32(120)), 23);
124
+ uint32x4_t f32_mant_u32x4 = vshlq_n_u32(mant_u32x4, 20);
125
+ uint32x4_t normal_u32x4 = vorrq_u32(sign_u32x4, vorrq_u32(f32_exp_u32x4, f32_mant_u32x4));
126
+
127
+ // Subnormal path (exp=0, mant ≠ 0): value = ±mantissa × 2⁻⁹
128
+ float32x4_t subnormal_f32x4 = vmulq_n_f32(vcvtq_f32_u32(mant_u32x4), 1.0f / 512.0f);
129
+ uint32x4_t subnormal_u32x4 = vorrq_u32(vreinterpretq_u32_f32(subnormal_f32x4), sign_u32x4);
130
+
131
+ // NaN path: E4M3FN only has NaN when exp=15 AND mant=7 (0x7F or 0xFF)
132
+ uint32x4_t nan_u32x4 = vorrq_u32(sign_u32x4, vdupq_n_u32(0x7FC00000));
133
+ uint32x4_t is_nan_mask = vandq_u32(vceqq_u32(exp_u32x4, vdupq_n_u32(15)), vceqq_u32(mant_u32x4, vdupq_n_u32(7)));
134
+
135
+ // Blend paths: subnormal when exp=0, NaN when exp=15 && mant=7, else normal
136
+ uint32x4_t exp_zero_mask = vceqq_u32(exp_u32x4, vdupq_n_u32(0));
137
+ uint32x4_t result_u32x4 = vbslq_u32(exp_zero_mask, subnormal_u32x4, normal_u32x4);
138
+ result_u32x4 = vbslq_u32(is_nan_mask, nan_u32x4, result_u32x4);
139
+ return vreinterpretq_f32_u32(result_u32x4);
140
+ }
141
+
142
+ /** @brief Convert 4x e5m2 → f32x4 via bit manipulation (NEON).
143
+ * E5M2 format: S EEEEE MM (bias=15). F32: sign<<31, (exp+112)<<23, mant<<21.
144
+ * Handles subnormals (exp=0, mant ≠ 0), inf (exp=31, mant=0), and nan (exp=31, mant ≠ 0). */
145
+ NK_INTERNAL float32x4_t nk_e5m2x4_to_f32x4_neon_(nk_b32_vec_t src) {
146
+ uint8x8_t e5m2_u8x8 = vcreate_u8(src.u32);
147
+ uint16x8_t e5m2_u16x8 = vmovl_u8(e5m2_u8x8);
148
+ uint32x4_t e5m2_u32x4 = vmovl_u16(vget_low_u16(e5m2_u16x8));
149
+ uint32x4_t sign_u32x4 = vshlq_n_u32(vandq_u32(e5m2_u32x4, vdupq_n_u32(0x80)), 24);
150
+ uint32x4_t exp_u32x4 = vandq_u32(vshrq_n_u32(e5m2_u32x4, 2), vdupq_n_u32(0x1F));
151
+ uint32x4_t mant_u32x4 = vandq_u32(e5m2_u32x4, vdupq_n_u32(0x03));
152
+
153
+ // Normal path: f32 = sign | ((exp+112)<<23) | (mant<<21)
154
+ uint32x4_t f32_exp_u32x4 = vshlq_n_u32(vaddq_u32(exp_u32x4, vdupq_n_u32(112)), 23);
155
+ uint32x4_t f32_mant_u32x4 = vshlq_n_u32(mant_u32x4, 21);
156
+ uint32x4_t normal_u32x4 = vorrq_u32(sign_u32x4, vorrq_u32(f32_exp_u32x4, f32_mant_u32x4));
157
+
158
+ // Subnormal path (exp=0, mant ≠ 0): value = ±mantissa × 2⁻¹⁶
159
+ float32x4_t subnormal_f32x4 = vmulq_n_f32(vcvtq_f32_u32(mant_u32x4), 1.0f / 65536.0f);
160
+ uint32x4_t subnormal_u32x4 = vorrq_u32(vreinterpretq_u32_f32(subnormal_f32x4), sign_u32x4);
161
+
162
+ // Special path (exp=31): inf (mant=0) or nan (mant≠0)
163
+ uint32x4_t infinity_u32x4 = vorrq_u32(sign_u32x4, vdupq_n_u32(0x7F800000));
164
+ uint32x4_t nan_u32x4 = vorrq_u32(sign_u32x4, vdupq_n_u32(0x7FC00000));
165
+ uint32x4_t mant_zero_mask = vceqq_u32(mant_u32x4, vdupq_n_u32(0));
166
+ uint32x4_t special_u32x4 = vbslq_u32(mant_zero_mask, infinity_u32x4, nan_u32x4);
167
+
168
+ // Blend paths based on exponent value
169
+ uint32x4_t exp_zero_mask = vceqq_u32(exp_u32x4, vdupq_n_u32(0));
170
+ uint32x4_t exp_max_mask = vceqq_u32(exp_u32x4, vdupq_n_u32(31));
171
+ uint32x4_t result_u32x4 = vbslq_u32(exp_zero_mask, subnormal_u32x4, normal_u32x4);
172
+ result_u32x4 = vbslq_u32(exp_max_mask, special_u32x4, result_u32x4);
173
+ return vreinterpretq_f32_u32(result_u32x4);
174
+ }
175
+
176
+ /** @brief Convert 8x e4m3 → f16x8 via bit manipulation (NEON).
177
+ * E4M3FN format: S EEEE MMM (bias=7). F16: S EEEEE MMMMMMMMMM (bias=15).
178
+ * E4M3FN has no ∞; only exp=15, mant=7 is NaN. exp=15, mant ∈ [0,6] are valid normals. */
179
+ NK_INTERNAL float16x8_t nk_e4m3x8_to_f16x8_neon_(uint8x8_t e4m3_u8x8) {
180
+ uint16x8_t e4m3_u16x8 = vmovl_u8(e4m3_u8x8);
181
+ uint16x8_t sign_u16x8 = vshlq_n_u16(vandq_u16(e4m3_u16x8, vdupq_n_u16(0x80)), 8); // sign << 15
182
+ uint16x8_t exp_u16x8 = vandq_u16(vshrq_n_u16(e4m3_u16x8, 3), vdupq_n_u16(0x0F));
183
+ uint16x8_t mant_u16x8 = vandq_u16(e4m3_u16x8, vdupq_n_u16(0x07));
184
+
185
+ // Normal path: F16_exp = E4M3_exp + 8, F16_mant = E4M3_mant << 7
186
+ uint16x8_t f16_exp_u16x8 = vshlq_n_u16(vaddq_u16(exp_u16x8, vdupq_n_u16(8)), 10);
187
+ uint16x8_t f16_mant_u16x8 = vshlq_n_u16(mant_u16x8, 7);
188
+ uint16x8_t normal_u16x8 = vorrq_u16(sign_u16x8, vorrq_u16(f16_exp_u16x8, f16_mant_u16x8));
189
+
190
+ // Subnormal path (exp=0, mant ≠ 0): E4M3 subnormal value = mant × 2⁻⁹ = mant ÷ 512
191
+ // Compute arithmetically: mant → f32 → multiply → f16
192
+ float32x4_t subnormal_low_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(mant_u16x8))), 1.0f / 512.0f);
193
+ float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(mant_u16x8))), 1.0f / 512.0f);
194
+ uint16x8_t subnormal_abs_u16x8 = vreinterpretq_u16_f16(
195
+ vcombine_f16(vcvt_f16_f32(subnormal_low_f32x4), vcvt_f16_f32(subnormal_high_f32x4)));
196
+ uint16x8_t subnormal_u16x8 = vorrq_u16(subnormal_abs_u16x8, sign_u16x8);
197
+
198
+ // NaN path: E4M3FN only has NaN when exp=15 AND mant=7 (0x7F or 0xFF)
199
+ uint16x8_t nan_u16x8 = vorrq_u16(sign_u16x8, vdupq_n_u16(0x7E00)); // F16 quiet NaN
200
+ uint16x8_t is_nan_mask = vandq_u16(vceqq_u16(exp_u16x8, vdupq_n_u16(15)), vceqq_u16(mant_u16x8, vdupq_n_u16(7)));
201
+
202
+ // Blend paths: subnormal when exp=0, NaN when exp=15 && mant=7, else normal
203
+ uint16x8_t exp_zero_mask = vceqq_u16(exp_u16x8, vdupq_n_u16(0));
204
+ uint16x8_t result_u16x8 = vbslq_u16(exp_zero_mask, subnormal_u16x8, normal_u16x8);
205
+ result_u16x8 = vbslq_u16(is_nan_mask, nan_u16x8, result_u16x8);
206
+ return vreinterpretq_f16_u16(result_u16x8);
207
+ }
208
+
209
+ /** @brief Convert 16x e4m3 → 2x f16x8 via TBL lookup (NEON).
210
+ * E4M3FN format: S EEEE MMM (bias=7) → F16: S EEEEE MMMMMMMMMM (bias=15).
211
+ * Uses sign symmetry: negative LUT entries = positive XOR 0x80, so we strip sign,
212
+ * lookup 7-bit absolute value in 2× VQTBL4 (128 bytes), then OR sign back.
213
+ * Arithmetic for the low byte: (has_exp && !nan) ? (lsb << 7) : 0.
214
+ * Exact for all 256 input values including subnormals and NaN.
215
+ *
216
+ * Performance (per 16 elements): ~10-12 instructions vs ~40 for 2× nk_e4m3x8_to_f16x8_neon_ */
217
+ NK_INTERNAL void nk_e4m3x16_to_f16x8x2_neon_(uint8x16_t input_u8x16, float16x8_t *result_low_f16x8,
218
+ float16x8_t *result_high_f16x8) {
219
+ // Precomputed LUT: F16 high byte for unsigned 7-bit e4m3 values (sign handled separately).
220
+ // Quadrant 0: |e4m3| bytes 0x00..0x3F (exp=0..7)
221
+ static nk_u8_t const table_q0_u8x64[64] = {
222
+ 0x00, 0x18, 0x1C, 0x1E, 0x20, 0x21, 0x22, 0x23, 0x24, 0x24, 0x25, 0x25, 0x26, 0x26, 0x27, 0x27,
223
+ 0x28, 0x28, 0x29, 0x29, 0x2A, 0x2A, 0x2B, 0x2B, 0x2C, 0x2C, 0x2D, 0x2D, 0x2E, 0x2E, 0x2F, 0x2F,
224
+ 0x30, 0x30, 0x31, 0x31, 0x32, 0x32, 0x33, 0x33, 0x34, 0x34, 0x35, 0x35, 0x36, 0x36, 0x37, 0x37,
225
+ 0x38, 0x38, 0x39, 0x39, 0x3A, 0x3A, 0x3B, 0x3B, 0x3C, 0x3C, 0x3D, 0x3D, 0x3E, 0x3E, 0x3F, 0x3F,
226
+ };
227
+ // Quadrant 1: |e4m3| bytes 0x40..0x7F (exp=8..15)
228
+ static nk_u8_t const table_q1_u8x64[64] = {
229
+ 0x40, 0x40, 0x41, 0x41, 0x42, 0x42, 0x43, 0x43, 0x44, 0x44, 0x45, 0x45, 0x46, 0x46, 0x47, 0x47,
230
+ 0x48, 0x48, 0x49, 0x49, 0x4A, 0x4A, 0x4B, 0x4B, 0x4C, 0x4C, 0x4D, 0x4D, 0x4E, 0x4E, 0x4F, 0x4F,
231
+ 0x50, 0x50, 0x51, 0x51, 0x52, 0x52, 0x53, 0x53, 0x54, 0x54, 0x55, 0x55, 0x56, 0x56, 0x57, 0x57,
232
+ 0x58, 0x58, 0x59, 0x59, 0x5A, 0x5A, 0x5B, 0x5B, 0x5C, 0x5C, 0x5D, 0x5D, 0x5E, 0x5E, 0x5F, 0x7E,
233
+ };
234
+
235
+ uint8x16x4_t lut_q0 = vld1q_u8_x4(table_q0_u8x64);
236
+ uint8x16x4_t lut_q1 = vld1q_u8_x4(table_q1_u8x64);
237
+
238
+ // Strip sign bit, work with 7-bit absolute value
239
+ uint8x16_t sign_u8x16 = vandq_u8(input_u8x16, vdupq_n_u8(0x80));
240
+ uint8x16_t abs_u8x16 = vandq_u8(input_u8x16, vdupq_n_u8(0x7F));
241
+
242
+ // High byte via 2× VQTBL4 on unsigned index, then OR sign back.
243
+ // VQTBL4 returns 0 for out-of-range indices (>= 64), so results OR together cleanly.
244
+ uint8x16_t high_q0_u8x16 = vqtbl4q_u8(lut_q0, abs_u8x16);
245
+ uint8x16_t offset_q1_u8x16 = vsubq_u8(abs_u8x16, vdupq_n_u8(64));
246
+ uint8x16_t high_q1_u8x16 = vqtbl4q_u8(lut_q1, offset_q1_u8x16);
247
+ uint8x16_t high_bytes_u8x16 = vorrq_u8(vorrq_u8(high_q0_u8x16, high_q1_u8x16), sign_u8x16);
248
+
249
+ // Low byte: (lsb << 7), masked to 0 for subnormals (exp=0) and NaN (exp=15, mant=7)
250
+ uint8x16_t lsb_shifted_u8x16 = vshlq_n_u8(input_u8x16, 7);
251
+ uint8x16_t exponent_bits_u8x16 = vandq_u8(input_u8x16, vdupq_n_u8(0x78));
252
+ uint8x16_t has_exp_u8x16 = vcgtq_u8(exponent_bits_u8x16, vdupq_n_u8(0));
253
+ // Exclude NaN: |input| == 0x7F means exp=15, mant=7
254
+ uint8x16_t abs_input_u8x16 = vandq_u8(input_u8x16, vdupq_n_u8(0x7F));
255
+ uint8x16_t not_nan_u8x16 = vmvnq_u8(vceqq_u8(abs_input_u8x16, vdupq_n_u8(0x7F)));
256
+ uint8x16_t low_bytes_u8x16 = vandq_u8(lsb_shifted_u8x16, vandq_u8(has_exp_u8x16, not_nan_u8x16));
257
+
258
+ // ZIP to interleave: [l0,l1..l15] + [h0,h1..h15] → [l0,h0,l1,h1..] (little-endian f16)
259
+ uint8x16x2_t interleaved_u8x16x2 = vzipq_u8(low_bytes_u8x16, high_bytes_u8x16);
260
+ *result_low_f16x8 = vreinterpretq_f16_u8(interleaved_u8x16x2.val[0]);
261
+ *result_high_f16x8 = vreinterpretq_f16_u8(interleaved_u8x16x2.val[1]);
262
+ }
263
+
264
+ /** @brief Convert 8x e5m2 → f16x8 via bit shift (NEON).
265
+ * E5M2 (bias=15) and F16 (bias=15) share the same exponent bias, so conversion is trivial.
266
+ * E5M2: S EEEEE MM → F16: S EEEEE MM 00000000. Works for all: zero, subnormal, normal, inf, nan. */
267
+ NK_INTERNAL float16x8_t nk_e5m2x8_to_f16x8_neon_(uint8x8_t e5m2_u8x8) {
268
+ uint16x8_t e5m2_u16x8 = vmovl_u8(e5m2_u8x8);
269
+ return vreinterpretq_f16_u16(vshlq_n_u16(e5m2_u16x8, 8));
270
+ }
271
+
272
+ /** @brief Convert 8x e2m3 → f16x8 via direct bit manipulation (NEON).
273
+ * E2M3FN (FP6): S EE MMM (bias=1) → F16: S EEEEE MMMMMMMMMM (bias=15).
274
+ * Handles subnormals (exp=0) via arithmetic conversion. No Inf/NaN in E2M3FN. */
275
+ NK_INTERNAL float16x8_t nk_e2m3x8_to_f16x8_neon_(uint8x8_t e2m3_u8x8) {
276
+ // Widen to 16-bit for NEON operations
277
+ uint16x8_t e2m3_u16x8 = vmovl_u8(e2m3_u8x8);
278
+
279
+ // Extract fields: format is 0b00SEEMMM (6 bits used)
280
+ uint16x8_t sign_u16x8 = vshlq_n_u16(vandq_u16(e2m3_u16x8, vdupq_n_u16(0x20)), 10); // sign << 15
281
+ uint16x8_t exp_u16x8 = vandq_u16(vshrq_n_u16(e2m3_u16x8, 3), vdupq_n_u16(0x03)); // 2-bit exp
282
+ uint16x8_t mant_u16x8 = vandq_u16(e2m3_u16x8, vdupq_n_u16(0x07)); // 3-bit mant
283
+
284
+ // Normal path: F16_exp = E2M3_exp + 14, F16_mant = E2M3_mant << 7
285
+ uint16x8_t exp_rebiased_u16x8 = vaddq_u16(exp_u16x8, vdupq_n_u16(14));
286
+ uint16x8_t exp_positioned_u16x8 = vshlq_n_u16(exp_rebiased_u16x8, 10);
287
+ uint16x8_t mant_positioned_u16x8 = vshlq_n_u16(mant_u16x8, 7);
288
+ uint16x8_t normal_u16x8 = vorrq_u16(sign_u16x8, vorrq_u16(exp_positioned_u16x8, mant_positioned_u16x8));
289
+
290
+ // Subnormal path (exp=0): E2M3 subnormal = mant / 8
291
+ // Compute via f32: mant → f32 → multiply → f16
292
+ float32x4_t subnormal_low_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(mant_u16x8))), 1.0f / 8.0f);
293
+ float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(mant_u16x8))), 1.0f / 8.0f);
294
+ uint16x8_t subnormal_abs_u16x8 = vreinterpretq_u16_f16(
295
+ vcombine_f16(vcvt_f16_f32(subnormal_low_f32x4), vcvt_f16_f32(subnormal_high_f32x4)));
296
+ uint16x8_t subnormal_u16x8 = vorrq_u16(subnormal_abs_u16x8, sign_u16x8);
297
+
298
+ // Blend: use subnormal result when exp=0, else normal
299
+ uint16x8_t exp_zero_mask = vceqq_u16(exp_u16x8, vdupq_n_u16(0));
300
+ uint16x8_t result_u16x8 = vbslq_u16(exp_zero_mask, subnormal_u16x8, normal_u16x8);
301
+
302
+ return vreinterpretq_f16_u16(result_u16x8);
303
+ }
304
+
305
+ /** @brief Convert 8x e3m2 → f16x8 via direct bit manipulation (NEON).
306
+ * E3M2FN (FP6): S EEE MM (bias=3) → F16: S EEEEE MMMMMMMMMM (bias=15).
307
+ * Handles subnormals (exp=0) via arithmetic conversion. No Inf/NaN in E3M2FN. */
308
+ NK_INTERNAL float16x8_t nk_e3m2x8_to_f16x8_neon_(uint8x8_t e3m2_u8x8) {
309
+ // Widen to 16-bit for NEON operations
310
+ uint16x8_t e3m2_u16x8 = vmovl_u8(e3m2_u8x8);
311
+
312
+ // Extract fields: format is 0b00SEEEMM (6 bits used)
313
+ uint16x8_t sign_u16x8 = vshlq_n_u16(vandq_u16(e3m2_u16x8, vdupq_n_u16(0x20)), 10); // sign << 15
314
+ uint16x8_t exp_u16x8 = vandq_u16(vshrq_n_u16(e3m2_u16x8, 2), vdupq_n_u16(0x07)); // 3-bit exp
315
+ uint16x8_t mant_u16x8 = vandq_u16(e3m2_u16x8, vdupq_n_u16(0x03)); // 2-bit mant
316
+
317
+ // Normal path: F16_exp = E3M2_exp + 12, F16_mant = E3M2_mant << 8
318
+ uint16x8_t exp_rebiased_u16x8 = vaddq_u16(exp_u16x8, vdupq_n_u16(12));
319
+ uint16x8_t exp_positioned_u16x8 = vshlq_n_u16(exp_rebiased_u16x8, 10);
320
+ uint16x8_t mant_positioned_u16x8 = vshlq_n_u16(mant_u16x8, 8);
321
+ uint16x8_t normal_u16x8 = vorrq_u16(sign_u16x8, vorrq_u16(exp_positioned_u16x8, mant_positioned_u16x8));
322
+
323
+ // Subnormal path (exp=0): E3M2 subnormal = mant × 2^(-2) × (1/4) = mant / 16
324
+ // Compute via f32: mant → f32 → multiply → f16
325
+ float32x4_t subnormal_low_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(mant_u16x8))), 1.0f / 16.0f);
326
+ float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(mant_u16x8))), 1.0f / 16.0f);
327
+ uint16x8_t subnormal_abs_u16x8 = vreinterpretq_u16_f16(
328
+ vcombine_f16(vcvt_f16_f32(subnormal_low_f32x4), vcvt_f16_f32(subnormal_high_f32x4)));
329
+ uint16x8_t subnormal_u16x8 = vorrq_u16(subnormal_abs_u16x8, sign_u16x8);
330
+
331
+ // Blend: use subnormal result when exp=0, else normal
332
+ uint16x8_t exp_zero_mask = vceqq_u16(exp_u16x8, vdupq_n_u16(0));
333
+ uint16x8_t result_u16x8 = vbslq_u16(exp_zero_mask, subnormal_u16x8, normal_u16x8);
334
+
335
+ return vreinterpretq_f16_u16(result_u16x8);
336
+ }
337
+
338
+ /** @brief Convert 16x e2m3 → 2x f16x8 via TBL lookup (NEON).
339
+ * E2M3FN (FP6): S EE MMM (bias=1) → F16: S EEEEE MMMMMMMMMM (bias=15).
340
+ * Uses precomputed lookup tables for 64 possible 6-bit values.
341
+ * VQTBL4 byte shuffle (p01) + VZIP interleave (p01) (~6 instructions, parallel execution) */
342
+ NK_INTERNAL void nk_e2m3x16_to_f16x8x2_neon_(uint8x16_t input_u8x16, float16x8_t *result_low_f16x8,
343
+ float16x8_t *result_high_f16x8) {
344
+ // E2M3FN → F16 conversion using TBL for high byte, arithmetic for low byte.
345
+ // Sign symmetry: negative half = positive XOR 0x80, so we strip sign (bit 5),
346
+ // lookup unsigned 5-bit value in VQTBL2 (32 bytes), then OR sign back.
347
+ // E2M3FN: sign(1) exp(2) mant(3), bias=1
348
+ // F16: sign(1) exp(5) mant(10), bias=15
349
+ // Normal (exp!=0): f16 = (sign << 15) | ((exp + 14) << 10) | (mant << 7)
350
+ // Subnormal (exp=0): f16 = mant/8 converted to f16
351
+ //
352
+ // Low byte pattern: E2M3 has 3 mantissa bits → f16 bits 9-7, so low byte (bits 7-0) is:
353
+ // - Subnormals (exp=0): always 0x00
354
+ // - Normals (exp≠0): (mant & 1) << 7 = 0x00 or 0x80
355
+ // This simple pattern can be computed arithmetically, saving table registers!
356
+ static nk_u8_t const table_high_u8x32[32] = {
357
+ 0x00, 0x30, 0x34, 0x36, 0x38, 0x39, 0x3A, 0x3B, // exp=0 (subnormals)
358
+ 0x3C, 0x3C, 0x3D, 0x3D, 0x3E, 0x3E, 0x3F, 0x3F, // exp=1 → f16_exp=15 (0x3C-0x3F)
359
+ 0x40, 0x40, 0x41, 0x41, 0x42, 0x42, 0x43, 0x43, // exp=2 → f16_exp=16 (0x40-0x43)
360
+ 0x44, 0x44, 0x45, 0x45, 0x46, 0x46, 0x47, 0x47, // exp=3 → f16_exp=17 (0x44-0x47)
361
+ };
362
+
363
+ // Load unsigned high byte table (2 registers instead of 4)
364
+ uint8x16x2_t table_high_u8x16x2 = vld1q_u8_x2(table_high_u8x32);
365
+
366
+ // Extract sign (bit 5 → bit 7 for F16 high byte) and 5-bit unsigned index
367
+ uint8x16_t sign_u8x16 = vshlq_n_u8(vandq_u8(input_u8x16, vdupq_n_u8(0x20)), 2);
368
+ uint8x16_t abs_indices_u8x16 = vandq_u8(input_u8x16, vdupq_n_u8(0x1F));
369
+
370
+ // High bytes via VQTBL2 on unsigned index, then OR sign back
371
+ uint8x16_t high_bytes_u8x16 = vorrq_u8(vqtbl2q_u8(table_high_u8x16x2, abs_indices_u8x16), sign_u8x16);
372
+
373
+ // Low bytes via arithmetic: (exp != 0) ? (bit0 << 7) : 0
374
+ // This uses shift/logic ports instead of permute port, and frees 4 registers
375
+ uint8x16_t shifted_u8x16 = vshlq_n_u8(input_u8x16, 7); // bit 0 → bit 7
376
+ uint8x16_t exp_bits_u8x16 = vandq_u8(input_u8x16, vdupq_n_u8(0x18)); // isolate exp (bits 4-3)
377
+ uint8x16_t is_normal_u8x16 = vcgtq_u8(exp_bits_u8x16, vdupq_n_u8(0)); // 0xFF if exp≠0
378
+ uint8x16_t low_bytes_u8x16 = vandq_u8(shifted_u8x16, is_normal_u8x16); // mask off subnormals
379
+
380
+ // ZIP to interleave bytes into uint16 values: [l0,l1...l15] + [h0,h1...h15] → [l0,h0,l1,h1...]
381
+ uint8x16x2_t interleaved_u8x16x2 = vzipq_u8(low_bytes_u8x16, high_bytes_u8x16);
382
+
383
+ *result_low_f16x8 = vreinterpretq_f16_u8(interleaved_u8x16x2.val[0]); // elements 0-7
384
+ *result_high_f16x8 = vreinterpretq_f16_u8(interleaved_u8x16x2.val[1]); // elements 8-15
385
+ }
386
+
387
+ /** @brief Convert 16x e3m2 → 2x f16x8 via TBL lookup (NEON).
388
+ * E3M2FN (FP6): S EEE MM (bias=3) → F16: S EEEEE MMMMMMMMMM (bias=15).
389
+ * Uses precomputed lookup tables for 64 possible 6-bit values.
390
+ *
391
+ * Performance (per 16 elements):
392
+ * VLD1Q x4 (4 loads) 4 cy latency, 1 cy throughput each
393
+ * VAND (mask) 2 cy latency, 0.5 cy throughput
394
+ * VQTBL4 (table lookup) 3 cy latency, 1 cy throughput
395
+ * VZIP (interleave) 2 cy latency, 1 cy throughput
396
+ * Total: ~6-8 cy latency, ~2.5 cy amortized throughput (dominated by table lookup + zip) */
397
+ NK_INTERNAL void nk_e3m2x16_to_f16x8x2_neon_(uint8x16_t input_u8x16, float16x8_t *result_low_f16x8,
398
+ float16x8_t *result_high_f16x8) {
399
+ // Precomputed lookup table for E3M2FN → F16 conversion (high byte only, unsigned).
400
+ // Sign symmetry: negative half = positive XOR 0x80, so we strip sign (bit 5),
401
+ // lookup unsigned 5-bit value in VQTBL2 (32 bytes), then OR sign back.
402
+ // E3M2FN: sign(1) exp(3) mant(2), bias=3
403
+ // F16: sign(1) exp(5) mant(10), bias=15
404
+ // Normal (exp!=0): f16 = (sign << 15) | ((exp + 12) << 10) | (mant << 8)
405
+ // Subnormal (exp=0): f16 = mant/16 converted to f16
406
+ // E3M2 mantissa (2 bits) maps to F16 bits [9:8]; F16 bits [7:0] = 0
407
+ static nk_u8_t const table_high_u8x32[32] = {
408
+ 0x00, 0x2C, 0x30, 0x32, // exp=0 (subnormals): 0, 1/16, 2/16, 3/16
409
+ 0x34, 0x35, 0x36, 0x37, // exp=1 → f16_exp=13 (0x34-0x37)
410
+ 0x38, 0x39, 0x3A, 0x3B, // exp=2 → f16_exp=14 (0x38-0x3B)
411
+ 0x3C, 0x3D, 0x3E, 0x3F, // exp=3 → f16_exp=15 (0x3C-0x3F)
412
+ 0x40, 0x41, 0x42, 0x43, // exp=4 → f16_exp=16 (0x40-0x43)
413
+ 0x44, 0x45, 0x46, 0x47, // exp=5 → f16_exp=17 (0x44-0x47)
414
+ 0x48, 0x49, 0x4A, 0x4B, // exp=6 → f16_exp=18 (0x48-0x4B)
415
+ 0x4C, 0x4D, 0x4E, 0x4F, // exp=7 → f16_exp=19 (0x4C-0x4F)
416
+ };
417
+
418
+ // Load unsigned high byte table (2 registers instead of 4 - low bytes are always zero!)
419
+ uint8x16x2_t table_high_u8x16x2 = vld1q_u8_x2(table_high_u8x32);
420
+
421
+ // Extract sign (bit 5 → bit 7 for F16 high byte) and 5-bit unsigned index
422
+ uint8x16_t sign_u8x16 = vshlq_n_u8(vandq_u8(input_u8x16, vdupq_n_u8(0x20)), 2);
423
+ uint8x16_t abs_indices_u8x16 = vandq_u8(input_u8x16, vdupq_n_u8(0x1F));
424
+
425
+ // Table lookup for high bytes on unsigned index, then OR sign back
426
+ uint8x16_t high_bytes_u8x16 = vorrq_u8(vqtbl2q_u8(table_high_u8x16x2, abs_indices_u8x16), sign_u8x16);
427
+
428
+ // ZIP zeros with high bytes: [0,0...0] + [h0,h1...h15] → [0,h0,0,h1...]
429
+ uint8x16x2_t interleaved_u8x16x2 = vzipq_u8(vdupq_n_u8(0), high_bytes_u8x16);
430
+
431
+ *result_low_f16x8 = vreinterpretq_f16_u8(interleaved_u8x16x2.val[0]); // elements 0-7
432
+ *result_high_f16x8 = vreinterpretq_f16_u8(interleaved_u8x16x2.val[1]); // elements 8-15
433
+ }
434
+
435
+ /** @brief Convert f16x8 → 8x e4m3 with RNE rounding (NEON).
436
+ * F16: S EEEEE MMMMMMMMMM (bias=15) → E4M3: S EEEE MMM (bias=7).
437
+ * Handles subnormals (exp < 9 → E4M3 subnormal), overflow (> 448 → clamp), inf → max, nan → nan. */
438
+ NK_INTERNAL uint8x8_t nk_f16x8_to_e4m3x8_neon_(float16x8_t f16x8) {
439
+ uint16x8_t bits_u16x8 = vreinterpretq_u16_f16(f16x8);
440
+ uint16x8_t sign_byte_u16x8 = vshrq_n_u16(vandq_u16(bits_u16x8, vdupq_n_u16(0x8000)), 8);
441
+ uint16x8_t f16_exp_u16x8 = vandq_u16(vshrq_n_u16(bits_u16x8, 10), vdupq_n_u16(0x1F));
442
+ uint16x8_t f16_mant_u16x8 = vandq_u16(bits_u16x8, vdupq_n_u16(0x03FF));
443
+
444
+ // Rebias exponent: F16 bias=15 → E4M3 bias=7, subtract 8
445
+ int16x8_t e4m3_exp_s16x8 = vsubq_s16(vreinterpretq_s16_u16(f16_exp_u16x8), vdupq_n_s16(8));
446
+
447
+ // Detect special cases
448
+ uint16x8_t is_f16_zero = vceqq_u16(vandq_u16(bits_u16x8, vdupq_n_u16(0x7FFF)), vdupq_n_u16(0));
449
+ uint16x8_t is_f16_special = vceqq_u16(f16_exp_u16x8, vdupq_n_u16(31)); // inf or nan
450
+ uint16x8_t is_f16_nan = vandq_u16(is_f16_special, vcgtq_u16(f16_mant_u16x8, vdupq_n_u16(0)));
451
+ uint16x8_t is_underflow = vcltq_s16(e4m3_exp_s16x8, vdupq_n_s16(1)); // exp < 1 → subnormal/zero
452
+ uint16x8_t is_overflow = vcgtq_s16(e4m3_exp_s16x8, vdupq_n_s16(15)); // exp > 15 → overflow
453
+
454
+ // Normal path with RNE rounding: round mantissa from 10 to 3 bits
455
+ // RNE: add (0x3F + lsb) where lsb = bit 7 of mantissa
456
+ uint16x8_t lsb_u16x8 = vandq_u16(vshrq_n_u16(f16_mant_u16x8, 7), vdupq_n_u16(1));
457
+ uint16x8_t rounded_mant_u16x8 = vaddq_u16(f16_mant_u16x8, vaddq_u16(vdupq_n_u16(0x3F), lsb_u16x8));
458
+ uint16x8_t carry_u16x8 = vshrq_n_u16(rounded_mant_u16x8, 10); // Mantissa overflow → carry to exponent
459
+ e4m3_exp_s16x8 = vaddq_s16(e4m3_exp_s16x8, vreinterpretq_s16_u16(carry_u16x8));
460
+ uint16x8_t e4m3_mant_u16x8 = vandq_u16(vshrq_n_u16(rounded_mant_u16x8, 7), vdupq_n_u16(0x07));
461
+ e4m3_mant_u16x8 = vbicq_u16(e4m3_mant_u16x8, vceqq_u16(carry_u16x8, vdupq_n_u16(1))); // Clear mant if carry
462
+
463
+ // Recheck overflow after rounding (carry might have pushed us over)
464
+ is_overflow = vorrq_u16(is_overflow, vcgtq_s16(e4m3_exp_s16x8, vdupq_n_s16(15)));
465
+
466
+ // Clamp exponent to [1, 15] for normal values
467
+ int16x8_t clamped_exp_s16x8 = vmaxq_s16(e4m3_exp_s16x8, vdupq_n_s16(1));
468
+ clamped_exp_s16x8 = vminq_s16(clamped_exp_s16x8, vdupq_n_s16(15));
469
+
470
+ // E4M3FN quirk: exp=15, mant=7 is NaN, so clamp mantissa to 6 when exp=15
471
+ uint16x8_t is_max_exp = vceqq_s16(clamped_exp_s16x8, vdupq_n_s16(15));
472
+ e4m3_mant_u16x8 = vbslq_u16(is_max_exp, vminq_u16(e4m3_mant_u16x8, vdupq_n_u16(6)), e4m3_mant_u16x8);
473
+
474
+ // Assemble normal result
475
+ uint16x8_t normal_result_u16x8 = vorrq_u16(
476
+ sign_byte_u16x8, vorrq_u16(vshlq_n_u16(vreinterpretq_u16_s16(clamped_exp_s16x8), 3), e4m3_mant_u16x8));
477
+
478
+ // Subnormal path: E4M3 subnormal = mant × 2⁻⁹
479
+ // Use float conversion for correctness: abs(f16) × 512, round to int, clamp to [0,7]
480
+ float32x4_t abs_low_f32x4 = vabsq_f32(vcvt_f32_f16(vget_low_f16(f16x8)));
481
+ float32x4_t abs_high_f32x4 = vabsq_f32(vcvt_f32_f16(vget_high_f16(f16x8)));
482
+ float32x4_t scaled_low_f32x4 = vmulq_n_f32(abs_low_f32x4, 512.0f);
483
+ float32x4_t scaled_high_f32x4 = vmulq_n_f32(abs_high_f32x4, 512.0f);
484
+ int32x4_t subnormal_mantissa_low_i32x4 = vcvtnq_s32_f32(scaled_low_f32x4); // Round to nearest even
485
+ int32x4_t subnormal_mantissa_high_i32x4 = vcvtnq_s32_f32(scaled_high_f32x4);
486
+ subnormal_mantissa_low_i32x4 = vmaxq_s32(vminq_s32(subnormal_mantissa_low_i32x4, vdupq_n_s32(7)), vdupq_n_s32(0));
487
+ subnormal_mantissa_high_i32x4 = vmaxq_s32(vminq_s32(subnormal_mantissa_high_i32x4, vdupq_n_s32(7)), vdupq_n_s32(0));
488
+ int16x4_t subnormal_mantissa_low_i16x4 = vmovn_s32(subnormal_mantissa_low_i32x4);
489
+ int16x4_t subnormal_mantissa_high_i16x4 = vmovn_s32(subnormal_mantissa_high_i32x4);
490
+ uint16x8_t subnormal_mant_u16x8 = vreinterpretq_u16_s16(
491
+ vcombine_s16(subnormal_mantissa_low_i16x4, subnormal_mantissa_high_i16x4));
492
+ uint16x8_t subnormal_result_u16x8 = vorrq_u16(sign_byte_u16x8, subnormal_mant_u16x8);
493
+
494
+ // Special values: E4M3FN has no ∞, max normal = 0x7E (exp=15, mant=6 = 448)
495
+ uint16x8_t e4m3_max = vorrq_u16(sign_byte_u16x8, vdupq_n_u16(0x7E)); // ±448 (exp=15, mant=6)
496
+ uint16x8_t e4m3_nan = vorrq_u16(sign_byte_u16x8, vdupq_n_u16(0x7F)); // ±NaN (exp=15, mant=7)
497
+ uint16x8_t e4m3_zero = sign_byte_u16x8; // ±0
498
+
499
+ // Blend results (order matters: later conditions override earlier)
500
+ uint16x8_t result_u16x8 = normal_result_u16x8;
501
+ result_u16x8 = vbslq_u16(is_underflow, subnormal_result_u16x8, result_u16x8);
502
+ result_u16x8 = vbslq_u16(is_overflow, e4m3_max, result_u16x8);
503
+ result_u16x8 = vbslq_u16(is_f16_special, e4m3_max, result_u16x8); // F16 inf → E4M3 max (no inf in E4M3FN)
504
+ result_u16x8 = vbslq_u16(is_f16_nan, e4m3_nan, result_u16x8); // F16 nan → E4M3 nan
505
+ result_u16x8 = vbslq_u16(is_f16_zero, e4m3_zero, result_u16x8); // Preserve ±0
506
+
507
+ return vmovn_u16(result_u16x8);
508
+ }
509
+
510
+ /** @brief Convert f16x8 → 8x e5m2 with RNE rounding (NEON).
511
+ * F16 (bias=15) and E5M2 (bias=15) share the same bias, so conversion is truncation with RNE rounding.
512
+ * F16: S EEEEE MMMMMMMMMM → E5M2: S EEEEE MM. Mantissa overflow carries into exponent. */
513
+ NK_INTERNAL uint8x8_t nk_f16x8_to_e5m2x8_neon_(float16x8_t f16x8) {
514
+ uint16x8_t bits_u16x8 = vreinterpretq_u16_f16(f16x8);
515
+
516
+ // Detect inf/nan (exp=31) - these should not be rounded, just truncated
517
+ uint16x8_t exp_u16x8 = vandq_u16(vshrq_n_u16(bits_u16x8, 10), vdupq_n_u16(0x1F));
518
+ uint16x8_t is_special_mask = vceqq_u16(exp_u16x8, vdupq_n_u16(31));
519
+
520
+ // RNE rounding: add (0x7F + lsb) where lsb = bit 8 of F16
521
+ // This rounds the lower 8 bits correctly and may carry into exponent
522
+ uint16x8_t lsb_u16x8 = vandq_u16(vshrq_n_u16(bits_u16x8, 8), vdupq_n_u16(1));
523
+ uint16x8_t rounding_bias_u16x8 = vaddq_u16(vdupq_n_u16(0x7F), lsb_u16x8);
524
+ uint16x8_t rounded_bits_u16x8 = vaddq_u16(bits_u16x8, rounding_bias_u16x8);
525
+
526
+ // For special values (inf/nan), use original bits without rounding
527
+ uint16x8_t final_bits_u16x8 = vbslq_u16(is_special_mask, bits_u16x8, rounded_bits_u16x8);
528
+
529
+ // Shift right by 8 to get E5M2 format
530
+ uint16x8_t e5m2_u16x8 = vshrq_n_u16(final_bits_u16x8, 8);
531
+ return vmovn_u16(e5m2_u16x8);
532
+ }
533
+
534
+ /** @brief Convert 4x bf16 → f32x4 via bit shift (NEON).
535
+ * BF16 format: S EEEEEEEE MMMMMMM (bias=127, same as f32 but truncated mantissa).
536
+ * F32 = bf16 << 16. */
537
+ NK_INTERNAL float32x4_t nk_bf16x4_to_f32x4_neon_(uint16x4_t bf16_u16x4) {
538
+ uint32x4_t bits_u32x4 = vshlq_n_u32(vmovl_u16(bf16_u16x4), 16);
539
+ return vreinterpretq_f32_u32(bits_u32x4);
540
+ }
541
+
542
+ /** @brief Convert 4x f16 (as u16 bits) → f32x4 via integer bit manipulation (NEON).
543
+ * F16 format: S EEEEE MMMMMMMMMM (bias=15, 5-bit exponent, 10-bit mantissa).
544
+ * Works on ARMv8.0 without the FP16 arithmetic extension. Treats denormals as zero. */
545
+ NK_INTERNAL float32x4_t nk_f16x4_to_f32x4_neon_(uint16x4_t half_u16x4) {
546
+ // Widen u16 to u32
547
+ uint32x4_t bits_u32x4 = vmovl_u16(half_u16x4);
548
+ // Extract sign, exponent, mantissa
549
+ uint32x4_t sign_u32x4 = vshlq_n_u32(vandq_u32(bits_u32x4, vdupq_n_u32(0x8000)), 16);
550
+ uint32x4_t exponent_u32x4 = vandq_u32(bits_u32x4, vdupq_n_u32(0x7C00));
551
+ uint32x4_t mantissa_u32x4 = vandq_u32(bits_u32x4, vdupq_n_u32(0x03FF));
552
+ // Normal path: ((exponent + mantissa) << 13) + rebias(112 << 23 = 0x38000000)
553
+ uint32x4_t exponent_mantissa_u32x4 = vandq_u32(bits_u32x4, vdupq_n_u32(0x7FFF));
554
+ uint32x4_t normal_u32x4 = vaddq_u32(vshlq_n_u32(exponent_mantissa_u32x4, 13), vdupq_n_u32(0x38000000));
555
+ // Inf/NaN path (exponent == 0x7C00): 0x7F800000 | (mantissa << 13)
556
+ uint32x4_t inf_nan_u32x4 = vorrq_u32(vdupq_n_u32(0x7F800000), vshlq_n_u32(mantissa_u32x4, 13));
557
+ // Select inf/NaN where exponent == 31 (0x7C00)
558
+ uint32x4_t is_inf_nan_u32x4 = vceqq_u32(exponent_u32x4, vdupq_n_u32(0x7C00));
559
+ uint32x4_t result_u32x4 = vbslq_u32(is_inf_nan_u32x4, inf_nan_u32x4, normal_u32x4);
560
+ // Zero path (exponent == 0): treat denormals as zero for simplicity
561
+ uint32x4_t is_zero_u32x4 = vceqq_u32(exponent_u32x4, vdupq_n_u32(0));
562
+ result_u32x4 = vbslq_u32(is_zero_u32x4, vdupq_n_u32(0), result_u32x4);
563
+ // OR sign back
564
+ result_u32x4 = vorrq_u32(result_u32x4, sign_u32x4);
565
+ return vreinterpretq_f32_u32(result_u32x4);
566
+ }
567
+
568
+ /** @brief Convert f32x4 → 4x bf16 with RNE rounding (NEON).
569
+ * Round-to-nearest-even: add (0x7FFF + lsb) before truncation. */
570
+ NK_INTERNAL uint16x4_t nk_f32x4_to_bf16x4_neon_(float32x4_t f32x4) {
571
+ uint32x4_t bits_u32x4 = vreinterpretq_u32_f32(f32x4);
572
+ uint32x4_t lsb_u32x4 = vandq_u32(vshrq_n_u32(bits_u32x4, 16), vdupq_n_u32(1));
573
+ uint32x4_t rounding_u32x4 = vaddq_u32(vdupq_n_u32(0x7FFF), lsb_u32x4);
574
+ bits_u32x4 = vaddq_u32(bits_u32x4, rounding_u32x4);
575
+ return vmovn_u32(vshrq_n_u32(bits_u32x4, 16));
576
+ }
577
+
578
+ /** @brief Convert 8x e4m3 → bf16x8 via direct bit manipulation (NEON).
579
+ * E4M3FN format: S EEEE MMM (bias=7). BF16: S EEEEEEEE MMMMMMM (bias=127).
580
+ * Direct conversion without F16 ÷ F32 intermediate for hot loop efficiency. */
581
+ NK_INTERNAL uint16x8_t nk_e4m3x8_to_bf16x8_neon_(uint8x8_t e4m3_u8x8) {
582
+ uint16x8_t e4m3_u16x8 = vmovl_u8(e4m3_u8x8);
583
+ uint16x8_t sign_u16x8 = vshlq_n_u16(vandq_u16(e4m3_u16x8, vdupq_n_u16(0x80)), 8); // sign << 15
584
+ uint16x8_t exp_u16x8 = vandq_u16(vshrq_n_u16(e4m3_u16x8, 3), vdupq_n_u16(0x0F));
585
+ uint16x8_t mant_u16x8 = vandq_u16(e4m3_u16x8, vdupq_n_u16(0x07));
586
+
587
+ // Normal path: BF16_exp = E4M3_exp + 120, BF16_mant = E4M3_mant << 4
588
+ uint16x8_t bf16_exp_u16x8 = vshlq_n_u16(vaddq_u16(exp_u16x8, vdupq_n_u16(120)), 7);
589
+ uint16x8_t bf16_mant_u16x8 = vshlq_n_u16(mant_u16x8, 4);
590
+ uint16x8_t normal_u16x8 = vorrq_u16(sign_u16x8, vorrq_u16(bf16_exp_u16x8, bf16_mant_u16x8));
591
+
592
+ // Subnormal path (exp=0): E4M3 subnormal = mant × 2⁻⁹ = mant ÷ 512 → BF16
593
+ // Compute via f32: mant → f32 → multiply → truncate to bf16
594
+ float32x4_t subnormal_low_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(mant_u16x8))), 1.0f / 512.0f);
595
+ float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(mant_u16x8))), 1.0f / 512.0f);
596
+ uint16x8_t subnormal_abs_u16x8 = vcombine_u16(nk_f32x4_to_bf16x4_neon_(subnormal_low_f32x4),
597
+ nk_f32x4_to_bf16x4_neon_(subnormal_high_f32x4));
598
+ uint16x8_t subnormal_u16x8 = vorrq_u16(subnormal_abs_u16x8, sign_u16x8);
599
+
600
+ // NaN path: E4M3FN only has NaN when exp=15 AND mant=7 (0x7F or 0xFF)
601
+ uint16x8_t nan_u16x8 = vorrq_u16(sign_u16x8, vdupq_n_u16(0x7FC0)); // BF16 quiet NaN
602
+ uint16x8_t is_nan_mask = vandq_u16(vceqq_u16(exp_u16x8, vdupq_n_u16(15)), vceqq_u16(mant_u16x8, vdupq_n_u16(7)));
603
+
604
+ // Blend paths: subnormal when exp=0, NaN when exp=15 && mant=7, else normal
605
+ uint16x8_t exp_zero_mask = vceqq_u16(exp_u16x8, vdupq_n_u16(0));
606
+ uint16x8_t result_u16x8 = vbslq_u16(exp_zero_mask, subnormal_u16x8, normal_u16x8);
607
+ result_u16x8 = vbslq_u16(is_nan_mask, nan_u16x8, result_u16x8);
608
+ return result_u16x8;
609
+ }
610
+
611
+ /** @brief Convert 8x e5m2 → bf16x8 via direct bit manipulation (NEON).
612
+ * E5M2 format: S EEEEE MM (bias=15). BF16: S EEEEEEEE MMMMMMM (bias=127).
613
+ * Direct conversion without F16 ÷ F32 intermediate for hot loop efficiency. */
614
+ NK_INTERNAL uint16x8_t nk_e5m2x8_to_bf16x8_neon_(uint8x8_t e5m2_u8x8) {
615
+ uint16x8_t e5m2_u16x8 = vmovl_u8(e5m2_u8x8);
616
+ uint16x8_t sign_u16x8 = vshlq_n_u16(vandq_u16(e5m2_u16x8, vdupq_n_u16(0x80)), 8); // sign << 15
617
+ uint16x8_t exp_u16x8 = vandq_u16(vshrq_n_u16(e5m2_u16x8, 2), vdupq_n_u16(0x1F));
618
+ uint16x8_t mant_u16x8 = vandq_u16(e5m2_u16x8, vdupq_n_u16(0x03));
619
+
620
+ // Normal path: BF16_exp = E5M2_exp + 112, BF16_mant = E5M2_mant << 5
621
+ uint16x8_t bf16_exp_u16x8 = vshlq_n_u16(vaddq_u16(exp_u16x8, vdupq_n_u16(112)), 7);
622
+ uint16x8_t bf16_mant_u16x8 = vshlq_n_u16(mant_u16x8, 5);
623
+ uint16x8_t normal_u16x8 = vorrq_u16(sign_u16x8, vorrq_u16(bf16_exp_u16x8, bf16_mant_u16x8));
624
+
625
+ // Subnormal path (exp=0): E5M2 subnormal = mant × 2⁻¹⁶ = mant ÷ 65536 → BF16
626
+ // Compute via f32: mant → f32 → multiply → truncate to bf16
627
+ float32x4_t subnormal_low_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(mant_u16x8))), 1.0f / 65536.0f);
628
+ float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(mant_u16x8))),
629
+ 1.0f / 65536.0f);
630
+ uint16x8_t subnormal_abs_u16x8 = vcombine_u16(nk_f32x4_to_bf16x4_neon_(subnormal_low_f32x4),
631
+ nk_f32x4_to_bf16x4_neon_(subnormal_high_f32x4));
632
+ uint16x8_t subnormal_u16x8 = vorrq_u16(subnormal_abs_u16x8, sign_u16x8);
633
+
634
+ // Special path (exp=31): inf (mant=0) or nan (mant≠0)
635
+ uint16x8_t infinity_u16x8 = vorrq_u16(sign_u16x8, vdupq_n_u16(0x7F80));
636
+ uint16x8_t nan_u16x8 = vorrq_u16(sign_u16x8, vdupq_n_u16(0x7FC0));
637
+ uint16x8_t mant_zero_mask = vceqq_u16(mant_u16x8, vdupq_n_u16(0));
638
+ uint16x8_t special_u16x8 = vbslq_u16(mant_zero_mask, infinity_u16x8, nan_u16x8);
639
+
640
+ // Blend paths based on exponent value
641
+ uint16x8_t exp_zero_mask = vceqq_u16(exp_u16x8, vdupq_n_u16(0));
642
+ uint16x8_t exp_max_mask = vceqq_u16(exp_u16x8, vdupq_n_u16(31));
643
+ uint16x8_t result_u16x8 = vbslq_u16(exp_zero_mask, subnormal_u16x8, normal_u16x8);
644
+ result_u16x8 = vbslq_u16(exp_max_mask, special_u16x8, result_u16x8);
645
+ return result_u16x8;
646
+ }
647
+
648
+ /** @brief Convert 4x i16 → f32x4 (NEON). Widen to i32, then convert. */
649
+ NK_INTERNAL float32x4_t nk_i16x4_to_f32x4_neon_(int16x4_t i16x4) { return vcvtq_f32_s32(vmovl_s16(i16x4)); }
650
+
651
+ /** @brief Convert 4x u16 → f32x4 (NEON). Widen to u32, then convert. */
652
+ NK_INTERNAL float32x4_t nk_u16x4_to_f32x4_neon_(uint16x4_t u16x4) { return vcvtq_f32_u32(vmovl_u16(u16x4)); }
653
+
654
+ /** @brief Convert 4x i8 → f32x4 (NEON). Loads exactly 4 bytes via nk_b32_vec_t to avoid overread. */
655
+ NK_INTERNAL float32x4_t nk_i8x4_to_f32x4_neon_(nk_b32_vec_t in_vec) {
656
+ int8x8_t i8x8 = vcreate_s8((nk_u64_t)in_vec.u32);
657
+ int16x8_t i16x8 = vmovl_s8(i8x8);
658
+ return vcvtq_f32_s32(vmovl_s16(vget_low_s16(i16x8)));
659
+ }
660
+
661
+ /** @brief Convert 4x u8 → f32x4 (NEON). Loads exactly 4 bytes via nk_b32_vec_t to avoid overread. */
662
+ NK_INTERNAL float32x4_t nk_u8x4_to_f32x4_neon_(nk_b32_vec_t in_vec) {
663
+ uint8x8_t u8x8 = vcreate_u8((nk_u64_t)in_vec.u32);
664
+ uint16x8_t u16x8 = vmovl_u8(u8x8);
665
+ return vcvtq_f32_u32(vmovl_u16(vget_low_u16(u16x8)));
666
+ }
667
+
668
+ /** @brief Convert f32x4 → 4x i16 with saturation (NEON). Convert to i32, narrow. */
669
+ NK_INTERNAL int16x4_t nk_f32x4_to_i16x4_neon_(float32x4_t f32x4) {
670
+ int32x4_t i32x4 = vcvtnq_s32_f32(f32x4);
671
+ return vqmovn_s32(i32x4);
672
+ }
673
+
674
+ /** @brief Convert f32x4 → 4x u16 with saturation (NEON). Convert to u32, narrow. */
675
+ NK_INTERNAL uint16x4_t nk_f32x4_to_u16x4_neon_(float32x4_t f32x4) {
676
+ uint32x4_t u32x4 = vcvtnq_u32_f32(f32x4);
677
+ return vqmovn_u32(u32x4);
678
+ }
679
+
680
+ /** @brief Convert f32x4 → 4x i8 with saturation (NEON). Convert to i32, narrow twice. */
681
+ NK_INTERNAL void nk_f32x4_to_i8x4_neon_(float32x4_t f32x4, nk_i8_t *dst) {
682
+ int32x4_t i32x4 = vcvtnq_s32_f32(f32x4);
683
+ int16x4_t i16x4 = vqmovn_s32(i32x4);
684
+ int8x8_t i8x8 = vqmovn_s16(vcombine_s16(i16x4, i16x4));
685
+ // Reinterpret as s32x2, store lane 0 (4 bytes in one instruction)
686
+ vst1_lane_s32((int32_t *)dst, vreinterpret_s32_s8(i8x8), 0);
687
+ }
688
+
689
+ /** @brief Convert f32x4 → 4x u8 with saturation (NEON). Convert to u32, narrow twice. */
690
+ NK_INTERNAL void nk_f32x4_to_u8x4_neon_(float32x4_t f32x4, nk_u8_t *dst) {
691
+ uint32x4_t u32x4 = vcvtnq_u32_f32(f32x4);
692
+ uint16x4_t u16x4 = vqmovn_u32(u32x4);
693
+ uint8x8_t u8x8 = vqmovn_u16(vcombine_u16(u16x4, u16x4));
694
+ // Reinterpret as u32x2, store lane 0 (4 bytes in one instruction)
695
+ vst1_lane_u32((uint32_t *)dst, vreinterpret_u32_u8(u8x8), 0);
696
+ }
697
+
698
+ /** @brief Convert f32x4 → 4x e4m3 via bit manipulation (NEON).
699
+ * E4M3 format: S EEEE MMM (bias=7). Handles normal, subnormal, and overflow cases.
700
+ * Uses RNE (round to nearest even) for mantissa rounding. Returns packed result in nk_b32_vec_t. */
701
+ NK_INTERNAL nk_b32_vec_t nk_f32x4_to_e4m3x4_neon_(float32x4_t f32x4) {
702
+ uint32x4_t bits_u32x4 = vreinterpretq_u32_f32(f32x4);
703
+ uint32x4_t sign_u32x4 = vshrq_n_u32(bits_u32x4, 31);
704
+ uint32x4_t f32_exp_u32x4 = vandq_u32(vshrq_n_u32(bits_u32x4, 23), vdupq_n_u32(0xFF));
705
+
706
+ // Round mantissa from 23 to 3 bits using RNE (round to nearest, ties to even)
707
+ // RNE trick: add (half - 1 + lsb) where lsb is the bit that will become the new lsb after shift
708
+ uint32x4_t significand_u32x4 = vorrq_u32(vandq_u32(bits_u32x4, vdupq_n_u32(0x007FFFFF)),
709
+ vdupq_n_u32(0x00800000)); // Add implicit 1 bit
710
+ uint32x4_t lsb_u32x4 = vandq_u32(vshrq_n_u32(significand_u32x4, 20), vdupq_n_u32(1));
711
+ uint32x4_t rounding_bias_u32x4 = vaddq_u32(vdupq_n_u32(0x0007FFFF), lsb_u32x4);
712
+ uint32x4_t rounded_sig_u32x4 = vaddq_u32(significand_u32x4, rounding_bias_u32x4);
713
+ uint32x4_t carry_u32x4 = vshrq_n_u32(rounded_sig_u32x4, 24); // Carry into exponent if bit 24 set
714
+ uint32x4_t f32_mantissa_u32x4 = vandq_u32(vshrq_n_u32(rounded_sig_u32x4, 20), vdupq_n_u32(0x07));
715
+ // If carry, mantissa becomes 0 (we rounded up to next power of 2)
716
+ uint32x4_t carry_mask_u32x4 = vceqq_u32(carry_u32x4, vdupq_n_u32(1));
717
+ f32_mantissa_u32x4 = vbicq_u32(f32_mantissa_u32x4, carry_mask_u32x4);
718
+
719
+ // Rebias exponent: f32 bias 127 → e4m3 bias 7 (subtract 120)
720
+ int32x4_t e4m3_exp_i32x4 = vsubq_s32(
721
+ vaddq_s32(vreinterpretq_s32_u32(f32_exp_u32x4), vreinterpretq_s32_u32(carry_u32x4)), vdupq_n_s32(120));
722
+
723
+ // Detect underflow (exp <= 0, maps to subnormal/zero) and overflow (exp > 15)
724
+ uint32x4_t is_subnormal_u32x4 = vcltq_s32(e4m3_exp_i32x4, vdupq_n_s32(1));
725
+ uint32x4_t overflow_u32x4 = vcgtq_s32(e4m3_exp_i32x4, vdupq_n_s32(15));
726
+
727
+ // Normal path: clamp exp to [1,15], extract mantissa bits
728
+ // e4m3FN quirk: exp=15 with mantissa=7 is NaN (0x7F), so clamp mantissa to 6 when exp=15.
729
+ int32x4_t clamped_exp_i32x4 = vmaxq_s32(e4m3_exp_i32x4, vdupq_n_s32(1));
730
+ clamped_exp_i32x4 = vminq_s32(clamped_exp_i32x4, vdupq_n_s32(15));
731
+ uint32x4_t is_max_exp_u32x4 = vceqq_s32(clamped_exp_i32x4, vdupq_n_s32(15));
732
+ uint32x4_t max_mantissa_u32x4 = vbslq_u32(is_max_exp_u32x4, vdupq_n_u32(6), vdupq_n_u32(7));
733
+ uint32x4_t normal_mantissa_u32x4 = vminq_u32(f32_mantissa_u32x4, max_mantissa_u32x4);
734
+ normal_mantissa_u32x4 = vbslq_u32(overflow_u32x4, vdupq_n_u32(0x06), normal_mantissa_u32x4);
735
+ uint32x4_t normal_e4m3_u32x4 = vorrq_u32(
736
+ vshlq_n_u32(sign_u32x4, 7),
737
+ vorrq_u32(vshlq_n_u32(vreinterpretq_u32_s32(clamped_exp_i32x4), 3), normal_mantissa_u32x4));
738
+
739
+ // Subnormal path: mantissa = round(abs_f32 * 512)
740
+ // If mantissa rounds to 8 or higher, promote to first normal (exp_field=1, mantissa=0) = 0x08
741
+ float32x4_t abs_f32x4 = vabsq_f32(f32x4);
742
+ float32x4_t scaled_f32x4 = vmulq_n_f32(abs_f32x4, 512.0f);
743
+ int32x4_t subnormal_mantissa_i32x4 = vcvtnq_s32_f32(scaled_f32x4);
744
+ uint32x4_t promotes_to_normal_u32x4 = vcgtq_s32(subnormal_mantissa_i32x4, vdupq_n_s32(7));
745
+ subnormal_mantissa_i32x4 = vminq_s32(subnormal_mantissa_i32x4, vdupq_n_s32(7));
746
+ subnormal_mantissa_i32x4 = vmaxq_s32(subnormal_mantissa_i32x4, vdupq_n_s32(0));
747
+ uint32x4_t subnormal_e4m3_u32x4 = vorrq_u32(vshlq_n_u32(sign_u32x4, 7),
748
+ vreinterpretq_u32_s32(subnormal_mantissa_i32x4));
749
+ // When mantissa rounds to 8, use first normal value (0x08) instead of clamped subnormal
750
+ uint32x4_t first_normal_e4m3_u32x4 = vorrq_u32(vshlq_n_u32(sign_u32x4, 7), vdupq_n_u32(0x08));
751
+ subnormal_e4m3_u32x4 = vbslq_u32(promotes_to_normal_u32x4, first_normal_e4m3_u32x4, subnormal_e4m3_u32x4);
752
+
753
+ // Blend: use subnormal result when exp <= 0, else normal
754
+ uint32x4_t e4m3_u32x4 = vbslq_u32(is_subnormal_u32x4, subnormal_e4m3_u32x4, normal_e4m3_u32x4);
755
+
756
+ // Pack 4 u32s to 4 u8s
757
+ uint16x4_t e4m3_u16x4 = vmovn_u32(e4m3_u32x4);
758
+ uint8x8_t e4m3_u8x8 = vmovn_u16(vcombine_u16(e4m3_u16x4, e4m3_u16x4));
759
+ nk_b32_vec_t result;
760
+ result.u32 = vget_lane_u32(vreinterpret_u32_u8(e4m3_u8x8), 0);
761
+ return result;
762
+ }
763
+
764
+ /** @brief Convert f32x4 → 4x e5m2 via bit manipulation (NEON).
765
+ * E5M2 format: S EEEEE MM (bias=15). Handles normal, subnormal, and overflow cases.
766
+ * Uses RNE (round to nearest even) for mantissa rounding. Returns packed result in nk_b32_vec_t. */
767
+ NK_INTERNAL nk_b32_vec_t nk_f32x4_to_e5m2x4_neon_(float32x4_t f32x4) {
768
+ uint32x4_t bits_u32x4 = vreinterpretq_u32_f32(f32x4);
769
+ uint32x4_t sign_u32x4 = vshrq_n_u32(bits_u32x4, 31);
770
+ uint32x4_t f32_exp_u32x4 = vandq_u32(vshrq_n_u32(bits_u32x4, 23), vdupq_n_u32(0xFF));
771
+
772
+ // Round mantissa from 23 to 2 bits using RNE (round to nearest, ties to even)
773
+ // RNE trick: add (half - 1 + lsb) where lsb is the bit that will become the new lsb after shift
774
+ uint32x4_t significand_u32x4 = vorrq_u32(vandq_u32(bits_u32x4, vdupq_n_u32(0x007FFFFF)),
775
+ vdupq_n_u32(0x00800000)); // Add implicit 1 bit
776
+ uint32x4_t lsb_u32x4 = vandq_u32(vshrq_n_u32(significand_u32x4, 21), vdupq_n_u32(1));
777
+ uint32x4_t rounding_bias_u32x4 = vaddq_u32(vdupq_n_u32(0x000FFFFF), lsb_u32x4); // half = 0x100000
778
+ uint32x4_t rounded_sig_u32x4 = vaddq_u32(significand_u32x4, rounding_bias_u32x4);
779
+ uint32x4_t carry_u32x4 = vshrq_n_u32(rounded_sig_u32x4, 24); // Carry into exponent if bit 24 set
780
+ uint32x4_t f32_mantissa_u32x4 = vandq_u32(vshrq_n_u32(rounded_sig_u32x4, 21), vdupq_n_u32(0x03));
781
+ // If carry, mantissa becomes 0 (we rounded up to next power of 2)
782
+ uint32x4_t carry_mask_u32x4 = vceqq_u32(carry_u32x4, vdupq_n_u32(1));
783
+ f32_mantissa_u32x4 = vbicq_u32(f32_mantissa_u32x4, carry_mask_u32x4);
784
+
785
+ // Rebias exponent: f32 bias 127 → e5m2 bias 15 (subtract 112)
786
+ int32x4_t e5m2_exp_i32x4 = vsubq_s32(
787
+ vaddq_s32(vreinterpretq_s32_u32(f32_exp_u32x4), vreinterpretq_s32_u32(carry_u32x4)), vdupq_n_s32(112));
788
+
789
+ // Detect subnormal (exp <= 0) and overflow (exp > 31)
790
+ uint32x4_t is_subnormal_u32x4 = vcltq_s32(e5m2_exp_i32x4, vdupq_n_s32(1));
791
+ uint32x4_t overflow_u32x4 = vcgtq_s32(e5m2_exp_i32x4, vdupq_n_s32(31));
792
+
793
+ // Normal path: clamp exp to [1,31], on overflow return infinity (exp=31, mantissa=0 = 0x7C)
794
+ int32x4_t clamped_exp_i32x4 = vmaxq_s32(e5m2_exp_i32x4, vdupq_n_s32(1));
795
+ clamped_exp_i32x4 = vminq_s32(clamped_exp_i32x4, vdupq_n_s32(31));
796
+ uint32x4_t normal_mantissa_u32x4 = vbslq_u32(overflow_u32x4, vdupq_n_u32(0), f32_mantissa_u32x4);
797
+ uint32x4_t normal_e5m2_u32x4 = vorrq_u32(
798
+ vshlq_n_u32(sign_u32x4, 7),
799
+ vorrq_u32(vshlq_n_u32(vreinterpretq_u32_s32(clamped_exp_i32x4), 2), normal_mantissa_u32x4));
800
+
801
+ // Subnormal path: mantissa = round(abs_f32 * 65536)
802
+ // If mantissa rounds to 4 or higher, promote to first normal (exp_field=1, mantissa=0) = 0x04
803
+ float32x4_t abs_f32x4 = vabsq_f32(f32x4);
804
+ float32x4_t scaled_f32x4 = vmulq_n_f32(abs_f32x4, 65536.0f);
805
+ int32x4_t subnormal_mantissa_i32x4 = vcvtnq_s32_f32(scaled_f32x4);
806
+ uint32x4_t promotes_to_normal_u32x4 = vcgtq_s32(subnormal_mantissa_i32x4, vdupq_n_s32(3));
807
+ subnormal_mantissa_i32x4 = vminq_s32(subnormal_mantissa_i32x4, vdupq_n_s32(3));
808
+ subnormal_mantissa_i32x4 = vmaxq_s32(subnormal_mantissa_i32x4, vdupq_n_s32(0));
809
+ uint32x4_t subnormal_e5m2_u32x4 = vorrq_u32(vshlq_n_u32(sign_u32x4, 7),
810
+ vreinterpretq_u32_s32(subnormal_mantissa_i32x4));
811
+ // When mantissa rounds to 4, use first normal value (0x04) instead of clamped subnormal
812
+ uint32x4_t first_normal_e5m2_u32x4 = vorrq_u32(vshlq_n_u32(sign_u32x4, 7), vdupq_n_u32(0x04));
813
+ subnormal_e5m2_u32x4 = vbslq_u32(promotes_to_normal_u32x4, first_normal_e5m2_u32x4, subnormal_e5m2_u32x4);
814
+
815
+ // Blend: use subnormal result when exp <= 0
816
+ uint32x4_t e5m2_u32x4 = vbslq_u32(is_subnormal_u32x4, subnormal_e5m2_u32x4, normal_e5m2_u32x4);
817
+
818
+ // Pack 4 u32s to 4 u8s
819
+ uint16x4_t e5m2_u16x4 = vmovn_u32(e5m2_u32x4);
820
+ uint8x8_t e5m2_u8x8 = vmovn_u16(vcombine_u16(e5m2_u16x4, e5m2_u16x4));
821
+ nk_b32_vec_t result;
822
+ result.u32 = vget_lane_u32(vreinterpret_u32_u8(e5m2_u8x8), 0);
823
+ return result;
824
+ }
825
+
826
+ /** @brief Convert 4x e2m3 → f32x4 via bit manipulation (NEON).
827
+ * E2M3 format: S EE MMM (bias=1). F32: sign<<31, (exp+126)<<23, mantissa<<20.
828
+ * Handles subnormals (exp=0, mant ≠ 0). */
829
+ NK_INTERNAL float32x4_t nk_e2m3x4_to_f32x4_neon_(nk_b32_vec_t src) {
830
+ uint8x8_t e2m3_u8x8 = vcreate_u8(src.u32);
831
+ uint16x8_t e2m3_u16x8 = vmovl_u8(e2m3_u8x8);
832
+ uint32x4_t e2m3_u32x4 = vmovl_u16(vget_low_u16(e2m3_u16x8));
833
+ uint32x4_t sign_u32x4 = vshlq_n_u32(vandq_u32(e2m3_u32x4, vdupq_n_u32(0x20)), 26);
834
+ uint32x4_t exp_u32x4 = vandq_u32(vshrq_n_u32(e2m3_u32x4, 3), vdupq_n_u32(0x03));
835
+ uint32x4_t mant_u32x4 = vandq_u32(e2m3_u32x4, vdupq_n_u32(0x07));
836
+
837
+ // Normal path: f32 = sign | ((exp+126)<<23) | (mant<<20)
838
+ uint32x4_t f32_exp_u32x4 = vshlq_n_u32(vaddq_u32(exp_u32x4, vdupq_n_u32(126)), 23);
839
+ uint32x4_t f32_mant_u32x4 = vshlq_n_u32(mant_u32x4, 20);
840
+ uint32x4_t normal_u32x4 = vorrq_u32(sign_u32x4, vorrq_u32(f32_exp_u32x4, f32_mant_u32x4));
841
+
842
+ // Subnormal path (exp=0, mant ≠ 0): value = ±mantissa × 2⁻³
843
+ float32x4_t subnormal_f32x4 = vmulq_n_f32(vcvtq_f32_u32(mant_u32x4), 1.0f / 8.0f);
844
+ uint32x4_t subnormal_u32x4 = vorrq_u32(vreinterpretq_u32_f32(subnormal_f32x4), sign_u32x4);
845
+
846
+ // Blend paths: subnormal when exp=0, else normal
847
+ uint32x4_t exp_zero_mask = vceqq_u32(exp_u32x4, vdupq_n_u32(0));
848
+ uint32x4_t result_u32x4 = vbslq_u32(exp_zero_mask, subnormal_u32x4, normal_u32x4);
849
+ return vreinterpretq_f32_u32(result_u32x4);
850
+ }
851
+
852
+ /** @brief Convert 4x e3m2 → f32x4 via bit manipulation (NEON).
853
+ * E3M2 format: S EEE MM (bias=3). F32: sign<<31, (exp+124)<<23, mantissa<<21.
854
+ * Handles subnormals (exp=0, mant ≠ 0). */
855
+ NK_INTERNAL float32x4_t nk_e3m2x4_to_f32x4_neon_(nk_b32_vec_t src) {
856
+ uint8x8_t e3m2_u8x8 = vcreate_u8(src.u32);
857
+ uint16x8_t e3m2_u16x8 = vmovl_u8(e3m2_u8x8);
858
+ uint32x4_t e3m2_u32x4 = vmovl_u16(vget_low_u16(e3m2_u16x8));
859
+ uint32x4_t sign_u32x4 = vshlq_n_u32(vandq_u32(e3m2_u32x4, vdupq_n_u32(0x20)), 26);
860
+ uint32x4_t exp_u32x4 = vandq_u32(vshrq_n_u32(e3m2_u32x4, 2), vdupq_n_u32(0x07));
861
+ uint32x4_t mant_u32x4 = vandq_u32(e3m2_u32x4, vdupq_n_u32(0x03));
862
+
863
+ // Normal path: f32 = sign | ((exp+124)<<23) | (mant<<21)
864
+ uint32x4_t f32_exp_u32x4 = vshlq_n_u32(vaddq_u32(exp_u32x4, vdupq_n_u32(124)), 23);
865
+ uint32x4_t f32_mant_u32x4 = vshlq_n_u32(mant_u32x4, 21);
866
+ uint32x4_t normal_u32x4 = vorrq_u32(sign_u32x4, vorrq_u32(f32_exp_u32x4, f32_mant_u32x4));
867
+
868
+ // Subnormal path (exp=0, mant ≠ 0): value = ±mantissa × 2⁻⁴
869
+ float32x4_t subnormal_f32x4 = vmulq_n_f32(vcvtq_f32_u32(mant_u32x4), 1.0f / 16.0f);
870
+ uint32x4_t subnormal_u32x4 = vorrq_u32(vreinterpretq_u32_f32(subnormal_f32x4), sign_u32x4);
871
+
872
+ // Blend paths: subnormal when exp=0, else normal
873
+ uint32x4_t exp_zero_mask = vceqq_u32(exp_u32x4, vdupq_n_u32(0));
874
+ uint32x4_t result_u32x4 = vbslq_u32(exp_zero_mask, subnormal_u32x4, normal_u32x4);
875
+ return vreinterpretq_f32_u32(result_u32x4);
876
+ }
877
+
878
+ /** @brief Convert f32x4 → 4x e2m3 via bit manipulation (NEON).
879
+ * E2M3 format: S EE MMM (bias=1). Handles normal, subnormal, and overflow cases.
880
+ * Uses RNE (round to nearest even) for mantissa rounding. Returns packed result in nk_b32_vec_t. */
881
+ NK_INTERNAL nk_b32_vec_t nk_f32x4_to_e2m3x4_neon_(float32x4_t f32x4) {
882
+ uint32x4_t bits_u32x4 = vreinterpretq_u32_f32(f32x4);
883
+ uint32x4_t sign_u32x4 = vshrq_n_u32(bits_u32x4, 31);
884
+ uint32x4_t f32_exp_u32x4 = vandq_u32(vshrq_n_u32(bits_u32x4, 23), vdupq_n_u32(0xFF));
885
+
886
+ // Round mantissa from 23 to 3 bits using RNE (round to nearest, ties to even)
887
+ uint32x4_t significand_u32x4 = vorrq_u32(vandq_u32(bits_u32x4, vdupq_n_u32(0x007FFFFF)),
888
+ vdupq_n_u32(0x00800000)); // Add implicit 1 bit
889
+ uint32x4_t lsb_u32x4 = vandq_u32(vshrq_n_u32(significand_u32x4, 20), vdupq_n_u32(1));
890
+ uint32x4_t rounding_bias_u32x4 = vaddq_u32(vdupq_n_u32(0x0007FFFF), lsb_u32x4);
891
+ uint32x4_t rounded_sig_u32x4 = vaddq_u32(significand_u32x4, rounding_bias_u32x4);
892
+ uint32x4_t carry_u32x4 = vshrq_n_u32(rounded_sig_u32x4, 24); // Carry into exponent if bit 24 set
893
+ uint32x4_t f32_mantissa_u32x4 = vandq_u32(vshrq_n_u32(rounded_sig_u32x4, 20), vdupq_n_u32(0x07));
894
+ // If carry, mantissa becomes 0 (we rounded up to next power of 2)
895
+ uint32x4_t carry_mask_u32x4 = vceqq_u32(carry_u32x4, vdupq_n_u32(1));
896
+ f32_mantissa_u32x4 = vbicq_u32(f32_mantissa_u32x4, carry_mask_u32x4);
897
+
898
+ // Rebias exponent: f32 bias 127 → e2m3 bias 1 (subtract 126)
899
+ int32x4_t e2m3_exp_i32x4 = vsubq_s32(
900
+ vaddq_s32(vreinterpretq_s32_u32(f32_exp_u32x4), vreinterpretq_s32_u32(carry_u32x4)), vdupq_n_s32(126));
901
+
902
+ // Detect underflow (exp <= 0, maps to subnormal/zero) and overflow (exp > 3)
903
+ uint32x4_t is_subnormal_u32x4 = vcltq_s32(e2m3_exp_i32x4, vdupq_n_s32(1));
904
+ uint32x4_t overflow_u32x4 = vcgtq_s32(e2m3_exp_i32x4, vdupq_n_s32(3));
905
+
906
+ // Normal path: clamp exp to [1,3], extract mantissa bits
907
+ int32x4_t clamped_exp_i32x4 = vmaxq_s32(e2m3_exp_i32x4, vdupq_n_s32(1));
908
+ clamped_exp_i32x4 = vminq_s32(clamped_exp_i32x4, vdupq_n_s32(3));
909
+ uint32x4_t normal_mantissa_u32x4 = vbslq_u32(overflow_u32x4, vdupq_n_u32(0x07), f32_mantissa_u32x4);
910
+ uint32x4_t normal_e2m3_u32x4 = vorrq_u32(
911
+ vshlq_n_u32(sign_u32x4, 5),
912
+ vorrq_u32(vshlq_n_u32(vreinterpretq_u32_s32(clamped_exp_i32x4), 3), normal_mantissa_u32x4));
913
+
914
+ // Subnormal path: mantissa = round(abs_f32 * 8)
915
+ // If mantissa rounds to 8 or higher, promote to first normal (exp_field=1, mantissa=0) = 0x08
916
+ float32x4_t abs_f32x4 = vabsq_f32(f32x4);
917
+ float32x4_t scaled_f32x4 = vmulq_n_f32(abs_f32x4, 8.0f);
918
+ int32x4_t subnormal_mantissa_i32x4 = vcvtnq_s32_f32(scaled_f32x4);
919
+ uint32x4_t promotes_to_normal_u32x4 = vcgtq_s32(subnormal_mantissa_i32x4, vdupq_n_s32(7));
920
+ subnormal_mantissa_i32x4 = vminq_s32(subnormal_mantissa_i32x4, vdupq_n_s32(7));
921
+ subnormal_mantissa_i32x4 = vmaxq_s32(subnormal_mantissa_i32x4, vdupq_n_s32(0));
922
+ uint32x4_t subnormal_e2m3_u32x4 = vorrq_u32(vshlq_n_u32(sign_u32x4, 5),
923
+ vreinterpretq_u32_s32(subnormal_mantissa_i32x4));
924
+ // When mantissa rounds to 8, use first normal value (0x08) instead of clamped subnormal
925
+ uint32x4_t first_normal_e2m3_u32x4 = vorrq_u32(vshlq_n_u32(sign_u32x4, 5), vdupq_n_u32(0x08));
926
+ subnormal_e2m3_u32x4 = vbslq_u32(promotes_to_normal_u32x4, first_normal_e2m3_u32x4, subnormal_e2m3_u32x4);
927
+
928
+ // Blend: use subnormal result when exp <= 0, else normal
929
+ uint32x4_t e2m3_u32x4 = vbslq_u32(is_subnormal_u32x4, subnormal_e2m3_u32x4, normal_e2m3_u32x4);
930
+
931
+ // Pack 4 u32s to 4 u8s
932
+ uint16x4_t e2m3_u16x4 = vmovn_u32(e2m3_u32x4);
933
+ uint8x8_t e2m3_u8x8 = vmovn_u16(vcombine_u16(e2m3_u16x4, e2m3_u16x4));
934
+ nk_b32_vec_t result;
935
+ result.u32 = vget_lane_u32(vreinterpret_u32_u8(e2m3_u8x8), 0);
936
+ return result;
937
+ }
938
+
939
+ /** @brief Convert f32x4 → 4x e3m2 via bit manipulation (NEON).
940
+ * E3M2 format: S EEE MM (bias=3). Handles normal, subnormal, and overflow cases.
941
+ * Uses RNE (round to nearest even) for mantissa rounding. Returns packed result in nk_b32_vec_t. */
942
+ NK_INTERNAL nk_b32_vec_t nk_f32x4_to_e3m2x4_neon_(float32x4_t f32x4) {
943
+ uint32x4_t bits_u32x4 = vreinterpretq_u32_f32(f32x4);
944
+ uint32x4_t sign_u32x4 = vshrq_n_u32(bits_u32x4, 31);
945
+ uint32x4_t f32_exp_u32x4 = vandq_u32(vshrq_n_u32(bits_u32x4, 23), vdupq_n_u32(0xFF));
946
+
947
+ // Round mantissa from 23 to 2 bits using RNE (round to nearest, ties to even)
948
+ uint32x4_t significand_u32x4 = vorrq_u32(vandq_u32(bits_u32x4, vdupq_n_u32(0x007FFFFF)),
949
+ vdupq_n_u32(0x00800000)); // Add implicit 1 bit
950
+ uint32x4_t lsb_u32x4 = vandq_u32(vshrq_n_u32(significand_u32x4, 21), vdupq_n_u32(1));
951
+ uint32x4_t rounding_bias_u32x4 = vaddq_u32(vdupq_n_u32(0x000FFFFF), lsb_u32x4);
952
+ uint32x4_t rounded_sig_u32x4 = vaddq_u32(significand_u32x4, rounding_bias_u32x4);
953
+ uint32x4_t carry_u32x4 = vshrq_n_u32(rounded_sig_u32x4, 24); // Carry into exponent if bit 24 set
954
+ uint32x4_t f32_mantissa_u32x4 = vandq_u32(vshrq_n_u32(rounded_sig_u32x4, 21), vdupq_n_u32(0x03));
955
+ // If carry, mantissa becomes 0 (we rounded up to next power of 2)
956
+ uint32x4_t carry_mask_u32x4 = vceqq_u32(carry_u32x4, vdupq_n_u32(1));
957
+ f32_mantissa_u32x4 = vbicq_u32(f32_mantissa_u32x4, carry_mask_u32x4);
958
+
959
+ // Rebias exponent: f32 bias 127 → e3m2 bias 3 (subtract 124)
960
+ int32x4_t e3m2_exp_i32x4 = vsubq_s32(
961
+ vaddq_s32(vreinterpretq_s32_u32(f32_exp_u32x4), vreinterpretq_s32_u32(carry_u32x4)), vdupq_n_s32(124));
962
+
963
+ // Detect underflow (exp <= 0, maps to subnormal/zero) and overflow (exp > 7)
964
+ uint32x4_t is_subnormal_u32x4 = vcltq_s32(e3m2_exp_i32x4, vdupq_n_s32(1));
965
+ uint32x4_t overflow_u32x4 = vcgtq_s32(e3m2_exp_i32x4, vdupq_n_s32(7));
966
+
967
+ // Normal path: clamp exp to [1,7], extract mantissa bits
968
+ int32x4_t clamped_exp_i32x4 = vmaxq_s32(e3m2_exp_i32x4, vdupq_n_s32(1));
969
+ clamped_exp_i32x4 = vminq_s32(clamped_exp_i32x4, vdupq_n_s32(7));
970
+ uint32x4_t normal_mantissa_u32x4 = vbslq_u32(overflow_u32x4, vdupq_n_u32(0x03), f32_mantissa_u32x4);
971
+ uint32x4_t normal_e3m2_u32x4 = vorrq_u32(
972
+ vshlq_n_u32(sign_u32x4, 5),
973
+ vorrq_u32(vshlq_n_u32(vreinterpretq_u32_s32(clamped_exp_i32x4), 2), normal_mantissa_u32x4));
974
+
975
+ // Subnormal path: mantissa = round(abs_f32 * 16)
976
+ // If mantissa rounds to 4 or higher, promote to first normal (exp_field=1, mantissa=0) = 0x04
977
+ float32x4_t abs_f32x4 = vabsq_f32(f32x4);
978
+ float32x4_t scaled_f32x4 = vmulq_n_f32(abs_f32x4, 16.0f);
979
+ int32x4_t subnormal_mantissa_i32x4 = vcvtnq_s32_f32(scaled_f32x4);
980
+ uint32x4_t promotes_to_normal_u32x4 = vcgtq_s32(subnormal_mantissa_i32x4, vdupq_n_s32(3));
981
+ subnormal_mantissa_i32x4 = vminq_s32(subnormal_mantissa_i32x4, vdupq_n_s32(3));
982
+ subnormal_mantissa_i32x4 = vmaxq_s32(subnormal_mantissa_i32x4, vdupq_n_s32(0));
983
+ uint32x4_t subnormal_e3m2_u32x4 = vorrq_u32(vshlq_n_u32(sign_u32x4, 5),
984
+ vreinterpretq_u32_s32(subnormal_mantissa_i32x4));
985
+ // When mantissa rounds to 4, use first normal value (0x04) instead of clamped subnormal
986
+ uint32x4_t first_normal_e3m2_u32x4 = vorrq_u32(vshlq_n_u32(sign_u32x4, 5), vdupq_n_u32(0x04));
987
+ subnormal_e3m2_u32x4 = vbslq_u32(promotes_to_normal_u32x4, first_normal_e3m2_u32x4, subnormal_e3m2_u32x4);
988
+
989
+ // Blend: use subnormal result when exp <= 0
990
+ uint32x4_t e3m2_u32x4 = vbslq_u32(is_subnormal_u32x4, subnormal_e3m2_u32x4, normal_e3m2_u32x4);
991
+
992
+ // Pack 4 u32s to 4 u8s
993
+ uint16x4_t e3m2_u16x4 = vmovn_u32(e3m2_u32x4);
994
+ uint8x8_t e3m2_u8x8 = vmovn_u16(vcombine_u16(e3m2_u16x4, e3m2_u16x4));
995
+ nk_b32_vec_t result;
996
+ result.u32 = vget_lane_u32(vreinterpret_u32_u8(e3m2_u8x8), 0);
997
+ return result;
998
+ }
999
+
1000
+ #pragma endregion - Vectorized Conversions
1001
+
1002
+ #pragma region - Public API
1003
+
1004
+ NK_PUBLIC void nk_cast_neon(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type) {
1005
+ // Same-type fast path
1006
+ if (from_type == to_type) {
1007
+ nk_size_t size_bits = nk_dtype_bits(from_type);
1008
+ if (size_bits > 0) nk_copy_bytes_(to, from, nk_size_divide_round_up_(n * size_bits, 8));
1009
+ return;
1010
+ }
1011
+
1012
+ // Validate supported types (f32 and smaller)
1013
+ int from_ok = (from_type == nk_f32_k || from_type == nk_f16_k || from_type == nk_bf16_k || from_type == nk_e4m3_k ||
1014
+ from_type == nk_e5m2_k || from_type == nk_e2m3_k || from_type == nk_e3m2_k || from_type == nk_i8_k ||
1015
+ from_type == nk_u8_k || from_type == nk_i16_k || from_type == nk_u16_k || from_type == nk_i32_k ||
1016
+ from_type == nk_u32_k);
1017
+ int to_ok = (to_type == nk_f32_k || to_type == nk_f16_k || to_type == nk_bf16_k || to_type == nk_e4m3_k ||
1018
+ to_type == nk_e5m2_k || to_type == nk_e2m3_k || to_type == nk_e3m2_k || to_type == nk_i8_k ||
1019
+ to_type == nk_u8_k || to_type == nk_i16_k || to_type == nk_u16_k || to_type == nk_i32_k ||
1020
+ to_type == nk_u32_k);
1021
+
1022
+ // Fall back to serial for unsupported or i32<->u32 (loses precision through f32)
1023
+ if (!from_ok || !to_ok || (from_type == nk_i32_k && to_type == nk_u32_k) ||
1024
+ (from_type == nk_u32_k && to_type == nk_i32_k)) {
1025
+ nk_cast_serial(from, from_type, n, to, to_type);
1026
+ return;
1027
+ }
1028
+
1029
+ // Check if F16 hub is applicable (FP8/F16/BF16 conversions, 8 elements/iter)
1030
+ // Exception: BF16 ↔ F16 skips F16 hub since it needs F32 intermediate anyway
1031
+ int from_f16_hub = (from_type == nk_e4m3_k || from_type == nk_e5m2_k || from_type == nk_e2m3_k ||
1032
+ from_type == nk_e3m2_k || from_type == nk_f16_k || from_type == nk_bf16_k);
1033
+ int to_f16_hub = (to_type == nk_e4m3_k || to_type == nk_e5m2_k || to_type == nk_f16_k || to_type == nk_bf16_k ||
1034
+ to_type == nk_f32_k);
1035
+ int is_bf16_f16 = (from_type == nk_bf16_k && to_type == nk_f16_k) ||
1036
+ (from_type == nk_f16_k && to_type == nk_bf16_k);
1037
+
1038
+ if (from_f16_hub && to_f16_hub && !is_bf16_f16) {
1039
+ // F16 hub: 8 elements per iteration (float16x8_t intermediate)
1040
+ nk_size_t batches = n / 8;
1041
+ nk_size_t from_step = 8 * nk_dtype_bits(from_type) / 8;
1042
+ nk_size_t to_step = 8 * nk_dtype_bits(to_type) / 8;
1043
+ nk_u8_t const *from_ptr = (nk_u8_t const *)from;
1044
+ nk_u8_t *to_ptr = (nk_u8_t *)to;
1045
+
1046
+ for (nk_size_t idx = 0; idx < batches; ++idx, from_ptr += from_step, to_ptr += to_step) {
1047
+ // Upcast to f16x8 hub
1048
+ float16x8_t hub_f16x8;
1049
+ switch (from_type) {
1050
+ case nk_e4m3_k: hub_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(from_ptr)); break;
1051
+ case nk_e5m2_k: hub_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(from_ptr)); break;
1052
+ case nk_e2m3_k: hub_f16x8 = nk_e2m3x8_to_f16x8_neon_(vld1_u8(from_ptr)); break;
1053
+ case nk_e3m2_k: hub_f16x8 = nk_e3m2x8_to_f16x8_neon_(vld1_u8(from_ptr)); break;
1054
+ case nk_f16_k: hub_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)from_ptr)); break;
1055
+ case nk_bf16_k: {
1056
+ uint16x4_t brain_low_u16x4 = vld1_u16((nk_u16_t const *)from_ptr);
1057
+ uint16x4_t brain_high_u16x4 = vld1_u16((nk_u16_t const *)(from_ptr + 8));
1058
+ float32x4_t ieee_low_f32x4 = nk_bf16x4_to_f32x4_neon_(brain_low_u16x4);
1059
+ float32x4_t ieee_high_f32x4 = nk_bf16x4_to_f32x4_neon_(brain_high_u16x4);
1060
+ hub_f16x8 = vcombine_f16(vcvt_f16_f32(ieee_low_f32x4), vcvt_f16_f32(ieee_high_f32x4));
1061
+ } break;
1062
+ default: hub_f16x8 = vreinterpretq_f16_u16(vdupq_n_u16(0)); break;
1063
+ }
1064
+
1065
+ // Downcast from f16x8 hub
1066
+ switch (to_type) {
1067
+ case nk_e4m3_k: vst1_u8(to_ptr, nk_f16x8_to_e4m3x8_neon_(hub_f16x8)); break;
1068
+ case nk_e5m2_k: vst1_u8(to_ptr, nk_f16x8_to_e5m2x8_neon_(hub_f16x8)); break;
1069
+ case nk_f16_k: vst1q_u16((nk_u16_t *)to_ptr, vreinterpretq_u16_f16(hub_f16x8)); break;
1070
+ case nk_bf16_k: {
1071
+ float32x4_t ieee_low_f32x4 = vcvt_f32_f16(vget_low_f16(hub_f16x8));
1072
+ float32x4_t ieee_high_f32x4 = vcvt_f32_f16(vget_high_f16(hub_f16x8));
1073
+ vst1_u16((nk_u16_t *)to_ptr, nk_f32x4_to_bf16x4_neon_(ieee_low_f32x4));
1074
+ vst1_u16((nk_u16_t *)(to_ptr + 8), nk_f32x4_to_bf16x4_neon_(ieee_high_f32x4));
1075
+ } break;
1076
+ case nk_f32_k: {
1077
+ vst1q_f32((nk_f32_t *)to_ptr, vcvt_f32_f16(vget_low_f16(hub_f16x8)));
1078
+ vst1q_f32((nk_f32_t *)(to_ptr + 16), vcvt_f32_f16(vget_high_f16(hub_f16x8)));
1079
+ } break;
1080
+ default: break;
1081
+ }
1082
+ }
1083
+
1084
+ // Handle remaining elements (0-7) with F32 hub or serial
1085
+ n = n % 8;
1086
+ from = from_ptr;
1087
+ to = to_ptr;
1088
+ if (n == 0) return;
1089
+ }
1090
+
1091
+ // F32 hub: 4 elements per iteration (f32x4 intermediate)
1092
+ nk_size_t batches = n / 4;
1093
+ nk_size_t tail = n % 4;
1094
+ nk_size_t from_step = 4 * nk_dtype_bits(from_type) / 8;
1095
+ nk_size_t to_step = 4 * nk_dtype_bits(to_type) / 8;
1096
+ nk_u8_t const *from_ptr = (nk_u8_t const *)from;
1097
+ nk_u8_t *to_ptr = (nk_u8_t *)to;
1098
+
1099
+ for (nk_size_t idx = 0; idx < batches; ++idx, from_ptr += from_step, to_ptr += to_step) {
1100
+ // Load and upcast to f32x4
1101
+ float32x4_t hub_f32x4;
1102
+ switch (from_type) {
1103
+ case nk_f32_k: hub_f32x4 = vld1q_f32((nk_f32_t const *)from_ptr); break;
1104
+ case nk_f16_k: hub_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16((nk_u16_t const *)from_ptr))); break;
1105
+ case nk_bf16_k: hub_f32x4 = nk_bf16x4_to_f32x4_neon_(vld1_u16((nk_u16_t const *)from_ptr)); break;
1106
+ case nk_e4m3_k: {
1107
+ nk_b32_vec_t in_vec;
1108
+ nk_load_b32_serial_(from_ptr, &in_vec);
1109
+ hub_f32x4 = nk_e4m3x4_to_f32x4_neon_(in_vec);
1110
+ } break;
1111
+ case nk_e5m2_k: {
1112
+ nk_b32_vec_t in_vec;
1113
+ nk_load_b32_serial_(from_ptr, &in_vec);
1114
+ hub_f32x4 = nk_e5m2x4_to_f32x4_neon_(in_vec);
1115
+ } break;
1116
+ case nk_e2m3_k: {
1117
+ nk_b32_vec_t in_vec;
1118
+ nk_load_b32_serial_(from_ptr, &in_vec);
1119
+ hub_f32x4 = nk_e2m3x4_to_f32x4_neon_(in_vec);
1120
+ } break;
1121
+ case nk_e3m2_k: {
1122
+ nk_b32_vec_t in_vec;
1123
+ nk_load_b32_serial_(from_ptr, &in_vec);
1124
+ hub_f32x4 = nk_e3m2x4_to_f32x4_neon_(in_vec);
1125
+ } break;
1126
+ case nk_i32_k: hub_f32x4 = vcvtq_f32_s32(vld1q_s32((nk_i32_t const *)from_ptr)); break;
1127
+ case nk_u32_k: hub_f32x4 = vcvtq_f32_u32(vld1q_u32((nk_u32_t const *)from_ptr)); break;
1128
+ case nk_i16_k: hub_f32x4 = nk_i16x4_to_f32x4_neon_(vld1_s16((nk_i16_t const *)from_ptr)); break;
1129
+ case nk_u16_k: hub_f32x4 = nk_u16x4_to_f32x4_neon_(vld1_u16((nk_u16_t const *)from_ptr)); break;
1130
+ case nk_i8_k: {
1131
+ nk_b32_vec_t in_vec;
1132
+ nk_load_b32_serial_(from_ptr, &in_vec);
1133
+ hub_f32x4 = nk_i8x4_to_f32x4_neon_(in_vec);
1134
+ } break;
1135
+ case nk_u8_k: {
1136
+ nk_b32_vec_t in_vec;
1137
+ nk_load_b32_serial_(from_ptr, &in_vec);
1138
+ hub_f32x4 = nk_u8x4_to_f32x4_neon_(in_vec);
1139
+ } break;
1140
+ default: hub_f32x4 = vdupq_n_f32(0); break;
1141
+ }
1142
+
1143
+ // Downcast from f32x4 and store
1144
+ switch (to_type) {
1145
+ case nk_f32_k: vst1q_f32((nk_f32_t *)to_ptr, hub_f32x4); break;
1146
+ case nk_f16_k: vst1_u16((nk_u16_t *)to_ptr, vreinterpret_u16_f16(vcvt_f16_f32(hub_f32x4))); break;
1147
+ case nk_bf16_k: vst1_u16((nk_u16_t *)to_ptr, nk_f32x4_to_bf16x4_neon_(hub_f32x4)); break;
1148
+ case nk_e4m3_k: {
1149
+ nk_b32_vec_t out_vec = nk_f32x4_to_e4m3x4_neon_(hub_f32x4);
1150
+ *(nk_u32_t *)to_ptr = out_vec.u32;
1151
+ } break;
1152
+ case nk_e5m2_k: {
1153
+ nk_b32_vec_t out_vec = nk_f32x4_to_e5m2x4_neon_(hub_f32x4);
1154
+ *(nk_u32_t *)to_ptr = out_vec.u32;
1155
+ } break;
1156
+ case nk_e2m3_k: {
1157
+ nk_b32_vec_t out_vec = nk_f32x4_to_e2m3x4_neon_(hub_f32x4);
1158
+ nk_copy_bytes_(to_ptr, &out_vec, sizeof(nk_b32_vec_t));
1159
+ } break;
1160
+ case nk_e3m2_k: {
1161
+ nk_b32_vec_t out_vec = nk_f32x4_to_e3m2x4_neon_(hub_f32x4);
1162
+ nk_copy_bytes_(to_ptr, &out_vec, sizeof(nk_b32_vec_t));
1163
+ } break;
1164
+ case nk_i32_k: vst1q_s32((nk_i32_t *)to_ptr, vcvtnq_s32_f32(hub_f32x4)); break;
1165
+ case nk_u32_k: vst1q_u32((nk_u32_t *)to_ptr, vcvtnq_u32_f32(hub_f32x4)); break;
1166
+ case nk_i16_k: vst1_s16((nk_i16_t *)to_ptr, nk_f32x4_to_i16x4_neon_(hub_f32x4)); break;
1167
+ case nk_u16_k: vst1_u16((nk_u16_t *)to_ptr, nk_f32x4_to_u16x4_neon_(hub_f32x4)); break;
1168
+ case nk_i8_k: nk_f32x4_to_i8x4_neon_(hub_f32x4, (nk_i8_t *)to_ptr); break;
1169
+ case nk_u8_k: nk_f32x4_to_u8x4_neon_(hub_f32x4, (nk_u8_t *)to_ptr); break;
1170
+ default: break;
1171
+ }
1172
+ }
1173
+
1174
+ // Handle tail elements with serial fallback
1175
+ if (tail) nk_cast_serial(from_ptr, from_type, tail, to_ptr, to_type);
1176
+ }
1177
+
1178
+ #pragma endregion - Public API
1179
+
1180
+ #if defined(__clang__)
1181
+ #pragma clang attribute pop
1182
+ #elif defined(__GNUC__)
1183
+ #pragma GCC pop_options
1184
+ #endif
1185
+
1186
+ #if defined(__cplusplus)
1187
+ } // extern "C"
1188
+ #endif
1189
+
1190
+ #endif // NK_TARGET_NEON
1191
+ #endif // NK_TARGET_ARM_
1192
+ #endif // NK_CAST_NEON_H