numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -13,14 +13,32 @@
13
13
  extern "C" {
14
14
  #endif
15
15
 
16
- #pragma region - Type Punned Loads and Stores
16
+ #pragma region Type Punned Loads and Stores
17
17
 
18
18
  /** @brief Type-agnostic 32-bit full load (scalar). */
19
19
  NK_INTERNAL void nk_load_b32_serial_(void const *src, nk_b32_vec_t *dst) { dst->u32 = *(nk_u32_t const *)src; }
20
20
 
21
+ /** @brief Type-agnostic 64-bit full load. */
22
+ NK_INTERNAL void nk_load_b64_serial_(void const *src, nk_b64_vec_t *dst) { dst->u64 = *(nk_u64_t const *)src; }
23
+
24
+ /** @brief Type-agnostic 128-bit full load. */
25
+ NK_INTERNAL void nk_load_b128_serial_(void const *src, nk_b128_vec_t *dst) {
26
+ nk_u64_t const *s = (nk_u64_t const *)src;
27
+ dst->u64s[0] = s[0], dst->u64s[1] = s[1];
28
+ }
29
+
30
+ /** @brief Type-agnostic 256-bit full load. */
31
+ NK_INTERNAL void nk_load_b256_serial_(void const *src, nk_b256_vec_t *dst) {
32
+ nk_u64_t const *s = (nk_u64_t const *)src;
33
+ dst->u64s[0] = s[0], dst->u64s[1] = s[1], dst->u64s[2] = s[2], dst->u64s[3] = s[3];
34
+ }
35
+
21
36
  /** @brief Type-agnostic 32-bit full store (scalar). */
22
37
  NK_INTERNAL void nk_store_b32_serial_(nk_b32_vec_t const *src, void *dst) { *(nk_u32_t *)dst = src->u32; }
23
38
 
39
+ /** @brief Type-agnostic 64-bit full store (scalar). */
40
+ NK_INTERNAL void nk_store_b64_serial_(nk_b64_vec_t const *src, void *dst) { *(nk_u64_t *)dst = src->u64; }
41
+
24
42
  /** @brief Type-agnostic 128-bit store (serial, word-by-word). */
25
43
  NK_INTERNAL void nk_store_b128_serial_(nk_b128_vec_t const *src, void *dst) {
26
44
  nk_u64_t *d = (nk_u64_t *)dst;
@@ -37,164 +55,681 @@ NK_INTERNAL void nk_store_b256_serial_(nk_b256_vec_t const *src, void *dst) {
37
55
  d[3] = src->u64s[3];
38
56
  }
39
57
 
40
- #pragma endregion - Type Punned Loads and Stores
41
-
42
- /**
43
- * @brief Expands an `f16` (IEEE-754 16-bit) to a `float`.
44
- *
45
- * Handles all IEEE-754 edge cases:
46
- *
47
- * Input F16 Hex F32 Hex Description
48
- * +0 0x0000 0x00000000 Positive zero
49
- * -0 0x8000 0x80000000 Negative zero
50
- * +inf 0x7C00 0x7F800000 Positive infinity
51
- * -inf 0xFC00 0xFF800000 Negative infinity
52
- * NaN 0x7E00 0x7FC00000 Quiet NaN (payload preserved)
53
- * Min normal 0x0400 0x38800000 2⁻¹⁴
54
- * Max normal 0x7BFF 0x477FE000 65504
55
- * Min denorm 0x0001 0x33800000 2⁻²⁴
56
- * Max denorm 0x03FF 0x387FC000 2⁻¹⁴ - 2⁻²⁴
57
- *
58
- * https://stackoverflow.com/a/60047308
59
- * https://gist.github.com/milhidaka/95863906fe828198f47991c813dbe233
60
- * https://github.com/OpenCyphal/libcanard/blob/636795f4bc395f56af8d2c61d3757b5e762bb9e5/canard.c#L811-L834
61
- */
62
- NK_PUBLIC void nk_f16_to_f32_serial(nk_f16_t const *src, nk_f32_t *dest) {
63
- #if NK_NATIVE_F16
64
- *dest = (nk_f32_t)(*src);
65
- #else
66
- unsigned short x;
67
- nk_copy_bytes_(&x, src, 2);
68
-
69
- unsigned int sign = (x >> 15) & 1;
70
- unsigned int exponent = (x >> 10) & 0x1F;
71
- unsigned int mantissa = x & 0x03FF;
72
-
73
- nk_fui32_t conv;
74
-
75
- if (exponent == 0) {
76
- if (mantissa == 0) {
77
- // Zero (preserve sign)
78
- conv.u = sign << 31;
79
- }
80
- else {
81
- // Denormal: value = mantissa × 2⁻²⁴
82
- // Use FPU normalization, then subtract 24 from exponent
83
- nk_fui32_t temp;
84
- temp.f = (float)mantissa;
85
- conv.u = (sign << 31) | (temp.u - 0x0C000000);
86
- }
87
- }
88
- else if (exponent == 31) {
89
- // Infinity (mantissa=0) or NaN (mantissa!=0)
90
- conv.u = (sign << 31) | 0x7F800000 | (mantissa << 13);
91
- }
92
- else {
93
- // Normal: rebias exponent (127-15=112), shift mantissa
94
- conv.u = (sign << 31) | ((exponent + 112) << 23) | (mantissa << 13);
58
+ /** @brief Type-agnostic partial load for 64-bit elements (4 elements max) into 256-bit vector. */
59
+ NK_INTERNAL void nk_partial_load_b64x4_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
60
+ nk_u64_t const *s = (nk_u64_t const *)src;
61
+ dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
62
+ switch (n) {
63
+ default:
64
+ case 4: dst->u64s[3] = s[3]; // fallthrough
65
+ case 3: dst->u64s[2] = s[2]; // fallthrough
66
+ case 2: dst->u64s[1] = s[1]; // fallthrough
67
+ case 1: dst->u64s[0] = s[0]; // fallthrough
68
+ case 0: break;
95
69
  }
96
-
97
- *dest = conv.f;
98
- #endif
99
70
  }
100
71
 
101
- /**
102
- * @brief Compresses a `float` to an `f16` (IEEE-754 16-bit).
103
- *
104
- * Handles all IEEE-754 edge cases with round-to-nearest:
105
- *
106
- * Input F32 Hex F16 Hex Description
107
- * +0 0x00000000 0x0000 Positive zero
108
- * -0 0x80000000 0x8000 Negative zero
109
- * +inf 0x7F800000 0x7C00 Positive infinity
110
- * -inf 0xFF800000 0xFC00 Negative infinity
111
- * NaN 0x7FC00000 0x7E00 Quiet NaN (payload truncated)
112
- * 1.0 0x3F800000 0x3C00 Normal number
113
- * 65504 0x477FE000 0x7BFF Max f16 normal
114
- * 65520+ >0x477FE000 0x7C00 Overflow → infinity
115
- * 2⁻¹⁴ 0x38800000 0x0400 Min f16 normal
116
- * 2⁻²⁴ 0x33800000 0x0001 Min f16 denormal
117
- * <2⁻²⁵ <0x33000000 0x0000 Underflow → zero
118
- *
119
- * https://stackoverflow.com/a/60047308
120
- * https://gist.github.com/milhidaka/95863906fe828198f47991c813dbe233
121
- * https://github.com/OpenCyphal/libcanard/blob/636795f4bc395f56af8d2c61d3757b5e762bb9e5/canard.c#L811-L834
122
- */
123
- NK_PUBLIC void nk_f32_to_f16_serial(nk_f32_t const *src, nk_f16_t *dest) {
124
- #if NK_NATIVE_F16
125
- *dest = (nk_f16_t)(*src);
126
- #else
127
- nk_fui32_t conv;
128
- conv.f = *src;
129
-
130
- unsigned int sign = (conv.u >> 31) & 1;
131
- unsigned int exponent = (conv.u >> 23) & 0xFF;
132
- unsigned int mantissa = conv.u & 0x007FFFFF;
133
-
134
- unsigned short result;
135
-
136
- if (exponent == 0) {
137
- // Zero or f32 denormal → f16 zero
138
- result = (unsigned short)(sign << 15);
139
- }
140
- else if (exponent == 255) {
141
- // Infinity or NaN
142
- unsigned short payload = (unsigned short)(mantissa >> 13);
143
- if (mantissa != 0 && payload == 0) payload = 1; // Preserve NaN-ness
144
- result = (unsigned short)((sign << 15) | 0x7C00 | payload);
145
- }
146
- else if (exponent <= 102) {
147
- // Below or at f16 denormal threshold
148
- // exp=102 with mant=0 is exactly 2^-25 (tie point, rounds to 0 per round-to-even)
149
- // exp=102 with mant>0 is above tie point (rounds to smallest denormal 0x0001)
150
- if (exponent == 102 && mantissa > 0) result = (unsigned short)((sign << 15) | 0x0001);
151
- else result = (unsigned short)(sign << 15);
72
+ /** @brief Type-agnostic partial store for 64-bit elements (4 elements max) from 256-bit vector. */
73
+ NK_INTERNAL void nk_partial_store_b64x4_serial_(nk_b256_vec_t const *src, void *dst, nk_size_t n) {
74
+ nk_u64_t *d = (nk_u64_t *)dst;
75
+ switch (n) {
76
+ default:
77
+ case 4: d[3] = src->u64s[3]; // fallthrough
78
+ case 3: d[2] = src->u64s[2]; // fallthrough
79
+ case 2: d[1] = src->u64s[1]; // fallthrough
80
+ case 1: d[0] = src->u64s[0]; // fallthrough
81
+ case 0: break;
152
82
  }
153
- else if (exponent < 113) {
154
- // F16 denormal range (exp 103-112) with IEEE 754 round-to-nearest-even
155
- unsigned int shift = 113 - exponent;
156
- unsigned int shift_amount = shift + 13;
157
- unsigned long long full_mant = 0x00800000ULL | mantissa;
158
-
159
- // Extract result before rounding
160
- unsigned int mant = (unsigned int)(full_mant >> shift_amount);
161
-
162
- // IEEE 754 round-to-nearest-even: round up if round_bit is set AND
163
- // (sticky_bits are nonzero OR result is odd)
164
- unsigned int round_bit = (full_mant >> (shift_amount - 1)) & 1;
165
- unsigned long long sticky_bits = full_mant & ((1ULL << (shift_amount - 1)) - 1);
166
-
167
- if (round_bit && (sticky_bits || (mant & 1))) mant++;
83
+ }
168
84
 
169
- result = (unsigned short)((sign << 15) | mant);
85
+ NK_INTERNAL void nk_partial_load_b64x2_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
86
+ dst->u64s[0] = 0, dst->u64s[1] = 0;
87
+ nk_u64_t const *s = (nk_u64_t const *)src;
88
+ switch (n) {
89
+ default:
90
+ case 2: dst->u64s[1] = s[1]; // fallthrough
91
+ case 1: dst->u64s[0] = s[0]; // fallthrough
92
+ case 0: break;
170
93
  }
171
- else if (exponent < 143) {
172
- // Normal f16 range with IEEE 754 round-to-nearest-even
173
- unsigned int f16_exp = exponent - 112;
174
- unsigned int f16_mant = mantissa >> 13;
94
+ }
175
95
 
176
- // IEEE 754 rounding: check round bit (bit 12) and sticky bits (bits 0-11)
177
- unsigned int round_bit = (mantissa >> 12) & 1;
178
- unsigned int sticky_bits = mantissa & 0xFFF;
96
+ /** @brief Type-agnostic partial store for 64-bit elements (2 elements max) from 128-bit vector. */
97
+ NK_INTERNAL void nk_partial_store_b64x2_serial_(nk_b128_vec_t const *src, void *dst, nk_size_t n) {
98
+ nk_u64_t *d = (nk_u64_t *)dst;
99
+ switch (n) {
100
+ default:
101
+ case 2: d[1] = src->u64s[1]; // fallthrough
102
+ case 1: d[0] = src->u64s[0]; // fallthrough
103
+ case 0: break;
104
+ }
105
+ }
179
106
 
180
- if (round_bit && (sticky_bits || (f16_mant & 1))) {
181
- f16_mant++;
182
- if (f16_mant > 0x3FF) f16_mant = 0, f16_exp++;
183
- }
107
+ /** @brief Type-agnostic partial load for 32-bit elements (8 elements max) into 256-bit vector. */
108
+ NK_INTERNAL void nk_partial_load_b32x8_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
109
+ dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
110
+ nk_u32_t const *s = (nk_u32_t const *)src;
111
+ switch (n) {
112
+ default:
113
+ case 8: dst->u32s[7] = s[7]; // fallthrough
114
+ case 7: dst->u32s[6] = s[6]; // fallthrough
115
+ case 6: dst->u32s[5] = s[5]; // fallthrough
116
+ case 5: dst->u32s[4] = s[4]; // fallthrough
117
+ case 4: dst->u32s[3] = s[3]; // fallthrough
118
+ case 3: dst->u32s[2] = s[2]; // fallthrough
119
+ case 2: dst->u32s[1] = s[1]; // fallthrough
120
+ case 1: dst->u32s[0] = s[0]; // fallthrough
121
+ case 0: break;
122
+ }
123
+ }
184
124
 
185
- if (f16_exp > 30) result = (unsigned short)((sign << 15) | 0x7C00);
186
- else result = (unsigned short)((sign << 15) | (f16_exp << 10) | f16_mant);
125
+ /** @brief Type-agnostic partial store for 32-bit elements (8 elements max) from 256-bit vector. */
126
+ NK_INTERNAL void nk_partial_store_b32x8_serial_(nk_b256_vec_t const *src, void *dst, nk_size_t n) {
127
+ nk_u32_t *d = (nk_u32_t *)dst;
128
+ switch (n) {
129
+ default:
130
+ case 8: d[7] = src->u32s[7]; // fallthrough
131
+ case 7: d[6] = src->u32s[6]; // fallthrough
132
+ case 6: d[5] = src->u32s[5]; // fallthrough
133
+ case 5: d[4] = src->u32s[4]; // fallthrough
134
+ case 4: d[3] = src->u32s[3]; // fallthrough
135
+ case 3: d[2] = src->u32s[2]; // fallthrough
136
+ case 2: d[1] = src->u32s[1]; // fallthrough
137
+ case 1: d[0] = src->u32s[0]; // fallthrough
138
+ case 0: break;
187
139
  }
188
- else {
189
- // Overflow → infinity
190
- result = (unsigned short)((sign << 15) | 0x7C00);
140
+ }
141
+
142
+ /** @brief Type-agnostic partial load for 32-bit elements (4 elements max) into 128-bit vector. */
143
+ NK_INTERNAL void nk_partial_load_b32x4_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
144
+ dst->u64s[0] = 0, dst->u64s[1] = 0;
145
+ nk_u32_t const *s = (nk_u32_t const *)src;
146
+ switch (n) {
147
+ default:
148
+ case 4: dst->u32s[3] = s[3]; // fallthrough
149
+ case 3: dst->u32s[2] = s[2]; // fallthrough
150
+ case 2: dst->u32s[1] = s[1]; // fallthrough
151
+ case 1: dst->u32s[0] = s[0]; // fallthrough
152
+ case 0: break;
191
153
  }
154
+ }
192
155
 
193
- nk_copy_bytes_(dest, &result, 2);
194
- #endif
156
+ /** @brief Type-agnostic partial store for 32-bit elements (4 elements max) from 128-bit vector. */
157
+ NK_INTERNAL void nk_partial_store_b32x4_serial_(nk_b128_vec_t const *src, void *dst, nk_size_t n) {
158
+ nk_u32_t *d = (nk_u32_t *)dst;
159
+ switch (n) {
160
+ default:
161
+ case 4: d[3] = src->u32s[3]; // fallthrough
162
+ case 3: d[2] = src->u32s[2]; // fallthrough
163
+ case 2: d[1] = src->u32s[1]; // fallthrough
164
+ case 1: d[0] = src->u32s[0]; // fallthrough
165
+ case 0: break;
166
+ }
195
167
  }
196
168
 
197
- /**
169
+ /** @brief Type-agnostic partial load for 32-bit elements (2 elements max) into 64-bit vector. */
170
+ NK_INTERNAL void nk_partial_load_b32x2_serial_(void const *src, nk_b64_vec_t *dst, nk_size_t n) {
171
+ dst->u64 = 0;
172
+ nk_u32_t const *s = (nk_u32_t const *)src;
173
+ switch (n) {
174
+ default:
175
+ case 2: dst->u32s[1] = s[1]; // fallthrough
176
+ case 1: dst->u32s[0] = s[0]; // fallthrough
177
+ case 0: break;
178
+ }
179
+ }
180
+
181
+ /** @brief Type-agnostic partial load for 16-bit elements (8 elements max) into 128-bit vector. */
182
+ NK_INTERNAL void nk_partial_load_b16x8_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
183
+ dst->u64s[0] = 0, dst->u64s[1] = 0;
184
+ nk_u16_t const *s = (nk_u16_t const *)src;
185
+ switch (n) {
186
+ default:
187
+ case 8: dst->u16s[7] = s[7]; // fallthrough
188
+ case 7: dst->u16s[6] = s[6]; // fallthrough
189
+ case 6: dst->u16s[5] = s[5]; // fallthrough
190
+ case 5: dst->u16s[4] = s[4]; // fallthrough
191
+ case 4: dst->u16s[3] = s[3]; // fallthrough
192
+ case 3: dst->u16s[2] = s[2]; // fallthrough
193
+ case 2: dst->u16s[1] = s[1]; // fallthrough
194
+ case 1: dst->u16s[0] = s[0]; // fallthrough
195
+ case 0: break;
196
+ }
197
+ }
198
+
199
+ /** @brief Type-agnostic partial store for 16-bit elements (8 elements max) from 128-bit vector. */
200
+ NK_INTERNAL void nk_partial_store_b16x8_serial_(nk_b128_vec_t const *src, void *dst, nk_size_t n) {
201
+ nk_u16_t *d = (nk_u16_t *)dst;
202
+ switch (n) {
203
+ default:
204
+ case 8: d[7] = src->u16s[7]; // fallthrough
205
+ case 7: d[6] = src->u16s[6]; // fallthrough
206
+ case 6: d[5] = src->u16s[5]; // fallthrough
207
+ case 5: d[4] = src->u16s[4]; // fallthrough
208
+ case 4: d[3] = src->u16s[3]; // fallthrough
209
+ case 3: d[2] = src->u16s[2]; // fallthrough
210
+ case 2: d[1] = src->u16s[1]; // fallthrough
211
+ case 1: d[0] = src->u16s[0]; // fallthrough
212
+ case 0: break;
213
+ }
214
+ }
215
+
216
+ /** @brief Type-agnostic partial load for 16-bit elements (16 elements max) into 256-bit vector. */
217
+ NK_INTERNAL void nk_partial_load_b16x16_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
218
+ dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
219
+ nk_u16_t const *s = (nk_u16_t const *)src;
220
+ switch (n) {
221
+ default:
222
+ case 16: dst->u16s[15] = s[15]; // fallthrough
223
+ case 15: dst->u16s[14] = s[14]; // fallthrough
224
+ case 14: dst->u16s[13] = s[13]; // fallthrough
225
+ case 13: dst->u16s[12] = s[12]; // fallthrough
226
+ case 12: dst->u16s[11] = s[11]; // fallthrough
227
+ case 11: dst->u16s[10] = s[10]; // fallthrough
228
+ case 10: dst->u16s[9] = s[9]; // fallthrough
229
+ case 9: dst->u16s[8] = s[8]; // fallthrough
230
+ case 8: dst->u16s[7] = s[7]; // fallthrough
231
+ case 7: dst->u16s[6] = s[6]; // fallthrough
232
+ case 6: dst->u16s[5] = s[5]; // fallthrough
233
+ case 5: dst->u16s[4] = s[4]; // fallthrough
234
+ case 4: dst->u16s[3] = s[3]; // fallthrough
235
+ case 3: dst->u16s[2] = s[2]; // fallthrough
236
+ case 2: dst->u16s[1] = s[1]; // fallthrough
237
+ case 1: dst->u16s[0] = s[0]; // fallthrough
238
+ case 0: break;
239
+ }
240
+ }
241
+
242
+ /** @brief Type-agnostic partial store for 16-bit elements (16 elements max) from 256-bit vector. */
243
+ NK_INTERNAL void nk_partial_store_b16x16_serial_(nk_b256_vec_t const *src, void *dst, nk_size_t n) {
244
+ nk_u16_t *d = (nk_u16_t *)dst;
245
+ switch (n) {
246
+ default:
247
+ case 16: d[15] = src->u16s[15]; // fallthrough
248
+ case 15: d[14] = src->u16s[14]; // fallthrough
249
+ case 14: d[13] = src->u16s[13]; // fallthrough
250
+ case 13: d[12] = src->u16s[12]; // fallthrough
251
+ case 12: d[11] = src->u16s[11]; // fallthrough
252
+ case 11: d[10] = src->u16s[10]; // fallthrough
253
+ case 10: d[9] = src->u16s[9]; // fallthrough
254
+ case 9: d[8] = src->u16s[8]; // fallthrough
255
+ case 8: d[7] = src->u16s[7]; // fallthrough
256
+ case 7: d[6] = src->u16s[6]; // fallthrough
257
+ case 6: d[5] = src->u16s[5]; // fallthrough
258
+ case 5: d[4] = src->u16s[4]; // fallthrough
259
+ case 4: d[3] = src->u16s[3]; // fallthrough
260
+ case 3: d[2] = src->u16s[2]; // fallthrough
261
+ case 2: d[1] = src->u16s[1]; // fallthrough
262
+ case 1: d[0] = src->u16s[0]; // fallthrough
263
+ case 0: break;
264
+ }
265
+ }
266
+
267
+ /** @brief Type-agnostic partial load for 16-bit elements (4 elements max) into 64-bit vector. */
268
+ NK_INTERNAL void nk_partial_load_b16x4_serial_(void const *src, nk_b64_vec_t *dst, nk_size_t n) {
269
+ dst->u64 = 0;
270
+ nk_u16_t const *s = (nk_u16_t const *)src;
271
+ switch (n) {
272
+ default:
273
+ case 4: dst->u16s[3] = s[3]; // fallthrough
274
+ case 3: dst->u16s[2] = s[2]; // fallthrough
275
+ case 2: dst->u16s[1] = s[1]; // fallthrough
276
+ case 1: dst->u16s[0] = s[0]; // fallthrough
277
+ case 0: break;
278
+ }
279
+ }
280
+
281
+ /** @brief Type-agnostic partial store for 16-bit elements (4 elements max) from 64-bit vector. */
282
+ NK_INTERNAL void nk_partial_store_b16x4_serial_(void *dst, nk_b64_vec_t const *src, nk_size_t n) {
283
+ nk_u16_t *d = (nk_u16_t *)dst;
284
+ switch (n) {
285
+ default:
286
+ case 4: d[3] = src->u16s[3]; // fallthrough
287
+ case 3: d[2] = src->u16s[2]; // fallthrough
288
+ case 2: d[1] = src->u16s[1]; // fallthrough
289
+ case 1: d[0] = src->u16s[0]; // fallthrough
290
+ case 0: break;
291
+ }
292
+ }
293
+
294
+ /** @brief Type-agnostic partial load for 8-bit elements (8 elements max) into 64-bit vector. */
295
+ NK_INTERNAL void nk_partial_load_b8x8_serial_(void const *src, nk_b64_vec_t *dst, nk_size_t n) {
296
+ dst->u64 = 0;
297
+ nk_u8_t const *s = (nk_u8_t const *)src;
298
+ switch (n) {
299
+ default:
300
+ case 8: dst->u8s[7] = s[7]; // fallthrough
301
+ case 7: dst->u8s[6] = s[6]; // fallthrough
302
+ case 6: dst->u8s[5] = s[5]; // fallthrough
303
+ case 5: dst->u8s[4] = s[4]; // fallthrough
304
+ case 4: dst->u8s[3] = s[3]; // fallthrough
305
+ case 3: dst->u8s[2] = s[2]; // fallthrough
306
+ case 2: dst->u8s[1] = s[1]; // fallthrough
307
+ case 1: dst->u8s[0] = s[0]; // fallthrough
308
+ case 0: break;
309
+ }
310
+ }
311
+
312
+ /** @brief Type-agnostic partial store for 8-bit elements (8 elements max) from 64-bit vector. */
313
+ NK_INTERNAL void nk_partial_store_b8x8_serial_(nk_b64_vec_t const *src, void *dst, nk_size_t n) {
314
+ nk_u8_t *d = (nk_u8_t *)dst;
315
+ switch (n) {
316
+ default:
317
+ case 8: d[7] = src->u8s[7]; // fallthrough
318
+ case 7: d[6] = src->u8s[6]; // fallthrough
319
+ case 6: d[5] = src->u8s[5]; // fallthrough
320
+ case 5: d[4] = src->u8s[4]; // fallthrough
321
+ case 4: d[3] = src->u8s[3]; // fallthrough
322
+ case 3: d[2] = src->u8s[2]; // fallthrough
323
+ case 2: d[1] = src->u8s[1]; // fallthrough
324
+ case 1: d[0] = src->u8s[0]; // fallthrough
325
+ case 0: break;
326
+ }
327
+ }
328
+
329
+ /** @brief Type-agnostic partial store for 8-bit elements (16 elements max) from 128-bit vector. */
330
+ NK_INTERNAL void nk_partial_store_b8x16_serial_(nk_b128_vec_t const *src, void *dst, nk_size_t n) {
331
+ nk_u8_t *d = (nk_u8_t *)dst;
332
+ switch (n) {
333
+ default:
334
+ case 16: d[15] = src->u8s[15]; // fallthrough
335
+ case 15: d[14] = src->u8s[14]; // fallthrough
336
+ case 14: d[13] = src->u8s[13]; // fallthrough
337
+ case 13: d[12] = src->u8s[12]; // fallthrough
338
+ case 12: d[11] = src->u8s[11]; // fallthrough
339
+ case 11: d[10] = src->u8s[10]; // fallthrough
340
+ case 10: d[9] = src->u8s[9]; // fallthrough
341
+ case 9: d[8] = src->u8s[8]; // fallthrough
342
+ case 8: d[7] = src->u8s[7]; // fallthrough
343
+ case 7: d[6] = src->u8s[6]; // fallthrough
344
+ case 6: d[5] = src->u8s[5]; // fallthrough
345
+ case 5: d[4] = src->u8s[4]; // fallthrough
346
+ case 4: d[3] = src->u8s[3]; // fallthrough
347
+ case 3: d[2] = src->u8s[2]; // fallthrough
348
+ case 2: d[1] = src->u8s[1]; // fallthrough
349
+ case 1: d[0] = src->u8s[0]; // fallthrough
350
+ case 0: break;
351
+ }
352
+ }
353
+
354
+ /** @brief Type-agnostic partial store for 8-bit elements (32 elements max) from 256-bit vector. */
355
+ NK_INTERNAL void nk_partial_store_b8x32_serial_(nk_b256_vec_t const *src, void *dst, nk_size_t n) {
356
+ nk_u8_t *d = (nk_u8_t *)dst;
357
+ switch (n) {
358
+ default:
359
+ case 32: d[31] = src->u8s[31]; // fallthrough
360
+ case 31: d[30] = src->u8s[30]; // fallthrough
361
+ case 30: d[29] = src->u8s[29]; // fallthrough
362
+ case 29: d[28] = src->u8s[28]; // fallthrough
363
+ case 28: d[27] = src->u8s[27]; // fallthrough
364
+ case 27: d[26] = src->u8s[26]; // fallthrough
365
+ case 26: d[25] = src->u8s[25]; // fallthrough
366
+ case 25: d[24] = src->u8s[24]; // fallthrough
367
+ case 24: d[23] = src->u8s[23]; // fallthrough
368
+ case 23: d[22] = src->u8s[22]; // fallthrough
369
+ case 22: d[21] = src->u8s[21]; // fallthrough
370
+ case 21: d[20] = src->u8s[20]; // fallthrough
371
+ case 20: d[19] = src->u8s[19]; // fallthrough
372
+ case 19: d[18] = src->u8s[18]; // fallthrough
373
+ case 18: d[17] = src->u8s[17]; // fallthrough
374
+ case 17: d[16] = src->u8s[16]; // fallthrough
375
+ case 16: d[15] = src->u8s[15]; // fallthrough
376
+ case 15: d[14] = src->u8s[14]; // fallthrough
377
+ case 14: d[13] = src->u8s[13]; // fallthrough
378
+ case 13: d[12] = src->u8s[12]; // fallthrough
379
+ case 12: d[11] = src->u8s[11]; // fallthrough
380
+ case 11: d[10] = src->u8s[10]; // fallthrough
381
+ case 10: d[9] = src->u8s[9]; // fallthrough
382
+ case 9: d[8] = src->u8s[8]; // fallthrough
383
+ case 8: d[7] = src->u8s[7]; // fallthrough
384
+ case 7: d[6] = src->u8s[6]; // fallthrough
385
+ case 6: d[5] = src->u8s[5]; // fallthrough
386
+ case 5: d[4] = src->u8s[4]; // fallthrough
387
+ case 4: d[3] = src->u8s[3]; // fallthrough
388
+ case 3: d[2] = src->u8s[2]; // fallthrough
389
+ case 2: d[1] = src->u8s[1]; // fallthrough
390
+ case 1: d[0] = src->u8s[0]; // fallthrough
391
+ case 0: break;
392
+ }
393
+ }
394
+
395
+ /** @brief Type-agnostic partial load for 8-bit elements (16 elements max) into 128-bit vector. */
396
+ NK_INTERNAL void nk_partial_load_b8x16_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
397
+ dst->u64s[0] = 0, dst->u64s[1] = 0;
398
+ nk_u8_t const *s = (nk_u8_t const *)src;
399
+ switch (n) {
400
+ default:
401
+ case 16: dst->u8s[15] = s[15]; // fallthrough
402
+ case 15: dst->u8s[14] = s[14]; // fallthrough
403
+ case 14: dst->u8s[13] = s[13]; // fallthrough
404
+ case 13: dst->u8s[12] = s[12]; // fallthrough
405
+ case 12: dst->u8s[11] = s[11]; // fallthrough
406
+ case 11: dst->u8s[10] = s[10]; // fallthrough
407
+ case 10: dst->u8s[9] = s[9]; // fallthrough
408
+ case 9: dst->u8s[8] = s[8]; // fallthrough
409
+ case 8: dst->u8s[7] = s[7]; // fallthrough
410
+ case 7: dst->u8s[6] = s[6]; // fallthrough
411
+ case 6: dst->u8s[5] = s[5]; // fallthrough
412
+ case 5: dst->u8s[4] = s[4]; // fallthrough
413
+ case 4: dst->u8s[3] = s[3]; // fallthrough
414
+ case 3: dst->u8s[2] = s[2]; // fallthrough
415
+ case 2: dst->u8s[1] = s[1]; // fallthrough
416
+ case 1: dst->u8s[0] = s[0]; // fallthrough
417
+ case 0: break;
418
+ }
419
+ }
420
+
421
+ /** @brief Type-agnostic partial load for 8-bit elements (4 elements max) into 32-bit vector. */
422
+ NK_INTERNAL nk_b32_vec_t nk_partial_load_b8x4_serial_(void const *src, nk_size_t n) {
423
+ nk_b32_vec_t dst = {0};
424
+ nk_u8_t const *s = (nk_u8_t const *)src;
425
+ switch (n) {
426
+ default:
427
+ case 4: dst.u8s[3] = s[3]; // fallthrough
428
+ case 3: dst.u8s[2] = s[2]; // fallthrough
429
+ case 2: dst.u8s[1] = s[1]; // fallthrough
430
+ case 1: dst.u8s[0] = s[0]; // fallthrough
431
+ case 0: break;
432
+ }
433
+ return dst;
434
+ }
435
+
436
+ /** @brief Partial store for 8-bit elements (up to 4) from nk_b32_vec_t. */
437
+ NK_INTERNAL void nk_partial_store_b8x4_serial_(nk_b32_vec_t const *src, void *dst, nk_size_t n) {
438
+ nk_u8_t *d = (nk_u8_t *)dst;
439
+ switch (n) {
440
+ default:
441
+ case 4: d[3] = src->u8s[3]; // fallthrough
442
+ case 3: d[2] = src->u8s[2]; // fallthrough
443
+ case 2: d[1] = src->u8s[1]; // fallthrough
444
+ case 1: d[0] = src->u8s[0]; // fallthrough
445
+ case 0: break;
446
+ }
447
+ }
448
+
449
+ /** @brief Partial load for 8-bit elements (32 max) into 256-bit vector (zeros in remaining slots). */
450
+ NK_INTERNAL void nk_partial_load_b8x32_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
451
+ dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
452
+ nk_u8_t const *s = (nk_u8_t const *)src;
453
+ switch (n) {
454
+ default:
455
+ case 32: dst->u8s[31] = s[31]; // fallthrough
456
+ case 31: dst->u8s[30] = s[30]; // fallthrough
457
+ case 30: dst->u8s[29] = s[29]; // fallthrough
458
+ case 29: dst->u8s[28] = s[28]; // fallthrough
459
+ case 28: dst->u8s[27] = s[27]; // fallthrough
460
+ case 27: dst->u8s[26] = s[26]; // fallthrough
461
+ case 26: dst->u8s[25] = s[25]; // fallthrough
462
+ case 25: dst->u8s[24] = s[24]; // fallthrough
463
+ case 24: dst->u8s[23] = s[23]; // fallthrough
464
+ case 23: dst->u8s[22] = s[22]; // fallthrough
465
+ case 22: dst->u8s[21] = s[21]; // fallthrough
466
+ case 21: dst->u8s[20] = s[20]; // fallthrough
467
+ case 20: dst->u8s[19] = s[19]; // fallthrough
468
+ case 19: dst->u8s[18] = s[18]; // fallthrough
469
+ case 18: dst->u8s[17] = s[17]; // fallthrough
470
+ case 17: dst->u8s[16] = s[16]; // fallthrough
471
+ case 16: dst->u8s[15] = s[15]; // fallthrough
472
+ case 15: dst->u8s[14] = s[14]; // fallthrough
473
+ case 14: dst->u8s[13] = s[13]; // fallthrough
474
+ case 13: dst->u8s[12] = s[12]; // fallthrough
475
+ case 12: dst->u8s[11] = s[11]; // fallthrough
476
+ case 11: dst->u8s[10] = s[10]; // fallthrough
477
+ case 10: dst->u8s[9] = s[9]; // fallthrough
478
+ case 9: dst->u8s[8] = s[8]; // fallthrough
479
+ case 8: dst->u8s[7] = s[7]; // fallthrough
480
+ case 7: dst->u8s[6] = s[6]; // fallthrough
481
+ case 6: dst->u8s[5] = s[5]; // fallthrough
482
+ case 5: dst->u8s[4] = s[4]; // fallthrough
483
+ case 4: dst->u8s[3] = s[3]; // fallthrough
484
+ case 3: dst->u8s[2] = s[2]; // fallthrough
485
+ case 2: dst->u8s[1] = s[1]; // fallthrough
486
+ case 1: dst->u8s[0] = s[0]; // fallthrough
487
+ case 0: break;
488
+ }
489
+ }
490
+
491
+ /** @brief Partial load for 4-bit nibbles (64 max = 32 bytes) into 256-bit vector (zeros in remaining slots). */
492
+ NK_INTERNAL void nk_partial_load_b4x64_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
493
+ dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
494
+ nk_u8_t const *s = (nk_u8_t const *)src;
495
+ nk_size_t n_bytes = nk_size_divide_round_up_(n, 2);
496
+ for (nk_size_t i = 0; i < n_bytes && i < 32; i++) dst->u8s[i] = s[i];
497
+ }
498
+
499
+ /** @brief Partial load for 4-bit nibbles (32 max = 16 bytes) into 128-bit vector (zeros in remaining slots). */
500
+ NK_INTERNAL void nk_partial_load_b4x32_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
501
+ dst->u64s[0] = 0, dst->u64s[1] = 0;
502
+ nk_u8_t const *s = (nk_u8_t const *)src;
503
+ nk_size_t n_bytes = nk_size_divide_round_up_(n, 2);
504
+ for (nk_size_t i = 0; i < n_bytes && i < 16; i++) dst->u8s[i] = s[i];
505
+ }
506
+
507
+ /** @brief Partial load for 1-bit elements (128 max = 16 bytes) into 128-bit vector (zeros in remaining slots). */
508
+ NK_INTERNAL void nk_partial_load_b1x128_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n_bits) {
509
+ dst->u64s[0] = 0, dst->u64s[1] = 0;
510
+ nk_u8_t const *s = (nk_u8_t const *)src;
511
+ nk_size_t n_bytes = nk_size_divide_round_up_(n_bits, 8);
512
+ for (nk_size_t i = 0; i < n_bytes && i < 16; i++) dst->u8s[i] = s[i];
513
+ }
514
+
515
+ /** @brief Partial load for binary (u1) data into 256-bit vector, converting n_bits → n_bytes. */
516
+ NK_INTERNAL void nk_partial_load_b1x256_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n_bits) {
517
+ nk_size_t n_bytes = nk_size_divide_round_up_(n_bits, 8);
518
+ nk_partial_load_b8x32_serial_(src, dst, n_bytes);
519
+ }
520
+
521
+ /** @brief Partial load for 4-bit nibbles (16 max = 8 bytes) into 64-bit vector (zeros in remaining slots). */
522
+ NK_INTERNAL void nk_partial_load_b4x16_serial_(void const *src, nk_b64_vec_t *dst, nk_size_t n) {
523
+ dst->u64 = 0;
524
+ nk_u8_t const *s = (nk_u8_t const *)src;
525
+ nk_size_t n_bytes = nk_size_divide_round_up_(n, 2);
526
+ for (nk_size_t i = 0; i < n_bytes && i < 8; i++) ((nk_u8_t *)&dst->u64)[i] = s[i];
527
+ }
528
+
529
+ /** @brief Strided partial load for 32-bit elements (4 max) into 128-bit vector. */
530
+ NK_INTERNAL void nk_strided_load_b32x4_serial_(void const *src, nk_size_t stride_elements, nk_b128_vec_t *dst,
531
+ nk_size_t n) {
532
+ dst->u64s[0] = 0, dst->u64s[1] = 0;
533
+ nk_u32_t const *s = (nk_u32_t const *)src;
534
+ for (nk_size_t i = 0; i < n && i < 4; ++i) dst->u32s[i] = s[i * stride_elements];
535
+ }
536
+
537
+ /** @brief Strided partial load for 16-bit elements (8 max) into 128-bit vector. */
538
+ NK_INTERNAL void nk_strided_load_b16x8_serial_(void const *src, nk_size_t stride_elements, nk_b128_vec_t *dst,
539
+ nk_size_t n) {
540
+ dst->u64s[0] = 0, dst->u64s[1] = 0;
541
+ nk_u16_t const *s = (nk_u16_t const *)src;
542
+ for (nk_size_t i = 0; i < n && i < 8; ++i) dst->u16s[i] = s[i * stride_elements];
543
+ }
544
+
545
+ /** @brief Strided partial load for 8-bit elements (16 max) into 128-bit vector. */
546
+ NK_INTERNAL void nk_strided_load_b8x16_serial_(void const *src, nk_size_t stride_elements, nk_b128_vec_t *dst,
547
+ nk_size_t n) {
548
+ dst->u64s[0] = 0, dst->u64s[1] = 0;
549
+ nk_u8_t const *s = (nk_u8_t const *)src;
550
+ for (nk_size_t i = 0; i < n && i < 16; ++i) dst->u8s[i] = s[i * stride_elements];
551
+ }
552
+
553
+ #pragma endregion Type Punned Loads and Stores
554
+
555
+ /**
556
+ * @brief Expands an `f16` (IEEE-754 16-bit) to a `float`.
557
+ *
558
+ * Handles all IEEE-754 edge cases:
559
+ *
560
+ * Input F16 Hex F32 Hex Description
561
+ * +0 0x0000 0x00000000 Positive zero
562
+ * -0 0x8000 0x80000000 Negative zero
563
+ * +inf 0x7C00 0x7F800000 Positive infinity
564
+ * -inf 0xFC00 0xFF800000 Negative infinity
565
+ * NaN 0x7E00 0x7FC00000 Quiet NaN (payload preserved)
566
+ * Min normal 0x0400 0x38800000 2⁻¹⁴
567
+ * Max normal 0x7BFF 0x477FE000 65504
568
+ * Min denorm 0x0001 0x33800000 2⁻²⁴
569
+ * Max denorm 0x03FF 0x387FC000 2⁻¹⁴ - 2⁻²⁴
570
+ *
571
+ * https://stackoverflow.com/a/60047308
572
+ * https://gist.github.com/milhidaka/95863906fe828198f47991c813dbe233
573
+ * https://github.com/OpenCyphal/libcanard/blob/636795f4bc395f56af8d2c61d3757b5e762bb9e5/canard.c#L811-L834
574
+ */
575
+ NK_PUBLIC void nk_f16_to_f32_serial(nk_f16_t const *src, nk_f32_t *dest) {
576
+ #if NK_NATIVE_F16
577
+ *dest = (nk_f32_t)(*src);
578
+ #else
579
+ unsigned short x;
580
+ nk_copy_bytes_(&x, src, 2);
581
+
582
+ unsigned int sign = (x >> 15) & 1;
583
+ unsigned int exponent = (x >> 10) & 0x1F;
584
+ unsigned int mantissa = x & 0x03FF;
585
+
586
+ nk_fui32_t conv;
587
+
588
+ if (exponent == 0) {
589
+ if (mantissa == 0) {
590
+ // Zero (preserve sign)
591
+ conv.u = sign << 31;
592
+ }
593
+ else {
594
+ // Denormal: value = mantissa × 2⁻²⁴
595
+ // Use FPU normalization, then subtract 24 from exponent
596
+ nk_fui32_t temp;
597
+ temp.f = (float)mantissa;
598
+ conv.u = (sign << 31) | (temp.u - 0x0C000000);
599
+ }
600
+ }
601
+ else if (exponent == 31) {
602
+ // Infinity (mantissa=0) or NaN (mantissa!=0)
603
+ conv.u = (sign << 31) | 0x7F800000 | (mantissa << 13);
604
+ }
605
+ else {
606
+ // Normal: rebias exponent (127-15=112), shift mantissa
607
+ conv.u = (sign << 31) | ((exponent + 112) << 23) | (mantissa << 13);
608
+ }
609
+
610
+ *dest = conv.f;
611
+ #endif
612
+ }
613
+
614
+ /** @brief Load 4 × f16 from memory and upcast them to 4 × f32. */
615
+ NK_INTERNAL void nk_load_f16x4_to_f32x4_serial_(void const *src, nk_b128_vec_t *dst) {
616
+ nk_f16_t const *scalars = (nk_f16_t const *)src;
617
+ nk_f16_to_f32_serial(scalars + 0, dst->f32s + 0);
618
+ nk_f16_to_f32_serial(scalars + 1, dst->f32s + 1);
619
+ nk_f16_to_f32_serial(scalars + 2, dst->f32s + 2);
620
+ nk_f16_to_f32_serial(scalars + 3, dst->f32s + 3);
621
+ }
622
+
623
+ /** @brief Partial load for up to 4 × f16 with upcast to 4 × f32. */
624
+ NK_INTERNAL void nk_partial_load_f16x4_to_f32x4_serial_(nk_f16_t const *src, nk_b128_vec_t *dst, nk_size_t n) {
625
+ dst->u64s[0] = 0, dst->u64s[1] = 0;
626
+ switch (n) {
627
+ default:
628
+ case 4: nk_f16_to_f32_serial(src + 3, dst->f32s + 3); // fallthrough
629
+ case 3: nk_f16_to_f32_serial(src + 2, dst->f32s + 2); // fallthrough
630
+ case 2: nk_f16_to_f32_serial(src + 1, dst->f32s + 1); // fallthrough
631
+ case 1: nk_f16_to_f32_serial(src + 0, dst->f32s + 0); // fallthrough
632
+ case 0: break;
633
+ }
634
+ }
635
+
636
+ /**
637
+ * @brief Compresses a `float` to an `f16` (IEEE-754 16-bit).
638
+ *
639
+ * Handles all IEEE-754 edge cases with round-to-nearest:
640
+ *
641
+ * Input F32 Hex F16 Hex Description
642
+ * +0 0x00000000 0x0000 Positive zero
643
+ * -0 0x80000000 0x8000 Negative zero
644
+ * +inf 0x7F800000 0x7C00 Positive infinity
645
+ * -inf 0xFF800000 0xFC00 Negative infinity
646
+ * NaN 0x7FC00000 0x7E00 Quiet NaN (payload truncated)
647
+ * 1.0 0x3F800000 0x3C00 Normal number
648
+ * 65504 0x477FE000 0x7BFF Max f16 normal
649
+ * 65520+ >0x477FE000 0x7C00 Overflow → infinity
650
+ * 2⁻¹⁴ 0x38800000 0x0400 Min f16 normal
651
+ * 2⁻²⁴ 0x33800000 0x0001 Min f16 denormal
652
+ * <2⁻²⁵ <0x33000000 0x0000 Underflow → zero
653
+ *
654
+ * https://stackoverflow.com/a/60047308
655
+ * https://gist.github.com/milhidaka/95863906fe828198f47991c813dbe233
656
+ * https://github.com/OpenCyphal/libcanard/blob/636795f4bc395f56af8d2c61d3757b5e762bb9e5/canard.c#L811-L834
657
+ */
658
+ NK_PUBLIC void nk_f32_to_f16_serial(nk_f32_t const *src, nk_f16_t *dest) {
659
+ #if NK_NATIVE_F16
660
+ *dest = (nk_f16_t)(*src);
661
+ #else
662
+ nk_fui32_t conv;
663
+ conv.f = *src;
664
+
665
+ unsigned int sign = (conv.u >> 31) & 1;
666
+ unsigned int exponent = (conv.u >> 23) & 0xFF;
667
+ unsigned int mantissa = conv.u & 0x007FFFFF;
668
+
669
+ unsigned short result;
670
+
671
+ if (exponent == 0) {
672
+ // Zero or f32 denormal → f16 zero
673
+ result = (unsigned short)(sign << 15);
674
+ }
675
+ else if (exponent == 255) {
676
+ // Infinity or NaN
677
+ unsigned short payload = (unsigned short)(mantissa >> 13);
678
+ if (mantissa != 0 && payload == 0) payload = 1; // Preserve NaN-ness
679
+ result = (unsigned short)((sign << 15) | 0x7C00 | payload);
680
+ }
681
+ else if (exponent <= 102) {
682
+ // Below or at f16 denormal threshold
683
+ // exp=102 with mant=0 is exactly 2^-25 (tie point, rounds to 0 per round-to-even)
684
+ // exp=102 with mant>0 is above tie point (rounds to smallest denormal 0x0001)
685
+ if (exponent == 102 && mantissa > 0) result = (unsigned short)((sign << 15) | 0x0001);
686
+ else result = (unsigned short)(sign << 15);
687
+ }
688
+ else if (exponent < 113) {
689
+ // F16 denormal range (exp 103-112) with IEEE 754 round-to-nearest-even
690
+ unsigned int shift = 113 - exponent;
691
+ unsigned int shift_amount = shift + 13;
692
+ unsigned long long full_mant = 0x00800000ULL | mantissa;
693
+
694
+ // Extract result before rounding
695
+ unsigned int mant = (unsigned int)(full_mant >> shift_amount);
696
+
697
+ // IEEE 754 round-to-nearest-even: round up if round_bit is set AND
698
+ // (sticky_bits are nonzero OR result is odd)
699
+ unsigned int round_bit = (full_mant >> (shift_amount - 1)) & 1;
700
+ unsigned long long sticky_bits = full_mant & ((1ULL << (shift_amount - 1)) - 1);
701
+
702
+ if (round_bit && (sticky_bits || (mant & 1))) mant++;
703
+
704
+ result = (unsigned short)((sign << 15) | mant);
705
+ }
706
+ else if (exponent < 143) {
707
+ // Normal f16 range with IEEE 754 round-to-nearest-even
708
+ unsigned int f16_exp = exponent - 112;
709
+ unsigned int f16_mant = mantissa >> 13;
710
+
711
+ // IEEE 754 rounding: check round bit (bit 12) and sticky bits (bits 0-11)
712
+ unsigned int round_bit = (mantissa >> 12) & 1;
713
+ unsigned int sticky_bits = mantissa & 0xFFF;
714
+
715
+ if (round_bit && (sticky_bits || (f16_mant & 1))) {
716
+ f16_mant++;
717
+ if (f16_mant > 0x3FF) f16_mant = 0, f16_exp++;
718
+ }
719
+
720
+ if (f16_exp > 30) result = (unsigned short)((sign << 15) | 0x7C00);
721
+ else result = (unsigned short)((sign << 15) | (f16_exp << 10) | f16_mant);
722
+ }
723
+ else {
724
+ // Overflow → infinity
725
+ result = (unsigned short)((sign << 15) | 0x7C00);
726
+ }
727
+
728
+ nk_copy_bytes_(dest, &result, 2);
729
+ #endif
730
+ }
731
+
732
+ /**
198
733
  * @brief For compilers that don't natively support the `__bf16` type,
199
734
  * upcasts contents into a more conventional `float`.
200
735
  *
@@ -309,8 +844,8 @@ NK_PUBLIC void nk_e4m3_to_f32_serial(nk_e4m3_t const *src, nk_f32_t *dest) {
309
844
  * NaN 0x7FC00000 0x7F Quiet NaN
310
845
  * 1.0 0x3F800000 0x38 Normal (exp=7, mant=0)
311
846
  * 448+ >0x43E00000 0x7E Overflow → max
312
- * 2⁻⁶ 0x3E800000 0x08 Min normal
313
- * <2⁻¹² × ⁵ <0x39800000 0x00 Underflow → zero (RNE boundary)
847
+ * 2⁻⁶ 0x3E800000 0x08 Min normal
848
+ * 2⁻¹⁰ ≤0x3A800000 0x00 Underflow → zero (RNE boundary)
314
849
  *
315
850
  * References:
316
851
  * https://arxiv.org/pdf/2209.05433 (NVIDIA/Intel/Arm FP8 paper)
@@ -552,8 +1087,8 @@ NK_PUBLIC void nk_e5m2_to_f32_serial(nk_e5m2_t const *src, nk_f32_t *dest) {
552
1087
  * NaN 0x7FC00000 0x7D Quiet NaN
553
1088
  * 1.0 0x3F800000 0x3C Normal (exp=15, mant=0)
554
1089
  * 57344+ >0x47600000 0x7C Overflow → infinity
555
- * 2⁻¹⁴ 0x38800000 0x04 Min normal
556
- * <2⁻¹⁷ × ⁵ <0x36800000 0x00 Underflow → zero (RNE boundary)
1090
+ * 2⁻¹⁴ 0x38800000 0x04 Min normal
1091
+ * 2⁻¹⁷ ≤0x37000000 0x00 Underflow → zero (RNE boundary)
557
1092
  *
558
1093
  * References:
559
1094
  * https://arxiv.org/pdf/2209.05433 (NVIDIA/Intel/Arm FP8 paper)
@@ -1050,565 +1585,156 @@ NK_INTERNAL nk_u64_t nk_rint_even_f64_to_u64_serial_(nk_f64_t x) {
1050
1585
  }
1051
1586
 
1052
1587
  NK_INTERNAL void nk_f32_to_i8_serial(nk_f32_t const *x, nk_i8_t *y) {
1053
- if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1054
- else *y = (nk_i8_t)nk_rint_even_f64_to_i64_serial_(*x > 127.0f ? 127.0 : (*x < -128.0f ? -128.0 : (nk_f64_t)*x));
1055
- }
1056
-
1057
- NK_INTERNAL void nk_f32_to_u8_serial(nk_f32_t const *x, nk_u8_t *y) {
1058
- if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1059
- else *y = (nk_u8_t)nk_rint_even_f64_to_u64_serial_(*x > 255.0f ? 255.0 : (*x < 0 ? 0.0 : (nk_f64_t)*x));
1060
- }
1061
-
1062
- NK_INTERNAL void nk_f32_to_i16_serial(nk_f32_t const *x, nk_i16_t *y) {
1063
- if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1064
- else
1065
- *y = (nk_i16_t)nk_rint_even_f64_to_i64_serial_(*x > 32767.0f ? 32767.0
1066
- : (*x < -32768.0f ? -32768.0 : (nk_f64_t)*x));
1067
- }
1068
-
1069
- NK_INTERNAL void nk_f32_to_u16_serial(nk_f32_t const *x, nk_u16_t *y) {
1070
- if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1071
- else *y = (nk_u16_t)nk_rint_even_f64_to_u64_serial_(*x > 65535.0f ? 65535.0 : (*x < 0 ? 0.0 : (nk_f64_t)*x));
1072
- }
1073
-
1074
- NK_INTERNAL void nk_f64_to_i8_serial(nk_f64_t const *x, nk_i8_t *y) {
1075
- if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1076
- else *y = (nk_i8_t)nk_rint_even_f64_to_i64_serial_(*x > 127.0 ? 127.0 : (*x < -128.0 ? -128.0 : *x));
1077
- }
1078
-
1079
- NK_INTERNAL void nk_f64_to_u8_serial(nk_f64_t const *x, nk_u8_t *y) {
1080
- if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1081
- else *y = (nk_u8_t)nk_rint_even_f64_to_u64_serial_(*x > 255.0 ? 255.0 : (*x < 0 ? 0.0 : *x));
1082
- }
1083
-
1084
- NK_INTERNAL void nk_f64_to_i16_serial(nk_f64_t const *x, nk_i16_t *y) {
1085
- if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1086
- else *y = (nk_i16_t)nk_rint_even_f64_to_i64_serial_(*x > 32767.0 ? 32767.0 : (*x < -32768.0 ? -32768.0 : *x));
1087
- }
1088
-
1089
- NK_INTERNAL void nk_f64_to_u16_serial(nk_f64_t const *x, nk_u16_t *y) {
1090
- if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1091
- else *y = (nk_u16_t)nk_rint_even_f64_to_u64_serial_(*x > 65535.0 ? 65535.0 : (*x < 0 ? 0.0 : *x));
1092
- }
1093
-
1094
- NK_INTERNAL void nk_f64_to_i32_serial(nk_f64_t const *x, nk_i32_t *y) {
1095
- if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1096
- else
1097
- *y = (nk_i32_t)nk_rint_even_f64_to_i64_serial_(*x > 2147483647.0 ? 2147483647.0
1098
- : (*x < -2147483648.0 ? -2147483648.0 : *x));
1099
- }
1100
-
1101
- NK_INTERNAL void nk_f64_to_u32_serial(nk_f64_t const *x, nk_u32_t *y) {
1102
- if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1103
- else *y = (nk_u32_t)nk_rint_even_f64_to_u64_serial_(*x > 4294967295.0 ? 4294967295.0 : (*x < 0 ? 0.0 : *x));
1104
- }
1105
-
1106
- NK_INTERNAL void nk_f64_to_i64_serial(nk_f64_t const *x, nk_i64_t *y) {
1107
- if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1108
- else
1109
- *y = nk_rint_even_f64_to_i64_serial_(*x > 9223372036854775807.0
1110
- ? 9223372036854775807.0
1111
- : (*x < -9223372036854775808.0 ? -9223372036854775808.0 : *x));
1112
- }
1113
-
1114
- NK_INTERNAL void nk_f64_to_u64_serial(nk_f64_t const *x, nk_u64_t *y) {
1115
- if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1116
- else
1117
- *y = nk_rint_even_f64_to_u64_serial_(*x > 18446744073709551615.0 ? 18446744073709551615.0
1118
- : (*x < 0 ? 0.0 : *x));
1119
- }
1120
-
1121
- NK_INTERNAL void nk_i64_to_i8_serial(nk_i64_t const *x, nk_i8_t *y) {
1122
- *y = (nk_i8_t)(*x > 127ll ? 127ll : (*x < -128ll ? -128ll : *x));
1123
- }
1124
-
1125
- NK_INTERNAL void nk_i64_to_u8_serial(nk_i64_t const *x, nk_u8_t *y) {
1126
- *y = (nk_u8_t)(*x > 255ll ? 255ll : (*x < 0ll ? 0ll : *x));
1127
- }
1128
-
1129
- NK_INTERNAL void nk_i64_to_i16_serial(nk_i64_t const *x, nk_i16_t *y) {
1130
- *y = (nk_i16_t)(*x > 32767ll ? 32767ll : (*x < -32768ll ? -32768ll : *x));
1131
- }
1132
-
1133
- NK_INTERNAL void nk_i64_to_u16_serial(nk_i64_t const *x, nk_u16_t *y) {
1134
- *y = (nk_u16_t)(*x > 65535ll ? 65535ll : (*x < 0ll ? 0ll : *x));
1135
- }
1136
-
1137
- NK_INTERNAL void nk_i64_to_i32_serial(nk_i64_t const *x, nk_i32_t *y) {
1138
- *y = (nk_i32_t)(*x > 2147483647ll ? 2147483647ll : (*x < -2147483648ll ? -2147483648ll : *x));
1139
- }
1140
-
1141
- NK_INTERNAL void nk_i64_to_u32_serial(nk_i64_t const *x, nk_u32_t *y) {
1142
- *y = (nk_u32_t)(*x > 4294967295ll ? 4294967295ll : (*x < 0ll ? 0ll : *x));
1143
- }
1144
-
1145
- NK_INTERNAL void nk_u64_to_i8_serial(nk_u64_t const *x, nk_i8_t *y) { *y = (nk_i8_t)(*x > 127ull ? 127ull : *x); }
1146
- NK_INTERNAL void nk_u64_to_u8_serial(nk_u64_t const *x, nk_u8_t *y) { *y = (nk_u8_t)(*x > 255ull ? 255ull : *x); }
1147
- NK_INTERNAL void nk_u64_to_i16_serial(nk_u64_t const *x, nk_i16_t *y) {
1148
- *y = (nk_i16_t)(*x > 32767ull ? 32767ull : *x);
1149
- }
1150
- NK_INTERNAL void nk_u64_to_u16_serial(nk_u64_t const *x, nk_u16_t *y) {
1151
- *y = (nk_u16_t)(*x > 65535ull ? 65535ull : *x);
1152
- }
1153
-
1154
- NK_INTERNAL void nk_u64_to_i32_serial(nk_u64_t const *x, nk_i32_t *y) {
1155
- *y = (nk_i32_t)(*x > 2147483647ull ? 2147483647ull : *x);
1156
- }
1157
-
1158
- NK_INTERNAL void nk_u64_to_u32_serial(nk_u64_t const *x, nk_u32_t *y) {
1159
- *y = (nk_u32_t)(*x > 4294967295ull ? 4294967295ull : *x);
1160
- }
1161
-
1162
- NK_PUBLIC void nk_f16_to_f64_(nk_f16_t const *src, nk_f64_t *dest) {
1163
- nk_f32_t f32;
1164
- nk_f16_to_f32_serial(src, &f32);
1165
- *dest = f32;
1166
- }
1167
- NK_PUBLIC void nk_bf16_to_f64_(nk_bf16_t const *src, nk_f64_t *dest) {
1168
- nk_f32_t f32;
1169
- nk_bf16_to_f32_serial(src, &f32);
1170
- *dest = f32;
1171
- }
1172
-
1173
- NK_INTERNAL void nk_u64_to_i64_serial(nk_u64_t const *x, nk_i64_t *y) {
1174
- *y = (nk_i64_t)(*x >= 9223372036854775807ull ? 9223372036854775807ll : *x);
1175
- }
1176
-
1177
- NK_INTERNAL void nk_i8_to_u64_serial(nk_i8_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
1178
- NK_INTERNAL void nk_i16_to_u64_serial(nk_i16_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
1179
- NK_INTERNAL void nk_i32_to_u64_serial(nk_i32_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
1180
- NK_INTERNAL void nk_i64_to_u64_serial(nk_i64_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
1181
-
1182
- NK_INTERNAL void nk_i64_to_f16_serial(nk_i64_t const *x, nk_f16_t *y) {
1183
- nk_f32_t f32 = (nk_f32_t)*x;
1184
- nk_f32_to_f16_serial(&f32, y);
1185
- }
1186
- NK_INTERNAL void nk_i64_to_bf16_serial(nk_i64_t const *x, nk_bf16_t *y) {
1187
- nk_f32_t f32 = (nk_f32_t)*x;
1188
- nk_f32_to_bf16_serial(&f32, y);
1189
- }
1190
- NK_INTERNAL void nk_u64_to_f16_serial(nk_u64_t const *x, nk_f16_t *y) {
1191
- nk_f32_t f32 = (nk_f32_t)*x;
1192
- nk_f32_to_f16_serial(&f32, y);
1193
- }
1194
- NK_INTERNAL void nk_u64_to_bf16_serial(nk_u64_t const *x, nk_bf16_t *y) {
1195
- nk_f32_t f32 = (nk_f32_t)*x;
1196
- nk_f32_to_bf16_serial(&f32, y);
1197
- }
1198
-
1199
- #pragma region - Type Punned Loads and Stores
1200
-
1201
- /** @brief Type-agnostic 256-bit full load. */
1202
- NK_INTERNAL void nk_load_b256_serial_(void const *src, nk_b256_vec_t *dst) {
1203
- nk_u64_t const *s = (nk_u64_t const *)src;
1204
- dst->u64s[0] = s[0], dst->u64s[1] = s[1], dst->u64s[2] = s[2], dst->u64s[3] = s[3];
1205
- }
1206
-
1207
- /** @brief Type-agnostic 128-bit full load. */
1208
- NK_INTERNAL void nk_load_b128_serial_(void const *src, nk_b128_vec_t *dst) {
1209
- nk_u64_t const *s = (nk_u64_t const *)src;
1210
- dst->u64s[0] = s[0], dst->u64s[1] = s[1];
1211
- }
1212
-
1213
- /** @brief Type-agnostic 64-bit full load. */
1214
- NK_INTERNAL void nk_load_b64_serial_(void const *src, nk_b64_vec_t *dst) { dst->u64 = *(nk_u64_t const *)src; }
1215
-
1216
- /** @brief Type-agnostic partial load for 32-bit elements (8 elements max) into 256-bit vector. */
1217
- NK_INTERNAL void nk_partial_load_b32x8_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
1218
- dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
1219
- nk_u32_t const *s = (nk_u32_t const *)src;
1220
- switch (n) {
1221
- default:
1222
- case 8: dst->u32s[7] = s[7]; // fallthrough
1223
- case 7: dst->u32s[6] = s[6]; // fallthrough
1224
- case 6: dst->u32s[5] = s[5]; // fallthrough
1225
- case 5: dst->u32s[4] = s[4]; // fallthrough
1226
- case 4: dst->u32s[3] = s[3]; // fallthrough
1227
- case 3: dst->u32s[2] = s[2]; // fallthrough
1228
- case 2: dst->u32s[1] = s[1]; // fallthrough
1229
- case 1: dst->u32s[0] = s[0]; // fallthrough
1230
- case 0: break;
1231
- }
1232
- }
1233
-
1234
- /** @brief Type-agnostic partial load for 32-bit elements (4 elements max) into 128-bit vector. */
1235
- NK_INTERNAL void nk_partial_load_b32x4_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
1236
- dst->u64s[0] = 0, dst->u64s[1] = 0;
1237
- nk_u32_t const *s = (nk_u32_t const *)src;
1238
- switch (n) {
1239
- default:
1240
- case 4: dst->u32s[3] = s[3]; // fallthrough
1241
- case 3: dst->u32s[2] = s[2]; // fallthrough
1242
- case 2: dst->u32s[1] = s[1]; // fallthrough
1243
- case 1: dst->u32s[0] = s[0]; // fallthrough
1244
- case 0: break;
1245
- }
1246
- }
1247
-
1248
- /** @brief Type-agnostic partial load for 8-bit elements (8 elements max) into 64-bit vector. */
1249
- NK_INTERNAL void nk_partial_load_b8x8_serial_(void const *src, nk_b64_vec_t *dst, nk_size_t n) {
1250
- dst->u64 = 0;
1251
- nk_u8_t const *s = (nk_u8_t const *)src;
1252
- switch (n) {
1253
- default:
1254
- case 8: dst->u8s[7] = s[7]; // fallthrough
1255
- case 7: dst->u8s[6] = s[6]; // fallthrough
1256
- case 6: dst->u8s[5] = s[5]; // fallthrough
1257
- case 5: dst->u8s[4] = s[4]; // fallthrough
1258
- case 4: dst->u8s[3] = s[3]; // fallthrough
1259
- case 3: dst->u8s[2] = s[2]; // fallthrough
1260
- case 2: dst->u8s[1] = s[1]; // fallthrough
1261
- case 1: dst->u8s[0] = s[0]; // fallthrough
1262
- case 0: break;
1263
- }
1264
- }
1265
-
1266
- /** @brief Type-agnostic partial load for 8-bit elements (4 elements max) into 32-bit vector. */
1267
- NK_INTERNAL nk_b32_vec_t nk_partial_load_b8x4_serial_(void const *src, nk_size_t n) {
1268
- nk_b32_vec_t dst = {0};
1269
- nk_u8_t const *s = (nk_u8_t const *)src;
1270
- switch (n) {
1271
- default:
1272
- case 4: dst.u8s[3] = s[3]; // fallthrough
1273
- case 3: dst.u8s[2] = s[2]; // fallthrough
1274
- case 2: dst.u8s[1] = s[1]; // fallthrough
1275
- case 1: dst.u8s[0] = s[0]; // fallthrough
1276
- case 0: break;
1277
- }
1278
- return dst;
1279
- }
1280
-
1281
- /** @brief Partial store for 8-bit elements (up to 4) from nk_b32_vec_t. */
1282
- NK_INTERNAL void nk_partial_store_b8x4_serial_(nk_b32_vec_t const *src, void *dst, nk_size_t n) {
1283
- nk_u8_t *d = (nk_u8_t *)dst;
1284
- switch (n) {
1285
- default:
1286
- case 4: d[3] = src->u8s[3]; // fallthrough
1287
- case 3: d[2] = src->u8s[2]; // fallthrough
1288
- case 2: d[1] = src->u8s[1]; // fallthrough
1289
- case 1: d[0] = src->u8s[0]; // fallthrough
1290
- case 0: break;
1291
- }
1588
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1589
+ else *y = (nk_i8_t)nk_rint_even_f64_to_i64_serial_(*x > 127.0f ? 127.0 : (*x < -128.0f ? -128.0 : (nk_f64_t)*x));
1292
1590
  }
