numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,2262 @@
1
+ /**
2
+ * @brief SWAR-accelerated Type Conversions for SIMD-free CPUs.
3
+ * @file include/numkong/cast/serial.h
4
+ * @author Ash Vardanian
5
+ * @date January 2, 2026
6
+ */
7
+ #ifndef NK_CAST_SERIAL_H
8
+ #define NK_CAST_SERIAL_H
9
+
10
+ #include "numkong/types.h"
11
+
12
+ #if defined(__cplusplus)
13
+ extern "C" {
14
+ #endif
15
+
16
+ #pragma region - Type Punned Loads and Stores
17
+
18
+ /** @brief Type-agnostic 32-bit full load (scalar). */
19
+ NK_INTERNAL void nk_load_b32_serial_(void const *src, nk_b32_vec_t *dst) { dst->u32 = *(nk_u32_t const *)src; }
20
+
21
+ /** @brief Type-agnostic 32-bit full store (scalar). */
22
+ NK_INTERNAL void nk_store_b32_serial_(nk_b32_vec_t const *src, void *dst) { *(nk_u32_t *)dst = src->u32; }
23
+
24
+ /** @brief Type-agnostic 128-bit store (serial, word-by-word). */
25
+ NK_INTERNAL void nk_store_b128_serial_(nk_b128_vec_t const *src, void *dst) {
26
+ nk_u64_t *d = (nk_u64_t *)dst;
27
+ d[0] = src->u64s[0];
28
+ d[1] = src->u64s[1];
29
+ }
30
+
31
+ /** @brief Type-agnostic 256-bit store (serial, word-by-word). */
32
+ NK_INTERNAL void nk_store_b256_serial_(nk_b256_vec_t const *src, void *dst) {
33
+ nk_u64_t *d = (nk_u64_t *)dst;
34
+ d[0] = src->u64s[0];
35
+ d[1] = src->u64s[1];
36
+ d[2] = src->u64s[2];
37
+ d[3] = src->u64s[3];
38
+ }
39
+
40
+ #pragma endregion - Type Punned Loads and Stores
41
+
42
+ /**
43
+ * @brief Expands an `f16` (IEEE-754 16-bit) to a `float`.
44
+ *
45
+ * Handles all IEEE-754 edge cases:
46
+ *
47
+ * Input F16 Hex F32 Hex Description
48
+ * +0 0x0000 0x00000000 Positive zero
49
+ * -0 0x8000 0x80000000 Negative zero
50
+ * +inf 0x7C00 0x7F800000 Positive infinity
51
+ * -inf 0xFC00 0xFF800000 Negative infinity
52
+ * NaN 0x7E00 0x7FC00000 Quiet NaN (payload preserved)
53
+ * Min normal 0x0400 0x38800000 2⁻¹⁴
54
+ * Max normal 0x7BFF 0x477FE000 65504
55
+ * Min denorm 0x0001 0x33800000 2⁻²⁴
56
+ * Max denorm 0x03FF 0x387FC000 2⁻¹⁴ - 2⁻²⁴
57
+ *
58
+ * https://stackoverflow.com/a/60047308
59
+ * https://gist.github.com/milhidaka/95863906fe828198f47991c813dbe233
60
+ * https://github.com/OpenCyphal/libcanard/blob/636795f4bc395f56af8d2c61d3757b5e762bb9e5/canard.c#L811-L834
61
+ */
62
+ NK_PUBLIC void nk_f16_to_f32_serial(nk_f16_t const *src, nk_f32_t *dest) {
63
+ #if NK_NATIVE_F16
64
+ *dest = (nk_f32_t)(*src);
65
+ #else
66
+ unsigned short x;
67
+ nk_copy_bytes_(&x, src, 2);
68
+
69
+ unsigned int sign = (x >> 15) & 1;
70
+ unsigned int exponent = (x >> 10) & 0x1F;
71
+ unsigned int mantissa = x & 0x03FF;
72
+
73
+ nk_fui32_t conv;
74
+
75
+ if (exponent == 0) {
76
+ if (mantissa == 0) {
77
+ // Zero (preserve sign)
78
+ conv.u = sign << 31;
79
+ }
80
+ else {
81
+ // Denormal: value = mantissa × 2⁻²⁴
82
+ // Use FPU normalization, then subtract 24 from exponent
83
+ nk_fui32_t temp;
84
+ temp.f = (float)mantissa;
85
+ conv.u = (sign << 31) | (temp.u - 0x0C000000);
86
+ }
87
+ }
88
+ else if (exponent == 31) {
89
+ // Infinity (mantissa=0) or NaN (mantissa!=0)
90
+ conv.u = (sign << 31) | 0x7F800000 | (mantissa << 13);
91
+ }
92
+ else {
93
+ // Normal: rebias exponent (127-15=112), shift mantissa
94
+ conv.u = (sign << 31) | ((exponent + 112) << 23) | (mantissa << 13);
95
+ }
96
+
97
+ *dest = conv.f;
98
+ #endif
99
+ }
100
+
101
+ /**
102
+ * @brief Compresses a `float` to an `f16` (IEEE-754 16-bit).
103
+ *
104
+ * Handles all IEEE-754 edge cases with round-to-nearest:
105
+ *
106
+ * Input F32 Hex F16 Hex Description
107
+ * +0 0x00000000 0x0000 Positive zero
108
+ * -0 0x80000000 0x8000 Negative zero
109
+ * +inf 0x7F800000 0x7C00 Positive infinity
110
+ * -inf 0xFF800000 0xFC00 Negative infinity
111
+ * NaN 0x7FC00000 0x7E00 Quiet NaN (payload truncated)
112
+ * 1.0 0x3F800000 0x3C00 Normal number
113
+ * 65504 0x477FE000 0x7BFF Max f16 normal
114
+ * 65520+ >0x477FE000 0x7C00 Overflow → infinity
115
+ * 2⁻¹⁴ 0x38800000 0x0400 Min f16 normal
116
+ * 2⁻²⁴ 0x33800000 0x0001 Min f16 denormal
117
+ * <2⁻²⁵ <0x33000000 0x0000 Underflow → zero
118
+ *
119
+ * https://stackoverflow.com/a/60047308
120
+ * https://gist.github.com/milhidaka/95863906fe828198f47991c813dbe233
121
+ * https://github.com/OpenCyphal/libcanard/blob/636795f4bc395f56af8d2c61d3757b5e762bb9e5/canard.c#L811-L834
122
+ */
123
+ NK_PUBLIC void nk_f32_to_f16_serial(nk_f32_t const *src, nk_f16_t *dest) {
124
+ #if NK_NATIVE_F16
125
+ *dest = (nk_f16_t)(*src);
126
+ #else
127
+ nk_fui32_t conv;
128
+ conv.f = *src;
129
+
130
+ unsigned int sign = (conv.u >> 31) & 1;
131
+ unsigned int exponent = (conv.u >> 23) & 0xFF;
132
+ unsigned int mantissa = conv.u & 0x007FFFFF;
133
+
134
+ unsigned short result;
135
+
136
+ if (exponent == 0) {
137
+ // Zero or f32 denormal → f16 zero
138
+ result = (unsigned short)(sign << 15);
139
+ }
140
+ else if (exponent == 255) {
141
+ // Infinity or NaN
142
+ unsigned short payload = (unsigned short)(mantissa >> 13);
143
+ if (mantissa != 0 && payload == 0) payload = 1; // Preserve NaN-ness
144
+ result = (unsigned short)((sign << 15) | 0x7C00 | payload);
145
+ }
146
+ else if (exponent <= 102) {
147
+ // Below or at f16 denormal threshold
148
+ // exp=102 with mant=0 is exactly 2^-25 (tie point, rounds to 0 per round-to-even)
149
+ // exp=102 with mant>0 is above tie point (rounds to smallest denormal 0x0001)
150
+ if (exponent == 102 && mantissa > 0) result = (unsigned short)((sign << 15) | 0x0001);
151
+ else result = (unsigned short)(sign << 15);
152
+ }
153
+ else if (exponent < 113) {
154
+ // F16 denormal range (exp 103-112) with IEEE 754 round-to-nearest-even
155
+ unsigned int shift = 113 - exponent;
156
+ unsigned int shift_amount = shift + 13;
157
+ unsigned long long full_mant = 0x00800000ULL | mantissa;
158
+
159
+ // Extract result before rounding
160
+ unsigned int mant = (unsigned int)(full_mant >> shift_amount);
161
+
162
+ // IEEE 754 round-to-nearest-even: round up if round_bit is set AND
163
+ // (sticky_bits are nonzero OR result is odd)
164
+ unsigned int round_bit = (full_mant >> (shift_amount - 1)) & 1;
165
+ unsigned long long sticky_bits = full_mant & ((1ULL << (shift_amount - 1)) - 1);
166
+
167
+ if (round_bit && (sticky_bits || (mant & 1))) mant++;
168
+
169
+ result = (unsigned short)((sign << 15) | mant);
170
+ }
171
+ else if (exponent < 143) {
172
+ // Normal f16 range with IEEE 754 round-to-nearest-even
173
+ unsigned int f16_exp = exponent - 112;
174
+ unsigned int f16_mant = mantissa >> 13;
175
+
176
+ // IEEE 754 rounding: check round bit (bit 12) and sticky bits (bits 0-11)
177
+ unsigned int round_bit = (mantissa >> 12) & 1;
178
+ unsigned int sticky_bits = mantissa & 0xFFF;
179
+
180
+ if (round_bit && (sticky_bits || (f16_mant & 1))) {
181
+ f16_mant++;
182
+ if (f16_mant > 0x3FF) f16_mant = 0, f16_exp++;
183
+ }
184
+
185
+ if (f16_exp > 30) result = (unsigned short)((sign << 15) | 0x7C00);
186
+ else result = (unsigned short)((sign << 15) | (f16_exp << 10) | f16_mant);
187
+ }
188
+ else {
189
+ // Overflow → infinity
190
+ result = (unsigned short)((sign << 15) | 0x7C00);
191
+ }
192
+
193
+ nk_copy_bytes_(dest, &result, 2);
194
+ #endif
195
+ }
196
+
197
+ /**
198
+ * @brief For compilers that don't natively support the `__bf16` type,
199
+ * upcasts contents into a more conventional `float`.
200
+ *
201
+ * https://stackoverflow.com/questions/55253233/convert-fp32-to-bfloat16-in-c/55254307#55254307
202
+ * https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus
203
+ */
204
+ NK_PUBLIC void nk_bf16_to_f32_serial(nk_bf16_t const *src, nk_f32_t *dest) {
205
+ #if NK_NATIVE_BF16
206
+ *dest = (nk_f32_t)(*src);
207
+ #else
208
+ unsigned short x;
209
+ nk_copy_bytes_(&x, src, 2);
210
+ nk_fui32_t conv;
211
+ conv.u = x << 16; // Zero extends the mantissa
212
+ *dest = conv.f;
213
+ #endif
214
+ }
215
+
216
+ /**
217
+ * @brief Compresses a `float` to a `bf16` representation.
218
+ *
219
+ * https://stackoverflow.com/questions/55253233/convert-fp32-to-bfloat16-in-c/55254307#55254307
220
+ * https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus
221
+ */
222
+ NK_PUBLIC void nk_f32_to_bf16_serial(nk_f32_t const *src, nk_bf16_t *dest) {
223
+ #if NK_NATIVE_BF16
224
+ *dest = (nk_bf16_t)(*src);
225
+ #else
226
+ nk_fui32_t conv;
227
+ conv.f = *src;
228
+ // IEEE 754 round-to-nearest-even: add (0x7FFF + LSB)
229
+ unsigned int lsb = (conv.u >> 16) & 1;
230
+ conv.u += 0x7FFF + lsb;
231
+ conv.u >>= 16;
232
+ // Use an intermediate variable to ensure correct behavior on big-endian systems.
233
+ // Copying directly from `&conv.u` would copy the wrong bytes on big-endian,
234
+ // since the lower 16 bits are at offset 2, not offset 0.
235
+ unsigned short result = (unsigned short)conv.u;
236
+ nk_copy_bytes_(dest, &result, 2);
237
+ #endif
238
+ }
239
+
240
+ /**
241
+ * @brief Convert FP8 E4M3 to IEEE 754 single-precision float.
242
+ *
243
+ * E4M3 (FP8) format: 1 sign bit, 4 exponent bits (bias=7), 3 mantissa bits.
244
+ * Range: [-448, +448], no ∞, only two NaN encodings (0x7F, 0xFF).
245
+ * Subnormal values: (-1)ˢ × mantissa × 2⁻⁹ = mantissa / 512.
246
+ *
247
+ * Special value mappings (E4M3 → F32):
248
+ * Input E4M3 Hex F32 Hex Description
249
+ * +0 0x00 0x00000000 Positive zero
250
+ * -0 0x80 0x80000000 Negative zero
251
+ * +NaN 0x7F 0x7FC00000 Quiet NaN (exp=15, mant!=0)
252
+ * -NaN 0xFF 0xFFC00000 Quiet NaN (signed)
253
+ * +448 (max) 0x7E 0x43E00000 Max normal = 448
254
+ * -448 0xFE 0xC3E00000 Min normal = -448
255
+ * 1.0 0x38 0x3F800000 Normal (exp=7, mant=0)
256
+ * Min denorm 0x01 0x3B000000 1/512 = 2⁻⁹
257
+ * Max denorm 0x07 0x3BE00000 7/512 = 7 × 2⁻⁹
258
+ *
259
+ * References:
260
+ * https://arxiv.org/pdf/2209.05433 (NVIDIA/Intel/Arm FP8 paper)
261
+ * https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1
262
+ * https://onnx.ai/onnx/technical/float8.html
263
+ */
264
+ NK_PUBLIC void nk_e4m3_to_f32_serial(nk_e4m3_t const *src, nk_f32_t *dest) {
265
+ nk_u8_t raw = *src;
266
+ nk_u32_t sign = (nk_u32_t)(raw & 0x80) << 24;
267
+ nk_u32_t exponent = (raw >> 3) & 0x0Fu;
268
+ nk_u32_t mantissa = raw & 0x07u;
269
+ nk_fui32_t conv;
270
+
271
+ if (exponent == 0) {
272
+ if (mantissa == 0) {
273
+ conv.u = sign;
274
+ *dest = conv.f;
275
+ return;
276
+ }
277
+ nk_f32_t value = (nk_f32_t)mantissa * (1.0f / 512.0f);
278
+ *dest = sign ? -value : value;
279
+ return;
280
+ }
281
+ // E4M3FN has no ∞. Only exp=15 && mant=7 is NaN.
282
+ // exp=15 && mant=0..6 are normal values (256, 288, 320, 352, 384, 416, 448).
283
+ if (exponent == 0x0Fu && mantissa == 7) {
284
+ conv.u = sign | 0x7FC00000u; // F32 quiet NaN
285
+ *dest = conv.f;
286
+ return;
287
+ }
288
+
289
+ nk_u32_t f32_exponent = (exponent + 120u) << 23;
290
+ nk_u32_t f32_mantissa = mantissa << 20;
291
+ conv.u = sign | f32_exponent | f32_mantissa;
292
+ *dest = conv.f;
293
+ }
294
+
295
+ /**
296
+ * @brief Convert IEEE 754 single-precision float to FP8 E4M3.
297
+ *
298
+ * E4M3 (FP8) format: 1 sign bit, 4 exponent bits (bias=7), 3 mantissa bits.
299
+ * Range: [-448, +448], no ∞, only two NaN encodings.
300
+ * Rounding: RNE (Round to Nearest Even) per IEEE 754 / OCP FP8 spec.
301
+ * Subnormal threshold: values with |x| < 2⁻⁶ use subnormal encoding.
302
+ *
303
+ * Special value mappings (F32 → E4M3):
304
+ * Input F32 Hex E4M3 Hex Description
305
+ * +0 0x00000000 0x00 Positive zero
306
+ * -0 0x80000000 0x80 Negative zero
307
+ * +inf 0x7F800000 0x7E Saturates to max (+448)
308
+ * -inf 0xFF800000 0xFE Saturates to min (-448)
309
+ * NaN 0x7FC00000 0x7F Quiet NaN
310
+ * 1.0 0x3F800000 0x38 Normal (exp=7, mant=0)
311
+ * 448+ >0x43E00000 0x7E Overflow → max
312
+ * 2⁻⁶ 0x3E800000 0x08 Min normal
313
+ * <2⁻¹² × ⁵ <0x39800000 0x00 Underflow → zero (RNE boundary)
314
+ *
315
+ * References:
316
+ * https://arxiv.org/pdf/2209.05433 (NVIDIA/Intel/Arm FP8 paper)
317
+ * https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1
318
+ * https://onnx.ai/onnx/technical/float8.html
319
+ */
320
+ NK_PUBLIC void nk_f32_to_e4m3_serial(nk_f32_t const *src, nk_e4m3_t *dest) {
321
+ nk_f32_t x = *src;
322
+ nk_fui32_t conv;
323
+ conv.f = x;
324
+ nk_u32_t sign_bit = conv.u >> 31;
325
+ nk_u32_t abs_bits = conv.u & 0x7FFFFFFFu;
326
+ nk_u8_t sign = (nk_u8_t)(sign_bit << 7);
327
+
328
+ // NaN → E4M3FN NaN (0x7F or 0xFF)
329
+ if (abs_bits > 0x7F800000u) {
330
+ *dest = (nk_e4m3_t)(sign | 0x7Fu);
331
+ return;
332
+ }
333
+ // Infinity → saturate to max (0x7E or 0xFE), E4M3FN has no ∞
334
+ if (abs_bits == 0x7F800000u) {
335
+ *dest = (nk_e4m3_t)(sign | 0x7Eu);
336
+ return;
337
+ }
338
+
339
+ if (abs_bits == 0) {
340
+ *dest = (nk_e4m3_t)sign;
341
+ return;
342
+ }
343
+
344
+ nk_f32_t abs_x = sign_bit ? -x : x;
345
+
346
+ // Subnormal range: [0, 1/64). Use RNE rounding via scaled * 512.
347
+ // The RNE boundary between 0 and 1/512 is at 0.5/512, not 1/512.
348
+ if (abs_x < (1.0f / 64.0f)) {
349
+ nk_f32_t scaled = abs_x * 512.0f;
350
+ nk_i32_t mant = (nk_i32_t)scaled;
351
+ nk_f32_t frac = scaled - (nk_f32_t)mant;
352
+ if (frac > 0.5f || (frac == 0.5f && (mant & 1))) { ++mant; }
353
+ // If rounds to 8, promote to first normal (exp_field=1, mantissa=0)
354
+ if (mant > 7) {
355
+ *dest = (nk_e4m3_t)(sign | 0x08u);
356
+ return;
357
+ }
358
+ if (mant == 0) { *dest = (nk_e4m3_t)sign; }
359
+ else { *dest = (nk_e4m3_t)(sign | (nk_u8_t)mant); }
360
+ return;
361
+ }
362
+
363
+ nk_i32_t exp = (nk_i32_t)((abs_bits >> 23) & 0xFFu) - 127;
364
+ nk_u32_t mantissa = abs_bits & 0x7FFFFFu;
365
+ nk_u32_t significand = (1u << 23) | mantissa;
366
+ nk_i32_t shift = 23 - 3;
367
+ nk_u32_t remainder_mask = (1u << shift) - 1;
368
+ nk_u32_t remainder = significand & remainder_mask;
369
+ nk_u32_t halfway = 1u << (shift - 1);
370
+ nk_u32_t significand_rounded = significand >> shift;
371
+ if (remainder > halfway || (remainder == halfway && (significand_rounded & 1))) { ++significand_rounded; }
372
+ if (significand_rounded == (1u << (3 + 1))) {
373
+ significand_rounded >>= 1;
374
+ ++exp;
375
+ }
376
+ if (exp > 8) {
377
+ // Saturate to max value 448 = 0x7E (exp=15, mantissa=6). Note: 0x7F is NaN in e4m3FN.
378
+ *dest = (nk_e4m3_t)(sign | 0x7Eu);
379
+ return;
380
+ }
381
+ if (exp < -6) {
382
+ nk_f32_t scaled = abs_x * 512.0f;
383
+ nk_i32_t mant = (nk_i32_t)scaled;
384
+ nk_f32_t frac = scaled - (nk_f32_t)mant;
385
+ if (frac > 0.5f || (frac == 0.5f && (mant & 1))) { ++mant; }
386
+ // If rounds to 8, promote to first normal (exp_field=1, mantissa=0)
387
+ if (mant > 7) {
388
+ *dest = (nk_e4m3_t)(sign | 0x08u);
389
+ return;
390
+ }
391
+ if (mant == 0) { *dest = (nk_e4m3_t)sign; }
392
+ else { *dest = (nk_e4m3_t)(sign | (nk_u8_t)mant); }
393
+ return;
394
+ }
395
+
396
+ nk_u8_t exp_field = (nk_u8_t)(exp + 7);
397
+ nk_u8_t mant_field = (nk_u8_t)(significand_rounded & 0x07u);
398
+ // For exp_field=15, clamp mantissa to 6 to avoid NaN encoding (0x7F in e4m3FN)
399
+ if (exp_field == 15 && mant_field > 6) { mant_field = 6; }
400
+ *dest = (nk_e4m3_t)(sign | (exp_field << 3) | mant_field);
401
+ }
402
+
403
+ /**
404
+ * @brief Convert FP8 E4M3 to IEEE 754 half-precision float.
405
+ *
406
+ * E4M3 format: 1 sign bit, 4 exponent bits (bias=7), 3 mantissa bits.
407
+ * F16 format: 1 sign bit, 5 exponent bits (bias=15), 10 mantissa bits.
408
+ *
409
+ * Conversion notes:
410
+ * - Normal values: F16_exp = E4M3_exp + 8, mantissa shifted left by 7 bits
411
+ * - Subnormals: mant × 2⁻⁹ (where 2⁻⁹ = 0x1800 in F16)
412
+ * - NaN (0x7F): maps to F16 quiet NaN (0x7E00)
413
+ */
414
+ NK_INTERNAL void nk_e4m3_to_f16_serial(nk_e4m3_t const *src, nk_f16_t *dest) {
415
+ nk_u8_t raw = *src;
416
+ nk_u16_t sign = ((nk_u16_t)(raw & 0x80)) << 8;
417
+ nk_u16_t mag = raw & 0x7F;
418
+ nk_u16_t mant = raw & 0x07;
419
+ nk_u16_t exp = (raw >> 3) & 0x0F;
420
+ nk_fui16_t result;
421
+
422
+ if (mag == 0x7F) {
423
+ result.u = sign | 0x7E00; // NaN
424
+ }
425
+ else if (exp == 0) {
426
+ // Subnormal: mant × 2⁻⁹, where 2⁻⁹ = 0x1800 in F16
427
+ nk_fui16_t scale;
428
+ scale.u = 0x1800;
429
+ nk_fui16_t mant_f16;
430
+ mant_f16.f = (nk_f16_t)mant;
431
+ result.f = mant_f16.f * scale.f;
432
+ result.u |= sign;
433
+ }
434
+ else {
435
+ // Normal: F16 = sign | ((mag << 7) + 0x2000)
436
+ result.u = sign | ((mag << 7) + 0x2000);
437
+ }
438
+ *dest = result.f;
439
+ }
440
+
441
+ /**
442
+ * @brief Convert FP8 E5M2 to IEEE 754 single-precision float.
443
+ *
444
+ * E5M2 (FP8) format: 1 sign bit, 5 exponent bits (bias=15), 2 mantissa bits.
445
+ * Range: [-57344, +57344], supports infinity and NaN (IEEE 754 compatible).
446
+ * Subnormal values: (-1)ˢ × mantissa × 2⁻¹⁶ = mantissa / 65536.
447
+ *
448
+ * Special value mappings (E5M2 → F32):
449
+ * Input E5M2 Hex F32 Hex Description
450
+ * +0 0x00 0x00000000 Positive zero
451
+ * -0 0x80 0x80000000 Negative zero
452
+ * +inf 0x7C 0x7F800000 Positive infinity
453
+ * -inf 0xFC 0xFF800000 Negative infinity
454
+ * +NaN 0x7D-7F 0x7FC00000 Quiet NaN (exp=31, mant!=0)
455
+ * -NaN 0xFD-FF 0xFFC00000 Quiet NaN (signed)
456
+ * +57344 (max) 0x7B 0x47600000 Max normal
457
+ * 1.0 0x3C 0x3F800000 Normal (exp=15, mant=0)
458
+ * Min denorm 0x01 0x37800000 1/65536 = 2⁻¹⁶
459
+ * Max denorm 0x03 0x38000000 3/65536 = 3 × 2⁻¹⁶
460
+ *
461
+ * References:
462
+ * https://arxiv.org/pdf/2209.05433 (NVIDIA/Intel/Arm FP8 paper)
463
+ * https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1
464
+ * https://onnx.ai/onnx/technical/float8.html
465
+ */
466
+ NK_INTERNAL void nk_e5m2_to_f32_manual_(nk_e5m2_t const *src, nk_f32_t *dest) {
467
+ nk_u8_t raw = *src;
468
+ nk_u32_t sign = (nk_u32_t)(raw & 0x80) << 24;
469
+ nk_u32_t exponent = (raw >> 2) & 0x1Fu;
470
+ nk_u32_t mantissa = raw & 0x03u;
471
+ nk_fui32_t conv;
472
+
473
+ if (exponent == 0) {
474
+ if (mantissa == 0) {
475
+ conv.u = sign;
476
+ *dest = conv.f;
477
+ return;
478
+ }
479
+ nk_f32_t value = (nk_f32_t)mantissa * (1.0f / 65536.0f);
480
+ *dest = sign ? -value : value;
481
+ return;
482
+ }
483
+ if (exponent == 0x1Fu) {
484
+ if (mantissa == 0) { conv.u = sign | 0x7F800000u; }
485
+ else { conv.u = sign | 0x7FC00000u; }
486
+ *dest = conv.f;
487
+ return;
488
+ }
489
+
490
+ nk_u32_t f32_exponent = (exponent + 112u) << 23;
491
+ nk_u32_t f32_mantissa = mantissa << 21;
492
+ conv.u = sign | f32_exponent | f32_mantissa;
493
+ *dest = conv.f;
494
+ }
495
+
496
+ NK_PUBLIC void nk_e5m2_to_f32_serial(nk_e5m2_t const *src, nk_f32_t *dest) {
497
+ static nk_u32_t const lut[128] = {
498
+ 0x00000000, 0x37800000, 0x38000000, 0x38400000, // exp=0 sub
499
+ 0x38800000, 0x38A00000, 0x38C00000, 0x38E00000, // exp=1
500
+ 0x39000000, 0x39200000, 0x39400000, 0x39600000, // exp=2
501
+ 0x39800000, 0x39A00000, 0x39C00000, 0x39E00000, // exp=3
502
+ 0x3A000000, 0x3A200000, 0x3A400000, 0x3A600000, // exp=4
503
+ 0x3A800000, 0x3AA00000, 0x3AC00000, 0x3AE00000, // exp=5
504
+ 0x3B000000, 0x3B200000, 0x3B400000, 0x3B600000, // exp=6
505
+ 0x3B800000, 0x3BA00000, 0x3BC00000, 0x3BE00000, // exp=7
506
+ 0x3C000000, 0x3C200000, 0x3C400000, 0x3C600000, // exp=8
507
+ 0x3C800000, 0x3CA00000, 0x3CC00000, 0x3CE00000, // exp=9
508
+ 0x3D000000, 0x3D200000, 0x3D400000, 0x3D600000, // exp=10
509
+ 0x3D800000, 0x3DA00000, 0x3DC00000, 0x3DE00000, // exp=11
510
+ 0x3E000000, 0x3E200000, 0x3E400000, 0x3E600000, // exp=12
511
+ 0x3E800000, 0x3EA00000, 0x3EC00000, 0x3EE00000, // exp=13
512
+ 0x3F000000, 0x3F200000, 0x3F400000, 0x3F600000, // exp=14
513
+ 0x3F800000, 0x3FA00000, 0x3FC00000, 0x3FE00000, // exp=15
514
+ 0x40000000, 0x40200000, 0x40400000, 0x40600000, // exp=16
515
+ 0x40800000, 0x40A00000, 0x40C00000, 0x40E00000, // exp=17
516
+ 0x41000000, 0x41200000, 0x41400000, 0x41600000, // exp=18
517
+ 0x41800000, 0x41A00000, 0x41C00000, 0x41E00000, // exp=19
518
+ 0x42000000, 0x42200000, 0x42400000, 0x42600000, // exp=20
519
+ 0x42800000, 0x42A00000, 0x42C00000, 0x42E00000, // exp=21
520
+ 0x43000000, 0x43200000, 0x43400000, 0x43600000, // exp=22
521
+ 0x43800000, 0x43A00000, 0x43C00000, 0x43E00000, // exp=23
522
+ 0x44000000, 0x44200000, 0x44400000, 0x44600000, // exp=24
523
+ 0x44800000, 0x44A00000, 0x44C00000, 0x44E00000, // exp=25
524
+ 0x45000000, 0x45200000, 0x45400000, 0x45600000, // exp=26
525
+ 0x45800000, 0x45A00000, 0x45C00000, 0x45E00000, // exp=27
526
+ 0x46000000, 0x46200000, 0x46400000, 0x46600000, // exp=28
527
+ 0x46800000, 0x46A00000, 0x46C00000, 0x46E00000, // exp=29
528
+ 0x47000000, 0x47200000, 0x47400000, 0x47600000, // exp=30
529
+ 0x7F800000, 0x7FC00000, 0x7FC00000, 0x7FC00000, // inf, nan
530
+ };
531
+ nk_u8_t raw = *src;
532
+ nk_u32_t sign = (nk_u32_t)(raw & 0x80) << 24;
533
+ nk_fui32_t conv;
534
+ conv.u = sign | lut[raw & 0x7F];
535
+ *dest = conv.f;
536
+ }
537
+
538
+ /**
539
+ * @brief Convert IEEE 754 single-precision float to FP8 E5M2.
540
+ *
541
+ * E5M2 (FP8) format: 1 sign bit, 5 exponent bits (bias=15), 2 mantissa bits.
542
+ * Range: [-57344, +57344], supports infinity and NaN (IEEE 754 compatible).
543
+ * Rounding: RNE (Round to Nearest Even) per IEEE 754 / OCP FP8 spec.
544
+ * Subnormal threshold: values with |x| < 2⁻¹⁴ use subnormal encoding.
545
+ *
546
+ * Special value mappings (F32 → E5M2):
547
+ * Input F32 Hex E5M2 Hex Description
548
+ * +0 0x00000000 0x00 Positive zero
549
+ * -0 0x80000000 0x80 Negative zero
550
+ * +inf 0x7F800000 0x7C Positive infinity
551
+ * -inf 0xFF800000 0xFC Negative infinity
552
+ * NaN 0x7FC00000 0x7D Quiet NaN
553
+ * 1.0 0x3F800000 0x3C Normal (exp=15, mant=0)
554
+ * 57344+ >0x47600000 0x7C Overflow → infinity
555
+ * 2⁻¹⁴ 0x38800000 0x04 Min normal
556
+ * <2⁻¹⁷ × ⁵ <0x36800000 0x00 Underflow → zero (RNE boundary)
557
+ *
558
+ * References:
559
+ * https://arxiv.org/pdf/2209.05433 (NVIDIA/Intel/Arm FP8 paper)
560
+ * https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1
561
+ * https://onnx.ai/onnx/technical/float8.html
562
+ */
563
+ NK_PUBLIC void nk_f32_to_e5m2_serial(nk_f32_t const *src, nk_e5m2_t *dest) {
564
+ nk_f32_t x = *src;
565
+ nk_fui32_t conv;
566
+ conv.f = x;
567
+ nk_u32_t sign_bit = conv.u >> 31;
568
+ nk_u32_t abs_bits = conv.u & 0x7FFFFFFFu;
569
+ nk_u8_t sign = (nk_u8_t)(sign_bit << 7);
570
+
571
+ if (abs_bits >= 0x7F800000u) {
572
+ nk_u8_t mant = (abs_bits > 0x7F800000u) ? 0x01u : 0x00u;
573
+ *dest = (nk_e5m2_t)(sign | 0x7Cu | mant);
574
+ return;
575
+ }
576
+
577
+ if (abs_bits == 0) {
578
+ *dest = (nk_e5m2_t)sign;
579
+ return;
580
+ }
581
+
582
+ nk_f32_t abs_x = sign_bit ? -x : x;
583
+
584
+ // Subnormal range: [0, 1/16384). Use RNE rounding via scaled * 65536.
585
+ // The RNE boundary between 0 and 1/65536 is at 0.5/65536, not 1/65536.
586
+ if (abs_x < (1.0f / 16384.0f)) {
587
+ nk_f32_t scaled = abs_x * 65536.0f;
588
+ nk_i32_t mant = (nk_i32_t)scaled;
589
+ nk_f32_t frac = scaled - (nk_f32_t)mant;
590
+ if (frac > 0.5f || (frac == 0.5f && (mant & 1))) { ++mant; }
591
+ // If rounds to 4, promote to first normal (exp_field=1, mantissa=0)
592
+ if (mant > 3) {
593
+ *dest = (nk_e5m2_t)(sign | 0x04u);
594
+ return;
595
+ }
596
+ if (mant == 0) { *dest = (nk_e5m2_t)sign; }
597
+ else { *dest = (nk_e5m2_t)(sign | (nk_u8_t)mant); }
598
+ return;
599
+ }
600
+
601
+ nk_i32_t exp = (nk_i32_t)((abs_bits >> 23) & 0xFFu) - 127;
602
+ nk_u32_t mantissa = abs_bits & 0x7FFFFFu;
603
+ nk_u32_t significand = (1u << 23) | mantissa;
604
+ nk_i32_t shift = 23 - 2;
605
+ nk_u32_t remainder_mask = (1u << shift) - 1;
606
+ nk_u32_t remainder = significand & remainder_mask;
607
+ nk_u32_t halfway = 1u << (shift - 1);
608
+ nk_u32_t significand_rounded = significand >> shift;
609
+ if (remainder > halfway || (remainder == halfway && (significand_rounded & 1))) { ++significand_rounded; }
610
+ if (significand_rounded == (1u << (2 + 1))) {
611
+ significand_rounded >>= 1;
612
+ ++exp;
613
+ }
614
+ if (exp > 15) {
615
+ *dest = (nk_e5m2_t)(sign | 0x7Cu);
616
+ return;
617
+ }
618
+ if (exp < -14) {
619
+ nk_f32_t scaled = abs_x * 65536.0f;
620
+ nk_i32_t mant = (nk_i32_t)scaled;
621
+ nk_f32_t frac = scaled - (nk_f32_t)mant;
622
+ if (frac > 0.5f || (frac == 0.5f && (mant & 1))) { ++mant; }
623
+ // If rounds to 4, promote to first normal (exp_field=1, mantissa=0)
624
+ if (mant > 3) {
625
+ *dest = (nk_e5m2_t)(sign | 0x04u);
626
+ return;
627
+ }
628
+ if (mant == 0) { *dest = (nk_e5m2_t)sign; }
629
+ else { *dest = (nk_e5m2_t)(sign | (nk_u8_t)mant); }
630
+ return;
631
+ }
632
+
633
+ nk_u8_t exp_field = (nk_u8_t)(exp + 15);
634
+ nk_u8_t mant_field = (nk_u8_t)(significand_rounded & 0x03u);
635
+ *dest = (nk_e5m2_t)(sign | (exp_field << 2) | mant_field);
636
+ }
637
+
638
+ /**
639
+ * @brief Convert FP8 E5M2 to IEEE 754 half-precision float.
640
+ *
641
+ * E5M2 format: 1 sign bit, 5 exponent bits (bias=15), 2 mantissa bits.
642
+ * F16 format: 1 sign bit, 5 exponent bits (bias=15), 10 mantissa bits.
643
+ *
644
+ * Since E5M2 and F16 share the same exponent bias (15), normal values
645
+ * convert by simply shifting the magnitude left by 8 bits.
646
+ *
647
+ * Conversion notes:
648
+ * - Normal values: F16 = sign | (mag << 8)
649
+ * - Subnormals: mant × 2⁻¹⁶ (where 2⁻¹⁶ = 0x0100 in F16)
650
+ * - Infinity (0x7C): maps to F16 infinity (0x7C00)
651
+ * - NaN (0x7D-0x7F): maps to F16 quiet NaN (0x7E00)
652
+ */
653
+ NK_INTERNAL void nk_e5m2_to_f16_manual_(nk_e5m2_t const *src, nk_f16_t *dest) {
654
+ nk_u8_t raw = *src;
655
+ nk_u16_t sign = ((nk_u16_t)(raw & 0x80)) << 8;
656
+ nk_u16_t mag = raw & 0x7F;
657
+ nk_u16_t mant = raw & 0x03;
658
+ nk_u16_t exp = (raw >> 2) & 0x1F;
659
+ nk_fui16_t result;
660
+
661
+ if (exp == 0) {
662
+ if (mant == 0) {
663
+ result.u = sign; // Zero
664
+ }
665
+ else {
666
+ // Subnormal: mant × 2⁻¹⁶, where 2⁻¹⁶ = 0x0100 in F16
667
+ nk_fui16_t scale;
668
+ scale.u = 0x0100;
669
+ nk_fui16_t mant_f16;
670
+ mant_f16.f = (nk_f16_t)mant;
671
+ result.f = mant_f16.f * scale.f;
672
+ result.u |= sign;
673
+ }
674
+ }
675
+ else if (mag == 0x7C) {
676
+ result.u = sign | 0x7C00; // Infinity
677
+ }
678
+ else if (mag > 0x7C) {
679
+ result.u = sign | 0x7E00; // NaN
680
+ }
681
+ else {
682
+ // Normal: E5M2 and F16 have same bias (15), just shift magnitude
683
+ result.u = sign | ((nk_u16_t)mag << 8);
684
+ }
685
+ *dest = result.f;
686
+ }
687
+
688
+ NK_INTERNAL void nk_e5m2_to_f16_serial(nk_e5m2_t const *src, nk_f16_t *dest) {
689
+ static nk_u16_t const lut[128] = {
690
+ 0x0000, 0x0100, 0x0200, 0x0300, // exp=0 sub
691
+ 0x0400, 0x0500, 0x0600, 0x0700, // exp=1
692
+ 0x0800, 0x0900, 0x0A00, 0x0B00, // exp=2
693
+ 0x0C00, 0x0D00, 0x0E00, 0x0F00, // exp=3
694
+ 0x1000, 0x1100, 0x1200, 0x1300, // exp=4
695
+ 0x1400, 0x1500, 0x1600, 0x1700, // exp=5
696
+ 0x1800, 0x1900, 0x1A00, 0x1B00, // exp=6
697
+ 0x1C00, 0x1D00, 0x1E00, 0x1F00, // exp=7
698
+ 0x2000, 0x2100, 0x2200, 0x2300, // exp=8
699
+ 0x2400, 0x2500, 0x2600, 0x2700, // exp=9
700
+ 0x2800, 0x2900, 0x2A00, 0x2B00, // exp=10
701
+ 0x2C00, 0x2D00, 0x2E00, 0x2F00, // exp=11
702
+ 0x3000, 0x3100, 0x3200, 0x3300, // exp=12
703
+ 0x3400, 0x3500, 0x3600, 0x3700, // exp=13
704
+ 0x3800, 0x3900, 0x3A00, 0x3B00, // exp=14
705
+ 0x3C00, 0x3D00, 0x3E00, 0x3F00, // exp=15
706
+ 0x4000, 0x4100, 0x4200, 0x4300, // exp=16
707
+ 0x4400, 0x4500, 0x4600, 0x4700, // exp=17
708
+ 0x4800, 0x4900, 0x4A00, 0x4B00, // exp=18
709
+ 0x4C00, 0x4D00, 0x4E00, 0x4F00, // exp=19
710
+ 0x5000, 0x5100, 0x5200, 0x5300, // exp=20
711
+ 0x5400, 0x5500, 0x5600, 0x5700, // exp=21
712
+ 0x5800, 0x5900, 0x5A00, 0x5B00, // exp=22
713
+ 0x5C00, 0x5D00, 0x5E00, 0x5F00, // exp=23
714
+ 0x6000, 0x6100, 0x6200, 0x6300, // exp=24
715
+ 0x6400, 0x6500, 0x6600, 0x6700, // exp=25
716
+ 0x6800, 0x6900, 0x6A00, 0x6B00, // exp=26
717
+ 0x6C00, 0x6D00, 0x6E00, 0x6F00, // exp=27
718
+ 0x7000, 0x7100, 0x7200, 0x7300, // exp=28
719
+ 0x7400, 0x7500, 0x7600, 0x7700, // exp=29
720
+ 0x7800, 0x7900, 0x7A00, 0x7B00, // exp=30
721
+ 0x7C00, 0x7E00, 0x7E00, 0x7E00, // inf, nan
722
+ };
723
+ nk_u8_t raw = *src;
724
+ nk_u16_t sign = ((nk_u16_t)(raw & 0x80)) << 8;
725
+ nk_fui16_t result;
726
+ result.u = sign | lut[raw & 0x7F];
727
+ *dest = result.f;
728
+ }
729
+
730
+ /**
731
+ * @brief Convert FP6 E2M3FN to IEEE 754 single-precision float.
732
+ *
733
+ * E2M3FN (FP6) format: 1 sign bit, 2 exponent bits (bias=1), 3 mantissa bits.
734
+ * Range: [-7.5, +7.5], no infinity or NaN (OCP Microscaling FN format).
735
+ * Uses precomputed lookup table for all 64 possible values.
736
+ *
737
+ * References:
738
+ * https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
739
+ * https://arxiv.org/abs/2401.14112 (FP6-LLM)
740
+ */
741
+ NK_INTERNAL void nk_e2m3_to_f32_manual_(nk_e2m3_t const *src, nk_f32_t *dest) {
742
+ nk_u8_t raw = *src;
743
+ nk_u32_t sign = (nk_u32_t)((raw >> 5) & 0x01u) << 31;
744
+ nk_u32_t exponent = (raw >> 3) & 0x03u;
745
+ nk_u32_t mantissa = raw & 0x07u;
746
+ nk_fui32_t conv;
747
+
748
+ // Handle zero
749
+ if (exponent == 0 && mantissa == 0) {
750
+ conv.u = sign;
751
+ *dest = conv.f;
752
+ return;
753
+ }
754
+
755
+ // Handle subnormal (exp=0, mant!=0)
756
+ if (exponent == 0) {
757
+ // Subnormal: value = 2^(1-bias) * (mantissa / 2^p) = 2^0 * (mantissa / 8) = mantissa / 8
758
+ nk_f32_t value = (nk_f32_t)mantissa * (1.0f / 8.0f);
759
+ *dest = sign ? -value : value;
760
+ return;
761
+ }
762
+
763
+ // Normal values: rebias from E2M3 (bias=1) to F32 (bias=127)
764
+ // E2M3 exp range: 1-3 (unbiased: 0-2)
765
+ // F32 needs: (e2m3_exp - 1) + 127 = e2m3_exp + 126
766
+ nk_u32_t f32_exponent = (exponent + 126u) << 23;
767
+ nk_u32_t f32_mantissa = mantissa << 20;
768
+ conv.u = sign | f32_exponent | f32_mantissa;
769
+ *dest = conv.f;
770
+ }
771
+
772
+ NK_PUBLIC void nk_e2m3_to_f32_serial(nk_e2m3_t const *src, nk_f32_t *dest) {
773
+ static nk_u32_t const lut[32] = {
774
+ 0x00000000, 0x3E000000, 0x3E800000, 0x3EC00000, 0x3F000000, 0x3F200000, 0x3F400000, 0x3F600000, // exp=0 sub
775
+ 0x3F800000, 0x3F900000, 0x3FA00000, 0x3FB00000, 0x3FC00000, 0x3FD00000, 0x3FE00000, 0x3FF00000, // exp=1
776
+ 0x40000000, 0x40100000, 0x40200000, 0x40300000, 0x40400000, 0x40500000, 0x40600000, 0x40700000, // exp=2
777
+ 0x40800000, 0x40900000, 0x40A00000, 0x40B00000, 0x40C00000, 0x40D00000, 0x40E00000, 0x40F00000, // exp=3
778
+ };
779
+ nk_u8_t raw = *src;
780
+ nk_u32_t sign = (nk_u32_t)((raw >> 5) & 0x01u) << 31;
781
+ nk_fui32_t conv;
782
+ conv.u = sign | lut[raw & 0x1F];
783
+ *dest = conv.f;
784
+ }
785
+
786
+ /**
787
+ * @brief Convert IEEE 754 single-precision float to FP6 E2M3FN.
788
+ *
789
+ * E2M3FN (FP6) format: 1 sign bit, 2 exponent bits (bias=1), 3 mantissa bits.
790
+ * Range: [-7.5, +7.5], no ∞ or NaN. Saturates to max on overflow.
791
+ * Rounding: RNE (Round to Nearest Even) per IEEE 754.
792
+ * Subnormal threshold: values with |x| < 0.5 use subnormal encoding.
793
+ */
794
+ NK_PUBLIC void nk_f32_to_e2m3_serial(nk_f32_t const *src, nk_e2m3_t *dest) {
795
+ nk_f32_t x = *src;
796
+ nk_fui32_t conv;
797
+ conv.f = x;
798
+ nk_u32_t sign_bit = conv.u >> 31;
799
+ nk_u32_t abs_bits = conv.u & 0x7FFFFFFFu;
800
+ nk_u8_t sign = (nk_u8_t)(sign_bit << 5);
801
+
802
+ // Zero
803
+ if (abs_bits == 0) {
804
+ *dest = (nk_e2m3_t)sign;
805
+ return;
806
+ }
807
+
808
+ nk_f32_t abs_x = sign_bit ? -x : x;
809
+
810
+ // Clamp to E2M3FN range [-7.5, 7.5]
811
+ // Max value: exp=3, mant=7 → (1 + 7/8) * 2^(3-1) = 1.875 * 4 = 7.5
812
+ if (abs_x >= 7.5f) {
813
+ *dest = (nk_e2m3_t)(sign | 0x1Fu); // Max: 0b011111
814
+ return;
815
+ }
816
+
817
+ // Subnormal range: [0, 1.0). exp=0, mant encodes value/0.125
818
+ if (abs_x < 1.0f) {
819
+ nk_f32_t scaled = abs_x * 8.0f; // Scale to mantissa range [0, 8)
820
+ nk_i32_t mant = (nk_i32_t)scaled;
821
+ nk_f32_t frac = scaled - (nk_f32_t)mant;
822
+ // RNE rounding
823
+ if (frac > 0.5f || (frac == 0.5f && (mant & 1))) { ++mant; }
824
+ // If rounds to 8, promote to first normal (exp=1, mant=0)
825
+ if (mant > 7) {
826
+ *dest = (nk_e2m3_t)(sign | 0x08u);
827
+ return;
828
+ }
829
+ *dest = (nk_e2m3_t)(sign | (nk_u8_t)mant);
830
+ return;
831
+ }
832
+
833
+ // Normal range: extract exponent and mantissa
834
+ nk_i32_t exp = (nk_i32_t)((abs_bits >> 23) & 0xFFu) - 127;
835
+ nk_u32_t mantissa = abs_bits & 0x7FFFFFu;
836
+ nk_u32_t significand = (1u << 23) | mantissa;
837
+
838
+ // Round mantissa from 23 to 3 bits
839
+ nk_i32_t shift = 23 - 3;
840
+ nk_u32_t remainder_mask = (1u << shift) - 1;
841
+ nk_u32_t remainder = significand & remainder_mask;
842
+ nk_u32_t halfway = 1u << (shift - 1);
843
+ nk_u32_t significand_rounded = significand >> shift;
844
+
845
+ // RNE rounding
846
+ if (remainder > halfway || (remainder == halfway && (significand_rounded & 1))) { ++significand_rounded; }
847
+
848
+ // Handle carry into exponent
849
+ if (significand_rounded == (1u << 4)) {
850
+ significand_rounded >>= 1;
851
+ ++exp;
852
+ }
853
+
854
+ // Rebias exponent: e2m3_exp = f32_exp + 1
855
+ nk_i32_t e2m3_exp = exp + 1;
856
+
857
+ // Clamp to valid range
858
+ if (e2m3_exp > 3) {
859
+ *dest = (nk_e2m3_t)(sign | 0x1Fu); // Max value
860
+ return;
861
+ }
862
+ if (e2m3_exp < 0) {
863
+ *dest = (nk_e2m3_t)sign; // Underflow to zero
864
+ return;
865
+ }
866
+
867
+ nk_u8_t exp_field = (nk_u8_t)e2m3_exp;
868
+ nk_u8_t mant_field = (nk_u8_t)(significand_rounded & 0x07u);
869
+ *dest = (nk_e2m3_t)(sign | (exp_field << 3) | mant_field);
870
+ }
871
+
872
+ /**
873
+ * @brief Convert FP6 E3M2FN to IEEE 754 single-precision float.
874
+ *
875
+ * E3M2FN (FP6) format: 1 sign bit, 3 exponent bits (bias=3), 2 mantissa bits.
876
+ * Range: [-28, +28], no infinity or NaN (OCP Microscaling FN format).
877
+ */
878
+ NK_INTERNAL void nk_e3m2_to_f32_manual_(nk_e3m2_t const *src, nk_f32_t *dest) {
879
+ nk_u8_t raw = *src;
880
+ nk_u32_t sign = (nk_u32_t)((raw >> 5) & 0x01u) << 31;
881
+ nk_u32_t exponent = (raw >> 2) & 0x07u;
882
+ nk_u32_t mantissa = raw & 0x03u;
883
+ nk_fui32_t conv;
884
+
885
+ // Handle zero
886
+ if (exponent == 0 && mantissa == 0) {
887
+ conv.u = sign;
888
+ *dest = conv.f;
889
+ return;
890
+ }
891
+
892
+ // Handle subnormal (exp=0, mant!=0)
893
+ if (exponent == 0) {
894
+ // Subnormal: value = 2^(-2) * (mantissa / 4)
895
+ nk_f32_t value = (nk_f32_t)mantissa * (1.0f / 16.0f); // 2^(-2) * (1/4) = 1/16
896
+ *dest = sign ? -value : value;
897
+ return;
898
+ }
899
+
900
+ // Normal values: rebias from E3M2 (bias=3) to F32 (bias=127)
901
+ // E3M2 exp range: 1-7 (unbiased: -2 to 4)
902
+ // F32 needs: (e3m2_exp - 3) + 127 = e3m2_exp + 124
903
+ nk_u32_t f32_exponent = (exponent + 124u) << 23;
904
+ nk_u32_t f32_mantissa = mantissa << 21;
905
+ conv.u = sign | f32_exponent | f32_mantissa;
906
+ *dest = conv.f;
907
+ }
908
+
909
+ NK_PUBLIC void nk_e3m2_to_f32_serial(nk_e3m2_t const *src, nk_f32_t *dest) {
910
+ static nk_u32_t const lut[32] = {
911
+ 0x00000000, 0x3D800000, 0x3E000000, 0x3E400000, // exp=0 sub
912
+ 0x3E800000, 0x3EA00000, 0x3EC00000, 0x3EE00000, // exp=1
913
+ 0x3F000000, 0x3F200000, 0x3F400000, 0x3F600000, // exp=2
914
+ 0x3F800000, 0x3FA00000, 0x3FC00000, 0x3FE00000, // exp=3
915
+ 0x40000000, 0x40200000, 0x40400000, 0x40600000, // exp=4
916
+ 0x40800000, 0x40A00000, 0x40C00000, 0x40E00000, // exp=5
917
+ 0x41000000, 0x41200000, 0x41400000, 0x41600000, // exp=6
918
+ 0x41800000, 0x41A00000, 0x41C00000, 0x41E00000, // exp=7
919
+ };
920
+ nk_u8_t raw = *src;
921
+ nk_u32_t sign = (nk_u32_t)((raw >> 5) & 0x01u) << 31;
922
+ nk_fui32_t conv;
923
+ conv.u = sign | lut[raw & 0x1F];
924
+ *dest = conv.f;
925
+ }
926
+
927
+ /**
928
+ * @brief Convert IEEE 754 single-precision float to FP6 E3M2FN.
929
+ *
930
+ * E3M2FN (FP6) format: 1 sign bit, 3 exponent bits (bias=3), 2 mantissa bits.
931
+ * Range: [-28, +28], no ∞ or NaN. Saturates to max on overflow.
932
+ * Rounding: RNE (Round to Nearest Even) per IEEE 754.
933
+ * Subnormal threshold: values with |x| < 0.25 use subnormal encoding.
934
+ */
935
+ NK_PUBLIC void nk_f32_to_e3m2_serial(nk_f32_t const *src, nk_e3m2_t *dest) {
936
+ nk_f32_t x = *src;
937
+ nk_fui32_t conv;
938
+ conv.f = x;
939
+ nk_u32_t sign_bit = conv.u >> 31;
940
+ nk_u32_t abs_bits = conv.u & 0x7FFFFFFFu;
941
+ nk_u8_t sign = (nk_u8_t)(sign_bit << 5);
942
+
943
+ // Zero
944
+ if (abs_bits == 0) {
945
+ *dest = (nk_e3m2_t)sign;
946
+ return;
947
+ }
948
+
949
+ nk_f32_t abs_x = sign_bit ? -x : x;
950
+
951
+ // Clamp to E3M2FN range [-28, 28]
952
+ // Max value: exp=7, mant=2 → (1 + 2/4) * 2^(7-3) = 1.5 * 16 = 24
953
+ // Actually max is exp=7, mant=3 → (1 + 3/4) * 2⁴ = 1.75 * 16 = 28
954
+ if (abs_x >= 28.0f) {
955
+ *dest = (nk_e3m2_t)(sign | 0x1Fu); // Max: 0b011111 (exp=7, mant=3)
956
+ return;
957
+ }
958
+
959
+ // Subnormal range: [0, 0.25). exp=0, mant encodes value/0.0625
960
+ if (abs_x < 0.25f) {
961
+ nk_f32_t scaled = abs_x * 16.0f; // Scale to mantissa range [0, 4)
962
+ nk_i32_t mant = (nk_i32_t)scaled;
963
+ nk_f32_t frac = scaled - (nk_f32_t)mant;
964
+ // RNE rounding
965
+ if (frac > 0.5f || (frac == 0.5f && (mant & 1))) { ++mant; }
966
+ // If rounds to 4, promote to first normal (exp=1, mant=0)
967
+ if (mant > 3) {
968
+ *dest = (nk_e3m2_t)(sign | 0x04u);
969
+ return;
970
+ }
971
+ *dest = (nk_e3m2_t)(sign | (nk_u8_t)mant);
972
+ return;
973
+ }
974
+
975
+ // Normal range: extract exponent and mantissa
976
+ nk_i32_t exp = (nk_i32_t)((abs_bits >> 23) & 0xFFu) - 127;
977
+ nk_u32_t mantissa = abs_bits & 0x7FFFFFu;
978
+ nk_u32_t significand = (1u << 23) | mantissa;
979
+
980
+ // Round mantissa from 23 to 2 bits
981
+ nk_i32_t shift = 23 - 2;
982
+ nk_u32_t remainder_mask = (1u << shift) - 1;
983
+ nk_u32_t remainder = significand & remainder_mask;
984
+ nk_u32_t halfway = 1u << (shift - 1);
985
+ nk_u32_t significand_rounded = significand >> shift;
986
+
987
+ // RNE rounding
988
+ if (remainder > halfway || (remainder == halfway && (significand_rounded & 1))) { ++significand_rounded; }
989
+
990
+ // Handle carry into exponent
991
+ if (significand_rounded == (1u << 3)) {
992
+ significand_rounded >>= 1;
993
+ ++exp;
994
+ }
995
+
996
+ // Rebias exponent: e3m2_exp = f32_exp + 3
997
+ nk_i32_t e3m2_exp = exp + 3;
998
+
999
+ // Clamp to valid range
1000
+ if (e3m2_exp > 7) {
1001
+ *dest = (nk_e3m2_t)(sign | 0x1Fu); // Max value
1002
+ return;
1003
+ }
1004
+ if (e3m2_exp < 0) {
1005
+ *dest = (nk_e3m2_t)sign; // Underflow to zero
1006
+ return;
1007
+ }
1008
+
1009
+ nk_u8_t exp_field = (nk_u8_t)e3m2_exp;
1010
+ nk_u8_t mant_field = (nk_u8_t)(significand_rounded & 0x03u);
1011
+ *dest = (nk_e3m2_t)(sign | (exp_field << 2) | mant_field);
1012
+ }
1013
+
1014
+ NK_INTERNAL void nk_f16_to_f64_serial(nk_f16_t const *x, nk_f64_t *y) {
1015
+ nk_f32_t f32;
1016
+ nk_f16_to_f32_serial(x, &f32);
1017
+ *y = (nk_f64_t)f32;
1018
+ }
1019
+ NK_INTERNAL void nk_f64_to_f16_serial(nk_f64_t const *x, nk_f16_t *y) {
1020
+ nk_f32_t f32 = (nk_f32_t)*x;
1021
+ nk_f32_to_f16_serial(&f32, y);
1022
+ }
1023
+ NK_INTERNAL void nk_bf16_to_f64_serial(nk_bf16_t const *x, nk_f64_t *y) {
1024
+ nk_f32_t f32;
1025
+ nk_bf16_to_f32_serial(x, &f32);
1026
+ *y = (nk_f64_t)f32;
1027
+ }
1028
+ NK_INTERNAL void nk_f64_to_bf16_serial(nk_f64_t const *x, nk_bf16_t *y) {
1029
+ nk_f32_t f32 = (nk_f32_t)*x;
1030
+ nk_f32_to_bf16_serial(&f32, y);
1031
+ }
1032
+
1033
+ /* Convert floating-point numbers to integers with the project-wide narrowing policy:
1034
+ * finite values are clamped and rounded to nearest, ties to even, infinities saturate,
1035
+ * and NaNs map to zero.
1036
+ */
1037
+ NK_INTERNAL nk_i64_t nk_rint_even_f64_to_i64_serial_(nk_f64_t x) {
1038
+ nk_i64_t integer = (nk_i64_t)x;
1039
+ nk_f64_t fraction = x - (nk_f64_t)integer;
1040
+ if (fraction > 0.5 || (fraction == 0.5 && (integer & 1))) ++integer;
1041
+ else if (fraction < -0.5 || (fraction == -0.5 && (integer & 1))) --integer;
1042
+ return integer;
1043
+ }
1044
+
1045
+ NK_INTERNAL nk_u64_t nk_rint_even_f64_to_u64_serial_(nk_f64_t x) {
1046
+ nk_u64_t integer = (nk_u64_t)x;
1047
+ nk_f64_t fraction = x - (nk_f64_t)integer;
1048
+ if (fraction > 0.5 || (fraction == 0.5 && (integer & 1))) ++integer;
1049
+ return integer;
1050
+ }
1051
+
1052
+ NK_INTERNAL void nk_f32_to_i8_serial(nk_f32_t const *x, nk_i8_t *y) {
1053
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1054
+ else *y = (nk_i8_t)nk_rint_even_f64_to_i64_serial_(*x > 127.0f ? 127.0 : (*x < -128.0f ? -128.0 : (nk_f64_t)*x));
1055
+ }
1056
+
1057
+ NK_INTERNAL void nk_f32_to_u8_serial(nk_f32_t const *x, nk_u8_t *y) {
1058
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1059
+ else *y = (nk_u8_t)nk_rint_even_f64_to_u64_serial_(*x > 255.0f ? 255.0 : (*x < 0 ? 0.0 : (nk_f64_t)*x));
1060
+ }
1061
+
1062
+ NK_INTERNAL void nk_f32_to_i16_serial(nk_f32_t const *x, nk_i16_t *y) {
1063
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1064
+ else
1065
+ *y = (nk_i16_t)nk_rint_even_f64_to_i64_serial_(*x > 32767.0f ? 32767.0
1066
+ : (*x < -32768.0f ? -32768.0 : (nk_f64_t)*x));
1067
+ }
1068
+
1069
+ NK_INTERNAL void nk_f32_to_u16_serial(nk_f32_t const *x, nk_u16_t *y) {
1070
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1071
+ else *y = (nk_u16_t)nk_rint_even_f64_to_u64_serial_(*x > 65535.0f ? 65535.0 : (*x < 0 ? 0.0 : (nk_f64_t)*x));
1072
+ }
1073
+
1074
+ NK_INTERNAL void nk_f64_to_i8_serial(nk_f64_t const *x, nk_i8_t *y) {
1075
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1076
+ else *y = (nk_i8_t)nk_rint_even_f64_to_i64_serial_(*x > 127.0 ? 127.0 : (*x < -128.0 ? -128.0 : *x));
1077
+ }
1078
+
1079
+ NK_INTERNAL void nk_f64_to_u8_serial(nk_f64_t const *x, nk_u8_t *y) {
1080
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1081
+ else *y = (nk_u8_t)nk_rint_even_f64_to_u64_serial_(*x > 255.0 ? 255.0 : (*x < 0 ? 0.0 : *x));
1082
+ }
1083
+
1084
+ NK_INTERNAL void nk_f64_to_i16_serial(nk_f64_t const *x, nk_i16_t *y) {
1085
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1086
+ else *y = (nk_i16_t)nk_rint_even_f64_to_i64_serial_(*x > 32767.0 ? 32767.0 : (*x < -32768.0 ? -32768.0 : *x));
1087
+ }
1088
+
1089
+ NK_INTERNAL void nk_f64_to_u16_serial(nk_f64_t const *x, nk_u16_t *y) {
1090
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1091
+ else *y = (nk_u16_t)nk_rint_even_f64_to_u64_serial_(*x > 65535.0 ? 65535.0 : (*x < 0 ? 0.0 : *x));
1092
+ }
1093
+
1094
+ NK_INTERNAL void nk_f64_to_i32_serial(nk_f64_t const *x, nk_i32_t *y) {
1095
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1096
+ else
1097
+ *y = (nk_i32_t)nk_rint_even_f64_to_i64_serial_(*x > 2147483647.0 ? 2147483647.0
1098
+ : (*x < -2147483648.0 ? -2147483648.0 : *x));
1099
+ }
1100
+
1101
+ NK_INTERNAL void nk_f64_to_u32_serial(nk_f64_t const *x, nk_u32_t *y) {
1102
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1103
+ else *y = (nk_u32_t)nk_rint_even_f64_to_u64_serial_(*x > 4294967295.0 ? 4294967295.0 : (*x < 0 ? 0.0 : *x));
1104
+ }
1105
+
1106
+ NK_INTERNAL void nk_f64_to_i64_serial(nk_f64_t const *x, nk_i64_t *y) {
1107
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1108
+ else
1109
+ *y = nk_rint_even_f64_to_i64_serial_(*x > 9223372036854775807.0
1110
+ ? 9223372036854775807.0
1111
+ : (*x < -9223372036854775808.0 ? -9223372036854775808.0 : *x));
1112
+ }
1113
+
1114
+ NK_INTERNAL void nk_f64_to_u64_serial(nk_f64_t const *x, nk_u64_t *y) {
1115
+ if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
1116
+ else
1117
+ *y = nk_rint_even_f64_to_u64_serial_(*x > 18446744073709551615.0 ? 18446744073709551615.0
1118
+ : (*x < 0 ? 0.0 : *x));
1119
+ }
1120
+
1121
+ NK_INTERNAL void nk_i64_to_i8_serial(nk_i64_t const *x, nk_i8_t *y) {
1122
+ *y = (nk_i8_t)(*x > 127ll ? 127ll : (*x < -128ll ? -128ll : *x));
1123
+ }
1124
+
1125
+ NK_INTERNAL void nk_i64_to_u8_serial(nk_i64_t const *x, nk_u8_t *y) {
1126
+ *y = (nk_u8_t)(*x > 255ll ? 255ll : (*x < 0ll ? 0ll : *x));
1127
+ }
1128
+
1129
+ NK_INTERNAL void nk_i64_to_i16_serial(nk_i64_t const *x, nk_i16_t *y) {
1130
+ *y = (nk_i16_t)(*x > 32767ll ? 32767ll : (*x < -32768ll ? -32768ll : *x));
1131
+ }
1132
+
1133
+ NK_INTERNAL void nk_i64_to_u16_serial(nk_i64_t const *x, nk_u16_t *y) {
1134
+ *y = (nk_u16_t)(*x > 65535ll ? 65535ll : (*x < 0ll ? 0ll : *x));
1135
+ }
1136
+
1137
+ NK_INTERNAL void nk_i64_to_i32_serial(nk_i64_t const *x, nk_i32_t *y) {
1138
+ *y = (nk_i32_t)(*x > 2147483647ll ? 2147483647ll : (*x < -2147483648ll ? -2147483648ll : *x));
1139
+ }
1140
+
1141
+ NK_INTERNAL void nk_i64_to_u32_serial(nk_i64_t const *x, nk_u32_t *y) {
1142
+ *y = (nk_u32_t)(*x > 4294967295ll ? 4294967295ll : (*x < 0ll ? 0ll : *x));
1143
+ }
1144
+
1145
+ NK_INTERNAL void nk_u64_to_i8_serial(nk_u64_t const *x, nk_i8_t *y) { *y = (nk_i8_t)(*x > 127ull ? 127ull : *x); }
1146
+ NK_INTERNAL void nk_u64_to_u8_serial(nk_u64_t const *x, nk_u8_t *y) { *y = (nk_u8_t)(*x > 255ull ? 255ull : *x); }
1147
+ NK_INTERNAL void nk_u64_to_i16_serial(nk_u64_t const *x, nk_i16_t *y) {
1148
+ *y = (nk_i16_t)(*x > 32767ull ? 32767ull : *x);
1149
+ }
1150
+ NK_INTERNAL void nk_u64_to_u16_serial(nk_u64_t const *x, nk_u16_t *y) {
1151
+ *y = (nk_u16_t)(*x > 65535ull ? 65535ull : *x);
1152
+ }
1153
+
1154
+ NK_INTERNAL void nk_u64_to_i32_serial(nk_u64_t const *x, nk_i32_t *y) {
1155
+ *y = (nk_i32_t)(*x > 2147483647ull ? 2147483647ull : *x);
1156
+ }
1157
+
1158
+ NK_INTERNAL void nk_u64_to_u32_serial(nk_u64_t const *x, nk_u32_t *y) {
1159
+ *y = (nk_u32_t)(*x > 4294967295ull ? 4294967295ull : *x);
1160
+ }
1161
+
1162
+ NK_PUBLIC void nk_f16_to_f64_(nk_f16_t const *src, nk_f64_t *dest) {
1163
+ nk_f32_t f32;
1164
+ nk_f16_to_f32_serial(src, &f32);
1165
+ *dest = f32;
1166
+ }
1167
+ NK_PUBLIC void nk_bf16_to_f64_(nk_bf16_t const *src, nk_f64_t *dest) {
1168
+ nk_f32_t f32;
1169
+ nk_bf16_to_f32_serial(src, &f32);
1170
+ *dest = f32;
1171
+ }
1172
+
1173
+ NK_INTERNAL void nk_u64_to_i64_serial(nk_u64_t const *x, nk_i64_t *y) {
1174
+ *y = (nk_i64_t)(*x >= 9223372036854775807ull ? 9223372036854775807ll : *x);
1175
+ }
1176
+
1177
+ NK_INTERNAL void nk_i8_to_u64_serial(nk_i8_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
1178
+ NK_INTERNAL void nk_i16_to_u64_serial(nk_i16_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
1179
+ NK_INTERNAL void nk_i32_to_u64_serial(nk_i32_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
1180
+ NK_INTERNAL void nk_i64_to_u64_serial(nk_i64_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
1181
+
1182
+ NK_INTERNAL void nk_i64_to_f16_serial(nk_i64_t const *x, nk_f16_t *y) {
1183
+ nk_f32_t f32 = (nk_f32_t)*x;
1184
+ nk_f32_to_f16_serial(&f32, y);
1185
+ }
1186
+ NK_INTERNAL void nk_i64_to_bf16_serial(nk_i64_t const *x, nk_bf16_t *y) {
1187
+ nk_f32_t f32 = (nk_f32_t)*x;
1188
+ nk_f32_to_bf16_serial(&f32, y);
1189
+ }
1190
+ NK_INTERNAL void nk_u64_to_f16_serial(nk_u64_t const *x, nk_f16_t *y) {
1191
+ nk_f32_t f32 = (nk_f32_t)*x;
1192
+ nk_f32_to_f16_serial(&f32, y);
1193
+ }
1194
+ NK_INTERNAL void nk_u64_to_bf16_serial(nk_u64_t const *x, nk_bf16_t *y) {
1195
+ nk_f32_t f32 = (nk_f32_t)*x;
1196
+ nk_f32_to_bf16_serial(&f32, y);
1197
+ }
1198
+
1199
+ #pragma region - Type Punned Loads and Stores
1200
+
1201
+ /** @brief Type-agnostic 256-bit full load. */
1202
+ NK_INTERNAL void nk_load_b256_serial_(void const *src, nk_b256_vec_t *dst) {
1203
+ nk_u64_t const *s = (nk_u64_t const *)src;
1204
+ dst->u64s[0] = s[0], dst->u64s[1] = s[1], dst->u64s[2] = s[2], dst->u64s[3] = s[3];
1205
+ }
1206
+
1207
+ /** @brief Type-agnostic 128-bit full load. */
1208
+ NK_INTERNAL void nk_load_b128_serial_(void const *src, nk_b128_vec_t *dst) {
1209
+ nk_u64_t const *s = (nk_u64_t const *)src;
1210
+ dst->u64s[0] = s[0], dst->u64s[1] = s[1];
1211
+ }
1212
+
1213
+ /** @brief Type-agnostic 64-bit full load. */
1214
+ NK_INTERNAL void nk_load_b64_serial_(void const *src, nk_b64_vec_t *dst) { dst->u64 = *(nk_u64_t const *)src; }
1215
+
1216
+ /** @brief Type-agnostic partial load for 32-bit elements (8 elements max) into 256-bit vector. */
1217
+ NK_INTERNAL void nk_partial_load_b32x8_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
1218
+ dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
1219
+ nk_u32_t const *s = (nk_u32_t const *)src;
1220
+ switch (n) {
1221
+ default:
1222
+ case 8: dst->u32s[7] = s[7]; // fallthrough
1223
+ case 7: dst->u32s[6] = s[6]; // fallthrough
1224
+ case 6: dst->u32s[5] = s[5]; // fallthrough
1225
+ case 5: dst->u32s[4] = s[4]; // fallthrough
1226
+ case 4: dst->u32s[3] = s[3]; // fallthrough
1227
+ case 3: dst->u32s[2] = s[2]; // fallthrough
1228
+ case 2: dst->u32s[1] = s[1]; // fallthrough
1229
+ case 1: dst->u32s[0] = s[0]; // fallthrough
1230
+ case 0: break;
1231
+ }
1232
+ }
1233
+
1234
+ /** @brief Type-agnostic partial load for 32-bit elements (4 elements max) into 128-bit vector. */
1235
+ NK_INTERNAL void nk_partial_load_b32x4_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
1236
+ dst->u64s[0] = 0, dst->u64s[1] = 0;
1237
+ nk_u32_t const *s = (nk_u32_t const *)src;
1238
+ switch (n) {
1239
+ default:
1240
+ case 4: dst->u32s[3] = s[3]; // fallthrough
1241
+ case 3: dst->u32s[2] = s[2]; // fallthrough
1242
+ case 2: dst->u32s[1] = s[1]; // fallthrough
1243
+ case 1: dst->u32s[0] = s[0]; // fallthrough
1244
+ case 0: break;
1245
+ }
1246
+ }
1247
+
1248
+ /** @brief Type-agnostic partial load for 8-bit elements (8 elements max) into 64-bit vector. */
1249
+ NK_INTERNAL void nk_partial_load_b8x8_serial_(void const *src, nk_b64_vec_t *dst, nk_size_t n) {
1250
+ dst->u64 = 0;
1251
+ nk_u8_t const *s = (nk_u8_t const *)src;
1252
+ switch (n) {
1253
+ default:
1254
+ case 8: dst->u8s[7] = s[7]; // fallthrough
1255
+ case 7: dst->u8s[6] = s[6]; // fallthrough
1256
+ case 6: dst->u8s[5] = s[5]; // fallthrough
1257
+ case 5: dst->u8s[4] = s[4]; // fallthrough
1258
+ case 4: dst->u8s[3] = s[3]; // fallthrough
1259
+ case 3: dst->u8s[2] = s[2]; // fallthrough
1260
+ case 2: dst->u8s[1] = s[1]; // fallthrough
1261
+ case 1: dst->u8s[0] = s[0]; // fallthrough
1262
+ case 0: break;
1263
+ }
1264
+ }
1265
+
1266
+ /** @brief Type-agnostic partial load for 8-bit elements (4 elements max) into 32-bit vector. */
1267
+ NK_INTERNAL nk_b32_vec_t nk_partial_load_b8x4_serial_(void const *src, nk_size_t n) {
1268
+ nk_b32_vec_t dst = {0};
1269
+ nk_u8_t const *s = (nk_u8_t const *)src;
1270
+ switch (n) {
1271
+ default:
1272
+ case 4: dst.u8s[3] = s[3]; // fallthrough
1273
+ case 3: dst.u8s[2] = s[2]; // fallthrough
1274
+ case 2: dst.u8s[1] = s[1]; // fallthrough
1275
+ case 1: dst.u8s[0] = s[0]; // fallthrough
1276
+ case 0: break;
1277
+ }
1278
+ return dst;
1279
+ }
1280
+
1281
+ /** @brief Partial store for 8-bit elements (up to 4) from nk_b32_vec_t. */
1282
+ NK_INTERNAL void nk_partial_store_b8x4_serial_(nk_b32_vec_t const *src, void *dst, nk_size_t n) {
1283
+ nk_u8_t *d = (nk_u8_t *)dst;
1284
+ switch (n) {
1285
+ default:
1286
+ case 4: d[3] = src->u8s[3]; // fallthrough
1287
+ case 3: d[2] = src->u8s[2]; // fallthrough
1288
+ case 2: d[1] = src->u8s[1]; // fallthrough
1289
+ case 1: d[0] = src->u8s[0]; // fallthrough
1290
+ case 0: break;
1291
+ }
1292
+ }
1293
+
1294
+ /** @brief Type-agnostic partial load for 16-bit elements (8 elements max) into 128-bit vector. */
1295
+ NK_INTERNAL void nk_partial_load_b16x8_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
1296
+ dst->u64s[0] = 0, dst->u64s[1] = 0;
1297
+ nk_u16_t const *s = (nk_u16_t const *)src;
1298
+ switch (n) {
1299
+ default:
1300
+ case 8: dst->u16s[7] = s[7]; // fallthrough
1301
+ case 7: dst->u16s[6] = s[6]; // fallthrough
1302
+ case 6: dst->u16s[5] = s[5]; // fallthrough
1303
+ case 5: dst->u16s[4] = s[4]; // fallthrough
1304
+ case 4: dst->u16s[3] = s[3]; // fallthrough
1305
+ case 3: dst->u16s[2] = s[2]; // fallthrough
1306
+ case 2: dst->u16s[1] = s[1]; // fallthrough
1307
+ case 1: dst->u16s[0] = s[0]; // fallthrough
1308
+ case 0: break;
1309
+ }
1310
+ }
1311
+
1312
+ /** @brief Type-agnostic partial load for 8-bit elements (16 elements max) into 128-bit vector. */
1313
+ NK_INTERNAL void nk_partial_load_b8x16_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
1314
+ dst->u64s[0] = 0, dst->u64s[1] = 0;
1315
+ nk_u8_t const *s = (nk_u8_t const *)src;
1316
+ switch (n) {
1317
+ default:
1318
+ case 16: dst->u8s[15] = s[15]; // fallthrough
1319
+ case 15: dst->u8s[14] = s[14]; // fallthrough
1320
+ case 14: dst->u8s[13] = s[13]; // fallthrough
1321
+ case 13: dst->u8s[12] = s[12]; // fallthrough
1322
+ case 12: dst->u8s[11] = s[11]; // fallthrough
1323
+ case 11: dst->u8s[10] = s[10]; // fallthrough
1324
+ case 10: dst->u8s[9] = s[9]; // fallthrough
1325
+ case 9: dst->u8s[8] = s[8]; // fallthrough
1326
+ case 8: dst->u8s[7] = s[7]; // fallthrough
1327
+ case 7: dst->u8s[6] = s[6]; // fallthrough
1328
+ case 6: dst->u8s[5] = s[5]; // fallthrough
1329
+ case 5: dst->u8s[4] = s[4]; // fallthrough
1330
+ case 4: dst->u8s[3] = s[3]; // fallthrough
1331
+ case 3: dst->u8s[2] = s[2]; // fallthrough
1332
+ case 2: dst->u8s[1] = s[1]; // fallthrough
1333
+ case 1: dst->u8s[0] = s[0]; // fallthrough
1334
+ case 0: break;
1335
+ }
1336
+ }
1337
+
1338
+ /** @brief Type-agnostic partial load for 16-bit elements (16 elements max) into 256-bit vector. */
1339
+ NK_INTERNAL void nk_partial_load_b16x16_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
1340
+ dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
1341
+ nk_u16_t const *s = (nk_u16_t const *)src;
1342
+ switch (n) {
1343
+ default:
1344
+ case 16: dst->u16s[15] = s[15]; // fallthrough
1345
+ case 15: dst->u16s[14] = s[14]; // fallthrough
1346
+ case 14: dst->u16s[13] = s[13]; // fallthrough
1347
+ case 13: dst->u16s[12] = s[12]; // fallthrough
1348
+ case 12: dst->u16s[11] = s[11]; // fallthrough
1349
+ case 11: dst->u16s[10] = s[10]; // fallthrough
1350
+ case 10: dst->u16s[9] = s[9]; // fallthrough
1351
+ case 9: dst->u16s[8] = s[8]; // fallthrough
1352
+ case 8: dst->u16s[7] = s[7]; // fallthrough
1353
+ case 7: dst->u16s[6] = s[6]; // fallthrough
1354
+ case 6: dst->u16s[5] = s[5]; // fallthrough
1355
+ case 5: dst->u16s[4] = s[4]; // fallthrough
1356
+ case 4: dst->u16s[3] = s[3]; // fallthrough
1357
+ case 3: dst->u16s[2] = s[2]; // fallthrough
1358
+ case 2: dst->u16s[1] = s[1]; // fallthrough
1359
+ case 1: dst->u16s[0] = s[0]; // fallthrough
1360
+ case 0: break;
1361
+ }
1362
+ }
1363
+
1364
+ /** @brief Partial load for 8-bit elements (32 max) into 256-bit vector (zeros in remaining slots). */
1365
+ NK_INTERNAL void nk_partial_load_b8x32_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
1366
+ dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
1367
+ nk_u8_t const *s = (nk_u8_t const *)src;
1368
+ switch (n) {
1369
+ default:
1370
+ case 32: dst->u8s[31] = s[31]; // fallthrough
1371
+ case 31: dst->u8s[30] = s[30]; // fallthrough
1372
+ case 30: dst->u8s[29] = s[29]; // fallthrough
1373
+ case 29: dst->u8s[28] = s[28]; // fallthrough
1374
+ case 28: dst->u8s[27] = s[27]; // fallthrough
1375
+ case 27: dst->u8s[26] = s[26]; // fallthrough
1376
+ case 26: dst->u8s[25] = s[25]; // fallthrough
1377
+ case 25: dst->u8s[24] = s[24]; // fallthrough
1378
+ case 24: dst->u8s[23] = s[23]; // fallthrough
1379
+ case 23: dst->u8s[22] = s[22]; // fallthrough
1380
+ case 22: dst->u8s[21] = s[21]; // fallthrough
1381
+ case 21: dst->u8s[20] = s[20]; // fallthrough
1382
+ case 20: dst->u8s[19] = s[19]; // fallthrough
1383
+ case 19: dst->u8s[18] = s[18]; // fallthrough
1384
+ case 18: dst->u8s[17] = s[17]; // fallthrough
1385
+ case 17: dst->u8s[16] = s[16]; // fallthrough
1386
+ case 16: dst->u8s[15] = s[15]; // fallthrough
1387
+ case 15: dst->u8s[14] = s[14]; // fallthrough
1388
+ case 14: dst->u8s[13] = s[13]; // fallthrough
1389
+ case 13: dst->u8s[12] = s[12]; // fallthrough
1390
+ case 12: dst->u8s[11] = s[11]; // fallthrough
1391
+ case 11: dst->u8s[10] = s[10]; // fallthrough
1392
+ case 10: dst->u8s[9] = s[9]; // fallthrough
1393
+ case 9: dst->u8s[8] = s[8]; // fallthrough
1394
+ case 8: dst->u8s[7] = s[7]; // fallthrough
1395
+ case 7: dst->u8s[6] = s[6]; // fallthrough
1396
+ case 6: dst->u8s[5] = s[5]; // fallthrough
1397
+ case 5: dst->u8s[4] = s[4]; // fallthrough
1398
+ case 4: dst->u8s[3] = s[3]; // fallthrough
1399
+ case 3: dst->u8s[2] = s[2]; // fallthrough
1400
+ case 2: dst->u8s[1] = s[1]; // fallthrough
1401
+ case 1: dst->u8s[0] = s[0]; // fallthrough
1402
+ case 0: break;
1403
+ }
1404
+ }
1405
+
1406
+ /** @brief Type-agnostic partial store for 32-bit elements (8 elements max) from 256-bit vector. */
1407
+ NK_INTERNAL void nk_partial_store_b32x8_serial_(nk_b256_vec_t const *src, void *dst, nk_size_t n) {
1408
+ nk_u32_t *d = (nk_u32_t *)dst;
1409
+ switch (n) {
1410
+ default:
1411
+ case 8: d[7] = src->u32s[7]; // fallthrough
1412
+ case 7: d[6] = src->u32s[6]; // fallthrough
1413
+ case 6: d[5] = src->u32s[5]; // fallthrough
1414
+ case 5: d[4] = src->u32s[4]; // fallthrough
1415
+ case 4: d[3] = src->u32s[3]; // fallthrough
1416
+ case 3: d[2] = src->u32s[2]; // fallthrough
1417
+ case 2: d[1] = src->u32s[1]; // fallthrough
1418
+ case 1: d[0] = src->u32s[0]; // fallthrough
1419
+ case 0: break;
1420
+ }
1421
+ }
1422
+
1423
+ /** @brief Type-agnostic partial store for 32-bit elements (4 elements max) from 128-bit vector. */
1424
+ NK_INTERNAL void nk_partial_store_b32x4_serial_(nk_b128_vec_t const *src, void *dst, nk_size_t n) {
1425
+ nk_u32_t *d = (nk_u32_t *)dst;
1426
+ switch (n) {
1427
+ default:
1428
+ case 4: d[3] = src->u32s[3]; // fallthrough
1429
+ case 3: d[2] = src->u32s[2]; // fallthrough
1430
+ case 2: d[1] = src->u32s[1]; // fallthrough
1431
+ case 1: d[0] = src->u32s[0]; // fallthrough
1432
+ case 0: break;
1433
+ }
1434
+ }
1435
+
1436
+ /** @brief Type-agnostic partial store for 16-bit elements (8 elements max) from 128-bit vector. */
1437
+ NK_INTERNAL void nk_partial_store_b16x8_serial_(nk_b128_vec_t const *src, void *dst, nk_size_t n) {
1438
+ nk_u16_t *d = (nk_u16_t *)dst;
1439
+ switch (n) {
1440
+ default:
1441
+ case 8: d[7] = src->u16s[7]; // fallthrough
1442
+ case 7: d[6] = src->u16s[6]; // fallthrough
1443
+ case 6: d[5] = src->u16s[5]; // fallthrough
1444
+ case 5: d[4] = src->u16s[4]; // fallthrough
1445
+ case 4: d[3] = src->u16s[3]; // fallthrough
1446
+ case 3: d[2] = src->u16s[2]; // fallthrough
1447
+ case 2: d[1] = src->u16s[1]; // fallthrough
1448
+ case 1: d[0] = src->u16s[0]; // fallthrough
1449
+ case 0: break;
1450
+ }
1451
+ }
1452
+
1453
+ /** @brief Type-agnostic partial store for 16-bit elements (4 elements max) from 64-bit vector. */
1454
+ NK_INTERNAL void nk_partial_store_b16x4_serial_(void *dst, nk_b64_vec_t const *src, nk_size_t n) {
1455
+ nk_u16_t *d = (nk_u16_t *)dst;
1456
+ switch (n) {
1457
+ default:
1458
+ case 4: d[3] = src->u16s[3]; // fallthrough
1459
+ case 3: d[2] = src->u16s[2]; // fallthrough
1460
+ case 2: d[1] = src->u16s[1]; // fallthrough
1461
+ case 1: d[0] = src->u16s[0]; // fallthrough
1462
+ case 0: break;
1463
+ }
1464
+ }
1465
+
1466
+ /** @brief Type-agnostic partial store for 8-bit elements (8 elements max) from 64-bit vector. */
1467
+ NK_INTERNAL void nk_partial_store_b8x8_serial_(nk_b64_vec_t const *src, void *dst, nk_size_t n) {
1468
+ nk_u8_t *d = (nk_u8_t *)dst;
1469
+ switch (n) {
1470
+ default:
1471
+ case 8: d[7] = src->u8s[7]; // fallthrough
1472
+ case 7: d[6] = src->u8s[6]; // fallthrough
1473
+ case 6: d[5] = src->u8s[5]; // fallthrough
1474
+ case 5: d[4] = src->u8s[4]; // fallthrough
1475
+ case 4: d[3] = src->u8s[3]; // fallthrough
1476
+ case 3: d[2] = src->u8s[2]; // fallthrough
1477
+ case 2: d[1] = src->u8s[1]; // fallthrough
1478
+ case 1: d[0] = src->u8s[0]; // fallthrough
1479
+ case 0: break;
1480
+ }
1481
+ }
1482
+
1483
+ /** @brief Type-agnostic partial load for 64-bit elements (4 elements max) into 256-bit vector. */
1484
+ NK_INTERNAL void nk_partial_load_b64x4_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
1485
+ nk_u64_t const *s = (nk_u64_t const *)src;
1486
+ dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
1487
+ switch (n) {
1488
+ default:
1489
+ case 4: dst->u64s[3] = s[3]; // fallthrough
1490
+ case 3: dst->u64s[2] = s[2]; // fallthrough
1491
+ case 2: dst->u64s[1] = s[1]; // fallthrough
1492
+ case 1: dst->u64s[0] = s[0]; // fallthrough
1493
+ case 0: break;
1494
+ }
1495
+ }
1496
+
1497
+ /** @brief Type-agnostic partial store for 64-bit elements (4 elements max) from 256-bit vector. */
1498
+ NK_INTERNAL void nk_partial_store_b64x4_serial_(nk_b256_vec_t const *src, void *dst, nk_size_t n) {
1499
+ nk_u64_t *d = (nk_u64_t *)dst;
1500
+ switch (n) {
1501
+ default:
1502
+ case 4: d[3] = src->u64s[3]; // fallthrough
1503
+ case 3: d[2] = src->u64s[2]; // fallthrough
1504
+ case 2: d[1] = src->u64s[1]; // fallthrough
1505
+ case 1: d[0] = src->u64s[0]; // fallthrough
1506
+ case 0: break;
1507
+ }
1508
+ }
1509
+
1510
+ /** @brief Type-agnostic partial load for 32-bit elements (2 elements max) into 64-bit vector. */
1511
+ NK_INTERNAL void nk_partial_load_b32x2_serial_(void const *src, nk_b64_vec_t *dst, nk_size_t n) {
1512
+ dst->u64 = 0;
1513
+ nk_u32_t const *s = (nk_u32_t const *)src;
1514
+ switch (n) {
1515
+ default:
1516
+ case 2: dst->u32s[1] = s[1]; // fallthrough
1517
+ case 1: dst->u32s[0] = s[0]; // fallthrough
1518
+ case 0: break;
1519
+ }
1520
+ }
1521
+
1522
+ /** @brief Type-agnostic partial load for 16-bit elements (4 elements max) into 64-bit vector. */
1523
+ NK_INTERNAL void nk_partial_load_b16x4_serial_(void const *src, nk_b64_vec_t *dst, nk_size_t n) {
1524
+ dst->u64 = 0;
1525
+ nk_u16_t const *s = (nk_u16_t const *)src;
1526
+ switch (n) {
1527
+ default:
1528
+ case 4: dst->u16s[3] = s[3]; // fallthrough
1529
+ case 3: dst->u16s[2] = s[2]; // fallthrough
1530
+ case 2: dst->u16s[1] = s[1]; // fallthrough
1531
+ case 1: dst->u16s[0] = s[0]; // fallthrough
1532
+ case 0: break;
1533
+ }
1534
+ }
1535
+
1536
+ /** @brief Partial load for 4-bit nibbles (64 max = 32 bytes) into 256-bit vector (zeros in remaining slots). */
1537
+ NK_INTERNAL void nk_partial_load_b4x64_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
1538
+ dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
1539
+ nk_u8_t const *s = (nk_u8_t const *)src;
1540
+ nk_size_t n_bytes = nk_size_divide_round_up_(n, 2);
1541
+ for (nk_size_t i = 0; i < n_bytes && i < 32; i++) dst->u8s[i] = s[i];
1542
+ }
1543
+
1544
+ /** @brief Partial load for 4-bit nibbles (32 max = 16 bytes) into 128-bit vector (zeros in remaining slots). */
1545
+ NK_INTERNAL void nk_partial_load_b4x32_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
1546
+ dst->u64s[0] = 0, dst->u64s[1] = 0;
1547
+ nk_u8_t const *s = (nk_u8_t const *)src;
1548
+ nk_size_t n_bytes = nk_size_divide_round_up_(n, 2);
1549
+ for (nk_size_t i = 0; i < n_bytes && i < 16; i++) dst->u8s[i] = s[i];
1550
+ }
1551
+
1552
+ /** @brief Partial load for 1-bit elements (128 max = 16 bytes) into 128-bit vector (zeros in remaining slots). */
1553
+ NK_INTERNAL void nk_partial_load_b1x128_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n_bits) {
1554
+ dst->u64s[0] = 0, dst->u64s[1] = 0;
1555
+ nk_u8_t const *s = (nk_u8_t const *)src;
1556
+ nk_size_t n_bytes = nk_size_divide_round_up_(n_bits, 8);
1557
+ for (nk_size_t i = 0; i < n_bytes && i < 16; i++) dst->u8s[i] = s[i];
1558
+ }
1559
+
1560
+ /** @brief Partial load for 4-bit nibbles (16 max = 8 bytes) into 64-bit vector (zeros in remaining slots). */
1561
+ NK_INTERNAL void nk_partial_load_b4x16_serial_(void const *src, nk_b64_vec_t *dst, nk_size_t n) {
1562
+ dst->u64 = 0;
1563
+ nk_u8_t const *s = (nk_u8_t const *)src;
1564
+ nk_size_t n_bytes = nk_size_divide_round_up_(n, 2);
1565
+ for (nk_size_t i = 0; i < n_bytes && i < 8; i++) ((nk_u8_t *)&dst->u64)[i] = s[i];
1566
+ }
1567
+
1568
+ NK_INTERNAL void nk_partial_load_b64x2_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
1569
+ dst->u64s[0] = 0, dst->u64s[1] = 0;
1570
+ nk_u64_t const *s = (nk_u64_t const *)src;
1571
+ switch (n) {
1572
+ default:
1573
+ case 2: dst->u64s[1] = s[1]; // fallthrough
1574
+ case 1: dst->u64s[0] = s[0]; // fallthrough
1575
+ case 0: break;
1576
+ }
1577
+ }
1578
+
1579
+ /** @brief Type-agnostic partial store for 64-bit elements (2 elements max) from 128-bit vector. */
1580
+ NK_INTERNAL void nk_partial_store_b64x2_serial_(nk_b128_vec_t const *src, void *dst, nk_size_t n) {
1581
+ nk_u64_t *d = (nk_u64_t *)dst;
1582
+ switch (n) {
1583
+ default:
1584
+ case 2: d[1] = src->u64s[1]; // fallthrough
1585
+ case 1: d[0] = src->u64s[0]; // fallthrough
1586
+ case 0: break;
1587
+ }
1588
+ }
1589
+
1590
+ /** @brief Strided partial load for 32-bit elements (4 max) into 128-bit vector. */
1591
+ NK_INTERNAL void nk_strided_load_b32x4_serial_(void const *src, nk_size_t stride_elements, nk_b128_vec_t *dst,
1592
+ nk_size_t n) {
1593
+ dst->u64s[0] = 0, dst->u64s[1] = 0;
1594
+ nk_u32_t const *s = (nk_u32_t const *)src;
1595
+ for (nk_size_t i = 0; i < n && i < 4; ++i) dst->u32s[i] = s[i * stride_elements];
1596
+ }
1597
+
1598
+ /** @brief Strided partial load for 16-bit elements (8 max) into 128-bit vector. */
1599
+ NK_INTERNAL void nk_strided_load_b16x8_serial_(void const *src, nk_size_t stride_elements, nk_b128_vec_t *dst,
1600
+ nk_size_t n) {
1601
+ dst->u64s[0] = 0, dst->u64s[1] = 0;
1602
+ nk_u16_t const *s = (nk_u16_t const *)src;
1603
+ for (nk_size_t i = 0; i < n && i < 8; ++i) dst->u16s[i] = s[i * stride_elements];
1604
+ }
1605
+
1606
+ /** @brief Strided partial load for 8-bit elements (16 max) into 128-bit vector. */
1607
+ NK_INTERNAL void nk_strided_load_b8x16_serial_(void const *src, nk_size_t stride_elements, nk_b128_vec_t *dst,
1608
+ nk_size_t n) {
1609
+ dst->u64s[0] = 0, dst->u64s[1] = 0;
1610
+ nk_u8_t const *s = (nk_u8_t const *)src;
1611
+ for (nk_size_t i = 0; i < n && i < 16; ++i) dst->u8s[i] = s[i * stride_elements];
1612
+ }
1613
+
1614
+ /**
1615
+ * @brief Union for type-punned scalar values at language binding boundaries.
1616
+ *
1617
+ * Used to bridge different type systems (Python, JavaScript, etc.) where
1618
+ * scalars arrive as f64 but need to be passed to kernels as typed pointers.
1619
+ * The caller fills the appropriate union member based on the target dtype,
1620
+ * then passes the union address as `void const *` to kernel functions.
1621
+ */
1622
+ typedef union nk_scalar_buffer_t {
1623
+ nk_u8_t bytes[16];
1624
+ nk_f64_t f64;
1625
+ nk_f32_t f32;
1626
+ nk_f16_t f16;
1627
+ nk_bf16_t bf16;
1628
+ nk_f64c_t f64c;
1629
+ nk_f32c_t f32c;
1630
+ nk_f16c_t f16c;
1631
+ nk_bf16c_t bf16c;
1632
+ nk_i64_t i64;
1633
+ nk_u64_t u64;
1634
+ nk_i32_t i32;
1635
+ nk_u32_t u32;
1636
+ nk_i16_t i16;
1637
+ nk_u16_t u16;
1638
+ nk_i8_t i8;
1639
+ nk_u8_t u8;
1640
+ } nk_scalar_buffer_t;
1641
+
1642
+ /**
1643
+ * @brief Converts up to 8x values from `from_ptr` buffer into 8x puned buffer objects
1644
+ * into a complex 64-bit floating point representation.
1645
+ */
1646
+ NK_INTERNAL void nk_scalar_buffers_fill_f64c_( //
1647
+ void const *from_ptr, nk_dtype_t from_dtype, nk_size_t from_count, //
1648
+ nk_scalar_buffer_t to_buffers[nk_at_least_(8)]) {
1649
+
1650
+ nk_f32_t temporary_f32;
1651
+ nk_size_t i;
1652
+ switch (from_dtype) {
1653
+ case nk_f64_k: {
1654
+ nk_f64_t const *p = (nk_f64_t const *)from_ptr;
1655
+ for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
1656
+ } break;
1657
+ case nk_f32_k: {
1658
+ nk_f32_t const *p = (nk_f32_t const *)from_ptr;
1659
+ for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
1660
+ } break;
1661
+ case nk_f16_k: {
1662
+ nk_f16_t const *p = (nk_f16_t const *)from_ptr;
1663
+ for (i = 0; i < from_count; ++i)
1664
+ nk_f16_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
1665
+ to_buffers[i].f64c.imag = 0;
1666
+ } break;
1667
+ case nk_bf16_k: {
1668
+ nk_bf16_t const *p = (nk_bf16_t const *)from_ptr;
1669
+ for (i = 0; i < from_count; ++i)
1670
+ nk_bf16_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
1671
+ to_buffers[i].f64c.imag = 0;
1672
+ } break;
1673
+ case nk_e4m3_k: {
1674
+ nk_u8_t const *p = (nk_u8_t const *)from_ptr;
1675
+ for (i = 0; i < from_count; ++i)
1676
+ nk_e4m3_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
1677
+ to_buffers[i].f64c.imag = 0;
1678
+ } break;
1679
+ case nk_e5m2_k: {
1680
+ nk_u8_t const *p = (nk_u8_t const *)from_ptr;
1681
+ for (i = 0; i < from_count; ++i)
1682
+ nk_e5m2_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
1683
+ to_buffers[i].f64c.imag = 0;
1684
+ } break;
1685
+ case nk_e2m3_k: {
1686
+ nk_u8_t const *p = (nk_u8_t const *)from_ptr;
1687
+ for (i = 0; i < from_count; ++i)
1688
+ nk_e2m3_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
1689
+ to_buffers[i].f64c.imag = 0;
1690
+ } break;
1691
+ case nk_e3m2_k: {
1692
+ nk_u8_t const *p = (nk_u8_t const *)from_ptr;
1693
+ for (i = 0; i < from_count; ++i)
1694
+ nk_e3m2_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
1695
+ to_buffers[i].f64c.imag = 0;
1696
+ } break;
1697
+ case nk_i64_k: {
1698
+ nk_i64_t const *p = (nk_i64_t const *)from_ptr;
1699
+ for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = (nk_f64_t)p[i], to_buffers[i].f64c.imag = 0;
1700
+ } break;
1701
+ case nk_i32_k: {
1702
+ nk_i32_t const *p = (nk_i32_t const *)from_ptr;
1703
+ for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
1704
+ } break;
1705
+ case nk_i16_k: {
1706
+ nk_i16_t const *p = (nk_i16_t const *)from_ptr;
1707
+ for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
1708
+ } break;
1709
+ case nk_i8_k: {
1710
+ nk_i8_t const *p = (nk_i8_t const *)from_ptr;
1711
+ for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
1712
+ } break;
1713
+ case nk_u64_k: {
1714
+ nk_u64_t const *p = (nk_u64_t const *)from_ptr;
1715
+ for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = (nk_f64_t)p[i], to_buffers[i].f64c.imag = 0;
1716
+ } break;
1717
+ case nk_u32_k: {
1718
+ nk_u32_t const *p = (nk_u32_t const *)from_ptr;
1719
+ for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
1720
+ } break;
1721
+ case nk_u16_k: {
1722
+ nk_u16_t const *p = (nk_u16_t const *)from_ptr;
1723
+ for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
1724
+ } break;
1725
+ case nk_u8_k: {
1726
+ nk_u8_t const *p = (nk_u8_t const *)from_ptr;
1727
+ for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
1728
+ } break;
1729
+ case nk_f64c_k: {
1730
+ nk_f64c_t const *p = (nk_f64c_t const *)from_ptr;
1731
+ for (i = 0; i < from_count; ++i) to_buffers[i].f64c = p[i];
1732
+ } break;
1733
+ case nk_f32c_k: {
1734
+ nk_f32c_t const *p = (nk_f32c_t const *)from_ptr;
1735
+ for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i].real, to_buffers[i].f64c.imag = p[i].imag;
1736
+ } break;
1737
+ case nk_f16c_k: {
1738
+ nk_f16c_t const *p = (nk_f16c_t const *)from_ptr;
1739
+ for (i = 0; i < from_count; ++i) {
1740
+ nk_f16_to_f32_serial(&p[i].real, &temporary_f32), to_buffers[i].f64c.real = temporary_f32;
1741
+ nk_f16_to_f32_serial(&p[i].imag, &temporary_f32), to_buffers[i].f64c.imag = temporary_f32;
1742
+ }
1743
+ } break;
1744
+ case nk_bf16c_k: {
1745
+ nk_bf16c_t const *p = (nk_bf16c_t const *)from_ptr;
1746
+ for (i = 0; i < from_count; ++i) {
1747
+ nk_bf16_to_f32_serial(&p[i].real, &temporary_f32), to_buffers[i].f64c.real = temporary_f32;
1748
+ nk_bf16_to_f32_serial(&p[i].imag, &temporary_f32), to_buffers[i].f64c.imag = temporary_f32;
1749
+ }
1750
+ } break;
1751
+ // Sub-byte: u1 - 8 bits from 1 byte, MSB-first
1752
+ case nk_u1_k: {
1753
+ nk_u8_t byte = *(nk_u8_t const *)from_ptr;
1754
+ for (i = 0; i < 8; ++i) to_buffers[i].f64c.real = (byte >> (7 - i)) & 1, to_buffers[i].f64c.imag = 0;
1755
+ } break;
1756
+ // Sub-byte: i4 - 8 nibbles from 4 bytes, high nibble = even index, sign-extended
1757
+ case nk_i4_k: {
1758
+ nk_u8_t const *p = (nk_u8_t const *)from_ptr;
1759
+ for (i = 0; i < 4; ++i) {
1760
+ nk_i8_t hi = (nk_i8_t)(p[i] >> 4), lo = (nk_i8_t)(p[i] & 0xF);
1761
+ to_buffers[i * 2].f64c.real = (hi ^ 8) - 8, to_buffers[i * 2].f64c.imag = 0;
1762
+ to_buffers[i * 2 + 1].f64c.real = (lo ^ 8) - 8, to_buffers[i * 2 + 1].f64c.imag = 0;
1763
+ }
1764
+ } break;
1765
+ // Sub-byte: u4 - 8 nibbles from 4 bytes, high nibble = even index
1766
+ case nk_u4_k: {
1767
+ nk_u8_t const *p = (nk_u8_t const *)from_ptr;
1768
+ for (i = 0; i < 4; ++i) {
1769
+ to_buffers[i * 2].f64c.real = p[i] >> 4, to_buffers[i * 2].f64c.imag = 0;
1770
+ to_buffers[i * 2 + 1].f64c.real = p[i] & 0xF, to_buffers[i * 2 + 1].f64c.imag = 0;
1771
+ }
1772
+ } break;
1773
+ default:
1774
+ for (i = 0; i < 8; ++i) to_buffers[i].f64c.real = 0, to_buffers[i].f64c.imag = 0;
1775
+ break;
1776
+ }
1777
+ }
1778
+
1779
+ /**
1780
+ * @brief Converts up to 8x values from `from_buffers` buffer into 8x typed scalars.
1781
+ */
1782
+ NK_INTERNAL void nk_scalar_buffers_export_f64c_( //
1783
+ nk_scalar_buffer_t const from_buffers[nk_at_least_(8)], //
1784
+ void *to_ptr, nk_dtype_t to_dtype, nk_size_t to_count) {
1785
+
1786
+ nk_f32_t temporary_f32;
1787
+ nk_size_t i;
1788
+ switch (to_dtype) {
1789
+ case nk_f64_k: {
1790
+ nk_f64_t *p = (nk_f64_t *)to_ptr;
1791
+ for (i = 0; i < to_count; ++i) p[i] = from_buffers[i].f64c.real;
1792
+ } break;
1793
+ case nk_f32_k: {
1794
+ nk_f32_t *p = (nk_f32_t *)to_ptr;
1795
+ for (i = 0; i < to_count; ++i) p[i] = (nk_f32_t)from_buffers[i].f64c.real;
1796
+ } break;
1797
+ case nk_f16_k: {
1798
+ nk_f16_t *p = (nk_f16_t *)to_ptr;
1799
+ for (i = 0; i < to_count; ++i)
1800
+ temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_f16_serial(&temporary_f32, &p[i]);
1801
+ } break;
1802
+ case nk_bf16_k: {
1803
+ nk_bf16_t *p = (nk_bf16_t *)to_ptr;
1804
+ for (i = 0; i < to_count; ++i)
1805
+ temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_bf16_serial(&temporary_f32, &p[i]);
1806
+ } break;
1807
+ case nk_e4m3_k: {
1808
+ nk_u8_t *p = (nk_u8_t *)to_ptr;
1809
+ for (i = 0; i < to_count; ++i)
1810
+ temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_e4m3_serial(&temporary_f32, &p[i]);
1811
+ } break;
1812
+ case nk_e5m2_k: {
1813
+ nk_u8_t *p = (nk_u8_t *)to_ptr;
1814
+ for (i = 0; i < to_count; ++i)
1815
+ temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_e5m2_serial(&temporary_f32, &p[i]);
1816
+ } break;
1817
+ case nk_e2m3_k: {
1818
+ nk_u8_t *p = (nk_u8_t *)to_ptr;
1819
+ for (i = 0; i < to_count; ++i)
1820
+ temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_e2m3_serial(&temporary_f32, &p[i]);
1821
+ } break;
1822
+ case nk_e3m2_k: {
1823
+ nk_u8_t *p = (nk_u8_t *)to_ptr;
1824
+ for (i = 0; i < to_count; ++i)
1825
+ temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_e3m2_serial(&temporary_f32, &p[i]);
1826
+ } break;
1827
+ case nk_i64_k: {
1828
+ nk_i64_t *p = (nk_i64_t *)to_ptr;
1829
+ for (i = 0; i < to_count; ++i) nk_f64_to_i64_serial(&from_buffers[i].f64c.real, &p[i]);
1830
+ } break;
1831
+ case nk_i32_k: {
1832
+ nk_i32_t *p = (nk_i32_t *)to_ptr;
1833
+ for (i = 0; i < to_count; ++i) nk_f64_to_i32_serial(&from_buffers[i].f64c.real, &p[i]);
1834
+ } break;
1835
+ case nk_i16_k: {
1836
+ nk_i16_t *p = (nk_i16_t *)to_ptr;
1837
+ for (i = 0; i < to_count; ++i) nk_f64_to_i16_serial(&from_buffers[i].f64c.real, &p[i]);
1838
+ } break;
1839
+ case nk_i8_k: {
1840
+ nk_i8_t *p = (nk_i8_t *)to_ptr;
1841
+ for (i = 0; i < to_count; ++i) nk_f64_to_i8_serial(&from_buffers[i].f64c.real, &p[i]);
1842
+ } break;
1843
+ case nk_u64_k: {
1844
+ nk_u64_t *p = (nk_u64_t *)to_ptr;
1845
+ for (i = 0; i < to_count; ++i) nk_f64_to_u64_serial(&from_buffers[i].f64c.real, &p[i]);
1846
+ } break;
1847
+ case nk_u32_k: {
1848
+ nk_u32_t *p = (nk_u32_t *)to_ptr;
1849
+ for (i = 0; i < to_count; ++i) nk_f64_to_u32_serial(&from_buffers[i].f64c.real, &p[i]);
1850
+ } break;
1851
+ case nk_u16_k: {
1852
+ nk_u16_t *p = (nk_u16_t *)to_ptr;
1853
+ for (i = 0; i < to_count; ++i) nk_f64_to_u16_serial(&from_buffers[i].f64c.real, &p[i]);
1854
+ } break;
1855
+ case nk_u8_k: {
1856
+ nk_u8_t *p = (nk_u8_t *)to_ptr;
1857
+ for (i = 0; i < to_count; ++i) nk_f64_to_u8_serial(&from_buffers[i].f64c.real, &p[i]);
1858
+ } break;
1859
+ case nk_f64c_k: {
1860
+ nk_f64c_t *p = (nk_f64c_t *)to_ptr;
1861
+ for (i = 0; i < to_count; ++i) p[i] = from_buffers[i].f64c;
1862
+ } break;
1863
+ case nk_f32c_k: {
1864
+ nk_f32c_t *p = (nk_f32c_t *)to_ptr;
1865
+ for (i = 0; i < to_count; ++i)
1866
+ p[i].real = (nk_f32_t)from_buffers[i].f64c.real, p[i].imag = (nk_f32_t)from_buffers[i].f64c.imag;
1867
+ } break;
1868
+ case nk_f16c_k: {
1869
+ nk_f16c_t *p = (nk_f16c_t *)to_ptr;
1870
+ for (i = 0; i < to_count; ++i) {
1871
+ temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_f16_serial(&temporary_f32, &p[i].real);
1872
+ temporary_f32 = (nk_f32_t)from_buffers[i].f64c.imag, nk_f32_to_f16_serial(&temporary_f32, &p[i].imag);
1873
+ }
1874
+ } break;
1875
+ case nk_bf16c_k: {
1876
+ nk_bf16c_t *p = (nk_bf16c_t *)to_ptr;
1877
+ for (i = 0; i < to_count; ++i) {
1878
+ temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_bf16_serial(&temporary_f32, &p[i].real);
1879
+ temporary_f32 = (nk_f32_t)from_buffers[i].f64c.imag, nk_f32_to_bf16_serial(&temporary_f32, &p[i].imag);
1880
+ }
1881
+ } break;
1882
+ // Sub-byte: u1 - 8 bits to 1 byte, MSB-first, non-zero → 1
1883
+ case nk_u1_k: {
1884
+ nk_u8_t *p = (nk_u8_t *)to_ptr;
1885
+ nk_u8_t byte = 0;
1886
+ for (i = 0; i < 8; ++i) byte |= (from_buffers[i].f64c.real != 0) << (7 - i);
1887
+ *p = byte;
1888
+ } break;
1889
+ // Sub-byte: i4 - 8 nibbles to 4 bytes, high nibble = even index
1890
+ case nk_i4_k: {
1891
+ nk_u8_t *p = (nk_u8_t *)to_ptr;
1892
+ for (i = 0; i < 4; ++i) {
1893
+ nk_i64_t hi = (nk_i64_t)from_buffers[i * 2].f64c.real;
1894
+ nk_i64_t lo = (nk_i64_t)from_buffers[i * 2 + 1].f64c.real;
1895
+ hi = hi > 7 ? 7 : (hi < -8 ? -8 : hi);
1896
+ lo = lo > 7 ? 7 : (lo < -8 ? -8 : lo);
1897
+ p[i] = (nk_u8_t)(((hi & 0xF) << 4) | (lo & 0xF));
1898
+ }
1899
+ } break;
1900
+ // Sub-byte: u4 - 8 nibbles to 4 bytes, high nibble = even index
1901
+ case nk_u4_k: {
1902
+ nk_u8_t *p = (nk_u8_t *)to_ptr;
1903
+ for (i = 0; i < 4; ++i) {
1904
+ nk_u64_t hi = (nk_u64_t)from_buffers[i * 2].f64c.real;
1905
+ nk_u64_t lo = (nk_u64_t)from_buffers[i * 2 + 1].f64c.real;
1906
+ hi = hi > 15 ? 15 : hi;
1907
+ lo = lo > 15 ? 15 : lo;
1908
+ p[i] = (nk_u8_t)((hi << 4) | lo);
1909
+ }
1910
+ } break;
1911
+ default: break;
1912
+ }
1913
+ }
1914
+
1915
+ /**
1916
+ * @brief Load 8 values from typed buffer into `buf[i].i64` (lossless widening for signed integers).
1917
+ */
1918
+ NK_INTERNAL void nk_scalar_buffers_fill_i64_( //
1919
+ void const *from_ptr, nk_dtype_t from_dtype, nk_size_t from_count, //
1920
+ nk_scalar_buffer_t to_buffers[nk_at_least_(8)]) { //
1921
+ nk_size_t i;
1922
+ switch (from_dtype) {
1923
+ case nk_i64_k: {
1924
+ nk_i64_t const *p = (nk_i64_t const *)from_ptr;
1925
+ for (i = 0; i < from_count; ++i) to_buffers[i].i64 = p[i];
1926
+ } break;
1927
+ case nk_i32_k: {
1928
+ nk_i32_t const *p = (nk_i32_t const *)from_ptr;
1929
+ for (i = 0; i < from_count; ++i) to_buffers[i].i64 = p[i];
1930
+ } break;
1931
+ case nk_i16_k: {
1932
+ nk_i16_t const *p = (nk_i16_t const *)from_ptr;
1933
+ for (i = 0; i < from_count; ++i) to_buffers[i].i64 = p[i];
1934
+ } break;
1935
+ case nk_i8_k: {
1936
+ nk_i8_t const *p = (nk_i8_t const *)from_ptr;
1937
+ for (i = 0; i < from_count; ++i) to_buffers[i].i64 = p[i];
1938
+ } break;
1939
+ // Sub-byte: i4 - 4 bytes to 8 nibbles, sign-extend each nibble
1940
+ case nk_i4_k: {
1941
+ nk_u8_t const *p = (nk_u8_t const *)from_ptr;
1942
+ for (i = 0; i < 4; ++i) {
1943
+ nk_i8_t hi = (nk_i8_t)(p[i] >> 4), lo = (nk_i8_t)(p[i] & 0xF);
1944
+ to_buffers[i * 2].i64 = (hi ^ 8) - 8;
1945
+ to_buffers[i * 2 + 1].i64 = (lo ^ 8) - 8;
1946
+ }
1947
+ } break;
1948
+ case nk_u64_k: {
1949
+ nk_u64_t const *p = (nk_u64_t const *)from_ptr;
1950
+ for (i = 0; i < from_count; ++i) to_buffers[i].i64 = (nk_i64_t)p[i];
1951
+ } break;
1952
+ case nk_u32_k: {
1953
+ nk_u32_t const *p = (nk_u32_t const *)from_ptr;
1954
+ for (i = 0; i < from_count; ++i) to_buffers[i].i64 = (nk_i64_t)p[i];
1955
+ } break;
1956
+ case nk_u16_k: {
1957
+ nk_u16_t const *p = (nk_u16_t const *)from_ptr;
1958
+ for (i = 0; i < from_count; ++i) to_buffers[i].i64 = (nk_i64_t)p[i];
1959
+ } break;
1960
+ case nk_u8_k: {
1961
+ nk_u8_t const *p = (nk_u8_t const *)from_ptr;
1962
+ for (i = 0; i < from_count; ++i) to_buffers[i].i64 = (nk_i64_t)p[i];
1963
+ } break;
1964
+ case nk_u4_k: {
1965
+ nk_u8_t const *p = (nk_u8_t const *)from_ptr;
1966
+ for (i = 0; i < 4; ++i) {
1967
+ to_buffers[i * 2].i64 = (nk_i64_t)(p[i] >> 4);
1968
+ to_buffers[i * 2 + 1].i64 = (nk_i64_t)(p[i] & 0xF);
1969
+ }
1970
+ } break;
1971
+ default: break;
1972
+ }
1973
+ }
1974
+
1975
+ /**
1976
+ * @brief Export 8 `buf[i].i64` values to typed buffer with saturation on downcast.
1977
+ */
1978
+ NK_INTERNAL void nk_scalar_buffers_export_i64_( //
1979
+ nk_scalar_buffer_t const from_buffers[nk_at_least_(8)], //
1980
+ void *to_ptr, nk_dtype_t to_dtype, nk_size_t to_count) { //
1981
+ nk_size_t i;
1982
+ switch (to_dtype) {
1983
+ case nk_i64_k: {
1984
+ nk_i64_t *p = (nk_i64_t *)to_ptr;
1985
+ for (i = 0; i < to_count; ++i) p[i] = from_buffers[i].i64;
1986
+ } break;
1987
+ case nk_i32_k: {
1988
+ nk_i32_t *p = (nk_i32_t *)to_ptr;
1989
+ for (i = 0; i < to_count; ++i) nk_i64_to_i32_serial(&from_buffers[i].i64, &p[i]);
1990
+ } break;
1991
+ case nk_i16_k: {
1992
+ nk_i16_t *p = (nk_i16_t *)to_ptr;
1993
+ for (i = 0; i < to_count; ++i) nk_i64_to_i16_serial(&from_buffers[i].i64, &p[i]);
1994
+ } break;
1995
+ case nk_i8_k: {
1996
+ nk_i8_t *p = (nk_i8_t *)to_ptr;
1997
+ for (i = 0; i < to_count; ++i) nk_i64_to_i8_serial(&from_buffers[i].i64, &p[i]);
1998
+ } break;
1999
+ // Unsigned targets: clamp negatives to 0
2000
+ case nk_u64_k: {
2001
+ nk_u64_t *p = (nk_u64_t *)to_ptr;
2002
+ for (i = 0; i < to_count; ++i) nk_i64_to_u64_serial(&from_buffers[i].i64, &p[i]);
2003
+ } break;
2004
+ case nk_u32_k: {
2005
+ nk_u32_t *p = (nk_u32_t *)to_ptr;
2006
+ for (i = 0; i < to_count; ++i) nk_i64_to_u32_serial(&from_buffers[i].i64, &p[i]);
2007
+ } break;
2008
+ case nk_u16_k: {
2009
+ nk_u16_t *p = (nk_u16_t *)to_ptr;
2010
+ for (i = 0; i < to_count; ++i) nk_i64_to_u16_serial(&from_buffers[i].i64, &p[i]);
2011
+ } break;
2012
+ case nk_u8_k: {
2013
+ nk_u8_t *p = (nk_u8_t *)to_ptr;
2014
+ for (i = 0; i < to_count; ++i) nk_i64_to_u8_serial(&from_buffers[i].i64, &p[i]);
2015
+ } break;
2016
+ // Sub-byte: i4 - 8 nibbles to 4 bytes, clamp [-8,7]
2017
+ case nk_i4_k: {
2018
+ nk_u8_t *p = (nk_u8_t *)to_ptr;
2019
+ for (i = 0; i < 4; ++i) {
2020
+ nk_i64_t hi = from_buffers[i * 2].i64, lo = from_buffers[i * 2 + 1].i64;
2021
+ hi = hi > 7 ? 7 : (hi < -8 ? -8 : hi);
2022
+ lo = lo > 7 ? 7 : (lo < -8 ? -8 : lo);
2023
+ p[i] = (nk_u8_t)(((hi & 0xF) << 4) | (lo & 0xF));
2024
+ }
2025
+ } break;
2026
+ default: break;
2027
+ }
2028
+ }
2029
+
2030
+ /**
2031
+ * @brief Load 8 values from typed buffer into `buf[i].u64` (lossless widening for unsigned integers).
2032
+ */
2033
+ NK_INTERNAL void nk_scalar_buffers_fill_u64_( //
2034
+ void const *from_ptr, nk_dtype_t from_dtype, nk_size_t from_count, //
2035
+ nk_scalar_buffer_t to_buffers[nk_at_least_(8)]) { //
2036
+ nk_size_t i;
2037
+ switch (from_dtype) {
2038
+ case nk_u64_k: {
2039
+ nk_u64_t const *p = (nk_u64_t const *)from_ptr;
2040
+ for (i = 0; i < from_count; ++i) to_buffers[i].u64 = p[i];
2041
+ } break;
2042
+ case nk_u32_k: {
2043
+ nk_u32_t const *p = (nk_u32_t const *)from_ptr;
2044
+ for (i = 0; i < from_count; ++i) to_buffers[i].u64 = p[i];
2045
+ } break;
2046
+ case nk_u16_k: {
2047
+ nk_u16_t const *p = (nk_u16_t const *)from_ptr;
2048
+ for (i = 0; i < from_count; ++i) to_buffers[i].u64 = p[i];
2049
+ } break;
2050
+ case nk_u8_k: {
2051
+ nk_u8_t const *p = (nk_u8_t const *)from_ptr;
2052
+ for (i = 0; i < from_count; ++i) to_buffers[i].u64 = p[i];
2053
+ } break;
2054
+ // Sub-byte: u4 - 4 bytes to 8 nibbles, zero-extend
2055
+ case nk_u4_k: {
2056
+ nk_u8_t const *p = (nk_u8_t const *)from_ptr;
2057
+ for (i = 0; i < 4; ++i) {
2058
+ to_buffers[i * 2].u64 = p[i] >> 4;
2059
+ to_buffers[i * 2 + 1].u64 = p[i] & 0xF;
2060
+ }
2061
+ } break;
2062
+ // Sub-byte: u1 - 1 byte to 8 bits, MSB-first
2063
+ case nk_u1_k: {
2064
+ nk_u8_t byte = *(nk_u8_t const *)from_ptr;
2065
+ for (i = 0; i < 8; ++i) to_buffers[i].u64 = (byte >> (7 - i)) & 1;
2066
+ } break;
2067
+ default: break;
2068
+ }
2069
+ }
2070
+
2071
+ /**
2072
+ * @brief Export 8 `buf[i].u64` values to typed buffer with saturation on downcast.
2073
+ */
2074
+ NK_INTERNAL void nk_scalar_buffers_export_u64_( //
2075
+ nk_scalar_buffer_t const from_buffers[nk_at_least_(8)], //
2076
+ void *to_ptr, nk_dtype_t to_dtype, nk_size_t to_count) { //
2077
+ nk_size_t i;
2078
+ switch (to_dtype) {
2079
+ case nk_u64_k: {
2080
+ nk_u64_t *p = (nk_u64_t *)to_ptr;
2081
+ for (i = 0; i < to_count; ++i) p[i] = from_buffers[i].u64;
2082
+ } break;
2083
+ case nk_u32_k: {
2084
+ nk_u32_t *p = (nk_u32_t *)to_ptr;
2085
+ for (i = 0; i < to_count; ++i) nk_u64_to_u32_serial(&from_buffers[i].u64, &p[i]);
2086
+ } break;
2087
+ case nk_u16_k: {
2088
+ nk_u16_t *p = (nk_u16_t *)to_ptr;
2089
+ for (i = 0; i < to_count; ++i) nk_u64_to_u16_serial(&from_buffers[i].u64, &p[i]);
2090
+ } break;
2091
+ case nk_u8_k: {
2092
+ nk_u8_t *p = (nk_u8_t *)to_ptr;
2093
+ for (i = 0; i < to_count; ++i) nk_u64_to_u8_serial(&from_buffers[i].u64, &p[i]);
2094
+ } break;
2095
+ // Signed targets: clamp to i64_max
2096
+ case nk_i64_k: {
2097
+ nk_i64_t *p = (nk_i64_t *)to_ptr;
2098
+ for (i = 0; i < to_count; ++i) nk_u64_to_i64_serial(&from_buffers[i].u64, &p[i]);
2099
+ } break;
2100
+ case nk_i32_k: {
2101
+ nk_i32_t *p = (nk_i32_t *)to_ptr;
2102
+ for (i = 0; i < to_count; ++i) nk_u64_to_i32_serial(&from_buffers[i].u64, &p[i]);
2103
+ } break;
2104
+ case nk_i16_k: {
2105
+ nk_i16_t *p = (nk_i16_t *)to_ptr;
2106
+ for (i = 0; i < to_count; ++i) nk_u64_to_i16_serial(&from_buffers[i].u64, &p[i]);
2107
+ } break;
2108
+ case nk_i8_k: {
2109
+ nk_i8_t *p = (nk_i8_t *)to_ptr;
2110
+ for (i = 0; i < to_count; ++i) nk_u64_to_i8_serial(&from_buffers[i].u64, &p[i]);
2111
+ } break;
2112
+ // Sub-byte: u4 - 8 nibbles to 4 bytes, clamp [0,15]
2113
+ case nk_u4_k: {
2114
+ nk_u8_t *p = (nk_u8_t *)to_ptr;
2115
+ for (i = 0; i < 4; ++i) {
2116
+ nk_u64_t hi = from_buffers[i * 2].u64, lo = from_buffers[i * 2 + 1].u64;
2117
+ hi = hi > 15 ? 15 : hi;
2118
+ lo = lo > 15 ? 15 : lo;
2119
+ p[i] = (nk_u8_t)((hi << 4) | lo);
2120
+ }
2121
+ } break;
2122
+ // Sub-byte: u1 - 8 bits to 1 byte, MSB-first, non-zero becomes 1
2123
+ case nk_u1_k: {
2124
+ nk_u8_t *p = (nk_u8_t *)to_ptr;
2125
+ nk_u8_t byte = 0;
2126
+ for (i = 0; i < 8; ++i) byte |= (from_buffers[i].u64 != 0) << (7 - i);
2127
+ *p = byte;
2128
+ } break;
2129
+ default: break;
2130
+ }
2131
+ }
2132
+
2133
+ #pragma endregion - Type Punned Loads and Stores
2134
+
2135
+ #pragma region - Public API
2136
+
2137
+ NK_PUBLIC void nk_cast_serial(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type) {
2138
+ if (from_type == to_type) {
2139
+ nk_size_t size_bits = nk_dtype_bits(from_type);
2140
+ nk_size_t size_bytes = nk_size_divide_round_up_(n * size_bits, NK_BITS_PER_BYTE);
2141
+ if (size_bytes > 0) nk_copy_bytes_(to, from, size_bytes);
2142
+ return;
2143
+ }
2144
+
2145
+ nk_size_t from_bits = nk_dtype_bits(from_type);
2146
+ nk_size_t to_bits = nk_dtype_bits(to_type);
2147
+ if (from_bits == 0 || to_bits == 0) return;
2148
+
2149
+ // Byte steps per batch of NK_BITS_PER_BYTE elements
2150
+ nk_size_t from_step = from_bits;
2151
+ nk_size_t to_step = to_bits;
2152
+
2153
+ nk_u8_t const *src = (nk_u8_t const *)from;
2154
+ nk_u8_t *dst = (nk_u8_t *)to;
2155
+ nk_dtype_family_t from_family = nk_dtype_family(from_type);
2156
+ nk_dtype_family_t to_family = nk_dtype_family(to_type);
2157
+
2158
+ nk_size_t batches = n / NK_BITS_PER_BYTE;
2159
+ nk_size_t tail = n % NK_BITS_PER_BYTE;
2160
+ nk_scalar_buffer_t bufs[NK_BITS_PER_BYTE];
2161
+
2162
+ // Both unsigned: u64 hub
2163
+ if (from_family == nk_dtype_family_uint_k && to_family == nk_dtype_family_uint_k) {
2164
+ for (nk_size_t b = 0; b < batches; ++b, src += from_step, dst += to_step) {
2165
+ nk_scalar_buffers_fill_u64_(src, from_type, NK_BITS_PER_BYTE, bufs);
2166
+ nk_scalar_buffers_export_u64_(bufs, dst, to_type, NK_BITS_PER_BYTE);
2167
+ }
2168
+ if (tail) {
2169
+ nk_scalar_buffers_fill_u64_(src, from_type, tail, bufs);
2170
+ nk_scalar_buffers_export_u64_(bufs, dst, to_type, tail);
2171
+ }
2172
+ return;
2173
+ }
2174
+
2175
+ // Both integers, at least one signed: i64 hub
2176
+ if ((from_family == nk_dtype_family_int_k || from_family == nk_dtype_family_uint_k) &&
2177
+ (to_family == nk_dtype_family_int_k || to_family == nk_dtype_family_uint_k)) {
2178
+ for (nk_size_t b = 0; b < batches; ++b, src += from_step, dst += to_step) {
2179
+ nk_scalar_buffers_fill_i64_(src, from_type, NK_BITS_PER_BYTE, bufs);
2180
+ nk_scalar_buffers_export_i64_(bufs, dst, to_type, NK_BITS_PER_BYTE);
2181
+ }
2182
+ if (tail) {
2183
+ nk_scalar_buffers_fill_i64_(src, from_type, tail, bufs);
2184
+ nk_scalar_buffers_export_i64_(bufs, dst, to_type, tail);
2185
+ }
2186
+ return;
2187
+ }
2188
+
2189
+ // Everything else: f64c hub (floats, complex, cross-category)
2190
+ for (nk_size_t b = 0; b < batches; ++b, src += from_step, dst += to_step) {
2191
+ nk_scalar_buffers_fill_f64c_(src, from_type, NK_BITS_PER_BYTE, bufs);
2192
+ nk_scalar_buffers_export_f64c_(bufs, dst, to_type, NK_BITS_PER_BYTE);
2193
+ }
2194
+ if (tail) {
2195
+ nk_scalar_buffers_fill_f64c_(src, from_type, tail, bufs);
2196
+ nk_scalar_buffers_export_f64c_(bufs, dst, to_type, tail);
2197
+ }
2198
+ }
2199
+
2200
+ /** @brief Convert E4M3 to BF16 via F32 intermediate. */
2201
+ NK_PUBLIC void nk_e4m3_to_bf16(nk_e4m3_t const *src, nk_bf16_t *dest) {
2202
+ nk_f32_t temp;
2203
+ nk_e4m3_to_f32_serial(src, &temp);
2204
+ nk_f32_to_bf16_serial(&temp, dest);
2205
+ }
2206
+
2207
+ /** @brief Convert E5M2 to BF16 via F32 intermediate. */
2208
+ NK_PUBLIC void nk_e5m2_to_bf16(nk_e5m2_t const *src, nk_bf16_t *dest) {
2209
+ nk_f32_t temp;
2210
+ nk_e5m2_to_f32_serial(src, &temp);
2211
+ nk_f32_to_bf16_serial(&temp, dest);
2212
+ }
2213
+
2214
+ /** @brief Convert E2M3 to BF16 via F32 intermediate. */
2215
+ NK_PUBLIC void nk_e2m3_to_bf16(nk_e2m3_t const *src, nk_bf16_t *dest) {
2216
+ nk_f32_t temp;
2217
+ nk_e2m3_to_f32_serial(src, &temp);
2218
+ nk_f32_to_bf16_serial(&temp, dest);
2219
+ }
2220
+
2221
+ /** @brief Convert E3M2 to BF16 via F32 intermediate. */
2222
+ NK_PUBLIC void nk_e3m2_to_bf16(nk_e3m2_t const *src, nk_bf16_t *dest) {
2223
+ nk_f32_t temp;
2224
+ nk_e3m2_to_f32_serial(src, &temp);
2225
+ nk_f32_to_bf16_serial(&temp, dest);
2226
+ }
2227
+
2228
+ /**
2229
+ * @brief Convert i4 (4-bit signed integer, -8 to 7) to i8.
2230
+ *
2231
+ * Nibbles are packed: low nibble in bits [0:3], high nibble in bits [4:7].
2232
+ * Sign extension: XOR with 8 then subtract 8 converts unsigned nibble to signed.
2233
+ */
2234
+ NK_PUBLIC void nk_i4_to_i8_serial_(nk_i4x2_t const *src, nk_i8_t *dest, nk_size_t count) {
2235
+ nk_u8_t const *bytes = (nk_u8_t const *)src;
2236
+ for (nk_size_t i = 0; i < count; ++i) {
2237
+ nk_u8_t byte = bytes[i / 2];
2238
+ nk_u8_t nibble = (i % 2 == 0) ? (byte & 0x0F) : (byte >> 4);
2239
+ dest[i] = (nk_i8_t)((nibble ^ 8) - 8); // Sign extend: 0-7 → 0-7, 8-15 → -8 to -1
2240
+ }
2241
+ }
2242
+
2243
+ /**
2244
+ * @brief Convert u4 (4-bit unsigned integer, 0 to 15) to u8.
2245
+ *
2246
+ * Nibbles are packed: low nibble in bits [0:3], high nibble in bits [4:7].
2247
+ */
2248
+ NK_PUBLIC void nk_u4_to_u8_serial_(nk_u4x2_t const *src, nk_u8_t *dest, nk_size_t count) {
2249
+ nk_u8_t const *bytes = (nk_u8_t const *)src;
2250
+ for (nk_size_t i = 0; i < count; ++i) {
2251
+ nk_u8_t byte = bytes[i / 2];
2252
+ dest[i] = (i % 2 == 0) ? (byte & 0x0F) : (byte >> 4);
2253
+ }
2254
+ }
2255
+
2256
+ #pragma endregion - Public API
2257
+
2258
+ #if defined(__cplusplus)
2259
+ } // extern "C"
2260
+ #endif
2261
+
2262
+ #endif // NK_CAST_SERIAL_H