numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,975 @@
1
+ /**
2
+ * @brief SIMD-accelerated Type Conversions for Haswell.
3
+ * @file include/numkong/cast/haswell.h
4
+ * @author Ash Vardanian
5
+ * @date January 2, 2026
6
+ *
7
+ * @section haswell_cast_instructions Key F16C/AVX2 Conversion Instructions
8
+ *
9
+ * Intrinsic Instruction Latency Throughput Ports
10
+ * _mm256_cvtph_ps VCVTPH2PS (YMM, XMM) 5cy 1/cy p01
11
+ * _mm256_cvtps_ph VCVTPS2PH (XMM, YMM, I8) 4cy 1/cy p01+p5
12
+ * _mm256_cvtepi16_epi32 VPMOVSXWD (YMM, XMM) 3cy 1/cy p5
13
+ * _mm256_slli_epi32 VPSLLD (YMM, YMM, I8) 1cy 0.5/cy p01
14
+ * _mm256_blendv_ps VBLENDVPS (YMM, YMM, YMM, YMM) 2cy 1/cy p015
15
+ *
16
+ * F16C provides hardware F16<->F32 conversion. BF16 lacks hardware support and is emulated via
17
+ * bit manipulation (shift upper 16 bits). FP8 formats (E4M3/E5M2) use lookup tables for subnormal
18
+ * handling combined with arithmetic for normal values. All conversions hub through F32.
19
+ */
20
+ #ifndef NK_CAST_HASWELL_H
21
+ #define NK_CAST_HASWELL_H
22
+
23
+ #if NK_TARGET_X86_
24
+ #if NK_TARGET_HASWELL
25
+
26
+ #include "numkong/types.h"
27
+ #include "numkong/cast/serial.h" // `nk_partial_load_b16x16_serial_`
28
+
29
+ #if defined(__cplusplus)
30
+ extern "C" {
31
+ #endif
32
+
33
+ #if defined(__clang__)
34
+ #pragma clang attribute push(__attribute__((target("avx2,f16c,fma,bmi,bmi2"))), apply_to = function)
35
+ #elif defined(__GNUC__)
36
+ #pragma GCC push_options
37
+ #pragma GCC target("avx2", "f16c", "fma", "bmi", "bmi2")
38
+ #endif
39
+
40
+ NK_PUBLIC void nk_f32_to_f16_haswell(nk_f32_t const *from, nk_f16_t *to) {
41
+ *to = _mm_cvtsi128_si32(_mm_cvtps_ph(_mm_set_ss(*from), _MM_FROUND_TO_NEAREST_INT));
42
+ }
43
+
44
+ NK_PUBLIC void nk_f16_to_f32_haswell(nk_f16_t const *from, nk_f32_t *to) {
45
+ *to = _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(*from)));
46
+ }
47
+
48
+ #pragma region - Type Punned Loads and Stores
49
+
50
+ /** @brief Type-agnostic 256-bit full load (Haswell AVX2). */
51
+ NK_INTERNAL void nk_load_b256_haswell_(void const *src, nk_b256_vec_t *dst) {
52
+ dst->ymm = _mm256_loadu_si256((const __m256i *)src);
53
+ }
54
+
55
+ /** @brief Type-agnostic 256-bit full store (Haswell AVX2). */
56
+ NK_INTERNAL void nk_store_b256_haswell_(nk_b256_vec_t const *src, void *dst) {
57
+ _mm256_storeu_si256((__m256i *)dst, src->ymm);
58
+ }
59
+
60
+ /** @brief Type-agnostic 128-bit full load (Haswell AVX2). */
61
+ NK_INTERNAL void nk_load_b128_haswell_(void const *src, nk_b128_vec_t *dst) {
62
+ dst->xmm = _mm_loadu_si128((const __m128i *)src);
63
+ }
64
+
65
+ /** @brief Type-agnostic 128-bit full store (SSE2). */
66
+ NK_INTERNAL void nk_store_b128_haswell_(nk_b128_vec_t const *src, void *dst) {
67
+ _mm_storeu_si128((__m128i *)dst, src->xmm);
68
+ }
69
+
70
+ /** @brief Type-agnostic 128-bit partial load with AVX maskload. */
71
+ NK_INTERNAL void nk_partial_load_b32x4_haswell_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
72
+ __m128i idx_i32x4 = _mm_setr_epi32(0, 1, 2, 3);
73
+ __m128i limit_i32x4 = _mm_set1_epi32((int)n);
74
+ __m128i mask_i32x4 = _mm_cmpgt_epi32(limit_i32x4, idx_i32x4);
75
+ dst->xmm = _mm_castps_si128(_mm_maskload_ps((float const *)src, mask_i32x4));
76
+ }
77
+
78
+ /** @brief Type-agnostic 128-bit partial store with AVX maskstore. */
79
+ NK_INTERNAL void nk_partial_store_b32x4_haswell_(nk_b128_vec_t const *src, void *dst, nk_size_t n) {
80
+ __m128i idx_i32x4 = _mm_setr_epi32(0, 1, 2, 3);
81
+ __m128i limit_i32x4 = _mm_set1_epi32((int)n);
82
+ __m128i mask_i32x4 = _mm_cmpgt_epi32(limit_i32x4, idx_i32x4);
83
+ _mm_maskstore_ps((float *)dst, mask_i32x4, _mm_castsi128_ps(src->xmm));
84
+ }
85
+
86
+ /** @brief Type-agnostic 256-bit partial load with AVX2 maskload. */
87
+ NK_INTERNAL void nk_partial_load_b64x4_haswell_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
88
+ __m256i idx_i64x4 = _mm256_setr_epi64x(0, 1, 2, 3);
89
+ __m256i limit_i64x4 = _mm256_set1_epi64x((long long)n);
90
+ __m256i mask_i64x4 = _mm256_cmpgt_epi64(limit_i64x4, idx_i64x4);
91
+ dst->ymm = _mm256_castpd_si256(_mm256_maskload_pd((double const *)src, mask_i64x4));
92
+ }
93
+
94
+ /** @brief Type-agnostic 256-bit partial store with AVX2 maskstore. */
95
+ NK_INTERNAL void nk_partial_store_b64x4_haswell_(nk_b256_vec_t const *src, void *dst, nk_size_t n) {
96
+ __m256i idx_i64x4 = _mm256_setr_epi64x(0, 1, 2, 3);
97
+ __m256i limit_i64x4 = _mm256_set1_epi64x((long long)n);
98
+ __m256i mask_i64x4 = _mm256_cmpgt_epi64(limit_i64x4, idx_i64x4);
99
+ _mm256_maskstore_pd((double *)dst, mask_i64x4, _mm256_castsi256_pd(src->ymm));
100
+ }
101
+
102
+ #pragma endregion - Type Punned Loads and Stores
103
+
104
+ #pragma region - Vectorized Conversions
105
+
106
+ /** @brief Convert 8x bf16 → 8x f32 by shifting left 16 bits (AVX2). */
107
+ NK_INTERNAL __m256 nk_bf16x8_to_f32x8_haswell_(__m128i bf16_i16x8) {
108
+ return _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(bf16_i16x8), 16));
109
+ }
110
+
111
+ /** @brief Convert 8x f32 → 8x bf16 by truncating with RNE rounding (AVX2). */
112
+ NK_INTERNAL __m128i nk_f32x8_to_bf16x8_haswell_(__m256 f32x8) {
113
+ __m256i bits_i32x8 = _mm256_castps_si256(f32x8);
114
+ // RNE rounding: add (0x7FFF + lsb) where lsb is bit 16
115
+ __m256i lsb_i32x8 = _mm256_and_si256(_mm256_srli_epi32(bits_i32x8, 16), _mm256_set1_epi32(1));
116
+ __m256i rounded_i32x8 = _mm256_add_epi32(bits_i32x8, _mm256_add_epi32(_mm256_set1_epi32(0x7FFF), lsb_i32x8));
117
+ __m256i bf16_i32x8 = _mm256_srli_epi32(rounded_i32x8, 16);
118
+ // Pack 8x i32 to 8x i16
119
+ __m128i lo_i32x4 = _mm256_castsi256_si128(bf16_i32x8);
120
+ __m128i hi_i32x4 = _mm256_extracti128_si256(bf16_i32x8, 1);
121
+ return _mm_packus_epi32(lo_i32x4, hi_i32x4);
122
+ }
123
+
124
+ /** @brief Integer upcasts to f32x8 (AVX2). */
125
+ NK_INTERNAL __m256 nk_i8x8_to_f32x8_haswell_(__m128i i8x8) { return _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(i8x8)); }
126
+ NK_INTERNAL __m256 nk_u8x8_to_f32x8_haswell_(__m128i u8x8) { return _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(u8x8)); }
127
+ NK_INTERNAL __m256 nk_i16x8_to_f32x8_haswell_(__m128i i16x8) {
128
+ return _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(i16x8));
129
+ }
130
+ NK_INTERNAL __m256 nk_u16x8_to_f32x8_haswell_(__m128i u16x8) {
131
+ return _mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(u16x8));
132
+ }
133
+ NK_INTERNAL __m256 nk_i32x8_to_f32x8_haswell_(__m256i i32x8) { return _mm256_cvtepi32_ps(i32x8); }
134
+ NK_INTERNAL __m256 nk_u32x8_to_f32x8_haswell_(__m256i u32x8) {
135
+ __m256i lo_i32x8 = _mm256_and_si256(u32x8, _mm256_set1_epi32(0xFFFF));
136
+ __m256i hi_i32x8 = _mm256_srli_epi32(u32x8, 16);
137
+ return _mm256_add_ps(_mm256_cvtepi32_ps(lo_i32x8),
138
+ _mm256_mul_ps(_mm256_cvtepi32_ps(hi_i32x8), _mm256_set1_ps(65536.0f)));
139
+ }
140
+
141
+ /** @brief Saturating f32x8 downcasts to integers (AVX2). */
142
+ NK_INTERNAL __m256i nk_f32x8_to_i32x8_haswell_(__m256 f32x8) { return _mm256_cvtps_epi32(f32x8); }
143
+ NK_INTERNAL __m256i nk_f32x8_to_u32x8_haswell_(__m256 f32x8) {
144
+ __m256 clamped_f32x8 = _mm256_max_ps(_mm256_min_ps(f32x8, _mm256_set1_ps((float)NK_U32_MAX)), _mm256_setzero_ps());
145
+ __m256 threshold_f32x8 = _mm256_set1_ps(2147483648.0f);
146
+ __m256i mask_i32x8 = _mm256_castps_si256(_mm256_cmp_ps(clamped_f32x8, threshold_f32x8, _CMP_GE_OQ));
147
+ __m256 adjusted_f32x8 = _mm256_sub_ps(clamped_f32x8,
148
+ _mm256_and_ps(_mm256_castsi256_ps(mask_i32x8), threshold_f32x8));
149
+ return _mm256_add_epi32(_mm256_cvtps_epi32(adjusted_f32x8),
150
+ _mm256_and_si256(mask_i32x8, _mm256_set1_epi32((int)0x80000000)));
151
+ }
152
+ NK_INTERNAL __m128i nk_f32x8_to_i16x8_haswell_(__m256 f32x8) {
153
+ __m256 clamped_f32x8 = _mm256_min_ps(_mm256_max_ps(f32x8, _mm256_set1_ps(-32768.0f)), _mm256_set1_ps(32767.0f));
154
+ __m256i i32x8 = _mm256_cvtps_epi32(clamped_f32x8);
155
+ return _mm_packs_epi32(_mm256_castsi256_si128(i32x8), _mm256_extracti128_si256(i32x8, 1));
156
+ }
157
+ NK_INTERNAL __m128i nk_f32x8_to_u16x8_haswell_(__m256 f32x8) {
158
+ __m256 clamped_f32x8 = _mm256_min_ps(_mm256_max_ps(f32x8, _mm256_setzero_ps()), _mm256_set1_ps(65535.0f));
159
+ __m256i i32x8 = _mm256_cvtps_epi32(clamped_f32x8);
160
+ return _mm_packus_epi32(_mm256_castsi256_si128(i32x8), _mm256_extracti128_si256(i32x8, 1));
161
+ }
162
+ NK_INTERNAL __m128i nk_f32x8_to_i8x8_haswell_(__m256 f32x8) {
163
+ __m256 clamped_f32x8 = _mm256_min_ps(_mm256_max_ps(f32x8, _mm256_set1_ps(-128.0f)), _mm256_set1_ps(127.0f));
164
+ __m256i i32x8 = _mm256_cvtps_epi32(clamped_f32x8);
165
+ __m128i i16x8 = _mm_packs_epi32(_mm256_castsi256_si128(i32x8), _mm256_extracti128_si256(i32x8, 1));
166
+ return _mm_packs_epi16(i16x8, _mm_setzero_si128());
167
+ }
168
+ NK_INTERNAL __m128i nk_f32x8_to_u8x8_haswell_(__m256 f32x8) {
169
+ __m256 clamped_f32x8 = _mm256_min_ps(_mm256_max_ps(f32x8, _mm256_setzero_ps()), _mm256_set1_ps(255.0f));
170
+ __m256i i32x8 = _mm256_cvtps_epi32(clamped_f32x8);
171
+ __m128i u16x8 = _mm_packus_epi32(_mm256_castsi256_si128(i32x8), _mm256_extracti128_si256(i32x8, 1));
172
+ return _mm_packus_epi16(u16x8, _mm_setzero_si128());
173
+ }
174
+
175
+ /** @brief Convert 16x e4m3 → 16x bf16 via arithmetic + small LUT for subnormals (AVX2).
176
+ * E4M3 format: S EEEE MMM (bias=7). BF16: S EEEEEEEE MMMMMMM (bias=127).
177
+ * Normal values: BF16 = sign | ((lower7 << 4) + 0x3C00).
178
+ * Subnormals (8 values): looked up via vpshufb from an 8-entry LUT.
179
+ * Handles all corner cases: zero, subnormals, normals, and NaN. */
180
+ NK_INTERNAL __m256i nk_e4m3x16_to_bf16x16_haswell_(__m128i e4m3x16) {
181
+ __m256i e4m3_i16x16 = _mm256_cvtepu8_epi16(e4m3x16);
182
+ __m256i sign_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16((short)0x80));
183
+ __m256i lower7_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x7F));
184
+
185
+ // Normal path: BF16 = ((lower7 << 4) + 0x3C00) | (sign << 8)
186
+ __m256i normal_abs_i16x16 = _mm256_add_epi16(_mm256_slli_epi16(lower7_i16x16, 4), _mm256_set1_epi16(0x3C00));
187
+ sign_i16x16 = _mm256_slli_epi16(sign_i16x16, 8);
188
+ __m256i normal_i16x16 = _mm256_or_si256(sign_i16x16, normal_abs_i16x16);
189
+
190
+ // Subnormal LUT via shuffle_epi8 (8 entries: mantissa 0-7 → BF16)
191
+ // E4M3 subnormal BF16 values: 0x0000, 0x3B00, 0x3B80, 0x3BC0, 0x3C00, 0x3C20, 0x3C40, 0x3C60
192
+ // Split into low bytes and high bytes for reconstruction
193
+ __m256i const lo_lut_i8x32 = _mm256_broadcastsi128_si256(_mm_set_epi8( //
194
+ 0x60, 0x40, 0x20, 0x00, (char)0xC0, (char)0x80, 0x00, 0x00, //
195
+ 0x60, 0x40, 0x20, 0x00, (char)0xC0, (char)0x80, 0x00, 0x00)); //
196
+ __m256i const hi_lut_i8x32 = _mm256_broadcastsi128_si256(_mm_set_epi8( //
197
+ 0x3C, 0x3C, 0x3C, 0x3C, 0x3B, 0x3B, 0x3B, 0x00, //
198
+ 0x3C, 0x3C, 0x3C, 0x3C, 0x3B, 0x3B, 0x3B, 0x00)); //
199
+
200
+ // Extract mantissa (bits 0-2) as byte indices for shuffle
201
+ __m256i byte_idx_i8x32 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi8(0x07));
202
+ __m256i lo_bytes_i8x32 = _mm256_shuffle_epi8(lo_lut_i8x32, byte_idx_i8x32);
203
+ __m256i hi_bytes_i8x32 = _mm256_shuffle_epi8(hi_lut_i8x32, byte_idx_i8x32);
204
+
205
+ // Combine low and high bytes into 16-bit values
206
+ __m256i subnorm_abs_i16x16 = _mm256_or_si256( //
207
+ _mm256_and_si256(lo_bytes_i8x32, _mm256_set1_epi16(0x00FF)), //
208
+ _mm256_slli_epi16(hi_bytes_i8x32, 8)); //
209
+ __m256i subnorm_i16x16 = _mm256_or_si256(subnorm_abs_i16x16, sign_i16x16);
210
+
211
+ // Blend: if exponent == 0, use subnormal result; else use normal result
212
+ __m256i exp_bits_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x78));
213
+ __m256i is_subnormal_i16x16 = _mm256_cmpeq_epi16(exp_bits_i16x16, _mm256_setzero_si256());
214
+ __m256i result_i16x16 = _mm256_blendv_epi8(normal_i16x16, subnorm_i16x16, is_subnormal_i16x16);
215
+
216
+ // Handle NaN: E4M3 index 127 (0x7F) → BF16 NaN (0x7FC0)
217
+ __m256i is_nan_i16x16 = _mm256_cmpeq_epi16(lower7_i16x16, _mm256_set1_epi16(0x7F));
218
+ __m256i nan_i16x16 = _mm256_or_si256(sign_i16x16, _mm256_set1_epi16(0x7FC0));
219
+ return _mm256_blendv_epi8(result_i16x16, nan_i16x16, is_nan_i16x16);
220
+ }
221
+
222
+ /** @brief Convert 16x e5m2 → 16x bf16 via arithmetic + small LUT for subnormals (AVX2).
223
+ * E5M2 format: S EEEEE MM (bias=15). BF16: S EEEEEEEE MMMMMMM (bias=127).
224
+ * Normal values: BF16 = sign | ((lower7 << 5) + 0x3800).
225
+ * Subnormals (4 values): looked up via vpshufb from a 4-entry LUT.
226
+ * Handles all corner cases: zero, subnormals, normals, infinity, and NaN. */
227
+ NK_INTERNAL __m256i nk_e5m2x16_to_bf16x16_haswell_(__m128i e5m2x16) {
228
+ __m256i e5m2_i16x16 = _mm256_cvtepu8_epi16(e5m2x16);
229
+ __m256i sign_i16x16 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16((short)0x80));
230
+ __m256i lower7_i16x16 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16(0x7F));
231
+
232
+ // Normal path: BF16 = ((lower7 << 5) + 0x3800) | (sign << 8)
233
+ __m256i normal_abs_i16x16 = _mm256_add_epi16(_mm256_slli_epi16(lower7_i16x16, 5), _mm256_set1_epi16(0x3800));
234
+ sign_i16x16 = _mm256_slli_epi16(sign_i16x16, 8);
235
+ __m256i normal_i16x16 = _mm256_or_si256(sign_i16x16, normal_abs_i16x16);
236
+
237
+ // Subnormal LUT via shuffle_epi8 (4 entries: mantissa 0-3 → BF16)
238
+ // E5M2 subnormal BF16 values: 0x0000, 0x3780, 0x3800, 0x3840
239
+ __m256i const lo_lut_i8x32 = _mm256_broadcastsi128_si256(_mm_set_epi8( //
240
+ 0x00, 0x00, 0x00, 0x00, 0x40, 0x00, (char)0x80, 0x00, //
241
+ 0x00, 0x00, 0x00, 0x00, 0x40, 0x00, (char)0x80, 0x00)); //
242
+ __m256i const hi_lut_i8x32 = _mm256_broadcastsi128_si256(_mm_set_epi8( //
243
+ 0x00, 0x00, 0x00, 0x00, 0x38, 0x38, 0x37, 0x00, //
244
+ 0x00, 0x00, 0x00, 0x00, 0x38, 0x38, 0x37, 0x00)); //
245
+
246
+ // Extract mantissa (bits 0-1) as byte indices for shuffle
247
+ __m256i byte_idx_i8x32 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi8(0x03));
248
+ __m256i lo_bytes_i8x32 = _mm256_shuffle_epi8(lo_lut_i8x32, byte_idx_i8x32);
249
+ __m256i hi_bytes_i8x32 = _mm256_shuffle_epi8(hi_lut_i8x32, byte_idx_i8x32);
250
+
251
+ // Combine low and high bytes into 16-bit values
252
+ __m256i subnorm_abs_i16x16 = _mm256_or_si256( //
253
+ _mm256_and_si256(lo_bytes_i8x32, _mm256_set1_epi16(0x00FF)), //
254
+ _mm256_slli_epi16(hi_bytes_i8x32, 8)); //
255
+ __m256i subnorm_i16x16 = _mm256_or_si256(subnorm_abs_i16x16, sign_i16x16);
256
+
257
+ // Blend: if exponent == 0, use subnormal result; else use normal result
258
+ __m256i exp_bits_i16x16 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16(0x7C));
259
+ __m256i is_subnormal_i16x16 = _mm256_cmpeq_epi16(exp_bits_i16x16, _mm256_setzero_si256());
260
+ __m256i result_i16x16 = _mm256_blendv_epi8(normal_i16x16, subnorm_i16x16, is_subnormal_i16x16);
261
+
262
+ // Handle Inf (0x7C) and NaN (0x7D-0x7F)
263
+ __m256i is_inf_i16x16 = _mm256_cmpeq_epi16(lower7_i16x16, _mm256_set1_epi16(0x7C));
264
+ __m256i is_nan_i16x16 = _mm256_cmpgt_epi16(lower7_i16x16, _mm256_set1_epi16(0x7C));
265
+ __m256i inf_i16x16 = _mm256_or_si256(sign_i16x16, _mm256_set1_epi16(0x7F80));
266
+ __m256i nan_i16x16 = _mm256_or_si256(sign_i16x16, _mm256_set1_epi16(0x7FC0));
267
+ result_i16x16 = _mm256_blendv_epi8(result_i16x16, inf_i16x16, is_inf_i16x16);
268
+ return _mm256_blendv_epi8(result_i16x16, nan_i16x16, is_nan_i16x16);
269
+ }
270
+
271
+ /** @brief Convert 16x e4m3 → 16x f16 via arithmetic + small LUT for subnormals (AVX2).
272
+ * E4M3 format: S EEEE MMM (bias=7). F16: S EEEEE MMMMMMMMMM (bias=15).
273
+ * Normal values: F16 = sign | ((lower7 << 7) + 0x2000).
274
+ * Subnormals (8 values): looked up via vpshufb from an 8-entry LUT.
275
+ * Handles all corner cases: zero, subnormals, normals, and NaN. */
276
+ NK_INTERNAL __m256i nk_e4m3x16_to_f16x16_haswell_(__m128i e4m3x16) {
277
+ __m256i e4m3_i16x16 = _mm256_cvtepu8_epi16(e4m3x16);
278
+ __m256i sign_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16((short)0x80));
279
+ __m256i lower7_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x7F));
280
+
281
+ // Normal path: F16 = ((lower7 << 7) + 0x2000) | (sign << 8)
282
+ __m256i normal_abs_i16x16 = _mm256_add_epi16(_mm256_slli_epi16(lower7_i16x16, 7), _mm256_set1_epi16(0x2000));
283
+ sign_i16x16 = _mm256_slli_epi16(sign_i16x16, 8);
284
+ __m256i normal_i16x16 = _mm256_or_si256(sign_i16x16, normal_abs_i16x16);
285
+
286
+ // Subnormal LUT via shuffle_epi8 (8 entries: mantissa 0-7 → F16)
287
+ // E4M3 subnormal F16 values: 0x0000, 0x1800, 0x1C00, 0x1E00, 0x2000, 0x2100, 0x2200, 0x2300
288
+ // All low bytes are 0x00, high bytes: 0x00, 0x18, 0x1C, 0x1E, 0x20, 0x21, 0x22, 0x23
289
+ // _mm_set_epi8 order: b15..u1 (unused), b7=idx7, b6=idx6, ..., b0=idx0
290
+ __m256i const lo_lut_i8x32 = _mm256_setzero_si256();
291
+ __m256i const hi_lut_i8x32 = _mm256_broadcastsi128_si256(_mm_set_epi8( //
292
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, //
293
+ 0x23, 0x22, 0x21, 0x20, 0x1E, 0x1C, 0x18, 0x00)); //
294
+
295
+ // Extract mantissa (bits 0-2) as byte indices for shuffle
296
+ __m256i byte_idx_i8x32 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi8(0x07));
297
+ __m256i lo_bytes_i8x32 = _mm256_shuffle_epi8(lo_lut_i8x32, byte_idx_i8x32);
298
+ __m256i hi_bytes_i8x32 = _mm256_shuffle_epi8(hi_lut_i8x32, byte_idx_i8x32);
299
+
300
+ // Combine low and high bytes into 16-bit values
301
+ __m256i subnorm_abs_i16x16 = _mm256_or_si256( //
302
+ _mm256_and_si256(lo_bytes_i8x32, _mm256_set1_epi16(0x00FF)), //
303
+ _mm256_slli_epi16(hi_bytes_i8x32, 8)); //
304
+ __m256i subnorm_i16x16 = _mm256_or_si256(subnorm_abs_i16x16, sign_i16x16);
305
+
306
+ // Blend: if exponent == 0, use subnormal result; else use normal result
307
+ __m256i exp_bits_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x78));
308
+ __m256i is_subnormal_i16x16 = _mm256_cmpeq_epi16(exp_bits_i16x16, _mm256_setzero_si256());
309
+ __m256i result_i16x16 = _mm256_blendv_epi8(normal_i16x16, subnorm_i16x16, is_subnormal_i16x16);
310
+
311
+ // Handle NaN: E4M3 index 127 (0x7F) → F16 NaN (0x7E00)
312
+ __m256i is_nan_i16x16 = _mm256_cmpeq_epi16(lower7_i16x16, _mm256_set1_epi16(0x7F));
313
+ __m256i nan_i16x16 = _mm256_or_si256(sign_i16x16, _mm256_set1_epi16(0x7E00));
314
+ return _mm256_blendv_epi8(result_i16x16, nan_i16x16, is_nan_i16x16);
315
+ }
316
+
317
+ /** @brief Convert 16x e5m2 → 16x f16 via simple bit shift (AVX2).
318
+ * E5M2 format: S EEEEE MM (bias=15). F16: S EEEEE MMMMMMMMMM (bias=15).
319
+ * Same exponent bias means F16 = (lower7 << 8) | (sign << 15).
320
+ * Handles all corner cases: zero, subnormals, normals, infinity, and NaN. */
321
+ NK_INTERNAL __m256i nk_e5m2x16_to_f16x16_haswell_(__m128i e5m2x16) {
322
+ __m256i e5m2_i16x16 = _mm256_cvtepu8_epi16(e5m2x16);
323
+ __m256i sign_i16x16 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16((short)0x80));
324
+ __m256i lower7_i16x16 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16(0x7F));
325
+
326
+ // F16 = (lower7 << 8) | (sign << 15)
327
+ // Works for all cases: subnormals, normals, infinity, and NaN
328
+ __m256i result_i16x16 = _mm256_slli_epi16(lower7_i16x16, 8);
329
+ sign_i16x16 = _mm256_slli_epi16(sign_i16x16, 8);
330
+ return _mm256_or_si256(result_i16x16, sign_i16x16);
331
+ }
332
+
333
+ /** @brief Convert 8x e4m3 → 8x f32 via bit manipulation (AVX2).
334
+ * E4M3 format: S EEEE MMM (bias=7). F32: sign<<31, (exp+120)<<23, mant<<20.
335
+ * Subnormals (exp=0): value = mantissa × 2⁽¹⁻⁷⁾ × 2⁻³ = mantissa ÷ 512. */
336
+ NK_INTERNAL __m256 nk_e4m3x8_to_f32x8_haswell_(__m128i e4m3_i8x8) {
337
+ __m256i e4m3_i32x8 = _mm256_cvtepu8_epi32(e4m3_i8x8);
338
+
339
+ // Extract fields
340
+ __m256i exp_i32x8 = _mm256_and_si256(_mm256_srli_epi32(e4m3_i32x8, 3), _mm256_set1_epi32(0x0F));
341
+ __m256i mant_i32x8 = _mm256_and_si256(e4m3_i32x8, _mm256_set1_epi32(0x07));
342
+
343
+ // Build F32 sign bit
344
+ __m256i f32_sign_i32x8 = _mm256_slli_epi32(_mm256_srli_epi32(e4m3_i32x8, 7), 31);
345
+
346
+ // Normal path: sign | ((exp+120)<<23) | (mant<<20)
347
+ __m256i f32_exp_i32x8 = _mm256_slli_epi32(_mm256_add_epi32(exp_i32x8, _mm256_set1_epi32(120)), 23);
348
+ __m256i f32_mant_i32x8 = _mm256_slli_epi32(mant_i32x8, 20);
349
+ __m256i normal_bits_i32x8 = _mm256_or_si256(f32_sign_i32x8, _mm256_or_si256(f32_exp_i32x8, f32_mant_i32x8));
350
+
351
+ // Subnormal path: value = mantissa / 512.0f, then apply sign
352
+ __m256 subnorm_abs_f32x8 = _mm256_mul_ps(_mm256_cvtepi32_ps(mant_i32x8), _mm256_set1_ps(1.0f / 512.0f));
353
+ __m256 subnorm_f32x8 = _mm256_or_ps(subnorm_abs_f32x8, _mm256_castsi256_ps(f32_sign_i32x8));
354
+
355
+ // Blend: if exp==0, use subnormal result; otherwise use normal bits
356
+ __m256i exp_zero_mask = _mm256_cmpeq_epi32(exp_i32x8, _mm256_setzero_si256());
357
+ __m256 result = _mm256_blendv_ps(_mm256_castsi256_ps(normal_bits_i32x8), subnorm_f32x8,
358
+ _mm256_castsi256_ps(exp_zero_mask));
359
+
360
+ // NaN path: E4M3FN has NaN only when exp=15 AND mant=7 (0x7F or 0xFF)
361
+ __m256i is_nan_mask = _mm256_and_si256( //
362
+ _mm256_cmpeq_epi32(exp_i32x8, _mm256_set1_epi32(15)), //
363
+ _mm256_cmpeq_epi32(mant_i32x8, _mm256_set1_epi32(7))); //
364
+ __m256i nan_bits = _mm256_or_si256(f32_sign_i32x8, _mm256_set1_epi32(0x7FC00000)); // F32 quiet NaN
365
+ return _mm256_blendv_ps(result, _mm256_castsi256_ps(nan_bits), _mm256_castsi256_ps(is_nan_mask));
366
+ }
367
+
368
+ /** @brief Convert 8x e5m2 → 8x f32 via bit manipulation (AVX2).
369
+ * E5M2 format: S EEEEE MM (bias=15). F32: sign<<31, (exp+112)<<23, mant<<21.
370
+ * Subnormals (exp=0): value = mantissa × 2⁽¹⁻¹⁵⁾ × 2⁻² = mantissa ÷ 65536. */
371
+ NK_INTERNAL __m256 nk_e5m2x8_to_f32x8_haswell_(__m128i e5m2_i8x8) {
372
+ __m256i e5m2_i32x8 = _mm256_cvtepu8_epi32(e5m2_i8x8);
373
+
374
+ // Extract fields
375
+ __m256i exp_i32x8 = _mm256_and_si256(_mm256_srli_epi32(e5m2_i32x8, 2), _mm256_set1_epi32(0x1F));
376
+ __m256i mant_i32x8 = _mm256_and_si256(e5m2_i32x8, _mm256_set1_epi32(0x03));
377
+
378
+ // Build F32 sign bit
379
+ __m256i f32_sign_i32x8 = _mm256_slli_epi32(_mm256_srli_epi32(e5m2_i32x8, 7), 31);
380
+
381
+ // Normal path: sign | ((exp+112)<<23) | (mant<<21)
382
+ __m256i f32_exp_i32x8 = _mm256_slli_epi32(_mm256_add_epi32(exp_i32x8, _mm256_set1_epi32(112)), 23);
383
+ __m256i f32_mant_i32x8 = _mm256_slli_epi32(mant_i32x8, 21);
384
+ __m256i normal_bits_i32x8 = _mm256_or_si256(f32_sign_i32x8, _mm256_or_si256(f32_exp_i32x8, f32_mant_i32x8));
385
+
386
+ // Subnormal path: value = mantissa / 65536.0f, then apply sign
387
+ __m256 subnorm_abs_f32x8 = _mm256_mul_ps(_mm256_cvtepi32_ps(mant_i32x8), _mm256_set1_ps(1.0f / 65536.0f));
388
+ __m256 subnorm_f32x8 = _mm256_or_ps(subnorm_abs_f32x8, _mm256_castsi256_ps(f32_sign_i32x8));
389
+
390
+ // Blend: if exp==0, use subnormal result; otherwise use normal bits
391
+ __m256i exp_zero_mask = _mm256_cmpeq_epi32(exp_i32x8, _mm256_setzero_si256());
392
+ return _mm256_blendv_ps(_mm256_castsi256_ps(normal_bits_i32x8), subnorm_f32x8, _mm256_castsi256_ps(exp_zero_mask));
393
+ }
394
+
395
+ /** @brief Convert 8x f32 → 8x e4m3 via bit manipulation (AVX2).
396
+ * E4M3 format: S EEEE MMM (bias=7). Handles normal, subnormal, and overflow cases.
397
+ * Subnormals (f32_exp ≤ 120): mantissa = round(abs_f32 * 512), clamped to [0,7]. */
398
+ NK_INTERNAL __m128i nk_f32x8_to_e4m3x8_haswell_(__m256 f32x8) {
399
+ __m256i bits_i32x8 = _mm256_castps_si256(f32x8);
400
+ __m256i sign_i32x8 = _mm256_srli_epi32(bits_i32x8, 31);
401
+ __m256i f32_exp_i32x8 = _mm256_and_si256(_mm256_srli_epi32(bits_i32x8, 23), _mm256_set1_epi32(0xFF));
402
+
403
+ // Round mantissa from 23 to 3 bits using RNE (round to nearest, ties to even)
404
+ // RNE trick: add (half - 1 + lsb) where lsb is the bit that will become the new lsb after shift
405
+ __m256i significand_i32x8 = _mm256_or_si256(_mm256_and_si256(bits_i32x8, _mm256_set1_epi32(0x007FFFFF)),
406
+ _mm256_set1_epi32(0x00800000)); // Add implicit 1 bit
407
+ __m256i lsb_i32x8 = _mm256_and_si256(_mm256_srli_epi32(significand_i32x8, 20), _mm256_set1_epi32(1));
408
+ __m256i rounding_bias_i32x8 = _mm256_add_epi32(_mm256_set1_epi32(0x0007FFFF), lsb_i32x8);
409
+ __m256i rounded_sig_i32x8 = _mm256_add_epi32(significand_i32x8, rounding_bias_i32x8);
410
+ __m256i carry_i32x8 = _mm256_srli_epi32(rounded_sig_i32x8, 24); // Carry into exponent if bit 24 set
411
+ __m256i f32_mantissa_i32x8 = _mm256_and_si256(_mm256_srli_epi32(rounded_sig_i32x8, 20), _mm256_set1_epi32(0x07));
412
+ // If carry, mantissa becomes 0 (we rounded up to next power of 2)
413
+ f32_mantissa_i32x8 = _mm256_andnot_si256(_mm256_slli_epi32(carry_i32x8, 31), f32_mantissa_i32x8);
414
+ __m256i e4m3_exp_i32x8 = _mm256_sub_epi32(_mm256_add_epi32(f32_exp_i32x8, carry_i32x8), _mm256_set1_epi32(120));
415
+
416
+ // Detect underflow (exp <= 0, maps to subnormal/zero) and overflow (exp > 15)
417
+ __m256i is_subnormal_i32x8 = _mm256_cmpgt_epi32(_mm256_set1_epi32(1), e4m3_exp_i32x8);
418
+ __m256i overflow_i32x8 = _mm256_cmpgt_epi32(e4m3_exp_i32x8, _mm256_set1_epi32(15));
419
+
420
+ // Normal path: clamp exp to [1,15], extract mantissa bits
421
+ // e4m3FN quirk: exp=15 with mantissa=7 is NaN (0x7F), so clamp mantissa to 6 when exp=15.
422
+ __m256i clamped_exp_i32x8 = _mm256_max_epi32(e4m3_exp_i32x8, _mm256_set1_epi32(1));
423
+ clamped_exp_i32x8 = _mm256_min_epi32(clamped_exp_i32x8, _mm256_set1_epi32(15));
424
+ __m256i is_max_exp_i32x8 = _mm256_cmpeq_epi32(clamped_exp_i32x8, _mm256_set1_epi32(15));
425
+ __m256i max_mantissa_i32x8 = _mm256_blendv_epi8(_mm256_set1_epi32(7), _mm256_set1_epi32(6), is_max_exp_i32x8);
426
+ __m256i normal_mantissa_i32x8 = _mm256_min_epi32(f32_mantissa_i32x8, max_mantissa_i32x8);
427
+ normal_mantissa_i32x8 = _mm256_blendv_epi8(normal_mantissa_i32x8, _mm256_set1_epi32(0x06), overflow_i32x8);
428
+ __m256i normal_e4m3_i32x8 = _mm256_or_si256(
429
+ _mm256_slli_epi32(sign_i32x8, 7),
430
+ _mm256_or_si256(_mm256_slli_epi32(clamped_exp_i32x8, 3), normal_mantissa_i32x8));
431
+
432
+ // Subnormal path: mantissa = round(abs_f32 * 512)
433
+ // If mantissa rounds to 8 or higher, promote to first normal (exp_field=1, mantissa=0) = 0x08
434
+ __m256 abs_f32x8 = _mm256_and_ps(f32x8, _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF)));
435
+ __m256 scaled_f32x8 = _mm256_mul_ps(abs_f32x8, _mm256_set1_ps(512.0f));
436
+ __m256i subnorm_mantissa_i32x8 = _mm256_cvtps_epi32(scaled_f32x8);
437
+ __m256i promotes_to_normal_i32x8 = _mm256_cmpgt_epi32(subnorm_mantissa_i32x8, _mm256_set1_epi32(7));
438
+ subnorm_mantissa_i32x8 = _mm256_min_epi32(subnorm_mantissa_i32x8, _mm256_set1_epi32(7));
439
+ subnorm_mantissa_i32x8 = _mm256_max_epi32(subnorm_mantissa_i32x8, _mm256_setzero_si256());
440
+ __m256i subnorm_e4m3_i32x8 = _mm256_or_si256(_mm256_slli_epi32(sign_i32x8, 7), subnorm_mantissa_i32x8);
441
+ // When mantissa rounds to 8, use first normal value (0x08) instead of clamped subnormal
442
+ __m256i first_normal_e4m3_i32x8 = _mm256_or_si256(_mm256_slli_epi32(sign_i32x8, 7), _mm256_set1_epi32(0x08));
443
+ subnorm_e4m3_i32x8 = _mm256_blendv_epi8(subnorm_e4m3_i32x8, first_normal_e4m3_i32x8, promotes_to_normal_i32x8);
444
+
445
+ // Blend: use subnormal result when exp <= 0, else normal
446
+ __m256i e4m3_i32x8 = _mm256_blendv_epi8(normal_e4m3_i32x8, subnorm_e4m3_i32x8, is_subnormal_i32x8);
447
+
448
+ // Pack 8 i32s to 8 unsigned i8s (use unsigned saturation to preserve values 128-255)
449
+ __m128i low_i32x4 = _mm256_castsi256_si128(e4m3_i32x8);
450
+ __m128i high_i32x4 = _mm256_extracti128_si256(e4m3_i32x8, 1);
451
+ __m128i packed_i16x8 = _mm_packus_epi32(low_i32x4, high_i32x4);
452
+ __m128i packed_i8x8 = _mm_packus_epi16(packed_i16x8, packed_i16x8);
453
+ return packed_i8x8;
454
+ }
455
+
456
+ /** @brief Convert 8x f32 → 8x e5m2 via bit manipulation (AVX2).
457
+ * E5M2 format: S EEEEE MM (bias=15). Handles normal, subnormal, and overflow cases.
458
+ * Uses RNE (round to nearest even) for mantissa rounding. */
459
+ NK_INTERNAL __m128i nk_f32x8_to_e5m2x8_haswell_(__m256 f32x8) {
460
+ __m256i bits_i32x8 = _mm256_castps_si256(f32x8);
461
+ __m256i sign_i32x8 = _mm256_srli_epi32(bits_i32x8, 31);
462
+ __m256i f32_exp_i32x8 = _mm256_and_si256(_mm256_srli_epi32(bits_i32x8, 23), _mm256_set1_epi32(0xFF));
463
+
464
+ // Round mantissa from 23 to 2 bits using RNE (round to nearest, ties to even)
465
+ // RNE trick: add (half - 1 + lsb) where lsb is the bit that will become the new lsb after shift
466
+ __m256i significand_i32x8 = _mm256_or_si256(_mm256_and_si256(bits_i32x8, _mm256_set1_epi32(0x007FFFFF)),
467
+ _mm256_set1_epi32(0x00800000)); // Add implicit 1 bit
468
+ __m256i lsb_i32x8 = _mm256_and_si256(_mm256_srli_epi32(significand_i32x8, 21), _mm256_set1_epi32(1));
469
+ __m256i rounding_bias_i32x8 = _mm256_add_epi32(_mm256_set1_epi32(0x000FFFFF), lsb_i32x8); // half = 0x100000
470
+ __m256i rounded_sig_i32x8 = _mm256_add_epi32(significand_i32x8, rounding_bias_i32x8);
471
+ __m256i carry_i32x8 = _mm256_srli_epi32(rounded_sig_i32x8, 24); // Carry into exponent if bit 24 set
472
+ __m256i f32_mantissa_i32x8 = _mm256_and_si256(_mm256_srli_epi32(rounded_sig_i32x8, 21), _mm256_set1_epi32(0x03));
473
+ // If carry, mantissa becomes 0 (we rounded up to next power of 2)
474
+ f32_mantissa_i32x8 = _mm256_andnot_si256(_mm256_slli_epi32(carry_i32x8, 31), f32_mantissa_i32x8);
475
+ __m256i e5m2_exp_i32x8 = _mm256_sub_epi32(_mm256_add_epi32(f32_exp_i32x8, carry_i32x8), _mm256_set1_epi32(112));
476
+
477
+ // Detect subnormal (exp <= 0) and overflow (exp > 31)
478
+ __m256i is_subnormal_i32x8 = _mm256_cmpgt_epi32(_mm256_set1_epi32(1), e5m2_exp_i32x8);
479
+ __m256i overflow_i32x8 = _mm256_cmpgt_epi32(e5m2_exp_i32x8, _mm256_set1_epi32(31));
480
+
481
+ // Normal path: clamp exp to [1,31], on overflow return infinity (exp=31, mantissa=0 = 0x7C)
482
+ __m256i clamped_exp_i32x8 = _mm256_max_epi32(e5m2_exp_i32x8, _mm256_set1_epi32(1));
483
+ clamped_exp_i32x8 = _mm256_min_epi32(clamped_exp_i32x8, _mm256_set1_epi32(31));
484
+ __m256i normal_mantissa_i32x8 = _mm256_blendv_epi8(f32_mantissa_i32x8, _mm256_setzero_si256(), overflow_i32x8);
485
+ __m256i normal_e5m2_i32x8 = _mm256_or_si256(
486
+ _mm256_slli_epi32(sign_i32x8, 7),
487
+ _mm256_or_si256(_mm256_slli_epi32(clamped_exp_i32x8, 2), normal_mantissa_i32x8));
488
+
489
+ // Subnormal path: mantissa = round(abs_f32 * 65536)
490
+ // If mantissa rounds to 4 or higher, promote to first normal (exp_field=1, mantissa=0) = 0x04
491
+ __m256 abs_f32x8 = _mm256_and_ps(f32x8, _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF)));
492
+ __m256 scaled_f32x8 = _mm256_mul_ps(abs_f32x8, _mm256_set1_ps(65536.0f));
493
+ __m256i subnorm_mantissa_i32x8 = _mm256_cvtps_epi32(scaled_f32x8);
494
+ __m256i promotes_to_normal_i32x8 = _mm256_cmpgt_epi32(subnorm_mantissa_i32x8, _mm256_set1_epi32(3));
495
+ subnorm_mantissa_i32x8 = _mm256_min_epi32(subnorm_mantissa_i32x8, _mm256_set1_epi32(3));
496
+ subnorm_mantissa_i32x8 = _mm256_max_epi32(subnorm_mantissa_i32x8, _mm256_setzero_si256());
497
+ __m256i subnorm_e5m2_i32x8 = _mm256_or_si256(_mm256_slli_epi32(sign_i32x8, 7), subnorm_mantissa_i32x8);
498
+ // When mantissa rounds to 4, use first normal value (0x04) instead of clamped subnormal
499
+ __m256i first_normal_e5m2_i32x8 = _mm256_or_si256(_mm256_slli_epi32(sign_i32x8, 7), _mm256_set1_epi32(0x04));
500
+ subnorm_e5m2_i32x8 = _mm256_blendv_epi8(subnorm_e5m2_i32x8, first_normal_e5m2_i32x8, promotes_to_normal_i32x8);
501
+
502
+ // Blend: use subnormal result when exp <= 0
503
+ __m256i e5m2_i32x8 = _mm256_blendv_epi8(normal_e5m2_i32x8, subnorm_e5m2_i32x8, is_subnormal_i32x8);
504
+
505
+ // Pack 8 i32s to 8 unsigned i8s (use unsigned saturation to preserve values 128-255)
506
+ __m128i low_i32x4 = _mm256_castsi256_si128(e5m2_i32x8);
507
+ __m128i high_i32x4 = _mm256_extracti128_si256(e5m2_i32x8, 1);
508
+ __m128i packed_i16x8 = _mm_packus_epi32(low_i32x4, high_i32x4);
509
+ __m128i packed_i8x8 = _mm_packus_epi16(packed_i16x8, packed_i16x8);
510
+ return packed_i8x8;
511
+ }
512
+
513
+ /** @brief Convert 8x e2m3 → 8x f32 via bit manipulation (AVX2).
514
+ * E2M3 format: S EE MMM (bias=1). F32: sign<<31, (exp+126)<<23, mantissa<<20.
515
+ * Subnormals (exp=0): value = mantissa × 2⁽¹⁻¹⁾ × 2⁻³ = mantissa ÷ 8. */
516
+ NK_INTERNAL __m256 nk_e2m3x8_to_f32x8_haswell_(__m128i e2m3_i8x8) {
517
+ __m256i e2m3_i32x8 = _mm256_cvtepu8_epi32(e2m3_i8x8);
518
+
519
+ // Extract fields (only 6 bits used: S EE MMM)
520
+ __m256i exp_i32x8 = _mm256_and_si256(_mm256_srli_epi32(e2m3_i32x8, 3), _mm256_set1_epi32(0x03));
521
+ __m256i mant_i32x8 = _mm256_and_si256(e2m3_i32x8, _mm256_set1_epi32(0x07));
522
+
523
+ // Build F32 sign bit
524
+ __m256i f32_sign_i32x8 = _mm256_slli_epi32(_mm256_srli_epi32(e2m3_i32x8, 5), 31);
525
+
526
+ // Normal path: sign | ((exp+126)<<23) | (mant<<20)
527
+ __m256i f32_exp_i32x8 = _mm256_slli_epi32(_mm256_add_epi32(exp_i32x8, _mm256_set1_epi32(126)), 23);
528
+ __m256i f32_mant_i32x8 = _mm256_slli_epi32(mant_i32x8, 20);
529
+ __m256i normal_bits_i32x8 = _mm256_or_si256(f32_sign_i32x8, _mm256_or_si256(f32_exp_i32x8, f32_mant_i32x8));
530
+
531
+ // Subnormal path: value = mantissa / 8.0f, then apply sign
532
+ __m256 subnorm_abs_f32x8 = _mm256_mul_ps(_mm256_cvtepi32_ps(mant_i32x8), _mm256_set1_ps(1.0f / 8.0f));
533
+ __m256 subnorm_f32x8 = _mm256_or_ps(subnorm_abs_f32x8, _mm256_castsi256_ps(f32_sign_i32x8));
534
+
535
+ // Blend: if exp==0, use subnormal result; otherwise use normal bits
536
+ __m256i exp_zero_mask = _mm256_cmpeq_epi32(exp_i32x8, _mm256_setzero_si256());
537
+ return _mm256_blendv_ps(_mm256_castsi256_ps(normal_bits_i32x8), subnorm_f32x8, _mm256_castsi256_ps(exp_zero_mask));
538
+ }
539
+
540
+ /** @brief Convert 8x e3m2 → 8x f32 via bit manipulation (AVX2).
541
+ * E3M2 format: S EEE MM (bias=3). F32: sign<<31, (exp+124)<<23, mantissa<<21.
542
+ * Subnormals (exp=0): value = mantissa × 2⁽¹⁻³⁾ × 2⁻² = mantissa ÷ 16. */
543
+ NK_INTERNAL __m256 nk_e3m2x8_to_f32x8_haswell_(__m128i e3m2_i8x8) {
544
+ __m256i e3m2_i32x8 = _mm256_cvtepu8_epi32(e3m2_i8x8);
545
+
546
+ // Extract fields (only 6 bits used: S EEE MM)
547
+ __m256i exp_i32x8 = _mm256_and_si256(_mm256_srli_epi32(e3m2_i32x8, 2), _mm256_set1_epi32(0x07));
548
+ __m256i mant_i32x8 = _mm256_and_si256(e3m2_i32x8, _mm256_set1_epi32(0x03));
549
+
550
+ // Build F32 sign bit
551
+ __m256i f32_sign_i32x8 = _mm256_slli_epi32(_mm256_srli_epi32(e3m2_i32x8, 5), 31);
552
+
553
+ // Normal path: sign | ((exp+124)<<23) | (mant<<21)
554
+ __m256i f32_exp_i32x8 = _mm256_slli_epi32(_mm256_add_epi32(exp_i32x8, _mm256_set1_epi32(124)), 23);
555
+ __m256i f32_mant_i32x8 = _mm256_slli_epi32(mant_i32x8, 21);
556
+ __m256i normal_bits_i32x8 = _mm256_or_si256(f32_sign_i32x8, _mm256_or_si256(f32_exp_i32x8, f32_mant_i32x8));
557
+
558
+ // Subnormal path: value = mantissa / 16.0f, then apply sign
559
+ __m256 subnorm_abs_f32x8 = _mm256_mul_ps(_mm256_cvtepi32_ps(mant_i32x8), _mm256_set1_ps(1.0f / 16.0f));
560
+ __m256 subnorm_f32x8 = _mm256_or_ps(subnorm_abs_f32x8, _mm256_castsi256_ps(f32_sign_i32x8));
561
+
562
+ // Blend: if exp==0, use subnormal result; otherwise use normal bits
563
+ __m256i exp_zero_mask = _mm256_cmpeq_epi32(exp_i32x8, _mm256_setzero_si256());
564
+ return _mm256_blendv_ps(_mm256_castsi256_ps(normal_bits_i32x8), subnorm_f32x8, _mm256_castsi256_ps(exp_zero_mask));
565
+ }
566
+
567
+ /** @brief Convert 8x f32 → 8x e2m3 via bit manipulation (AVX2).
568
+ * E2M3 format: S EE MMM (bias=1). Handles normal, subnormal, and overflow cases.
569
+ * Subnormals (f32_exp ≤ 126): mantissa = round(abs_f32 * 8), clamped to [0,7]. */
570
+ NK_INTERNAL __m128i nk_f32x8_to_e2m3x8_haswell_(__m256 f32x8) {
571
+ __m256i bits_i32x8 = _mm256_castps_si256(f32x8);
572
+ __m256i sign_i32x8 = _mm256_srli_epi32(bits_i32x8, 31);
573
+ __m256i f32_exp_i32x8 = _mm256_and_si256(_mm256_srli_epi32(bits_i32x8, 23), _mm256_set1_epi32(0xFF));
574
+
575
+ // Round mantissa from 23 to 3 bits using RNE (round to nearest, ties to even)
576
+ __m256i significand_i32x8 = _mm256_or_si256(_mm256_and_si256(bits_i32x8, _mm256_set1_epi32(0x007FFFFF)),
577
+ _mm256_set1_epi32(0x00800000)); // Add implicit 1 bit
578
+ __m256i lsb_i32x8 = _mm256_and_si256(_mm256_srli_epi32(significand_i32x8, 20), _mm256_set1_epi32(1));
579
+ __m256i rounding_bias_i32x8 = _mm256_add_epi32(_mm256_set1_epi32(0x0007FFFF), lsb_i32x8);
580
+ __m256i rounded_sig_i32x8 = _mm256_add_epi32(significand_i32x8, rounding_bias_i32x8);
581
+ __m256i carry_i32x8 = _mm256_srli_epi32(rounded_sig_i32x8, 24); // Carry into exponent if bit 24 set
582
+ __m256i f32_mantissa_i32x8 = _mm256_and_si256(_mm256_srli_epi32(rounded_sig_i32x8, 20), _mm256_set1_epi32(0x07));
583
+ // If carry, mantissa becomes 0 (we rounded up to next power of 2)
584
+ f32_mantissa_i32x8 = _mm256_andnot_si256(_mm256_slli_epi32(carry_i32x8, 31), f32_mantissa_i32x8);
585
+ __m256i e2m3_exp_i32x8 = _mm256_sub_epi32(_mm256_add_epi32(f32_exp_i32x8, carry_i32x8), _mm256_set1_epi32(126));
586
+
587
+ // Detect underflow (exp <= 0, maps to subnormal/zero) and overflow (exp > 3)
588
+ __m256i is_subnormal_i32x8 = _mm256_cmpgt_epi32(_mm256_set1_epi32(1), e2m3_exp_i32x8);
589
+ __m256i overflow_i32x8 = _mm256_cmpgt_epi32(e2m3_exp_i32x8, _mm256_set1_epi32(3));
590
+
591
+ // Normal path: clamp exp to [1,3], extract mantissa bits
592
+ __m256i clamped_exp_i32x8 = _mm256_max_epi32(e2m3_exp_i32x8, _mm256_set1_epi32(1));
593
+ clamped_exp_i32x8 = _mm256_min_epi32(clamped_exp_i32x8, _mm256_set1_epi32(3));
594
+ __m256i normal_mantissa_i32x8 = _mm256_blendv_epi8(f32_mantissa_i32x8, _mm256_set1_epi32(0x07), overflow_i32x8);
595
+ __m256i normal_e2m3_i32x8 = _mm256_or_si256(
596
+ _mm256_slli_epi32(sign_i32x8, 5),
597
+ _mm256_or_si256(_mm256_slli_epi32(clamped_exp_i32x8, 3), normal_mantissa_i32x8));
598
+
599
+ // Subnormal path: mantissa = round(abs_f32 * 8)
600
+ // If mantissa rounds to 8 or higher, promote to first normal (exp_field=1, mantissa=0) = 0x08
601
+ __m256 abs_f32x8 = _mm256_and_ps(f32x8, _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF)));
602
+ __m256 scaled_f32x8 = _mm256_mul_ps(abs_f32x8, _mm256_set1_ps(8.0f));
603
+ __m256i subnorm_mantissa_i32x8 = _mm256_cvtps_epi32(scaled_f32x8);
604
+ __m256i promotes_to_normal_i32x8 = _mm256_cmpgt_epi32(subnorm_mantissa_i32x8, _mm256_set1_epi32(7));
605
+ subnorm_mantissa_i32x8 = _mm256_min_epi32(subnorm_mantissa_i32x8, _mm256_set1_epi32(7));
606
+ subnorm_mantissa_i32x8 = _mm256_max_epi32(subnorm_mantissa_i32x8, _mm256_setzero_si256());
607
+ __m256i subnorm_e2m3_i32x8 = _mm256_or_si256(_mm256_slli_epi32(sign_i32x8, 5), subnorm_mantissa_i32x8);
608
+ // When mantissa rounds to 8, use first normal value (0x08) instead of clamped subnormal
609
+ __m256i first_normal_e2m3_i32x8 = _mm256_or_si256(_mm256_slli_epi32(sign_i32x8, 5), _mm256_set1_epi32(0x08));
610
+ subnorm_e2m3_i32x8 = _mm256_blendv_epi8(subnorm_e2m3_i32x8, first_normal_e2m3_i32x8, promotes_to_normal_i32x8);
611
+
612
+ // Blend: use subnormal result when exp <= 0, else normal
613
+ __m256i e2m3_i32x8 = _mm256_blendv_epi8(normal_e2m3_i32x8, subnorm_e2m3_i32x8, is_subnormal_i32x8);
614
+
615
+ // Pack 8 i32s to 8 unsigned i8s (use unsigned saturation to preserve values 128-255)
616
+ __m128i low_i32x4 = _mm256_castsi256_si128(e2m3_i32x8);
617
+ __m128i high_i32x4 = _mm256_extracti128_si256(e2m3_i32x8, 1);
618
+ __m128i packed_i16x8 = _mm_packus_epi32(low_i32x4, high_i32x4);
619
+ __m128i packed_i8x8 = _mm_packus_epi16(packed_i16x8, packed_i16x8);
620
+ return packed_i8x8;
621
+ }
622
+
623
+ /** @brief Convert 8x f32 → 8x e3m2 via bit manipulation (AVX2).
624
+ * E3M2 format: S EEE MM (bias=3). Handles normal, subnormal, and overflow cases.
625
+ * Subnormals (f32_exp ≤ 124): mantissa = round(abs_f32 * 16), clamped to [0,3]. */
626
+ NK_INTERNAL __m128i nk_f32x8_to_e3m2x8_haswell_(__m256 f32x8) {
627
+ __m256i bits_i32x8 = _mm256_castps_si256(f32x8);
628
+ __m256i sign_i32x8 = _mm256_srli_epi32(bits_i32x8, 31);
629
+ __m256i f32_exp_i32x8 = _mm256_and_si256(_mm256_srli_epi32(bits_i32x8, 23), _mm256_set1_epi32(0xFF));
630
+
631
+ // Round mantissa from 23 to 2 bits using RNE (round to nearest, ties to even)
632
+ __m256i significand_i32x8 = _mm256_or_si256(_mm256_and_si256(bits_i32x8, _mm256_set1_epi32(0x007FFFFF)),
633
+ _mm256_set1_epi32(0x00800000)); // Add implicit 1 bit
634
+ __m256i lsb_i32x8 = _mm256_and_si256(_mm256_srli_epi32(significand_i32x8, 21), _mm256_set1_epi32(1));
635
+ __m256i rounding_bias_i32x8 = _mm256_add_epi32(_mm256_set1_epi32(0x000FFFFF), lsb_i32x8);
636
+ __m256i rounded_sig_i32x8 = _mm256_add_epi32(significand_i32x8, rounding_bias_i32x8);
637
+ __m256i carry_i32x8 = _mm256_srli_epi32(rounded_sig_i32x8, 24); // Carry into exponent if bit 24 set
638
+ __m256i f32_mantissa_i32x8 = _mm256_and_si256(_mm256_srli_epi32(rounded_sig_i32x8, 21), _mm256_set1_epi32(0x03));
639
+ // If carry, mantissa becomes 0 (we rounded up to next power of 2)
640
+ f32_mantissa_i32x8 = _mm256_andnot_si256(_mm256_slli_epi32(carry_i32x8, 31), f32_mantissa_i32x8);
641
+ __m256i e3m2_exp_i32x8 = _mm256_sub_epi32(_mm256_add_epi32(f32_exp_i32x8, carry_i32x8), _mm256_set1_epi32(124));
642
+
643
+ // Detect underflow (exp <= 0, maps to subnormal/zero) and overflow (exp > 7)
644
+ __m256i is_subnormal_i32x8 = _mm256_cmpgt_epi32(_mm256_set1_epi32(1), e3m2_exp_i32x8);
645
+ __m256i overflow_i32x8 = _mm256_cmpgt_epi32(e3m2_exp_i32x8, _mm256_set1_epi32(7));
646
+
647
+ // Normal path: clamp exp to [1,7], extract mantissa bits
648
+ __m256i clamped_exp_i32x8 = _mm256_max_epi32(e3m2_exp_i32x8, _mm256_set1_epi32(1));
649
+ clamped_exp_i32x8 = _mm256_min_epi32(clamped_exp_i32x8, _mm256_set1_epi32(7));
650
+ __m256i normal_mantissa_i32x8 = _mm256_blendv_epi8(f32_mantissa_i32x8, _mm256_set1_epi32(0x03), overflow_i32x8);
651
+ __m256i normal_e3m2_i32x8 = _mm256_or_si256(
652
+ _mm256_slli_epi32(sign_i32x8, 5),
653
+ _mm256_or_si256(_mm256_slli_epi32(clamped_exp_i32x8, 2), normal_mantissa_i32x8));
654
+
655
+ // Subnormal path: mantissa = round(abs_f32 * 16)
656
+ // If mantissa rounds to 4 or higher, promote to first normal (exp_field=1, mantissa=0) = 0x04
657
+ __m256 abs_f32x8 = _mm256_and_ps(f32x8, _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF)));
658
+ __m256 scaled_f32x8 = _mm256_mul_ps(abs_f32x8, _mm256_set1_ps(16.0f));
659
+ __m256i subnorm_mantissa_i32x8 = _mm256_cvtps_epi32(scaled_f32x8);
660
+ __m256i promotes_to_normal_i32x8 = _mm256_cmpgt_epi32(subnorm_mantissa_i32x8, _mm256_set1_epi32(3));
661
+ subnorm_mantissa_i32x8 = _mm256_min_epi32(subnorm_mantissa_i32x8, _mm256_set1_epi32(3));
662
+ subnorm_mantissa_i32x8 = _mm256_max_epi32(subnorm_mantissa_i32x8, _mm256_setzero_si256());
663
+ __m256i subnorm_e3m2_i32x8 = _mm256_or_si256(_mm256_slli_epi32(sign_i32x8, 5), subnorm_mantissa_i32x8);
664
+ // When mantissa rounds to 4, use first normal value (0x04) instead of clamped subnormal
665
+ __m256i first_normal_e3m2_i32x8 = _mm256_or_si256(_mm256_slli_epi32(sign_i32x8, 5), _mm256_set1_epi32(0x04));
666
+ subnorm_e3m2_i32x8 = _mm256_blendv_epi8(subnorm_e3m2_i32x8, first_normal_e3m2_i32x8, promotes_to_normal_i32x8);
667
+
668
+ // Blend: use subnormal result when exp <= 0
669
+ __m256i e3m2_i32x8 = _mm256_blendv_epi8(normal_e3m2_i32x8, subnorm_e3m2_i32x8, is_subnormal_i32x8);
670
+
671
+ // Pack 8 i32s to 8 unsigned i8s (use unsigned saturation to preserve values 128-255)
672
+ __m128i low_i32x4 = _mm256_castsi256_si128(e3m2_i32x8);
673
+ __m128i high_i32x4 = _mm256_extracti128_si256(e3m2_i32x8, 1);
674
+ __m128i packed_i16x8 = _mm_packus_epi32(low_i32x4, high_i32x4);
675
+ __m128i packed_i8x8 = _mm_packus_epi16(packed_i16x8, packed_i16x8);
676
+ return packed_i8x8;
677
+ }
678
+
679
+ #pragma endregion - Vectorized Conversions
680
+
681
+ #pragma region - Converting Loads and Stores
682
+
683
+ /** @brief Full load for f16 elements (8) with conversion to f32 via F16C. */
684
+ NK_INTERNAL void nk_load_f16x8_to_f32x8_haswell_(void const *src, nk_b256_vec_t *dst) {
685
+ dst->ymm_ps = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)src));
686
+ }
687
+
688
+ /** @brief Partial load for f16 elements (up to 8) with conversion to f32 via F16C. */
689
+ NK_INTERNAL void nk_partial_load_f16x8_to_f32x8_haswell_(nk_f16_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
690
+ nk_b128_vec_t vec;
691
+ nk_partial_load_b16x8_serial_(src, &vec, n);
692
+ dst->ymm_ps = _mm256_cvtph_ps(vec.xmm);
693
+ }
694
+
695
+ /** @brief Full load for bf16 elements (8) with conversion to f32. */
696
+ NK_INTERNAL void nk_load_bf16x8_to_f32x8_haswell_(void const *src, nk_b256_vec_t *dst) {
697
+ dst->ymm_ps = nk_bf16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)src));
698
+ }
699
+
700
+ /** @brief Partial load for bf16 elements (up to 8) with conversion to f32. */
701
+ NK_INTERNAL void nk_partial_load_bf16x8_to_f32x8_haswell_(nk_bf16_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
702
+ nk_b128_vec_t vec;
703
+ nk_partial_load_b16x8_serial_(src, &vec, n);
704
+ dst->ymm_ps = nk_bf16x8_to_f32x8_haswell_(vec.xmm);
705
+ }
706
+
707
+ /** @brief Full load for e4m3 elements (8) with conversion to f32. */
708
+ NK_INTERNAL void nk_load_e4m3x8_to_f32x8_haswell_(void const *src, nk_b256_vec_t *dst) {
709
+ dst->ymm_ps = nk_e4m3x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)src));
710
+ }
711
+
712
+ /** @brief Partial load for e4m3 elements (up to 8) with conversion to f32. */
713
+ NK_INTERNAL void nk_partial_load_e4m3x8_to_f32x8_haswell_(nk_e4m3_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
714
+ nk_b64_vec_t vec;
715
+ nk_partial_load_b8x8_serial_(src, &vec, n);
716
+ dst->ymm_ps = nk_e4m3x8_to_f32x8_haswell_(_mm_cvtsi64_si128(vec.u64));
717
+ }
718
+
719
+ /** @brief Full load for e5m2 elements (8) with conversion to f32. */
720
+ NK_INTERNAL void nk_load_e5m2x8_to_f32x8_haswell_(void const *src, nk_b256_vec_t *dst) {
721
+ dst->ymm_ps = nk_e5m2x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)src));
722
+ }
723
+
724
+ /** @brief Partial load for e5m2 elements (up to 8) with conversion to f32. */
725
+ NK_INTERNAL void nk_partial_load_e5m2x8_to_f32x8_haswell_(nk_e5m2_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
726
+ nk_b64_vec_t vec;
727
+ nk_partial_load_b8x8_serial_(src, &vec, n);
728
+ dst->ymm_ps = nk_e5m2x8_to_f32x8_haswell_(_mm_cvtsi64_si128(vec.u64));
729
+ }
730
+
731
+ /** @brief Full load for e2m3 elements (8) with conversion to f32. */
732
+ NK_INTERNAL void nk_load_e2m3x8_to_f32x8_haswell_(void const *src, nk_b256_vec_t *dst) {
733
+ dst->ymm_ps = nk_e2m3x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)src));
734
+ }
735
+
736
+ /** @brief Partial load for e2m3 elements (up to 8) with conversion to f32. */
737
+ NK_INTERNAL void nk_partial_load_e2m3x8_to_f32x8_haswell_(nk_e2m3_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
738
+ nk_b64_vec_t vec;
739
+ nk_partial_load_b8x8_serial_(src, &vec, n);
740
+ dst->ymm_ps = nk_e2m3x8_to_f32x8_haswell_(_mm_cvtsi64_si128(vec.u64));
741
+ }
742
+
743
+ /** @brief Full load for e3m2 elements (8) with conversion to f32. */
744
+ NK_INTERNAL void nk_load_e3m2x8_to_f32x8_haswell_(void const *src, nk_b256_vec_t *dst) {
745
+ dst->ymm_ps = nk_e3m2x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)src));
746
+ }
747
+
748
+ /** @brief Partial load for e3m2 elements (up to 8) with conversion to f32. */
749
+ NK_INTERNAL void nk_partial_load_e3m2x8_to_f32x8_haswell_(nk_e3m2_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
750
+ nk_b64_vec_t vec;
751
+ nk_partial_load_b8x8_serial_(src, &vec, n);
752
+ dst->ymm_ps = nk_e3m2x8_to_f32x8_haswell_(_mm_cvtsi64_si128(vec.u64));
753
+ }
754
+
755
+ /** @brief Partial load for i8 elements (up to 8) with conversion to f32. */
756
+ NK_INTERNAL void nk_partial_load_i8x8_to_f32x8_haswell_(nk_i8_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
757
+ nk_b64_vec_t vec;
758
+ nk_partial_load_b8x8_serial_(src, &vec, n);
759
+ dst->ymm_ps = nk_i8x8_to_f32x8_haswell_(_mm_cvtsi64_si128(vec.u64));
760
+ }
761
+
762
+ /** @brief Partial load for u8 elements (up to 8) with conversion to f32. */
763
+ NK_INTERNAL void nk_partial_load_u8x8_to_f32x8_haswell_(nk_u8_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
764
+ nk_b64_vec_t vec;
765
+ nk_partial_load_b8x8_serial_(src, &vec, n);
766
+ dst->ymm_ps = nk_u8x8_to_f32x8_haswell_(_mm_cvtsi64_si128(vec.u64));
767
+ }
768
+
769
+ /** @brief Partial load for i16 elements (up to 8) with conversion to f32. */
770
+ NK_INTERNAL void nk_partial_load_i16x8_to_f32x8_haswell_(nk_i16_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
771
+ nk_b128_vec_t vec;
772
+ nk_partial_load_b16x8_serial_(src, &vec, n);
773
+ dst->ymm_ps = nk_i16x8_to_f32x8_haswell_(vec.xmm);
774
+ }
775
+
776
+ /** @brief Partial load for u16 elements (up to 8) with conversion to f32. */
777
+ NK_INTERNAL void nk_partial_load_u16x8_to_f32x8_haswell_(nk_u16_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
778
+ nk_b128_vec_t vec;
779
+ nk_partial_load_b16x8_serial_(src, &vec, n);
780
+ dst->ymm_ps = nk_u16x8_to_f32x8_haswell_(vec.xmm);
781
+ }
782
+
783
+ /** @brief Partial load for i32 elements (up to 8) with conversion to f32. */
784
+ NK_INTERNAL void nk_partial_load_i32x8_to_f32x8_haswell_(nk_i32_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
785
+ nk_b256_vec_t vec;
786
+ nk_partial_load_b32x8_serial_(src, &vec, n);
787
+ dst->ymm_ps = nk_i32x8_to_f32x8_haswell_(vec.ymm);
788
+ }
789
+
790
+ /** @brief Partial load for u32 elements (up to 8) with conversion to f32. */
791
+ NK_INTERNAL void nk_partial_load_u32x8_to_f32x8_haswell_(nk_u32_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
792
+ nk_b256_vec_t vec;
793
+ nk_partial_load_b32x8_serial_(src, &vec, n);
794
+ dst->ymm_ps = nk_u32x8_to_f32x8_haswell_(vec.ymm);
795
+ }
796
+
797
+ #pragma endregion - Converting Loads and Stores
798
+
799
+ #pragma region - Public API
800
+
801
+ NK_PUBLIC void nk_cast_haswell(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type) {
802
+ // Same-type fast path
803
+ if (from_type == to_type) {
804
+ nk_size_t size_bits = nk_dtype_bits(from_type);
805
+ if (size_bits > 0) nk_copy_bytes_(to, from, nk_size_divide_round_up_(n * size_bits, NK_BITS_PER_BYTE));
806
+ return;
807
+ }
808
+
809
+ // Supported types: floats (f32, f16, bf16, e4m3, e5m2, e2m3, e3m2) and integers (i8, u8, i16, u16, i32, u32)
810
+ int from_supported = (from_type == nk_f32_k || from_type == nk_f16_k || from_type == nk_bf16_k ||
811
+ from_type == nk_e4m3_k || from_type == nk_e5m2_k || from_type == nk_e2m3_k ||
812
+ from_type == nk_e3m2_k || from_type == nk_i8_k || from_type == nk_u8_k ||
813
+ from_type == nk_i16_k || from_type == nk_u16_k || from_type == nk_i32_k ||
814
+ from_type == nk_u32_k);
815
+ int to_supported = (to_type == nk_f32_k || to_type == nk_f16_k || to_type == nk_bf16_k || to_type == nk_e4m3_k ||
816
+ to_type == nk_e5m2_k || to_type == nk_e2m3_k || to_type == nk_e3m2_k || to_type == nk_i8_k ||
817
+ to_type == nk_u8_k || to_type == nk_i16_k || to_type == nk_u16_k || to_type == nk_i32_k ||
818
+ to_type == nk_u32_k);
819
+ if (!from_supported || !to_supported) {
820
+ nk_cast_serial(from, from_type, n, to, to_type);
821
+ return;
822
+ }
823
+
824
+ // Fall back to serial for i32/u32↔i32/u32 (f32 intermediate loses precision for large values)
825
+ int from_32bit_int = (from_type == nk_i32_k || from_type == nk_u32_k);
826
+ int to_32bit_int = (to_type == nk_i32_k || to_type == nk_u32_k);
827
+ if (from_32bit_int && to_32bit_int) {
828
+ nk_cast_serial(from, from_type, n, to, to_type);
829
+ return;
830
+ }
831
+
832
+ // Byte steps per 8 elements
833
+ nk_size_t from_step = 8 * nk_dtype_bits(from_type) / NK_BITS_PER_BYTE;
834
+ nk_size_t to_step = 8 * nk_dtype_bits(to_type) / NK_BITS_PER_BYTE;
835
+
836
+ nk_u8_t const *from_ptr = (nk_u8_t const *)from;
837
+ nk_u8_t *to_ptr = (nk_u8_t *)to;
838
+ nk_size_t batches = n / 8;
839
+ nk_size_t tail = n % 8;
840
+ nk_b256_vec_t hub;
841
+
842
+ for (nk_size_t idx = 0; idx < batches; ++idx, from_ptr += from_step, to_ptr += to_step) {
843
+ // Upcast to f32x8
844
+ if (from_type == nk_f32_k) hub.ymm_ps = _mm256_loadu_ps((float const *)from_ptr);
845
+ else if (from_type == nk_f16_k) hub.ymm_ps = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const *)from_ptr));
846
+ else if (from_type == nk_bf16_k)
847
+ hub.ymm_ps = nk_bf16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)from_ptr));
848
+ else if (from_type == nk_e4m3_k)
849
+ hub.ymm_ps = nk_e4m3x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)from_ptr));
850
+ else if (from_type == nk_e5m2_k)
851
+ hub.ymm_ps = nk_e5m2x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)from_ptr));
852
+ else if (from_type == nk_e2m3_k)
853
+ hub.ymm_ps = nk_e2m3x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)from_ptr));
854
+ else if (from_type == nk_e3m2_k)
855
+ hub.ymm_ps = nk_e3m2x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)from_ptr));
856
+ else if (from_type == nk_i8_k)
857
+ hub.ymm_ps = nk_i8x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)from_ptr));
858
+ else if (from_type == nk_u8_k)
859
+ hub.ymm_ps = nk_u8x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)from_ptr));
860
+ else if (from_type == nk_i16_k)
861
+ hub.ymm_ps = nk_i16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)from_ptr));
862
+ else if (from_type == nk_u16_k)
863
+ hub.ymm_ps = nk_u16x8_to_f32x8_haswell_(_mm_loadu_si128((__m128i const *)from_ptr));
864
+ else if (from_type == nk_i32_k)
865
+ hub.ymm_ps = nk_i32x8_to_f32x8_haswell_(_mm256_loadu_si256((__m256i const *)from_ptr));
866
+ else if (from_type == nk_u32_k)
867
+ hub.ymm_ps = nk_u32x8_to_f32x8_haswell_(_mm256_loadu_si256((__m256i const *)from_ptr));
868
+
869
+ // Downcast from f32x8
870
+ if (to_type == nk_f32_k) _mm256_storeu_ps((float *)to_ptr, hub.ymm_ps);
871
+ else if (to_type == nk_f16_k)
872
+ _mm_storeu_si128((__m128i *)to_ptr, _mm256_cvtps_ph(hub.ymm_ps, _MM_FROUND_TO_NEAREST_INT));
873
+ else if (to_type == nk_bf16_k) _mm_storeu_si128((__m128i *)to_ptr, nk_f32x8_to_bf16x8_haswell_(hub.ymm_ps));
874
+ else if (to_type == nk_e4m3_k) _mm_storel_epi64((__m128i *)to_ptr, nk_f32x8_to_e4m3x8_haswell_(hub.ymm_ps));
875
+ else if (to_type == nk_e5m2_k) _mm_storel_epi64((__m128i *)to_ptr, nk_f32x8_to_e5m2x8_haswell_(hub.ymm_ps));
876
+ else if (to_type == nk_e2m3_k) _mm_storel_epi64((__m128i *)to_ptr, nk_f32x8_to_e2m3x8_haswell_(hub.ymm_ps));
877
+ else if (to_type == nk_e3m2_k) _mm_storel_epi64((__m128i *)to_ptr, nk_f32x8_to_e3m2x8_haswell_(hub.ymm_ps));
878
+ else if (to_type == nk_i8_k) _mm_storel_epi64((__m128i *)to_ptr, nk_f32x8_to_i8x8_haswell_(hub.ymm_ps));
879
+ else if (to_type == nk_u8_k) _mm_storel_epi64((__m128i *)to_ptr, nk_f32x8_to_u8x8_haswell_(hub.ymm_ps));
880
+ else if (to_type == nk_i16_k) _mm_storeu_si128((__m128i *)to_ptr, nk_f32x8_to_i16x8_haswell_(hub.ymm_ps));
881
+ else if (to_type == nk_u16_k) _mm_storeu_si128((__m128i *)to_ptr, nk_f32x8_to_u16x8_haswell_(hub.ymm_ps));
882
+ else if (to_type == nk_i32_k) _mm256_storeu_si256((__m256i *)to_ptr, nk_f32x8_to_i32x8_haswell_(hub.ymm_ps));
883
+ else if (to_type == nk_u32_k) _mm256_storeu_si256((__m256i *)to_ptr, nk_f32x8_to_u32x8_haswell_(hub.ymm_ps));
884
+ }
885
+
886
+ // Handle tail with partial loads/stores
887
+ if (tail) {
888
+ // Upcast tail to f32x8
889
+ if (from_type == nk_f32_k) nk_partial_load_b32x8_serial_(from_ptr, &hub, tail);
890
+ else if (from_type == nk_f16_k) nk_partial_load_f16x8_to_f32x8_haswell_((nk_f16_t const *)from_ptr, &hub, tail);
891
+ else if (from_type == nk_bf16_k)
892
+ nk_partial_load_bf16x8_to_f32x8_haswell_((nk_bf16_t const *)from_ptr, &hub, tail);
893
+ else if (from_type == nk_e4m3_k)
894
+ nk_partial_load_e4m3x8_to_f32x8_haswell_((nk_e4m3_t const *)from_ptr, &hub, tail);
895
+ else if (from_type == nk_e5m2_k)
896
+ nk_partial_load_e5m2x8_to_f32x8_haswell_((nk_e5m2_t const *)from_ptr, &hub, tail);
897
+ else if (from_type == nk_e2m3_k)
898
+ nk_partial_load_e2m3x8_to_f32x8_haswell_((nk_e2m3_t const *)from_ptr, &hub, tail);
899
+ else if (from_type == nk_e3m2_k)
900
+ nk_partial_load_e3m2x8_to_f32x8_haswell_((nk_e3m2_t const *)from_ptr, &hub, tail);
901
+ else if (from_type == nk_i8_k) nk_partial_load_i8x8_to_f32x8_haswell_((nk_i8_t const *)from_ptr, &hub, tail);
902
+ else if (from_type == nk_u8_k) nk_partial_load_u8x8_to_f32x8_haswell_((nk_u8_t const *)from_ptr, &hub, tail);
903
+ else if (from_type == nk_i16_k) nk_partial_load_i16x8_to_f32x8_haswell_((nk_i16_t const *)from_ptr, &hub, tail);
904
+ else if (from_type == nk_u16_k) nk_partial_load_u16x8_to_f32x8_haswell_((nk_u16_t const *)from_ptr, &hub, tail);
905
+ else if (from_type == nk_i32_k) nk_partial_load_i32x8_to_f32x8_haswell_((nk_i32_t const *)from_ptr, &hub, tail);
906
+ else if (from_type == nk_u32_k) nk_partial_load_u32x8_to_f32x8_haswell_((nk_u32_t const *)from_ptr, &hub, tail);
907
+
908
+ // Downcast and store tail
909
+ if (to_type == nk_f32_k) nk_partial_store_b32x8_serial_(&hub, to_ptr, tail);
910
+ else if (to_type == nk_f16_k) {
911
+ hub.xmms[0] = _mm256_cvtps_ph(hub.ymm_ps, _MM_FROUND_TO_NEAREST_INT);
912
+ nk_partial_store_b16x8_serial_((nk_b128_vec_t *)&hub, to_ptr, tail);
913
+ }
914
+ else if (to_type == nk_bf16_k) {
915
+ hub.xmms[0] = nk_f32x8_to_bf16x8_haswell_(hub.ymm_ps);
916
+ nk_partial_store_b16x8_serial_((nk_b128_vec_t *)&hub, to_ptr, tail);
917
+ }
918
+ else if (to_type == nk_e4m3_k) {
919
+ hub.xmms[0] = nk_f32x8_to_e4m3x8_haswell_(hub.ymm_ps);
920
+ nk_partial_store_b8x8_serial_((nk_b64_vec_t *)&hub, to_ptr, tail);
921
+ }
922
+ else if (to_type == nk_e5m2_k) {
923
+ hub.xmms[0] = nk_f32x8_to_e5m2x8_haswell_(hub.ymm_ps);
924
+ nk_partial_store_b8x8_serial_((nk_b64_vec_t *)&hub, to_ptr, tail);
925
+ }
926
+ else if (to_type == nk_e2m3_k) {
927
+ hub.xmms[0] = nk_f32x8_to_e2m3x8_haswell_(hub.ymm_ps);
928
+ nk_partial_store_b8x8_serial_((nk_b64_vec_t *)&hub, to_ptr, tail);
929
+ }
930
+ else if (to_type == nk_e3m2_k) {
931
+ hub.xmms[0] = nk_f32x8_to_e3m2x8_haswell_(hub.ymm_ps);
932
+ nk_partial_store_b8x8_serial_((nk_b64_vec_t *)&hub, to_ptr, tail);
933
+ }
934
+ else if (to_type == nk_i8_k) {
935
+ hub.xmms[0] = nk_f32x8_to_i8x8_haswell_(hub.ymm_ps);
936
+ nk_partial_store_b8x8_serial_((nk_b64_vec_t *)&hub, to_ptr, tail);
937
+ }
938
+ else if (to_type == nk_u8_k) {
939
+ hub.xmms[0] = nk_f32x8_to_u8x8_haswell_(hub.ymm_ps);
940
+ nk_partial_store_b8x8_serial_((nk_b64_vec_t *)&hub, to_ptr, tail);
941
+ }
942
+ else if (to_type == nk_i16_k) {
943
+ hub.xmms[0] = nk_f32x8_to_i16x8_haswell_(hub.ymm_ps);
944
+ nk_partial_store_b16x8_serial_((nk_b128_vec_t *)&hub, to_ptr, tail);
945
+ }
946
+ else if (to_type == nk_u16_k) {
947
+ hub.xmms[0] = nk_f32x8_to_u16x8_haswell_(hub.ymm_ps);
948
+ nk_partial_store_b16x8_serial_((nk_b128_vec_t *)&hub, to_ptr, tail);
949
+ }
950
+ else if (to_type == nk_i32_k) {
951
+ hub.ymm = nk_f32x8_to_i32x8_haswell_(hub.ymm_ps);
952
+ nk_partial_store_b32x8_serial_(&hub, to_ptr, tail);
953
+ }
954
+ else if (to_type == nk_u32_k) {
955
+ hub.ymm = nk_f32x8_to_u32x8_haswell_(hub.ymm_ps);
956
+ nk_partial_store_b32x8_serial_(&hub, to_ptr, tail);
957
+ }
958
+ }
959
+ }
960
+
961
+ #pragma endregion - Public API
962
+
963
+ #if defined(__clang__)
964
+ #pragma clang attribute pop
965
+ #elif defined(__GNUC__)
966
+ #pragma GCC pop_options
967
+ #endif
968
+
969
+ #if defined(__cplusplus)
970
+ } // extern "C"
971
+ #endif
972
+
973
+ #endif // NK_TARGET_HASWELL
974
+ #endif // NK_TARGET_X86_
975
+ #endif // NK_CAST_HASWELL_H