1293
1591
 
1294
- /** @brief Type-agnostic partial load for 16-bit elements (8 elements max) into 128-bit vector. */
1295
- NK_INTERNAL void nk_partial_load_b16x8_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
1296
- dst->u64s[0] = 0, dst->u64s[1] = 0;
1297
- nk_u16_t const *s = (nk_u16_t const *)src;
1298
- switch (n) {
1299
- default:
1300
- case 8: dst->u16s[7] = s[7]; // fallthrough
1301
- case 7: dst->u16s[6] = s[6]; // fallthrough
1302
- case 6: dst->u16s[5] = s[5]; // fallthrough
1303
- case 5: dst->u16s[4] = s[4]; // fallthrough
1304
- case 4: dst->u16s[3] = s[3]; // fallthrough
1305
- case 3: dst->u16s[2] = s[2]; // fallthrough
1306
- case 2: dst->u16s[1] = s[1]; // fallthrough
1307
- case 1: dst->u16s[0] = s[0]; // fallthrough
1308
- case 0: break;
1309
- }
1592
+ NK_INTERNAL void nk_f32_to_u8_serial(nk_f32_t const *x, nk_u8_t *y) {
1593
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1594
+ else *y = (nk_u8_t)nk_rint_even_f64_to_u64_serial_(*x > 255.0f ? 255.0 : (*x < 0 ? 0.0 : (nk_f64_t)*x));
1310
1595
  }
1311
1596
 
1312
- /** @brief Type-agnostic partial load for 8-bit elements (16 elements max) into 128-bit vector. */
1313
- NK_INTERNAL void nk_partial_load_b8x16_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
1314
- dst->u64s[0] = 0, dst->u64s[1] = 0;
1315
- nk_u8_t const *s = (nk_u8_t const *)src;
1316
- switch (n) {
1317
- default:
1318
- case 16: dst->u8s[15] = s[15]; // fallthrough
1319
- case 15: dst->u8s[14] = s[14]; // fallthrough
1320
- case 14: dst->u8s[13] = s[13]; // fallthrough
1321
- case 13: dst->u8s[12] = s[12]; // fallthrough
1322
- case 12: dst->u8s[11] = s[11]; // fallthrough
1323
- case 11: dst->u8s[10] = s[10]; // fallthrough
1324
- case 10: dst->u8s[9] = s[9]; // fallthrough
1325
- case 9: dst->u8s[8] = s[8]; // fallthrough
1326
- case 8: dst->u8s[7] = s[7]; // fallthrough
1327
- case 7: dst->u8s[6] = s[6]; // fallthrough
1328
- case 6: dst->u8s[5] = s[5]; // fallthrough
1329
- case 5: dst->u8s[4] = s[4]; // fallthrough
1330
- case 4: dst->u8s[3] = s[3]; // fallthrough
1331
- case 3: dst->u8s[2] = s[2]; // fallthrough
1332
- case 2: dst->u8s[1] = s[1]; // fallthrough
1333
- case 1: dst->u8s[0] = s[0]; // fallthrough
1334
- case 0: break;
1335
- }
1597
+ NK_INTERNAL void nk_f32_to_i16_serial(nk_f32_t const *x, nk_i16_t *y) {
1598
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1599
+ else
1600
+ *y = (nk_i16_t)nk_rint_even_f64_to_i64_serial_(*x > 32767.0f ? 32767.0
1601
+ : (*x < -32768.0f ? -32768.0 : (nk_f64_t)*x));
1336
1602
  }
1337
1603
 
1338
- /** @brief Type-agnostic partial load for 16-bit elements (16 elements max) into 256-bit vector. */
1339
- NK_INTERNAL void nk_partial_load_b16x16_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
1340
- dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
1341
- nk_u16_t const *s = (nk_u16_t const *)src;
1342
- switch (n) {
1343
- default:
1344
- case 16: dst->u16s[15] = s[15]; // fallthrough
1345
- case 15: dst->u16s[14] = s[14]; // fallthrough
1346
- case 14: dst->u16s[13] = s[13]; // fallthrough
1347
- case 13: dst->u16s[12] = s[12]; // fallthrough
1348
- case 12: dst->u16s[11] = s[11]; // fallthrough
1349
- case 11: dst->u16s[10] = s[10]; // fallthrough
1350
- case 10: dst->u16s[9] = s[9]; // fallthrough
1351
- case 9: dst->u16s[8] = s[8]; // fallthrough
1352
- case 8: dst->u16s[7] = s[7]; // fallthrough
1353
- case 7: dst->u16s[6] = s[6]; // fallthrough
1354
- case 6: dst->u16s[5] = s[5]; // fallthrough
1355
- case 5: dst->u16s[4] = s[4]; // fallthrough
1356
- case 4: dst->u16s[3] = s[3]; // fallthrough
1357
- case 3: dst->u16s[2] = s[2]; // fallthrough
1358
- case 2: dst->u16s[1] = s[1]; // fallthrough
1359
- case 1: dst->u16s[0] = s[0]; // fallthrough
1360
- case 0: break;
1361
- }
1604
+ NK_INTERNAL void nk_f32_to_u16_serial(nk_f32_t const *x, nk_u16_t *y) {
1605
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1606
+ else *y = (nk_u16_t)nk_rint_even_f64_to_u64_serial_(*x > 65535.0f ? 65535.0 : (*x < 0 ? 0.0 : (nk_f64_t)*x));
1362
1607
  }
1363
1608
 
1364
- /** @brief Partial load for 8-bit elements (32 max) into 256-bit vector (zeros in remaining slots). */
1365
- NK_INTERNAL void nk_partial_load_b8x32_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
1366
- dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
1367
- nk_u8_t const *s = (nk_u8_t const *)src;
1368
- switch (n) {
1369
- default:
1370
- case 32: dst->u8s[31] = s[31]; // fallthrough
1371
- case 31: dst->u8s[30] = s[30]; // fallthrough
1372
- case 30: dst->u8s[29] = s[29]; // fallthrough
1373
- case 29: dst->u8s[28] = s[28]; // fallthrough
1374
- case 28: dst->u8s[27] = s[27]; // fallthrough
1375
- case 27: dst->u8s[26] = s[26]; // fallthrough
1376
- case 26: dst->u8s[25] = s[25]; // fallthrough
1377
- case 25: dst->u8s[24] = s[24]; // fallthrough
1378
- case 24: dst->u8s[23] = s[23]; // fallthrough
1379
- case 23: dst->u8s[22] = s[22]; // fallthrough
1380
- case 22: dst->u8s[21] = s[21]; // fallthrough
1381
- case 21: dst->u8s[20] = s[20]; // fallthrough
1382
- case 20: dst->u8s[19] = s[19]; // fallthrough
1383
- case 19: dst->u8s[18] = s[18]; // fallthrough
1384
- case 18: dst->u8s[17] = s[17]; // fallthrough
1385
- case 17: dst->u8s[16] = s[16]; // fallthrough
1386
- case 16: dst->u8s[15] = s[15]; // fallthrough
1387
- case 15: dst->u8s[14] = s[14]; // fallthrough
1388
- case 14: dst->u8s[13] = s[13]; // fallthrough
1389
- case 13: dst->u8s[12] = s[12]; // fallthrough
1390
- case 12: dst->u8s[11] = s[11]; // fallthrough
1391
- case 11: dst->u8s[10] = s[10]; // fallthrough
1392
- case 10: dst->u8s[9] = s[9]; // fallthrough
1393
- case 9: dst->u8s[8] = s[8]; // fallthrough
1394
- case 8: dst->u8s[7] = s[7]; // fallthrough
1395
- case 7: dst->u8s[6] = s[6]; // fallthrough
1396
- case 6: dst->u8s[5] = s[5]; // fallthrough
1397
- case 5: dst->u8s[4] = s[4]; // fallthrough
1398
- case 4: dst->u8s[3] = s[3]; // fallthrough
1399
- case 3: dst->u8s[2] = s[2]; // fallthrough
1400
- case 2: dst->u8s[1] = s[1]; // fallthrough
1401
- case 1: dst->u8s[0] = s[0]; // fallthrough
1402
- case 0: break;
1403
- }
1609
+ NK_INTERNAL void nk_f64_to_i8_serial(nk_f64_t const *x, nk_i8_t *y) {
1610
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1611
+ else *y = (nk_i8_t)nk_rint_even_f64_to_i64_serial_(*x > 127.0 ? 127.0 : (*x < -128.0 ? -128.0 : *x));
1404
1612
  }
1405
1613
 
1406
- /** @brief Type-agnostic partial store for 32-bit elements (8 elements max) from 256-bit vector. */
1407
- NK_INTERNAL void nk_partial_store_b32x8_serial_(nk_b256_vec_t const *src, void *dst, nk_size_t n) {
1408
- nk_u32_t *d = (nk_u32_t *)dst;
1409
- switch (n) {
1410
- default:
1411
- case 8: d[7] = src->u32s[7]; // fallthrough
1412
- case 7: d[6] = src->u32s[6]; // fallthrough
1413
- case 6: d[5] = src->u32s[5]; // fallthrough
1414
- case 5: d[4] = src->u32s[4]; // fallthrough
1415
- case 4: d[3] = src->u32s[3]; // fallthrough
1416
- case 3: d[2] = src->u32s[2]; // fallthrough
1417
- case 2: d[1] = src->u32s[1]; // fallthrough
1418
- case 1: d[0] = src->u32s[0]; // fallthrough
1419
- case 0: break;
1420
- }
1614
+ NK_INTERNAL void nk_f64_to_u8_serial(nk_f64_t const *x, nk_u8_t *y) {
1615
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1616
+ else *y = (nk_u8_t)nk_rint_even_f64_to_u64_serial_(*x > 255.0 ? 255.0 : (*x < 0 ? 0.0 : *x));
1421
1617
  }
1422
1618
 
1423
- /** @brief Type-agnostic partial store for 32-bit elements (4 elements max) from 128-bit vector. */
1424
- NK_INTERNAL void nk_partial_store_b32x4_serial_(nk_b128_vec_t const *src, void *dst, nk_size_t n) {
1425
- nk_u32_t *d = (nk_u32_t *)dst;
1426
- switch (n) {
1427
- default:
1428
- case 4: d[3] = src->u32s[3]; // fallthrough
1429
- case 3: d[2] = src->u32s[2]; // fallthrough
1430
- case 2: d[1] = src->u32s[1]; // fallthrough
1431
- case 1: d[0] = src->u32s[0]; // fallthrough
1432
- case 0: break;
1433
- }
1619
+ NK_INTERNAL void nk_f64_to_i16_serial(nk_f64_t const *x, nk_i16_t *y) {
1620
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1621
+ else *y = (nk_i16_t)nk_rint_even_f64_to_i64_serial_(*x > 32767.0 ? 32767.0 : (*x < -32768.0 ? -32768.0 : *x));
1434
1622
  }
1435
1623
 
1436
- /** @brief Type-agnostic partial store for 16-bit elements (8 elements max) from 128-bit vector. */
1437
- NK_INTERNAL void nk_partial_store_b16x8_serial_(nk_b128_vec_t const *src, void *dst, nk_size_t n) {
1438
- nk_u16_t *d = (nk_u16_t *)dst;
1439
- switch (n) {
1440
- default:
1441
- case 8: d[7] = src->u16s[7]; // fallthrough
1442
- case 7: d[6] = src->u16s[6]; // fallthrough
1443
- case 6: d[5] = src->u16s[5]; // fallthrough
1444
- case 5: d[4] = src->u16s[4]; // fallthrough
1445
- case 4: d[3] = src->u16s[3]; // fallthrough
1446
- case 3: d[2] = src->u16s[2]; // fallthrough
1447
- case 2: d[1] = src->u16s[1]; // fallthrough
1448
- case 1: d[0] = src->u16s[0]; // fallthrough
1449
- case 0: break;
1450
- }
1624
+ NK_INTERNAL void nk_f64_to_u16_serial(nk_f64_t const *x, nk_u16_t *y) {
1625
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1626
+ else *y = (nk_u16_t)nk_rint_even_f64_to_u64_serial_(*x > 65535.0 ? 65535.0 : (*x < 0 ? 0.0 : *x));
1451
1627
  }
1452
1628
 
1453
- /** @brief Type-agnostic partial store for 16-bit elements (4 elements max) from 64-bit vector. */
1454
- NK_INTERNAL void nk_partial_store_b16x4_serial_(void *dst, nk_b64_vec_t const *src, nk_size_t n) {
1455
- nk_u16_t *d = (nk_u16_t *)dst;
1456
- switch (n) {
1457
- default:
1458
- case 4: d[3] = src->u16s[3]; // fallthrough
1459
- case 3: d[2] = src->u16s[2]; // fallthrough
1460
- case 2: d[1] = src->u16s[1]; // fallthrough
1461
- case 1: d[0] = src->u16s[0]; // fallthrough
1462
- case 0: break;
1463
- }
1629
+ NK_INTERNAL void nk_f64_to_i32_serial(nk_f64_t const *x, nk_i32_t *y) {
1630
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1631
+ else
1632
+ *y = (nk_i32_t)nk_rint_even_f64_to_i64_serial_(*x > 2147483647.0 ? 2147483647.0
1633
+ : (*x < -2147483648.0 ? -2147483648.0 : *x));
1464
1634
  }
1465
1635
 
1466
- /** @brief Type-agnostic partial store for 8-bit elements (8 elements max) from 64-bit vector. */
1467
- NK_INTERNAL void nk_partial_store_b8x8_serial_(nk_b64_vec_t const *src, void *dst, nk_size_t n) {
1468
- nk_u8_t *d = (nk_u8_t *)dst;
1469
- switch (n) {
1470
- default:
1471
- case 8: d[7] = src->u8s[7]; // fallthrough
1472
- case 7: d[6] = src->u8s[6]; // fallthrough
1473
- case 6: d[5] = src->u8s[5]; // fallthrough
1474
- case 5: d[4] = src->u8s[4]; // fallthrough
1475
- case 4: d[3] = src->u8s[3]; // fallthrough
1476
- case 3: d[2] = src->u8s[2]; // fallthrough
1477
- case 2: d[1] = src->u8s[1]; // fallthrough
1478
- case 1: d[0] = src->u8s[0]; // fallthrough
1479
- case 0: break;
1480
- }
1636
+ NK_INTERNAL void nk_f64_to_u32_serial(nk_f64_t const *x, nk_u32_t *y) {
1637
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1638
+ else *y = (nk_u32_t)nk_rint_even_f64_to_u64_serial_(*x > 4294967295.0 ? 4294967295.0 : (*x < 0 ? 0.0 : *x));
1481
1639
  }
1482
1640
 
1483
- /** @brief Type-agnostic partial load for 64-bit elements (4 elements max) into 256-bit vector. */
1484
- NK_INTERNAL void nk_partial_load_b64x4_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
1485
- nk_u64_t const *s = (nk_u64_t const *)src;
1486
- dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
1487
- switch (n) {
1488
- default:
1489
- case 4: dst->u64s[3] = s[3]; // fallthrough
1490
- case 3: dst->u64s[2] = s[2]; // fallthrough
1491
- case 2: dst->u64s[1] = s[1]; // fallthrough
1492
- case 1: dst->u64s[0] = s[0]; // fallthrough
1493
- case 0: break;
1494
- }
1641
+ NK_INTERNAL void nk_f64_to_i64_serial(nk_f64_t const *x, nk_i64_t *y) {
1642
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1643
+ else
1644
+ *y = nk_rint_even_f64_to_i64_serial_(*x > 9223372036854775807.0
1645
+ ? 9223372036854775807.0
1646
+ : (*x < -9223372036854775808.0 ? -9223372036854775808.0 : *x));
1495
1647
  }
1496
1648
 
1497
- /** @brief Type-agnostic partial store for 64-bit elements (4 elements max) from 256-bit vector. */
1498
- NK_INTERNAL void nk_partial_store_b64x4_serial_(nk_b256_vec_t const *src, void *dst, nk_size_t n) {
1499
- nk_u64_t *d = (nk_u64_t *)dst;
1500
- switch (n) {
1501
- default:
1502
- case 4: d[3] = src->u64s[3]; // fallthrough
1503
- case 3: d[2] = src->u64s[2]; // fallthrough
1504
- case 2: d[1] = src->u64s[1]; // fallthrough
1505
- case 1: d[0] = src->u64s[0]; // fallthrough
1506
- case 0: break;
1507
- }
1649
+ NK_INTERNAL void nk_f64_to_u64_serial(nk_f64_t const *x, nk_u64_t *y) {
1650
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1651
+ else
1652
+ *y = nk_rint_even_f64_to_u64_serial_(*x > 18446744073709551615.0 ? 18446744073709551615.0
1653
+ : (*x < 0 ? 0.0 : *x));
1508
1654
  }
1509
1655
 
1510
- /** @brief Type-agnostic partial load for 32-bit elements (2 elements max) into 64-bit vector. */
1511
- NK_INTERNAL void nk_partial_load_b32x2_serial_(void const *src, nk_b64_vec_t *dst, nk_size_t n) {
1512
- dst->u64 = 0;
1513
- nk_u32_t const *s = (nk_u32_t const *)src;
1514
- switch (n) {
1515
- default:
1516
- case 2: dst->u32s[1] = s[1]; // fallthrough
1517
- case 1: dst->u32s[0] = s[0]; // fallthrough
1518
- case 0: break;
1519
- }
1656
+ NK_INTERNAL void nk_i64_to_i8_serial(nk_i64_t const *x, nk_i8_t *y) {
1657
+ *y = (nk_i8_t)(*x > 127ll ? 127ll : (*x < -128ll ? -128ll : *x));
1658
+ }
1659
+
1660
+ NK_INTERNAL void nk_i64_to_u8_serial(nk_i64_t const *x, nk_u8_t *y) {
1661
+ *y = (nk_u8_t)(*x > 255ll ? 255ll : (*x < 0ll ? 0ll : *x));
1662
+ }
1663
+
1664
+ NK_INTERNAL void nk_i64_to_i16_serial(nk_i64_t const *x, nk_i16_t *y) {
1665
+ *y = (nk_i16_t)(*x > 32767ll ? 32767ll : (*x < -32768ll ? -32768ll : *x));
1520
1666
  }
1521
1667
 
1522
- /** @brief Type-agnostic partial load for 16-bit elements (4 elements max) into 64-bit vector. */
1523
- NK_INTERNAL void nk_partial_load_b16x4_serial_(void const *src, nk_b64_vec_t *dst, nk_size_t n) {
1524
- dst->u64 = 0;
1525
- nk_u16_t const *s = (nk_u16_t const *)src;
1526
- switch (n) {
1527
- default:
1528
- case 4: dst->u16s[3] = s[3]; // fallthrough
1529
- case 3: dst->u16s[2] = s[2]; // fallthrough
1530
- case 2: dst->u16s[1] = s[1]; // fallthrough
1531
- case 1: dst->u16s[0] = s[0]; // fallthrough
1532
- case 0: break;
1533
- }
1668
+ NK_INTERNAL void nk_i64_to_u16_serial(nk_i64_t const *x, nk_u16_t *y) {
1669
+ *y = (nk_u16_t)(*x > 65535ll ? 65535ll : (*x < 0ll ? 0ll : *x));
1534
1670
  }
1535
1671
 
1536
- /** @brief Partial load for 4-bit nibbles (64 max = 32 bytes) into 256-bit vector (zeros in remaining slots). */
1537
- NK_INTERNAL void nk_partial_load_b4x64_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
1538
- dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
1539
- nk_u8_t const *s = (nk_u8_t const *)src;
1540
- nk_size_t n_bytes = nk_size_divide_round_up_(n, 2);
1541
- for (nk_size_t i = 0; i < n_bytes && i < 32; i++) dst->u8s[i] = s[i];
1672
+ NK_INTERNAL void nk_i64_to_i32_serial(nk_i64_t const *x, nk_i32_t *y) {
1673
+ *y = (nk_i32_t)(*x > 2147483647ll ? 2147483647ll : (*x < -2147483648ll ? -2147483648ll : *x));
1542
1674
  }
1543
1675
 
1544
- /** @brief Partial load for 4-bit nibbles (32 max = 16 bytes) into 128-bit vector (zeros in remaining slots). */
1545
- NK_INTERNAL void nk_partial_load_b4x32_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
1546
- dst->u64s[0] = 0, dst->u64s[1] = 0;
1547
- nk_u8_t const *s = (nk_u8_t const *)src;
1548
- nk_size_t n_bytes = nk_size_divide_round_up_(n, 2);
1549
- for (nk_size_t i = 0; i < n_bytes && i < 16; i++) dst->u8s[i] = s[i];
1676
+ NK_INTERNAL void nk_i64_to_u32_serial(nk_i64_t const *x, nk_u32_t *y) {
1677
+ *y = (nk_u32_t)(*x > 4294967295ll ? 4294967295ll : (*x < 0ll ? 0ll : *x));
1550
1678
  }
1551
1679
 
1552
- /** @brief Partial load for 1-bit elements (128 max = 16 bytes) into 128-bit vector (zeros in remaining slots). */
1553
- NK_INTERNAL void nk_partial_load_b1x128_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n_bits) {
1554
- dst->u64s[0] = 0, dst->u64s[1] = 0;
1555
- nk_u8_t const *s = (nk_u8_t const *)src;
1556
- nk_size_t n_bytes = nk_size_divide_round_up_(n_bits, 8);
1557
- for (nk_size_t i = 0; i < n_bytes && i < 16; i++) dst->u8s[i] = s[i];
1680
+ NK_INTERNAL void nk_u64_to_i8_serial(nk_u64_t const *x, nk_i8_t *y) { *y = (nk_i8_t)(*x > 127ull ? 127ull : *x); }
1681
+ NK_INTERNAL void nk_u64_to_u8_serial(nk_u64_t const *x, nk_u8_t *y) { *y = (nk_u8_t)(*x > 255ull ? 255ull : *x); }
1682
+ NK_INTERNAL void nk_u64_to_i16_serial(nk_u64_t const *x, nk_i16_t *y) {
1683
+ *y = (nk_i16_t)(*x > 32767ull ? 32767ull : *x);
1684
+ }
1685
+ NK_INTERNAL void nk_u64_to_u16_serial(nk_u64_t const *x, nk_u16_t *y) {
1686
+ *y = (nk_u16_t)(*x > 65535ull ? 65535ull : *x);
1558
1687
  }
1559
1688
 
1560
- /** @brief Partial load for 4-bit nibbles (16 max = 8 bytes) into 64-bit vector (zeros in remaining slots). */
1561
- NK_INTERNAL void nk_partial_load_b4x16_serial_(void const *src, nk_b64_vec_t *dst, nk_size_t n) {
1562
- dst->u64 = 0;
1563
- nk_u8_t const *s = (nk_u8_t const *)src;
1564
- nk_size_t n_bytes = nk_size_divide_round_up_(n, 2);
1565
- for (nk_size_t i = 0; i < n_bytes && i < 8; i++) ((nk_u8_t *)&dst->u64)[i] = s[i];
1689
+ NK_INTERNAL void nk_u64_to_i32_serial(nk_u64_t const *x, nk_i32_t *y) {
1690
+ *y = (nk_i32_t)(*x > 2147483647ull ? 2147483647ull : *x);
1566
1691
  }
1567
1692
 
1568
- NK_INTERNAL void nk_partial_load_b64x2_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
1569
- dst->u64s[0] = 0, dst->u64s[1] = 0;
1570
- nk_u64_t const *s = (nk_u64_t const *)src;
1571
- switch (n) {
1572
- default:
1573
- case 2: dst->u64s[1] = s[1]; // fallthrough
1574
- case 1: dst->u64s[0] = s[0]; // fallthrough
1575
- case 0: break;
1576
- }
1693
+ NK_INTERNAL void nk_u64_to_u32_serial(nk_u64_t const *x, nk_u32_t *y) {
1694
+ *y = (nk_u32_t)(*x > 4294967295ull ? 4294967295ull : *x);
1577
1695
  }
1578
1696
 
1579
- /** @brief Type-agnostic partial store for 64-bit elements (2 elements max) from 128-bit vector. */
1580
- NK_INTERNAL void nk_partial_store_b64x2_serial_(nk_b128_vec_t const *src, void *dst, nk_size_t n) {
1581
- nk_u64_t *d = (nk_u64_t *)dst;
1582
- switch (n) {
1583
- default:
1584
- case 2: d[1] = src->u64s[1]; // fallthrough
1585
- case 1: d[0] = src->u64s[0]; // fallthrough
1586
- case 0: break;
1587
- }
1697
+ NK_INTERNAL void nk_u64_to_i64_serial(nk_u64_t const *x, nk_i64_t *y) {
1698
+ *y = (nk_i64_t)(*x >= 9223372036854775807ull ? 9223372036854775807ll : *x);
1588
1699
  }
1589
1700
 
1590
- /** @brief Strided partial load for 32-bit elements (4 max) into 128-bit vector. */
1591
- NK_INTERNAL void nk_strided_load_b32x4_serial_(void const *src, nk_size_t stride_elements, nk_b128_vec_t *dst,
1592
- nk_size_t n) {
1593
- dst->u64s[0] = 0, dst->u64s[1] = 0;
1594
- nk_u32_t const *s = (nk_u32_t const *)src;
1595
- for (nk_size_t i = 0; i < n && i < 4; ++i) dst->u32s[i] = s[i * stride_elements];
1701
+ NK_INTERNAL void nk_i8_to_u64_serial(nk_i8_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
1702
+ NK_INTERNAL void nk_i16_to_u64_serial(nk_i16_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
1703
+ NK_INTERNAL void nk_i32_to_u64_serial(nk_i32_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
1704
+ NK_INTERNAL void nk_i64_to_u64_serial(nk_i64_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
1705
+
1706
+ NK_INTERNAL void nk_i64_to_f16_serial(nk_i64_t const *x, nk_f16_t *y) {
1707
+ nk_f32_t f32 = (nk_f32_t)*x;
1708
+ nk_f32_to_f16_serial(&f32, y);
1709
+ }
1710
+ NK_INTERNAL void nk_i64_to_bf16_serial(nk_i64_t const *x, nk_bf16_t *y) {
1711
+ nk_f32_t f32 = (nk_f32_t)*x;
1712
+ nk_f32_to_bf16_serial(&f32, y);
1713
+ }
1714
+ NK_INTERNAL void nk_u64_to_f16_serial(nk_u64_t const *x, nk_f16_t *y) {
1715
+ nk_f32_t f32 = (nk_f32_t)*x;
1716
+ nk_f32_to_f16_serial(&f32, y);
1717
+ }
1718
+ NK_INTERNAL void nk_u64_to_bf16_serial(nk_u64_t const *x, nk_bf16_t *y) {
1719
+ nk_f32_t f32 = (nk_f32_t)*x;
1720
+ nk_f32_to_bf16_serial(&f32, y);
1596
1721
  }
1597
1722
 
1598
- /** @brief Strided partial load for 16-bit elements (8 max) into 128-bit vector. */
1599
- NK_INTERNAL void nk_strided_load_b16x8_serial_(void const *src, nk_size_t stride_elements, nk_b128_vec_t *dst,
1600
- nk_size_t n) {
1601
- dst->u64s[0] = 0, dst->u64s[1] = 0;
1602
- nk_u16_t const *s = (nk_u16_t const *)src;
1603
- for (nk_size_t i = 0; i < n && i < 8; ++i) dst->u16s[i] = s[i * stride_elements];
1723
+ /** @brief Convert a pair of i4 (4-bit signed integer, -8 to 7) nibbles into signed integers. */
1724
+ NK_PUBLIC void nk_i4x2_to_i8x2_serial(nk_i4x2_t const *src, nk_i8_t *dest) {
1725
+ nk_u8_t byte = *(nk_u8_t const *)src;
1726
+ nk_u8_t high_nibble = byte >> 4;
1727
+ nk_u8_t low_nibble = byte & 0x0F;
1728
+ // Sign extend: 0-7 0-7, 8-15 -8 to -1
1729
+ dest[0] = (nk_i8_t)((high_nibble ^ 8) - 8);
1730
+ dest[1] = (nk_i8_t)((low_nibble ^ 8) - 8);
1604
1731
  }
1605
1732
 
1606
- /** @brief Strided partial load for 8-bit elements (16 max) into 128-bit vector. */
1607
- NK_INTERNAL void nk_strided_load_b8x16_serial_(void const *src, nk_size_t stride_elements, nk_b128_vec_t *dst,
1608
- nk_size_t n) {
1609
- dst->u64s[0] = 0, dst->u64s[1] = 0;
1610
- nk_u8_t const *s = (nk_u8_t const *)src;
1611
- for (nk_size_t i = 0; i < n && i < 16; ++i) dst->u8s[i] = s[i * stride_elements];
1733
+ /** @brief Convert a pair of u4 (4-bit unsigned integer, 0 to 15) nibbles into unsigned integers. */
1734
+ NK_PUBLIC void nk_u4x2_to_u8x2_serial(nk_u4x2_t const *src, nk_u8_t *dest) {
1735
+ nk_u8_t byte = *(nk_u8_t const *)src;
1736
+ dest[0] = byte >> 4;
1737
+ dest[1] = byte & 0x0F;
1612
1738
  }
1613
1739
 
1614
1740
  /**
@@ -1619,7 +1745,7 @@ NK_INTERNAL void nk_strided_load_b8x16_serial_(void const *src, nk_size_t stride
1619
1745
  * The caller fills the appropriate union member based on the target dtype,
1620
1746
  * then passes the union address as `void const *` to kernel functions.
1621
1747
  */
1622
- typedef union nk_scalar_buffer_t {
1748
+ typedef union NK_MAY_ALIAS_ nk_scalar_buffer_t {
1623
1749
  nk_u8_t bytes[16];
1624
1750
  nk_f64_t f64;
1625
1751
  nk_f32_t f32;
@@ -1639,115 +1765,78 @@ typedef union nk_scalar_buffer_t {
1639
1765
  nk_u8_t u8;
1640
1766
  } nk_scalar_buffer_t;
1641
1767
 
1768
+ /** @brief Reads a typed scalar from @p buf and writes the widened f64c into @p result.
1769
+ * Real types set `.imag = 0`. Safe when @p result aliases @p buf (in-place conversion).
1770
+ * @return 1 on success, 0 for unsupported types (sub-byte, unknown). */
1771
+ NK_INTERNAL int nk_scalar_buffer_to_f64c(nk_scalar_buffer_t const *buf, nk_dtype_t dtype, nk_f64c_t *result) {
1772
+ // Snapshot input so `result` may alias `buf` (e.g. in-place conversion within a union).
1773
+ nk_scalar_buffer_t local;
1774
+ local.f64c = buf->f64c;
1775
+ result->real = 0, result->imag = 0;
1776
+ switch (dtype) {
1777
+ case nk_f64_k: result->real = local.f64; break;
1778
+ case nk_f32_k: result->real = (nk_f64_t)local.f32; break;
1779
+ case nk_f16_k:
1780
+ nk_f16_to_f32_serial(&local.f16, &local.f32);
1781
+ result->real = (nk_f64_t)local.f32;
1782
+ break;
1783
+ case nk_bf16_k:
1784
+ nk_bf16_to_f32_serial(&local.bf16, &local.f32);
1785
+ result->real = (nk_f64_t)local.f32;
1786
+ break;
1787
+ case nk_f64c_k: result->real = local.f64c.real, result->imag = local.f64c.imag; break;
1788
+ case nk_f32c_k: result->real = (nk_f64_t)local.f32c.real, result->imag = (nk_f64_t)local.f32c.imag; break;
1789
+ case nk_f16c_k:
1790
+ nk_f16_to_f32_serial(&local.f16c.real, &local.f32);
1791
+ result->real = (nk_f64_t)local.f32;
1792
+ nk_f16_to_f32_serial(&local.f16c.imag, &local.f32);
1793
+ result->imag = (nk_f64_t)local.f32;
1794
+ break;
1795
+ case nk_bf16c_k:
1796
+ nk_bf16_to_f32_serial(&local.bf16c.real, &local.f32);
1797
+ result->real = (nk_f64_t)local.f32;
1798
+ nk_bf16_to_f32_serial(&local.bf16c.imag, &local.f32);
1799
+ result->imag = (nk_f64_t)local.f32;
1800
+ break;
1801
+ case nk_i64_k: result->real = (nk_f64_t)local.i64; break;
1802
+ case nk_u64_k: result->real = (nk_f64_t)local.u64; break;
1803
+ case nk_i32_k: result->real = (nk_f64_t)local.i32; break;
1804
+ case nk_u32_k: result->real = (nk_f64_t)local.u32; break;
1805
+ case nk_i16_k: result->real = (nk_f64_t)local.i16; break;
1806
+ case nk_u16_k: result->real = (nk_f64_t)local.u16; break;
1807
+ case nk_i8_k: result->real = (nk_f64_t)local.i8; break;
1808
+ case nk_u8_k: result->real = (nk_f64_t)local.u8; break;
1809
+ case nk_e4m3_k:
1810
+ nk_e4m3_to_f32_serial(&local.u8, &local.f32);
1811
+ result->real = (nk_f64_t)local.f32;
1812
+ break;
1813
+ case nk_e5m2_k:
1814
+ nk_e5m2_to_f32_serial(&local.u8, &local.f32);
1815
+ result->real = (nk_f64_t)local.f32;
1816
+ break;
1817
+ case nk_e2m3_k:
1818
+ nk_e2m3_to_f32_serial(&local.u8, &local.f32);
1819
+ result->real = (nk_f64_t)local.f32;
1820
+ break;
1821
+ case nk_e3m2_k:
1822
+ nk_e3m2_to_f32_serial(&local.u8, &local.f32);
1823
+ result->real = (nk_f64_t)local.f32;
1824
+ break;
1825
+ default: return 0;
1826
+ }
1827
+ return 1;
1828
+ }
1829
+
1642
1830
  /**
1643
1831
  * @brief Converts up to 8x values from `from_ptr` buffer into 8x puned buffer objects
1644
1832
  * into a complex 64-bit floating point representation.
1645
1833
  */
1646
- NK_INTERNAL void nk_scalar_buffers_fill_f64c_( //
1834
+ NK_INTERNAL void nk_scalar_buffers_to_f64c_( //
1647
1835
  void const *from_ptr, nk_dtype_t from_dtype, nk_size_t from_count, //
1648
1836
  nk_scalar_buffer_t to_buffers[nk_at_least_(8)]) {
1649
1837
 
1650
- nk_f32_t temporary_f32;
1651
1838
  nk_size_t i;
1652
1839
  switch (from_dtype) {
1653
- case nk_f64_k: {
1654
- nk_f64_t const *p = (nk_f64_t const *)from_ptr;
1655
- for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
1656
- } break;
1657
- case nk_f32_k: {
1658
- nk_f32_t const *p = (nk_f32_t const *)from_ptr;
1659
- for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
1660
- } break;
1661
- case nk_f16_k: {
1662
- nk_f16_t const *p = (nk_f16_t const *)from_ptr;
1663
- for (i = 0; i < from_count; ++i)
1664
- nk_f16_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
1665
- to_buffers[i].f64c.imag = 0;
1666
- } break;
1667
- case nk_bf16_k: {
1668
- nk_bf16_t const *p = (nk_bf16_t const *)from_ptr;
1669
- for (i = 0; i < from_count; ++i)
1670
- nk_bf16_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
1671
- to_buffers[i].f64c.imag = 0;
1672
- } break;
1673
- case nk_e4m3_k: {
1674
- nk_u8_t const *p = (nk_u8_t const *)from_ptr;
1675
- for (i = 0; i < from_count; ++i)
1676
- nk_e4m3_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
1677
- to_buffers[i].f64c.imag = 0;
1678
- } break;
1679
- case nk_e5m2_k: {
1680
- nk_u8_t const *p = (nk_u8_t const *)from_ptr;
1681
- for (i = 0; i < from_count; ++i)
1682
- nk_e5m2_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
1683
- to_buffers[i].f64c.imag = 0;
1684
- } break;
1685
- case nk_e2m3_k: {
1686
- nk_u8_t const *p = (nk_u8_t const *)from_ptr;
1687
- for (i = 0; i < from_count; ++i)
1688
- nk_e2m3_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
1689
- to_buffers[i].f64c.imag = 0;
1690
- } break;
1691
- case nk_e3m2_k: {
1692
- nk_u8_t const *p = (nk_u8_t const *)from_ptr;
1693
- for (i = 0; i < from_count; ++i)
1694
- nk_e3m2_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
1695
- to_buffers[i].f64c.imag = 0;
1696
- } break;
1697
- case nk_i64_k: {
1698
- nk_i64_t const *p = (nk_i64_t const *)from_ptr;
1699
- for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = (nk_f64_t)p[i], to_buffers[i].f64c.imag = 0;
1700
- } break;
1701
- case nk_i32_k: {
1702
- nk_i32_t const *p = (nk_i32_t const *)from_ptr;
1703
- for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
1704
- } break;
1705
- case nk_i16_k: {
1706
- nk_i16_t const *p = (nk_i16_t const *)from_ptr;
1707
- for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
1708
- } break;
1709
- case nk_i8_k: {
1710
- nk_i8_t const *p = (nk_i8_t const *)from_ptr;
1711
- for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
1712
- } break;
1713
- case nk_u64_k: {
1714
- nk_u64_t const *p = (nk_u64_t const *)from_ptr;
1715
- for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = (nk_f64_t)p[i], to_buffers[i].f64c.imag = 0;
1716
- } break;
1717
- case nk_u32_k: {
1718
- nk_u32_t const *p = (nk_u32_t const *)from_ptr;
1719
- for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
1720
- } break;
1721
- case nk_u16_k: {
1722
- nk_u16_t const *p = (nk_u16_t const *)from_ptr;
1723
- for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
1724
- } break;
1725
- case nk_u8_k: {
1726
- nk_u8_t const *p = (nk_u8_t const *)from_ptr;
1727
- for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
1728
- } break;
1729
- case nk_f64c_k: {
1730
- nk_f64c_t const *p = (nk_f64c_t const *)from_ptr;
1731
- for (i = 0; i < from_count; ++i) to_buffers[i].f64c = p[i];
1732
- } break;
1733
- case nk_f32c_k: {
1734
- nk_f32c_t const *p = (nk_f32c_t const *)from_ptr;
1735
- for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i].real, to_buffers[i].f64c.imag = p[i].imag;
1736
- } break;
1737
- case nk_f16c_k: {
1738
- nk_f16c_t const *p = (nk_f16c_t const *)from_ptr;
1739
- for (i = 0; i < from_count; ++i) {
1740
- nk_f16_to_f32_serial(&p[i].real, &temporary_f32), to_buffers[i].f64c.real = temporary_f32;
1741
- nk_f16_to_f32_serial(&p[i].imag, &temporary_f32), to_buffers[i].f64c.imag = temporary_f32;
1742
- }
1743
- } break;
1744
- case nk_bf16c_k: {
1745
- nk_bf16c_t const *p = (nk_bf16c_t const *)from_ptr;
1746
- for (i = 0; i < from_count; ++i) {
1747
- nk_bf16_to_f32_serial(&p[i].real, &temporary_f32), to_buffers[i].f64c.real = temporary_f32;
1748
- nk_bf16_to_f32_serial(&p[i].imag, &temporary_f32), to_buffers[i].f64c.imag = temporary_f32;
1749
- }
1750
- } break;
1751
1840
  // Sub-byte: u1 - 8 bits from 1 byte, MSB-first
1752
1841
  case nk_u1_k: {
1753
1842
  nk_u8_t byte = *(nk_u8_t const *)from_ptr;
@@ -1755,130 +1844,117 @@ NK_INTERNAL void nk_scalar_buffers_fill_f64c_( //
1755
1844
  } break;
1756
1845
  // Sub-byte: i4 - 8 nibbles from 4 bytes, high nibble = even index, sign-extended
1757
1846
  case nk_i4_k: {
1758
- nk_u8_t const *p = (nk_u8_t const *)from_ptr;
1847
+ nk_i4x2_t const *pairs = (nk_i4x2_t const *)from_ptr;
1848
+ nk_i8_t unpacked[2];
1759
1849
  for (i = 0; i < 4; ++i) {
1760
- nk_i8_t hi = (nk_i8_t)(p[i] >> 4), lo = (nk_i8_t)(p[i] & 0xF);
1761
- to_buffers[i * 2].f64c.real = (hi ^ 8) - 8, to_buffers[i * 2].f64c.imag = 0;
1762
- to_buffers[i * 2 + 1].f64c.real = (lo ^ 8) - 8, to_buffers[i * 2 + 1].f64c.imag = 0;
1850
+ nk_i4x2_to_i8x2_serial(&pairs[i], unpacked);
1851
+ to_buffers[i * 2].f64c.real = unpacked[0], to_buffers[i * 2].f64c.imag = 0;
1852
+ to_buffers[i * 2 + 1].f64c.real = unpacked[1], to_buffers[i * 2 + 1].f64c.imag = 0;
1763
1853
  }
1764
1854
  } break;
1765
1855
  // Sub-byte: u4 - 8 nibbles from 4 bytes, high nibble = even index
1766
1856
  case nk_u4_k: {
1767
- nk_u8_t const *p = (nk_u8_t const *)from_ptr;
1857
+ nk_u4x2_t const *pairs = (nk_u4x2_t const *)from_ptr;
1858
+ nk_u8_t unpacked[2];
1768
1859
  for (i = 0; i < 4; ++i) {
1769
- to_buffers[i * 2].f64c.real = p[i] >> 4, to_buffers[i * 2].f64c.imag = 0;
1770
- to_buffers[i * 2 + 1].f64c.real = p[i] & 0xF, to_buffers[i * 2 + 1].f64c.imag = 0;
1860
+ nk_u4x2_to_u8x2_serial(&pairs[i], unpacked);
1861
+ to_buffers[i * 2].f64c.real = unpacked[0], to_buffers[i * 2].f64c.imag = 0;
1862
+ to_buffers[i * 2 + 1].f64c.real = unpacked[1], to_buffers[i * 2 + 1].f64c.imag = 0;
1771
1863
  }
1772
1864
  } break;
1773
- default:
1774
- for (i = 0; i < 8; ++i) to_buffers[i].f64c.real = 0, to_buffers[i].f64c.imag = 0;
1865
+ // All byte-or-larger types: stage through a separate buffer to avoid
1866
+ // variable-length memcpy and type-punned read on the same union
1867
+ // a pattern that triggers an ICE in MSVC's ARM64 optimizer (C1001).
1868
+ default: {
1869
+ nk_size_t stride = nk_dtype_bits(from_dtype) / NK_BITS_PER_BYTE;
1870
+ nk_scalar_buffer_t staged;
1871
+ for (i = 0; i < from_count; ++i) {
1872
+ staged.u64 = 0;
1873
+ nk_copy_bytes_(&staged, (char const *)from_ptr + i * stride, stride);
1874
+ nk_scalar_buffer_to_f64c(&staged, from_dtype, &to_buffers[i].f64c);
1875
+ }
1876
+ } break;
1877
+ }
1878
+ }
1879
+
1880
+ /** @brief Narrows an f64c @p value into the appropriate typed member of @p buf.
1881
+ * Real types use only `.real`; complex types use both components.
1882
+ * Safe when @p value aliases @p buf (in-place conversion).
1883
+ * @note Integer targets (i64, i32, ...) go through f64 rounding — values beyond 2^53 may lose precision.
1884
+ * @return 1 on success, 0 for unsupported types (sub-byte, unknown). */
1885
+ NK_INTERNAL int nk_scalar_buffer_from_f64c(nk_f64c_t const *value, nk_scalar_buffer_t *buf, nk_dtype_t dtype) {
1886
+ // Snapshot input so `value` may point into `buf` (e.g. in-place conversion within a union).
1887
+ nk_f64c_t local = *value;
1888
+ nk_f32_t temporary_f32;
1889
+ switch (dtype) {
1890
+ case nk_f64_k: buf->f64 = local.real; break;
1891
+ case nk_f32_k: buf->f32 = (nk_f32_t)local.real; break;
1892
+ case nk_f16_k:
1893
+ temporary_f32 = (nk_f32_t)local.real;
1894
+ nk_f32_to_f16_serial(&temporary_f32, &buf->f16);
1895
+ break;
1896
+ case nk_bf16_k:
1897
+ temporary_f32 = (nk_f32_t)local.real;
1898
+ nk_f32_to_bf16_serial(&temporary_f32, &buf->bf16);
1899
+ break;
1900
+ case nk_f64c_k:
1901
+ buf->f64c.real = local.real;
1902
+ buf->f64c.imag = local.imag;
1903
+ break;
1904
+ case nk_f32c_k:
1905
+ buf->f32c.real = (nk_f32_t)local.real;
1906
+ buf->f32c.imag = (nk_f32_t)local.imag;
1907
+ break;
1908
+ case nk_f16c_k:
1909
+ temporary_f32 = (nk_f32_t)local.real;
1910
+ nk_f32_to_f16_serial(&temporary_f32, &buf->f16c.real);
1911
+ temporary_f32 = (nk_f32_t)local.imag;
1912
+ nk_f32_to_f16_serial(&temporary_f32, &buf->f16c.imag);
1913
+ break;
1914
+ case nk_bf16c_k:
1915
+ temporary_f32 = (nk_f32_t)local.real;
1916
+ nk_f32_to_bf16_serial(&temporary_f32, &buf->bf16c.real);
1917
+ temporary_f32 = (nk_f32_t)local.imag;
1918
+ nk_f32_to_bf16_serial(&temporary_f32, &buf->bf16c.imag);
1775
1919
  break;
1920
+ case nk_i64_k: nk_f64_to_i64_serial(&local.real, &buf->i64); break;
1921
+ case nk_u64_k: nk_f64_to_u64_serial(&local.real, &buf->u64); break;
1922
+ case nk_i32_k: nk_f64_to_i32_serial(&local.real, &buf->i32); break;
1923
+ case nk_u32_k: nk_f64_to_u32_serial(&local.real, &buf->u32); break;
1924
+ case nk_i16_k: nk_f64_to_i16_serial(&local.real, &buf->i16); break;
1925
+ case nk_u16_k: nk_f64_to_u16_serial(&local.real, &buf->u16); break;
1926
+ case nk_i8_k: nk_f64_to_i8_serial(&local.real, &buf->i8); break;
1927
+ case nk_u8_k: nk_f64_to_u8_serial(&local.real, &buf->u8); break;
1928
+ case nk_e4m3_k:
1929
+ temporary_f32 = (nk_f32_t)local.real;
1930
+ nk_f32_to_e4m3_serial(&temporary_f32, &buf->u8);
1931
+ break;
1932
+ case nk_e5m2_k:
1933
+ temporary_f32 = (nk_f32_t)local.real;
1934
+ nk_f32_to_e5m2_serial(&temporary_f32, &buf->u8);
1935
+ break;
1936
+ case nk_e2m3_k:
1937
+ temporary_f32 = (nk_f32_t)local.real;
1938
+ nk_f32_to_e2m3_serial(&temporary_f32, &buf->u8);
1939
+ break;
1940
+ case nk_e3m2_k:
1941
+ temporary_f32 = (nk_f32_t)local.real;
1942
+ nk_f32_to_e3m2_serial(&temporary_f32, &buf->u8);
1943
+ break;
1944
+ default: return 0;
1776
1945
  }
1946
+ return 1;
1777
1947
  }
1778
1948
 
1779
1949
  /**
1780
1950
  * @brief Converts up to 8x values from `from_buffers` buffer into 8x typed scalars.
1781
1951
  */
1782
- NK_INTERNAL void nk_scalar_buffers_export_f64c_( //
1952
+ NK_INTERNAL void nk_scalar_buffers_from_f64c_( //
1783
1953
  nk_scalar_buffer_t const from_buffers[nk_at_least_(8)], //
1784
1954
  void *to_ptr, nk_dtype_t to_dtype, nk_size_t to_count) {
1785
1955
 
1786
- nk_f32_t temporary_f32;
1787
1956
  nk_size_t i;
1788
1957
  switch (to_dtype) {
1789
- case nk_f64_k: {
1790
- nk_f64_t *p = (nk_f64_t *)to_ptr;
1791
- for (i = 0; i < to_count; ++i) p[i] = from_buffers[i].f64c.real;
1792
- } break;
1793
- case nk_f32_k: {
1794
- nk_f32_t *p = (nk_f32_t *)to_ptr;
1795
- for (i = 0; i < to_count; ++i) p[i] = (nk_f32_t)from_buffers[i].f64c.real;
1796
- } break;
1797
- case nk_f16_k: {
1798
- nk_f16_t *p = (nk_f16_t *)to_ptr;
1799
- for (i = 0; i < to_count; ++i)
1800
- temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_f16_serial(&temporary_f32, &p[i]);
1801
- } break;
1802
- case nk_bf16_k: {
1803
- nk_bf16_t *p = (nk_bf16_t *)to_ptr;
1804
- for (i = 0; i < to_count; ++i)
1805
- temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_bf16_serial(&temporary_f32, &p[i]);
1806
- } break;
1807
- case nk_e4m3_k: {
1808
- nk_u8_t *p = (nk_u8_t *)to_ptr;
1809
- for (i = 0; i < to_count; ++i)
1810
- temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_e4m3_serial(&temporary_f32, &p[i]);
1811
- } break;
1812
- case nk_e5m2_k: {
1813
- nk_u8_t *p = (nk_u8_t *)to_ptr;
1814
- for (i = 0; i < to_count; ++i)
1815
- temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_e5m2_serial(&temporary_f32, &p[i]);
1816
- } break;
1817
- case nk_e2m3_k: {
1818
- nk_u8_t *p = (nk_u8_t *)to_ptr;
1819
- for (i = 0; i < to_count; ++i)
1820
- temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_e2m3_serial(&temporary_f32, &p[i]);
1821
- } break;
1822
- case nk_e3m2_k: {
1823
- nk_u8_t *p = (nk_u8_t *)to_ptr;
1824
- for (i = 0; i < to_count; ++i)
1825
- temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_e3m2_serial(&temporary_f32, &p[i]);
1826
- } break;
1827
- case nk_i64_k: {
1828
- nk_i64_t *p = (nk_i64_t *)to_ptr;
1829
- for (i = 0; i < to_count; ++i) nk_f64_to_i64_serial(&from_buffers[i].f64c.real, &p[i]);
1830
- } break;
1831
- case nk_i32_k: {
1832
- nk_i32_t *p = (nk_i32_t *)to_ptr;
1833
- for (i = 0; i < to_count; ++i) nk_f64_to_i32_serial(&from_buffers[i].f64c.real, &p[i]);
1834
- } break;
1835
- case nk_i16_k: {
1836
- nk_i16_t *p = (nk_i16_t *)to_ptr;
1837
- for (i = 0; i < to_count; ++i) nk_f64_to_i16_serial(&from_buffers[i].f64c.real, &p[i]);
1838
- } break;
1839
- case nk_i8_k: {
1840
- nk_i8_t *p = (nk_i8_t *)to_ptr;
1841
- for (i = 0; i < to_count; ++i) nk_f64_to_i8_serial(&from_buffers[i].f64c.real, &p[i]);
1842
- } break;
1843
- case nk_u64_k: {
1844
- nk_u64_t *p = (nk_u64_t *)to_ptr;
1845
- for (i = 0; i < to_count; ++i) nk_f64_to_u64_serial(&from_buffers[i].f64c.real, &p[i]);
1846
- } break;
1847
- case nk_u32_k: {
1848
- nk_u32_t *p = (nk_u32_t *)to_ptr;
1849
- for (i = 0; i < to_count; ++i) nk_f64_to_u32_serial(&from_buffers[i].f64c.real, &p[i]);
1850
- } break;
1851
- case nk_u16_k: {
1852
- nk_u16_t *p = (nk_u16_t *)to_ptr;
1853
- for (i = 0; i < to_count; ++i) nk_f64_to_u16_serial(&from_buffers[i].f64c.real, &p[i]);
1854
- } break;
1855
- case nk_u8_k: {
1856
- nk_u8_t *p = (nk_u8_t *)to_ptr;
1857
- for (i = 0; i < to_count; ++i) nk_f64_to_u8_serial(&from_buffers[i].f64c.real, &p[i]);
1858
- } break;
1859
- case nk_f64c_k: {
1860
- nk_f64c_t *p = (nk_f64c_t *)to_ptr;
1861
- for (i = 0; i < to_count; ++i) p[i] = from_buffers[i].f64c;
1862
- } break;
1863
- case nk_f32c_k: {
1864
- nk_f32c_t *p = (nk_f32c_t *)to_ptr;
1865
- for (i = 0; i < to_count; ++i)
1866
- p[i].real = (nk_f32_t)from_buffers[i].f64c.real, p[i].imag = (nk_f32_t)from_buffers[i].f64c.imag;
1867
- } break;
1868
- case nk_f16c_k: {
1869
- nk_f16c_t *p = (nk_f16c_t *)to_ptr;
1870
- for (i = 0; i < to_count; ++i) {
1871
- temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_f16_serial(&temporary_f32, &p[i].real);
1872
- temporary_f32 = (nk_f32_t)from_buffers[i].f64c.imag, nk_f32_to_f16_serial(&temporary_f32, &p[i].imag);
1873
- }
1874
- } break;
1875
- case nk_bf16c_k: {
1876
- nk_bf16c_t *p = (nk_bf16c_t *)to_ptr;
1877
- for (i = 0; i < to_count; ++i) {
1878
- temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_bf16_serial(&temporary_f32, &p[i].real);
1879
- temporary_f32 = (nk_f32_t)from_buffers[i].f64c.imag, nk_f32_to_bf16_serial(&temporary_f32, &p[i].imag);
1880
- }
1881
- } break;
1882
1958
  // Sub-byte: u1 - 8 bits to 1 byte, MSB-first, non-zero → 1
1883
1959
  case nk_u1_k: {
1884
1960
  nk_u8_t *p = (nk_u8_t *)to_ptr;
@@ -1890,32 +1966,38 @@ NK_INTERNAL void nk_scalar_buffers_export_f64c_( //
1890
1966
  case nk_i4_k: {
1891
1967
  nk_u8_t *p = (nk_u8_t *)to_ptr;
1892
1968
  for (i = 0; i < 4; ++i) {
1893
- nk_i64_t hi = (nk_i64_t)from_buffers[i * 2].f64c.real;
1894
- nk_i64_t lo = (nk_i64_t)from_buffers[i * 2 + 1].f64c.real;
1895
- hi = hi > 7 ? 7 : (hi < -8 ? -8 : hi);
1896
- lo = lo > 7 ? 7 : (lo < -8 ? -8 : lo);
1897
- p[i] = (nk_u8_t)(((hi & 0xF) << 4) | (lo & 0xF));
1969
+ nk_f64_t high = from_buffers[i * 2].f64c.real, low = from_buffers[i * 2 + 1].f64c.real;
1970
+ high = high > 7 ? 7 : (high < -8 ? -8 : high);
1971
+ low = low > 7 ? 7 : (low < -8 ? -8 : low);
1972
+ p[i] = (nk_u8_t)((((nk_i8_t)high & 0x0F) << 4) | ((nk_i8_t)low & 0x0F));
1898
1973
  }
1899
1974
  } break;
1900
1975
  // Sub-byte: u4 - 8 nibbles to 4 bytes, high nibble = even index
1901
1976
  case nk_u4_k: {
1902
1977
  nk_u8_t *p = (nk_u8_t *)to_ptr;
1903
1978
  for (i = 0; i < 4; ++i) {
1904
- nk_u64_t hi = (nk_u64_t)from_buffers[i * 2].f64c.real;
1905
- nk_u64_t lo = (nk_u64_t)from_buffers[i * 2 + 1].f64c.real;
1906
- hi = hi > 15 ? 15 : hi;
1907
- lo = lo > 15 ? 15 : lo;
1908
- p[i] = (nk_u8_t)((hi << 4) | lo);
1979
+ nk_f64_t high = from_buffers[i * 2].f64c.real, low = from_buffers[i * 2 + 1].f64c.real;
1980
+ high = high > 15 ? 15 : (high < 0 ? 0 : high);
1981
+ low = low > 15 ? 15 : (low < 0 ? 0 : low);
1982
+ p[i] = (nk_u8_t)(((nk_u8_t)high << 4) | (nk_u8_t)low);
1983
+ }
1984
+ } break;
1985
+ // All byte-or-larger types: convert, then store relevant bytes
1986
+ default: {
1987
+ nk_size_t stride = nk_dtype_bits(to_dtype) / NK_BITS_PER_BYTE;
1988
+ nk_scalar_buffer_t tmp;
1989
+ for (i = 0; i < to_count; ++i) {
1990
+ nk_scalar_buffer_from_f64c(&from_buffers[i].f64c, &tmp, to_dtype);
1991
+ nk_copy_bytes_((char *)to_ptr + i * stride, &tmp, stride);
1909
1992
  }
1910
1993
  } break;
1911
- default: break;
1912
1994
  }
1913
1995
  }
1914
1996
 
1915
1997
  /**
1916
1998
  * @brief Load 8 values from typed buffer into `buf[i].i64` (lossless widening for signed integers).
1917
1999
  */
1918
- NK_INTERNAL void nk_scalar_buffers_fill_i64_( //
2000
+ NK_INTERNAL void nk_scalar_buffers_to_i64_( //
1919
2001
  void const *from_ptr, nk_dtype_t from_dtype, nk_size_t from_count, //
1920
2002
  nk_scalar_buffer_t to_buffers[nk_at_least_(8)]) { //
1921
2003
  nk_size_t i;
@@ -1938,11 +2020,12 @@ NK_INTERNAL void nk_scalar_buffers_fill_i64_( //
1938
2020
  } break;
1939
2021
  // Sub-byte: i4 - 4 bytes to 8 nibbles, sign-extend each nibble
1940
2022
  case nk_i4_k: {
1941
- nk_u8_t const *p = (nk_u8_t const *)from_ptr;
2023
+ nk_i4x2_t const *pairs = (nk_i4x2_t const *)from_ptr;
1942
2024
  for (i = 0; i < 4; ++i) {
1943
- nk_i8_t hi = (nk_i8_t)(p[i] >> 4), lo = (nk_i8_t)(p[i] & 0xF);
1944
- to_buffers[i * 2].i64 = (hi ^ 8) - 8;
1945
- to_buffers[i * 2 + 1].i64 = (lo ^ 8) - 8;
2025
+ nk_i8_t unpacked[2];
2026
+ nk_i4x2_to_i8x2_serial(&pairs[i], unpacked);
2027
+ to_buffers[i * 2].i64 = unpacked[0];
2028
+ to_buffers[i * 2 + 1].i64 = unpacked[1];
1946
2029
  }
1947
2030
  } break;
1948
2031
  case nk_u64_k: {
@@ -1974,8 +2057,9 @@ NK_INTERNAL void nk_scalar_buffers_fill_i64_( //
1974
2057
 
1975
2058
  /**
1976
2059
  * @brief Export 8 `buf[i].i64` values to typed buffer with saturation on downcast.
2060
+ * @note Only handles integer and sub-byte targets. Float/complex targets are silently skipped.
1977
2061
  */
1978
- NK_INTERNAL void nk_scalar_buffers_export_i64_( //
2062
+ NK_INTERNAL void nk_scalar_buffers_from_i64_( //
1979
2063
  nk_scalar_buffer_t const from_buffers[nk_at_least_(8)], //
1980
2064
  void *to_ptr, nk_dtype_t to_dtype, nk_size_t to_count) { //
1981
2065
  nk_size_t i;
@@ -2015,12 +2099,12 @@ NK_INTERNAL void nk_scalar_buffers_export_i64_( //
2015
2099
  } break;
2016
2100
  // Sub-byte: i4 - 8 nibbles to 4 bytes, clamp [-8,7]
2017
2101
  case nk_i4_k: {
2018
- nk_u8_t *p = (nk_u8_t *)to_ptr;
2102
+ nk_i4x2_t *p = (nk_i4x2_t *)to_ptr;
2019
2103
  for (i = 0; i < 4; ++i) {
2020
- nk_i64_t hi = from_buffers[i * 2].i64, lo = from_buffers[i * 2 + 1].i64;
2021
- hi = hi > 7 ? 7 : (hi < -8 ? -8 : hi);
2022
- lo = lo > 7 ? 7 : (lo < -8 ? -8 : lo);
2023
- p[i] = (nk_u8_t)(((hi & 0xF) << 4) | (lo & 0xF));
2104
+ nk_i64_t high = from_buffers[i * 2].i64, low = from_buffers[i * 2 + 1].i64;
2105
+ high = high > 7 ? 7 : (high < -8 ? -8 : high);
2106
+ low = low > 7 ? 7 : (low < -8 ? -8 : low);
2107
+ p[i] = (nk_u8_t)(((high & 0xF) << 4) | (low & 0xF));
2024
2108
  }
2025
2109
  } break;
2026
2110
  default: break;
@@ -2030,7 +2114,7 @@ NK_INTERNAL void nk_scalar_buffers_export_i64_( //
2030
2114
  /**
2031
2115
  * @brief Load 8 values from typed buffer into `buf[i].u64` (lossless widening for unsigned integers).
2032
2116
  */
2033
- NK_INTERNAL void nk_scalar_buffers_fill_u64_( //
2117
+ NK_INTERNAL void nk_scalar_buffers_to_u64_( //
2034
2118
  void const *from_ptr, nk_dtype_t from_dtype, nk_size_t from_count, //
2035
2119
  nk_scalar_buffer_t to_buffers[nk_at_least_(8)]) { //
2036
2120
  nk_size_t i;
@@ -2053,10 +2137,12 @@ NK_INTERNAL void nk_scalar_buffers_fill_u64_( //
2053
2137
  } break;
2054
2138
  // Sub-byte: u4 - 4 bytes to 8 nibbles, zero-extend
2055
2139
  case nk_u4_k: {
2056
- nk_u8_t const *p = (nk_u8_t const *)from_ptr;
2140
+ nk_u4x2_t const *pairs = (nk_u4x2_t const *)from_ptr;
2057
2141
  for (i = 0; i < 4; ++i) {
2058
- to_buffers[i * 2].u64 = p[i] >> 4;
2059
- to_buffers[i * 2 + 1].u64 = p[i] & 0xF;
2142
+ nk_u8_t unpacked[2];
2143
+ nk_u4x2_to_u8x2_serial(&pairs[i], unpacked);
2144
+ to_buffers[i * 2].u64 = unpacked[0];
2145
+ to_buffers[i * 2 + 1].u64 = unpacked[1];
2060
2146
  }
2061
2147
  } break;
2062
2148
  // Sub-byte: u1 - 1 byte to 8 bits, MSB-first
@@ -2070,8 +2156,9 @@ NK_INTERNAL void nk_scalar_buffers_fill_u64_( //
2070
2156
 
2071
2157
  /**
2072
2158
  * @brief Export 8 `buf[i].u64` values to typed buffer with saturation on downcast.
2159
+ * @note Only handles integer and sub-byte targets. Float/complex targets are silently skipped.
2073
2160
  */
2074
- NK_INTERNAL void nk_scalar_buffers_export_u64_( //
2161
+ NK_INTERNAL void nk_scalar_buffers_from_u64_( //
2075
2162
  nk_scalar_buffer_t const from_buffers[nk_at_least_(8)], //
2076
2163
  void *to_ptr, nk_dtype_t to_dtype, nk_size_t to_count) { //
2077
2164
  nk_size_t i;
@@ -2111,12 +2198,12 @@ NK_INTERNAL void nk_scalar_buffers_export_u64_( //
2111
2198
  } break;
2112
2199
  // Sub-byte: u4 - 8 nibbles to 4 bytes, clamp [0,15]
2113
2200
  case nk_u4_k: {
2114
- nk_u8_t *p = (nk_u8_t *)to_ptr;
2201
+ nk_u4x2_t *p = (nk_u4x2_t *)to_ptr;
2115
2202
  for (i = 0; i < 4; ++i) {
2116
- nk_u64_t hi = from_buffers[i * 2].u64, lo = from_buffers[i * 2 + 1].u64;
2117
- hi = hi > 15 ? 15 : hi;
2118
- lo = lo > 15 ? 15 : lo;
2119
- p[i] = (nk_u8_t)((hi << 4) | lo);
2203
+ nk_u64_t high = from_buffers[i * 2].u64, low = from_buffers[i * 2 + 1].u64;
2204
+ high = high > 15 ? 15 : high;
2205
+ low = low > 15 ? 15 : low;
2206
+ p[i] = (nk_u8_t)((high << 4) | low);
2120
2207
  }
2121
2208
  } break;
2122
2209
  // Sub-byte: u1 - 8 bits to 1 byte, MSB-first, non-zero becomes 1
@@ -2130,9 +2217,24 @@ NK_INTERNAL void nk_scalar_buffers_export_u64_( //
2130
2217
  }
2131
2218
  }
2132
2219
 
2133
- #pragma endregion - Type Punned Loads and Stores
2220
+ /** @brief Widens a typed scalar from @p buf into @p result as f64 (discards imaginary part).
2221
+ * Safe when @p result aliases @p buf (in-place conversion). */
2222
+ NK_INTERNAL int nk_scalar_buffer_to_f64(nk_scalar_buffer_t const *buf, nk_dtype_t dtype, nk_f64_t *result) {
2223
+ nk_f64c_t temporary_f64c;
2224
+ int ok = nk_scalar_buffer_to_f64c(buf, dtype, &temporary_f64c);
2225
+ *result = temporary_f64c.real;
2226
+ return ok;
2227
+ }
2228
+
2229
+ /** @brief Narrows an f64 @p value into the appropriate typed member of @p buf.
2230
+ * Safe when @p value aliases @p buf (in-place: `buf->f64 = x; from_f64(&buf->f64, buf, dtype)`).
2231
+ * @note Integer targets go through f64 rounding — values beyond 2^53 may lose precision. */
2232
+ NK_INTERNAL int nk_scalar_buffer_from_f64(nk_f64_t const *value, nk_scalar_buffer_t *buf, nk_dtype_t dtype) {
2233
+ nk_f64c_t temporary_f64c = {*value, 0};
2234
+ return nk_scalar_buffer_from_f64c(&temporary_f64c, buf, dtype);
2235
+ }
2134
2236
 
2135
- #pragma region - Public API
2237
+ #pragma region Public API
2136
2238
 
2137
2239
  NK_PUBLIC void nk_cast_serial(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type) {
2138
2240
  if (from_type == to_type) {
@@ -2162,12 +2264,12 @@ NK_PUBLIC void nk_cast_serial(void const *from, nk_dtype_t from_type, nk_size_t
2162
2264
  // Both unsigned: u64 hub
2163
2265
  if (from_family == nk_dtype_family_uint_k && to_family == nk_dtype_family_uint_k) {
2164
2266
  for (nk_size_t b = 0; b < batches; ++b, src += from_step, dst += to_step) {
2165
- nk_scalar_buffers_fill_u64_(src, from_type, NK_BITS_PER_BYTE, bufs);
2166
- nk_scalar_buffers_export_u64_(bufs, dst, to_type, NK_BITS_PER_BYTE);
2267
+ nk_scalar_buffers_to_u64_(src, from_type, NK_BITS_PER_BYTE, bufs);
2268
+ nk_scalar_buffers_from_u64_(bufs, dst, to_type, NK_BITS_PER_BYTE);
2167
2269
  }
2168
2270
  if (tail) {
2169
- nk_scalar_buffers_fill_u64_(src, from_type, tail, bufs);
2170
- nk_scalar_buffers_export_u64_(bufs, dst, to_type, tail);
2271
+ nk_scalar_buffers_to_u64_(src, from_type, tail, bufs);
2272
+ nk_scalar_buffers_from_u64_(bufs, dst, to_type, tail);
2171
2273
  }
2172
2274
  return;
2173
2275
  }
@@ -2176,24 +2278,24 @@ NK_PUBLIC void nk_cast_serial(void const *from, nk_dtype_t from_type, nk_size_t
2176
2278
  if ((from_family == nk_dtype_family_int_k || from_family == nk_dtype_family_uint_k) &&
2177
2279
  (to_family == nk_dtype_family_int_k || to_family == nk_dtype_family_uint_k)) {
2178
2280
  for (nk_size_t b = 0; b < batches; ++b, src += from_step, dst += to_step) {
2179
- nk_scalar_buffers_fill_i64_(src, from_type, NK_BITS_PER_BYTE, bufs);
2180
- nk_scalar_buffers_export_i64_(bufs, dst, to_type, NK_BITS_PER_BYTE);
2281
+ nk_scalar_buffers_to_i64_(src, from_type, NK_BITS_PER_BYTE, bufs);
2282
+ nk_scalar_buffers_from_i64_(bufs, dst, to_type, NK_BITS_PER_BYTE);
2181
2283
  }
2182
2284
  if (tail) {
2183
- nk_scalar_buffers_fill_i64_(src, from_type, tail, bufs);
2184
- nk_scalar_buffers_export_i64_(bufs, dst, to_type, tail);
2285
+ nk_scalar_buffers_to_i64_(src, from_type, tail, bufs);
2286
+ nk_scalar_buffers_from_i64_(bufs, dst, to_type, tail);
2185
2287
  }
2186
2288
  return;
2187
2289
  }
2188
2290
 
2189
2291
  // Everything else: f64c hub (floats, complex, cross-category)
2190
2292
  for (nk_size_t b = 0; b < batches; ++b, src += from_step, dst += to_step) {
2191
- nk_scalar_buffers_fill_f64c_(src, from_type, NK_BITS_PER_BYTE, bufs);
2192
- nk_scalar_buffers_export_f64c_(bufs, dst, to_type, NK_BITS_PER_BYTE);
2293
+ nk_scalar_buffers_to_f64c_(src, from_type, NK_BITS_PER_BYTE, bufs);
2294
+ nk_scalar_buffers_from_f64c_(bufs, dst, to_type, NK_BITS_PER_BYTE);
2193
2295
  }
2194
2296
  if (tail) {
2195
- nk_scalar_buffers_fill_f64c_(src, from_type, tail, bufs);
2196
- nk_scalar_buffers_export_f64c_(bufs, dst, to_type, tail);
2297
+ nk_scalar_buffers_to_f64c_(src, from_type, tail, bufs);
2298
+ nk_scalar_buffers_from_f64c_(bufs, dst, to_type, tail);
2197
2299
  }
2198
2300
  }
2199
2301
 
@@ -2225,35 +2327,7 @@ NK_PUBLIC void nk_e3m2_to_bf16(nk_e3m2_t const *src, nk_bf16_t *dest) {
2225
2327
  nk_f32_to_bf16_serial(&temp, dest);
2226
2328
  }
2227
2329
 
2228
- /**
2229
- * @brief Convert i4 (4-bit signed integer, -8 to 7) to i8.
2230
- *
2231
- * Nibbles are packed: low nibble in bits [0:3], high nibble in bits [4:7].
2232
- * Sign extension: XOR with 8 then subtract 8 converts unsigned nibble to signed.
2233
- */
2234
- NK_PUBLIC void nk_i4_to_i8_serial_(nk_i4x2_t const *src, nk_i8_t *dest, nk_size_t count) {
2235
- nk_u8_t const *bytes = (nk_u8_t const *)src;
2236
- for (nk_size_t i = 0; i < count; ++i) {
2237
- nk_u8_t byte = bytes[i / 2];
2238
- nk_u8_t nibble = (i % 2 == 0) ? (byte & 0x0F) : (byte >> 4);
2239
- dest[i] = (nk_i8_t)((nibble ^ 8) - 8); // Sign extend: 0-7 → 0-7, 8-15 → -8 to -1
2240
- }
2241
- }
2242
-
2243
- /**
2244
- * @brief Convert u4 (4-bit unsigned integer, 0 to 15) to u8.
2245
- *
2246
- * Nibbles are packed: low nibble in bits [0:3], high nibble in bits [4:7].
2247
- */
2248
- NK_PUBLIC void nk_u4_to_u8_serial_(nk_u4x2_t const *src, nk_u8_t *dest, nk_size_t count) {
2249
- nk_u8_t const *bytes = (nk_u8_t const *)src;
2250
- for (nk_size_t i = 0; i < count; ++i) {
2251
- nk_u8_t byte = bytes[i / 2];
2252
- dest[i] = (i % 2 == 0) ? (byte & 0x0F) : (byte >> 4);
2253
- }
2254
- }
2255
-
2256
- #pragma endregion - Public API
2330
+ #pragma endregion Public API
2257
2331
 
2258
2332
  #if defined(__cplusplus)
2259
2333
  } // extern "C"