numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -6,17 +6,17 @@ All conversions use round-to-nearest-even (RNE) for narrowing and exact widening
6
6
 
7
7
  BFloat16 relates to Float32 by truncation with rounding:
8
8
 
9
- ```math
9
+ $$
10
10
  \text{bf16} \approx \text{f32} \gg 16
11
- ```
11
+ $$
12
12
 
13
13
  With RNE tie-breaking to preserve the least significant bit of the truncated result.
14
14
 
15
15
  Float16 range and precision:
16
16
 
17
- ```math
17
+ $$
18
18
  \text{f16} \in [-65504, 65504], \quad \text{min positive normal} = 2^{-14}
19
- ```
19
+ $$
20
20
 
21
21
  Reformulating as Python pseudocode:
22
22
 
@@ -194,69 +194,69 @@ Measured with Wasmtime v42 (Cranelift backend).
194
194
  | __f32 ↔ e4m3__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░ |
195
195
  | `nk_cast_serial` | ? gb/s | ? gb/s | 0.239 gb/s | ? gb/s | ? gb/s | 0.746 gb/s |
196
196
 
197
- ### Apple M4
197
+ ### Apple M5
198
198
 
199
199
  #### Native
200
200
 
201
201
  | Kernel | ↓ 256 | ↓ 1K | ↓ 4K | ↑ 256 | ↑ 1K | ↑ 4K |
202
202
  | :--------------- | -----------: | -----------: | -----------: | -----------: | -----------: | -----------: |
203
203
  | __f32 ↔ bf16__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
204
- | `nk_cast_serial` | 10.2 gb/s | 10.6 gb/s | 10.7 gb/s | 8.06 gb/s | 8.34 gb/s | 8.32 gb/s |
205
- | `nk_cast_neon` | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s |
204
+ | `nk_cast_serial` | 1.37 gb/s | 1.35 gb/s | 1.41 gb/s | 1.37 gb/s | 1.34 gb/s | 1.38 gb/s |
205
+ | `nk_cast_neon` | 19.3 gb/s | 23.7 gb/s | 23.2 gb/s | 59.4 gb/s | 58.9 gb/s | 57.3 gb/s |
206
206
  | __f32 ↔ f16__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
207
- | `nk_cast_serial` | 10.9 gb/s | 11.3 gb/s | 11.4 gb/s | 8.40 gb/s | 8.62 gb/s | 8.70 gb/s |
208
- | `nk_cast_neon` | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s |
207
+ | `nk_cast_serial` | 1.37 gb/s | 1.31 gb/s | 1.32 gb/s | 1.37 gb/s | 1.31 gb/s | 1.40 gb/s |
208
+ | `nk_cast_neon` | 20.1 gb/s | 21.9 gb/s | 25.0 gb/s | 52.1 gb/s | 60.2 gb/s | 70.2 gb/s |
209
209
  | __f32 ↔ e5m2__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
210
- | `nk_cast_serial` | 1.65 gb/s | 1.52 gb/s | 1.36 gb/s | 5.96 gb/s | 6.08 gb/s | 6.11 gb/s |
211
- | `nk_cast_neon` | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s |
210
+ | `nk_cast_serial` | 0.681 gb/s | 0.621 gb/s | 0.600 gb/s | 1.17 gb/s | 1.17 gb/s | 1.23 gb/s |
211
+ | `nk_cast_neon` | 8.50 gb/s | 8.45 gb/s | 8.35 gb/s | 40.6 gb/s | 46.5 gb/s | 46.5 gb/s |
212
212
  | __f32 ↔ e4m3__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
213
- | `nk_cast_serial` | 1.49 gb/s | 1.36 gb/s | 1.24 gb/s | 4.96 gb/s | 5.05 gb/s | 4.81 gb/s |
214
- | `nk_cast_neon` | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s |
213
+ | `nk_cast_serial` | 0.683 gb/s | 0.618 gb/s | 0.586 gb/s | 1.02 gb/s | 1.01 gb/s | 1.02 gb/s |
214
+ | `nk_cast_neon` | 7.85 gb/s | 7.91 gb/s | 7.66 gb/s | 18.9 gb/s | 19.2 gb/s | 18.3 gb/s |
215
215
  | __f32 ↔ e3m2__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
216
- | `nk_cast_serial` | 2.17 gb/s | 2.13 gb/s | 1.97 gb/s | 5.90 gb/s | 6.02 gb/s | 6.07 gb/s |
217
- | `nk_cast_neon` | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s |
216
+ | `nk_cast_serial` | 0.702 gb/s | 0.632 gb/s | 0.596 gb/s | 1.17 gb/s | 1.13 gb/s | 1.15 gb/s |
217
+ | `nk_cast_neon` | 8.94 gb/s | 9.02 gb/s | 8.91 gb/s | 24.9 gb/s | 25.0 gb/s | 24.4 gb/s |
218
218
  | __f32 ↔ e2m3__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
219
- | `nk_cast_serial` | 2.54 gb/s | 2.45 gb/s | 2.23 gb/s | 5.88 gb/s | 6.11 gb/s | 6.10 gb/s |
220
- | `nk_cast_neon` | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s |
219
+ | `nk_cast_serial` | 0.921 gb/s | 0.843 gb/s | 0.715 gb/s | 1.21 gb/s | 1.21 gb/s | 1.26 gb/s |
220
+ | `nk_cast_neon` | 8.89 gb/s | 9.03 gb/s | 8.82 gb/s | 24.9 gb/s | 25.1 gb/s | 24.6 gb/s |
221
221
  | __f32 ↔ i16__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
222
- | `nk_cast_serial` | 6.13 gb/s | 5.99 gb/s | 6.10 gb/s | 8.29 gb/s | 8.53 gb/s | 8.58 gb/s |
223
- | `nk_cast_neon` | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s |
222
+ | `nk_cast_serial` | 0.785 gb/s | 0.679 gb/s | 0.678 gb/s | 1.44 gb/s | 1.39 gb/s | 1.49 gb/s |
223
+ | `nk_cast_neon` | 19.4 gb/s | 22.6 gb/s | 23.9 gb/s | 19.9 gb/s | 23.2 gb/s | 25.9 gb/s |
224
224
  | __f32 ↔ u16__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
225
- | `nk_cast_serial` | 5.36 gb/s | 5.01 gb/s | 4.49 gb/s | 8.43 gb/s | 8.64 gb/s | 8.76 gb/s |
226
- | `nk_cast_neon` | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s |
225
+ | `nk_cast_serial` | 0.916 gb/s | 0.822 gb/s | 0.726 gb/s | 1.37 gb/s | 1.36 gb/s | 1.48 gb/s |
226
+ | `nk_cast_neon` | 20.3 gb/s | 20.6 gb/s | 22.1 gb/s | 15.6 gb/s | 18.5 gb/s | 17.4 gb/s |
227
227
  | __f32 ↔ i8__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
228
- | `nk_cast_serial` | 4.83 gb/s | 4.89 gb/s | 5.09 gb/s | 6.67 gb/s | 6.92 gb/s | 7.08 gb/s |
229
- | `nk_cast_neon` | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s |
228
+ | `nk_cast_serial` | 0.725 gb/s | 0.616 gb/s | 0.578 gb/s | 1.21 gb/s | 1.21 gb/s | 1.28 gb/s |
229
+ | `nk_cast_neon` | 18.2 gb/s | 24.5 gb/s | 21.7 gb/s | 16.3 gb/s | 18.9 gb/s | 19.8 gb/s |
230
230
  | __f32 ↔ u8__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
231
- | `nk_cast_serial` | 4.31 gb/s | 4.10 gb/s | 3.62 gb/s | 7.03 gb/s | 7.21 gb/s | 7.28 gb/s |
232
- | `nk_cast_neon` | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s |
231
+ | `nk_cast_serial` | 0.967 gb/s | 0.795 gb/s | 0.723 gb/s | 1.29 gb/s | 1.25 gb/s | 1.40 gb/s |
232
+ | `nk_cast_neon` | 17.5 gb/s | 19.8 gb/s | 19.4 gb/s | 13.8 gb/s | 17.8 gb/s | 15.1 gb/s |
233
233
  | __f64 ↔ f32__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
234
- | `nk_cast_serial` | 17.3 gb/s | 17.8 gb/s | 18.1 gb/s | 17.9 gb/s | 18.5 gb/s | 18.5 gb/s |
235
- | `nk_cast_neon` | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s |
234
+ | `nk_cast_serial` | 2.65 gb/s | 2.60 gb/s | 2.70 gb/s | 2.59 gb/s | 2.55 gb/s | 2.65 gb/s |
235
+ | `nk_cast_neon` | 2.87 gb/s | 2.60 gb/s | 2.73 gb/s | 2.64 gb/s | 2.63 gb/s | 2.57 gb/s |
236
236
  | __f64 ↔ i64__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
237
- | `nk_cast_serial` | 16.8 gb/s | 17.2 gb/s | 17.0 gb/s | 23.9 gb/s | 24.7 gb/s | 24.8 gb/s |
238
- | `nk_cast_neon` | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s |
237
+ | `nk_cast_serial` | 2.42 gb/s | 2.00 gb/s | 1.86 gb/s | 3.79 gb/s | 3.61 gb/s | 4.03 gb/s |
238
+ | `nk_cast_neon` | 2.51 gb/s | 1.94 gb/s | 1.78 gb/s | 3.83 gb/s | 3.68 gb/s | 3.79 gb/s |
239
239
  | __f64 ↔ u64__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
240
- | `nk_cast_serial` | 13.5 gb/s | 12.8 gb/s | 11.3 gb/s | 24.4 gb/s | 25.0 gb/s | 25.1 gb/s |
241
- | `nk_cast_neon` | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s |
240
+ | `nk_cast_serial` | 2.56 gb/s | 2.19 gb/s | 2.06 gb/s | 3.71 gb/s | 3.50 gb/s | 3.87 gb/s |
241
+ | `nk_cast_neon` | 2.68 gb/s | 2.10 gb/s | 1.97 gb/s | 3.68 gb/s | 3.61 gb/s | 3.58 gb/s |
242
242
  | __f64 ↔ i32__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
243
- | `nk_cast_serial` | 12.1 gb/s | 12.4 gb/s | 12.6 gb/s | 18.2 gb/s | 18.9 gb/s | 19.2 gb/s |
244
- | `nk_cast_neon` | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s |
243
+ | `nk_cast_serial` | 1.58 gb/s | 1.32 gb/s | 1.29 gb/s | 2.65 gb/s | 2.58 gb/s | 2.84 gb/s |
244
+ | `nk_cast_neon` | 1.61 gb/s | 1.33 gb/s | 1.24 gb/s | 2.73 gb/s | 2.63 gb/s | 2.66 gb/s |
245
245
  | __f64 ↔ u32__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
246
- | `nk_cast_serial` | 10.9 gb/s | 10.6 gb/s | 9.58 gb/s | 17.6 gb/s | 18.0 gb/s | 18.1 gb/s |
247
- | `nk_cast_neon` | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s | ? gb/s |
246
+ | `nk_cast_serial` | 1.83 gb/s | 1.53 gb/s | 1.47 gb/s | 2.55 gb/s | 2.48 gb/s | 2.69 gb/s |
247
+ | `nk_cast_neon` | 1.89 gb/s | 1.53 gb/s | 1.38 gb/s | 2.56 gb/s | 2.54 gb/s | 2.59 gb/s |
248
248
 
249
249
  #### WASM
250
250
 
251
- Measured with Wasmtime v42 (Cranelift backend).
251
+ Measured with Wasmtime v43 (Cranelift backend).
252
252
 
253
253
  | Kernel | ↓ 256 | ↓ 1K | ↓ 4K | ↑ 256 | ↑ 1K | ↑ 4K |
254
254
  | :--------------- | -----------: | -----------: | -----------: | -----------: | -----------: | -----------: |
255
255
  | __f32 ↔ bf16__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
256
- | `nk_cast_serial` | 3.39 gb/s | 7.29 gb/s | 11.2 gb/s | 3.08 gb/s | 6.20 gb/s | 8.81 gb/s |
256
+ | `nk_cast_serial` | 0.514 gb/s | 0.522 gb/s | 0.538 gb/s | 0.511 gb/s | 0.526 gb/s | 0.519 gb/s |
257
257
  | __f32 ↔ f16__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
258
- | `nk_cast_serial` | 0.605 gb/s | 0.952 gb/s | 1.22 gb/s | 2.36 gb/s | 4.71 gb/s | 7.31 gb/s |
258
+ | `nk_cast_serial` | 0.368 gb/s | 0.363 gb/s | 0.360 gb/s | 0.490 gb/s | 0.480 gb/s | 0.489 gb/s |
259
259
  | __f32 ↔ e5m2__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
260
- | `nk_cast_serial` | 0.752 gb/s | 1.84 gb/s | 1.80 gb/s | 2.24 gb/s | 6.32 gb/s | 6.31 gb/s |
260
+ | `nk_cast_serial` | 0.323 gb/s | 0.312 gb/s | 0.304 gb/s | 0.423 gb/s | 0.425 gb/s | 0.425 gb/s |
261
261
  | __f32 ↔ e4m3__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
262
- | `nk_cast_serial` | 0.623 gb/s | 1.61 gb/s | 1.50 gb/s | 1.68 gb/s | 4.35 gb/s | 4.28 gb/s |
262
+ | `nk_cast_serial` | 0.315 gb/s | 0.304 gb/s | 0.295 gb/s | 0.396 gb/s | 0.396 gb/s | 0.397 gb/s |
@@ -0,0 +1,64 @@
1
+ /**
2
+ * @brief SIMD-accelerated Type Conversions for Diamond Rapids.
3
+ * @file include/numkong/cast/diamond.h
4
+ * @author Ash Vardanian
5
+ * @date March 23, 2026
6
+ *
7
+ * @sa include/numkong/cast/icelake.h
8
+ *
9
+ * Uses VCVTHF82PH (E4M3→FP16) and VCVTBF82PH (E5M2→FP16) for native 1-instruction
10
+ * FP8→FP16 conversion. Both conversions are exact (no rounding needed).
11
+ */
12
+ #ifndef NK_CAST_DIAMOND_H
13
+ #define NK_CAST_DIAMOND_H
14
+
15
+ #if NK_TARGET_X86_
16
+ #if NK_TARGET_DIAMOND
17
+
18
+ #include "numkong/types.h"
19
+
20
+ #if defined(__cplusplus)
21
+ extern "C" {
22
+ #endif
23
+
24
+ #if defined(__clang__)
25
+ #pragma clang attribute push( \
26
+ __attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,avx10.2-512,f16c,fma,bmi,bmi2"))), \
27
+ apply_to = function)
28
+ #elif defined(__GNUC__)
29
+ #pragma GCC push_options
30
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "avx10.2-512", "f16c", "fma", \
31
+ "bmi", "bmi2")
32
+ #endif
33
+
34
+ NK_INTERNAL void nk_load_e4m3x32_to_f16x32_diamond_(nk_e4m3_t const *src, nk_b512_vec_t *dst) {
35
+ dst->zmm_ph = _mm512_cvthf8_ph(_mm256_loadu_epi8(src));
36
+ }
37
+
38
+ NK_INTERNAL void nk_partial_load_e4m3x32_to_f16x32_diamond_(nk_e4m3_t const *src, nk_b512_vec_t *dst, nk_size_t count) {
39
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count);
40
+ dst->zmm_ph = _mm512_cvthf8_ph(_mm256_maskz_loadu_epi8(mask, src));
41
+ }
42
+
43
+ NK_INTERNAL void nk_load_e5m2x32_to_f16x32_diamond_(nk_e5m2_t const *src, nk_b512_vec_t *dst) {
44
+ dst->zmm_ph = _mm512_cvtbf8_ph(_mm256_loadu_epi8(src));
45
+ }
46
+
47
+ NK_INTERNAL void nk_partial_load_e5m2x32_to_f16x32_diamond_(nk_e5m2_t const *src, nk_b512_vec_t *dst, nk_size_t count) {
48
+ __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count);
49
+ dst->zmm_ph = _mm512_cvtbf8_ph(_mm256_maskz_loadu_epi8(mask, src));
50
+ }
51
+
52
+ #if defined(__clang__)
53
+ #pragma clang attribute pop
54
+ #elif defined(__GNUC__)
55
+ #pragma GCC pop_options
56
+ #endif
57
+
58
+ #if defined(__cplusplus)
59
+ } // extern "C"
60
+ #endif
61
+
62
+ #endif // NK_TARGET_DIAMOND
63
+ #endif // NK_TARGET_X86_
64
+ #endif // NK_CAST_DIAMOND_H
@@ -6,12 +6,12 @@
6
6
  *
7
7
  * @section haswell_cast_instructions Key F16C/AVX2 Conversion Instructions
8
8
  *
9
- * Intrinsic Instruction Latency Throughput Ports
10
- * _mm256_cvtph_ps VCVTPH2PS (YMM, XMM) 5cy 1/cy p01
11
- * _mm256_cvtps_ph VCVTPS2PH (XMM, YMM, I8) 4cy 1/cy p01+p5
12
- * _mm256_cvtepi16_epi32 VPMOVSXWD (YMM, XMM) 3cy 1/cy p5
13
- * _mm256_slli_epi32 VPSLLD (YMM, YMM, I8) 1cy 0.5/cy p01
14
- * _mm256_blendv_ps VBLENDVPS (YMM, YMM, YMM, YMM) 2cy 1/cy p015
9
+ * Intrinsic Instruction Haswell Genoa
10
+ * _mm256_cvtph_ps VCVTPH2PS (YMM, XMM) 5cy @ p01 4cy @ p12+p23
11
+ * _mm256_cvtps_ph VCVTPS2PH (XMM, YMM, I8) 5cy @ p01 4cy @ p12+p23
12
+ * _mm256_cvtepi16_epi32 VPMOVSXWD (YMM, XMM) 1cy @ p5 2cy @ p12
13
+ * _mm256_slli_epi32 VPSLLD (YMM, YMM, I8) 1cy @ p0 1cy @ p23
14
+ * _mm256_blendv_ps VBLENDVPS (YMM, YMM, YMM, YMM) 2cy @ p015 1cy @ p01
15
15
  *
16
16
  * F16C provides hardware F16<->F32 conversion. BF16 lacks hardware support and is emulated via
17
17
  * bit manipulation (shift upper 16 bits). FP8 formats (E4M3/E5M2) use lookup tables for subnormal
@@ -38,14 +38,14 @@ extern "C" {
38
38
  #endif
39
39
 
40
40
  NK_PUBLIC void nk_f32_to_f16_haswell(nk_f32_t const *from, nk_f16_t *to) {
41
- *to = _mm_cvtsi128_si32(_mm_cvtps_ph(_mm_set_ss(*from), _MM_FROUND_TO_NEAREST_INT));
41
+ *(nk_u16_t *)to = (nk_u16_t)_mm_cvtsi128_si32(_mm_cvtps_ph(_mm_set_ss(*from), _MM_FROUND_TO_NEAREST_INT));
42
42
  }
43
43
 
44
44
  NK_PUBLIC void nk_f16_to_f32_haswell(nk_f16_t const *from, nk_f32_t *to) {
45
- *to = _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(*from)));
45
+ *to = _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(*(nk_u16_t const *)from)));
46
46
  }
47
47
 
48
- #pragma region - Type Punned Loads and Stores
48
+ #pragma region Type Punned Loads and Stores
49
49
 
50
50
  /** @brief Type-agnostic 256-bit full load (Haswell AVX2). */
51
51
  NK_INTERNAL void nk_load_b256_haswell_(void const *src, nk_b256_vec_t *dst) {
@@ -99,9 +99,9 @@ NK_INTERNAL void nk_partial_store_b64x4_haswell_(nk_b256_vec_t const *src, void
99
99
  _mm256_maskstore_pd((double *)dst, mask_i64x4, _mm256_castsi256_pd(src->ymm));
100
100
  }
101
101
 
102
- #pragma endregion - Type Punned Loads and Stores
102
+ #pragma endregion Type Punned Loads and Stores
103
103
 
104
- #pragma region - Vectorized Conversions
104
+ #pragma region Vectorized Conversions
105
105
 
106
106
  /** @brief Convert 8x bf16 → 8x f32 by shifting left 16 bits (AVX2). */
107
107
  NK_INTERNAL __m256 nk_bf16x8_to_f32x8_haswell_(__m128i bf16_i16x8) {
@@ -116,9 +116,9 @@ NK_INTERNAL __m128i nk_f32x8_to_bf16x8_haswell_(__m256 f32x8) {
116
116
  __m256i rounded_i32x8 = _mm256_add_epi32(bits_i32x8, _mm256_add_epi32(_mm256_set1_epi32(0x7FFF), lsb_i32x8));
117
117
  __m256i bf16_i32x8 = _mm256_srli_epi32(rounded_i32x8, 16);
118
118
  // Pack 8x i32 to 8x i16
119
- __m128i lo_i32x4 = _mm256_castsi256_si128(bf16_i32x8);
120
- __m128i hi_i32x4 = _mm256_extracti128_si256(bf16_i32x8, 1);
121
- return _mm_packus_epi32(lo_i32x4, hi_i32x4);
119
+ __m128i low_i32x4 = _mm256_castsi256_si128(bf16_i32x8);
120
+ __m128i high_i32x4 = _mm256_extracti128_si256(bf16_i32x8, 1);
121
+ return _mm_packus_epi32(low_i32x4, high_i32x4);
122
122
  }
123
123
 
124
124
  /** @brief Integer upcasts to f32x8 (AVX2). */
@@ -132,10 +132,10 @@ NK_INTERNAL __m256 nk_u16x8_to_f32x8_haswell_(__m128i u16x8) {
132
132
  }
133
133
  NK_INTERNAL __m256 nk_i32x8_to_f32x8_haswell_(__m256i i32x8) { return _mm256_cvtepi32_ps(i32x8); }
134
134
  NK_INTERNAL __m256 nk_u32x8_to_f32x8_haswell_(__m256i u32x8) {
135
- __m256i lo_i32x8 = _mm256_and_si256(u32x8, _mm256_set1_epi32(0xFFFF));
136
- __m256i hi_i32x8 = _mm256_srli_epi32(u32x8, 16);
137
- return _mm256_add_ps(_mm256_cvtepi32_ps(lo_i32x8),
138
- _mm256_mul_ps(_mm256_cvtepi32_ps(hi_i32x8), _mm256_set1_ps(65536.0f)));
135
+ __m256i low_i32x8 = _mm256_and_si256(u32x8, _mm256_set1_epi32(0xFFFF));
136
+ __m256i high_i32x8 = _mm256_srli_epi32(u32x8, 16);
137
+ return _mm256_add_ps(_mm256_cvtepi32_ps(low_i32x8),
138
+ _mm256_mul_ps(_mm256_cvtepi32_ps(high_i32x8), _mm256_set1_ps(65536.0f)));
139
139
  }
140
140
 
141
141
  /** @brief Saturating f32x8 downcasts to integers (AVX2). */
@@ -172,167 +172,10 @@ NK_INTERNAL __m128i nk_f32x8_to_u8x8_haswell_(__m256 f32x8) {
172
172
  return _mm_packus_epi16(u16x8, _mm_setzero_si128());
173
173
  }
174
174
 
175
- /** @brief Convert 16x e4m3 → 16x bf16 via arithmetic + small LUT for subnormals (AVX2).
176
- * E4M3 format: S EEEE MMM (bias=7). BF16: S EEEEEEEE MMMMMMM (bias=127).
177
- * Normal values: BF16 = sign | ((lower7 << 4) + 0x3C00).
178
- * Subnormals (8 values): looked up via vpshufb from an 8-entry LUT.
179
- * Handles all corner cases: zero, subnormals, normals, and NaN. */
180
- NK_INTERNAL __m256i nk_e4m3x16_to_bf16x16_haswell_(__m128i e4m3x16) {
181
- __m256i e4m3_i16x16 = _mm256_cvtepu8_epi16(e4m3x16);
182
- __m256i sign_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16((short)0x80));
183
- __m256i lower7_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x7F));
184
-
185
- // Normal path: BF16 = ((lower7 << 4) + 0x3C00) | (sign << 8)
186
- __m256i normal_abs_i16x16 = _mm256_add_epi16(_mm256_slli_epi16(lower7_i16x16, 4), _mm256_set1_epi16(0x3C00));
187
- sign_i16x16 = _mm256_slli_epi16(sign_i16x16, 8);
188
- __m256i normal_i16x16 = _mm256_or_si256(sign_i16x16, normal_abs_i16x16);
189
-
190
- // Subnormal LUT via shuffle_epi8 (8 entries: mantissa 0-7 → BF16)
191
- // E4M3 subnormal BF16 values: 0x0000, 0x3B00, 0x3B80, 0x3BC0, 0x3C00, 0x3C20, 0x3C40, 0x3C60
192
- // Split into low bytes and high bytes for reconstruction
193
- __m256i const lo_lut_i8x32 = _mm256_broadcastsi128_si256(_mm_set_epi8( //
194
- 0x60, 0x40, 0x20, 0x00, (char)0xC0, (char)0x80, 0x00, 0x00, //
195
- 0x60, 0x40, 0x20, 0x00, (char)0xC0, (char)0x80, 0x00, 0x00)); //
196
- __m256i const hi_lut_i8x32 = _mm256_broadcastsi128_si256(_mm_set_epi8( //
197
- 0x3C, 0x3C, 0x3C, 0x3C, 0x3B, 0x3B, 0x3B, 0x00, //
198
- 0x3C, 0x3C, 0x3C, 0x3C, 0x3B, 0x3B, 0x3B, 0x00)); //
199
-
200
- // Extract mantissa (bits 0-2) as byte indices for shuffle
201
- __m256i byte_idx_i8x32 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi8(0x07));
202
- __m256i lo_bytes_i8x32 = _mm256_shuffle_epi8(lo_lut_i8x32, byte_idx_i8x32);
203
- __m256i hi_bytes_i8x32 = _mm256_shuffle_epi8(hi_lut_i8x32, byte_idx_i8x32);
204
-
205
- // Combine low and high bytes into 16-bit values
206
- __m256i subnorm_abs_i16x16 = _mm256_or_si256( //
207
- _mm256_and_si256(lo_bytes_i8x32, _mm256_set1_epi16(0x00FF)), //
208
- _mm256_slli_epi16(hi_bytes_i8x32, 8)); //
209
- __m256i subnorm_i16x16 = _mm256_or_si256(subnorm_abs_i16x16, sign_i16x16);
210
-
211
- // Blend: if exponent == 0, use subnormal result; else use normal result
212
- __m256i exp_bits_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x78));
213
- __m256i is_subnormal_i16x16 = _mm256_cmpeq_epi16(exp_bits_i16x16, _mm256_setzero_si256());
214
- __m256i result_i16x16 = _mm256_blendv_epi8(normal_i16x16, subnorm_i16x16, is_subnormal_i16x16);
215
-
216
- // Handle NaN: E4M3 index 127 (0x7F) → BF16 NaN (0x7FC0)
217
- __m256i is_nan_i16x16 = _mm256_cmpeq_epi16(lower7_i16x16, _mm256_set1_epi16(0x7F));
218
- __m256i nan_i16x16 = _mm256_or_si256(sign_i16x16, _mm256_set1_epi16(0x7FC0));
219
- return _mm256_blendv_epi8(result_i16x16, nan_i16x16, is_nan_i16x16);
220
- }
221
-
222
- /** @brief Convert 16x e5m2 → 16x bf16 via arithmetic + small LUT for subnormals (AVX2).
223
- * E5M2 format: S EEEEE MM (bias=15). BF16: S EEEEEEEE MMMMMMM (bias=127).
224
- * Normal values: BF16 = sign | ((lower7 << 5) + 0x3800).
225
- * Subnormals (4 values): looked up via vpshufb from a 4-entry LUT.
226
- * Handles all corner cases: zero, subnormals, normals, infinity, and NaN. */
227
- NK_INTERNAL __m256i nk_e5m2x16_to_bf16x16_haswell_(__m128i e5m2x16) {
228
- __m256i e5m2_i16x16 = _mm256_cvtepu8_epi16(e5m2x16);
229
- __m256i sign_i16x16 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16((short)0x80));
230
- __m256i lower7_i16x16 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16(0x7F));
231
-
232
- // Normal path: BF16 = ((lower7 << 5) + 0x3800) | (sign << 8)
233
- __m256i normal_abs_i16x16 = _mm256_add_epi16(_mm256_slli_epi16(lower7_i16x16, 5), _mm256_set1_epi16(0x3800));
234
- sign_i16x16 = _mm256_slli_epi16(sign_i16x16, 8);
235
- __m256i normal_i16x16 = _mm256_or_si256(sign_i16x16, normal_abs_i16x16);
236
-
237
- // Subnormal LUT via shuffle_epi8 (4 entries: mantissa 0-3 → BF16)
238
- // E5M2 subnormal BF16 values: 0x0000, 0x3780, 0x3800, 0x3840
239
- __m256i const lo_lut_i8x32 = _mm256_broadcastsi128_si256(_mm_set_epi8( //
240
- 0x00, 0x00, 0x00, 0x00, 0x40, 0x00, (char)0x80, 0x00, //
241
- 0x00, 0x00, 0x00, 0x00, 0x40, 0x00, (char)0x80, 0x00)); //
242
- __m256i const hi_lut_i8x32 = _mm256_broadcastsi128_si256(_mm_set_epi8( //
243
- 0x00, 0x00, 0x00, 0x00, 0x38, 0x38, 0x37, 0x00, //
244
- 0x00, 0x00, 0x00, 0x00, 0x38, 0x38, 0x37, 0x00)); //
245
-
246
- // Extract mantissa (bits 0-1) as byte indices for shuffle
247
- __m256i byte_idx_i8x32 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi8(0x03));
248
- __m256i lo_bytes_i8x32 = _mm256_shuffle_epi8(lo_lut_i8x32, byte_idx_i8x32);
249
- __m256i hi_bytes_i8x32 = _mm256_shuffle_epi8(hi_lut_i8x32, byte_idx_i8x32);
250
-
251
- // Combine low and high bytes into 16-bit values
252
- __m256i subnorm_abs_i16x16 = _mm256_or_si256( //
253
- _mm256_and_si256(lo_bytes_i8x32, _mm256_set1_epi16(0x00FF)), //
254
- _mm256_slli_epi16(hi_bytes_i8x32, 8)); //
255
- __m256i subnorm_i16x16 = _mm256_or_si256(subnorm_abs_i16x16, sign_i16x16);
256
-
257
- // Blend: if exponent == 0, use subnormal result; else use normal result
258
- __m256i exp_bits_i16x16 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16(0x7C));
259
- __m256i is_subnormal_i16x16 = _mm256_cmpeq_epi16(exp_bits_i16x16, _mm256_setzero_si256());
260
- __m256i result_i16x16 = _mm256_blendv_epi8(normal_i16x16, subnorm_i16x16, is_subnormal_i16x16);
261
-
262
- // Handle Inf (0x7C) and NaN (0x7D-0x7F)
263
- __m256i is_inf_i16x16 = _mm256_cmpeq_epi16(lower7_i16x16, _mm256_set1_epi16(0x7C));
264
- __m256i is_nan_i16x16 = _mm256_cmpgt_epi16(lower7_i16x16, _mm256_set1_epi16(0x7C));
265
- __m256i inf_i16x16 = _mm256_or_si256(sign_i16x16, _mm256_set1_epi16(0x7F80));
266
- __m256i nan_i16x16 = _mm256_or_si256(sign_i16x16, _mm256_set1_epi16(0x7FC0));
267
- result_i16x16 = _mm256_blendv_epi8(result_i16x16, inf_i16x16, is_inf_i16x16);
268
- return _mm256_blendv_epi8(result_i16x16, nan_i16x16, is_nan_i16x16);
269
- }
270
-
271
- /** @brief Convert 16x e4m3 → 16x f16 via arithmetic + small LUT for subnormals (AVX2).
272
- * E4M3 format: S EEEE MMM (bias=7). F16: S EEEEE MMMMMMMMMM (bias=15).
273
- * Normal values: F16 = sign | ((lower7 << 7) + 0x2000).
274
- * Subnormals (8 values): looked up via vpshufb from an 8-entry LUT.
275
- * Handles all corner cases: zero, subnormals, normals, and NaN. */
276
- NK_INTERNAL __m256i nk_e4m3x16_to_f16x16_haswell_(__m128i e4m3x16) {
277
- __m256i e4m3_i16x16 = _mm256_cvtepu8_epi16(e4m3x16);
278
- __m256i sign_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16((short)0x80));
279
- __m256i lower7_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x7F));
280
-
281
- // Normal path: F16 = ((lower7 << 7) + 0x2000) | (sign << 8)
282
- __m256i normal_abs_i16x16 = _mm256_add_epi16(_mm256_slli_epi16(lower7_i16x16, 7), _mm256_set1_epi16(0x2000));
283
- sign_i16x16 = _mm256_slli_epi16(sign_i16x16, 8);
284
- __m256i normal_i16x16 = _mm256_or_si256(sign_i16x16, normal_abs_i16x16);
285
-
286
- // Subnormal LUT via shuffle_epi8 (8 entries: mantissa 0-7 → F16)
287
- // E4M3 subnormal F16 values: 0x0000, 0x1800, 0x1C00, 0x1E00, 0x2000, 0x2100, 0x2200, 0x2300
288
- // All low bytes are 0x00, high bytes: 0x00, 0x18, 0x1C, 0x1E, 0x20, 0x21, 0x22, 0x23
289
- // _mm_set_epi8 order: b15..u1 (unused), b7=idx7, b6=idx6, ..., b0=idx0
290
- __m256i const lo_lut_i8x32 = _mm256_setzero_si256();
291
- __m256i const hi_lut_i8x32 = _mm256_broadcastsi128_si256(_mm_set_epi8( //
292
- 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, //
293
- 0x23, 0x22, 0x21, 0x20, 0x1E, 0x1C, 0x18, 0x00)); //
294
-
295
- // Extract mantissa (bits 0-2) as byte indices for shuffle
296
- __m256i byte_idx_i8x32 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi8(0x07));
297
- __m256i lo_bytes_i8x32 = _mm256_shuffle_epi8(lo_lut_i8x32, byte_idx_i8x32);
298
- __m256i hi_bytes_i8x32 = _mm256_shuffle_epi8(hi_lut_i8x32, byte_idx_i8x32);
299
-
300
- // Combine low and high bytes into 16-bit values
301
- __m256i subnorm_abs_i16x16 = _mm256_or_si256( //
302
- _mm256_and_si256(lo_bytes_i8x32, _mm256_set1_epi16(0x00FF)), //
303
- _mm256_slli_epi16(hi_bytes_i8x32, 8)); //
304
- __m256i subnorm_i16x16 = _mm256_or_si256(subnorm_abs_i16x16, sign_i16x16);
305
-
306
- // Blend: if exponent == 0, use subnormal result; else use normal result
307
- __m256i exp_bits_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x78));
308
- __m256i is_subnormal_i16x16 = _mm256_cmpeq_epi16(exp_bits_i16x16, _mm256_setzero_si256());
309
- __m256i result_i16x16 = _mm256_blendv_epi8(normal_i16x16, subnorm_i16x16, is_subnormal_i16x16);
310
-
311
- // Handle NaN: E4M3 index 127 (0x7F) → F16 NaN (0x7E00)
312
- __m256i is_nan_i16x16 = _mm256_cmpeq_epi16(lower7_i16x16, _mm256_set1_epi16(0x7F));
313
- __m256i nan_i16x16 = _mm256_or_si256(sign_i16x16, _mm256_set1_epi16(0x7E00));
314
- return _mm256_blendv_epi8(result_i16x16, nan_i16x16, is_nan_i16x16);
315
- }
316
-
317
- /** @brief Convert 16x e5m2 → 16x f16 via simple bit shift (AVX2).
318
- * E5M2 format: S EEEEE MM (bias=15). F16: S EEEEE MMMMMMMMMM (bias=15).
319
- * Same exponent bias means F16 = (lower7 << 8) | (sign << 15).
320
- * Handles all corner cases: zero, subnormals, normals, infinity, and NaN. */
321
- NK_INTERNAL __m256i nk_e5m2x16_to_f16x16_haswell_(__m128i e5m2x16) {
322
- __m256i e5m2_i16x16 = _mm256_cvtepu8_epi16(e5m2x16);
323
- __m256i sign_i16x16 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16((short)0x80));
324
- __m256i lower7_i16x16 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16(0x7F));
325
-
326
- // F16 = (lower7 << 8) | (sign << 15)
327
- // Works for all cases: subnormals, normals, infinity, and NaN
328
- __m256i result_i16x16 = _mm256_slli_epi16(lower7_i16x16, 8);
329
- sign_i16x16 = _mm256_slli_epi16(sign_i16x16, 8);
330
- return _mm256_or_si256(result_i16x16, sign_i16x16);
331
- }
332
-
333
175
  /** @brief Convert 8x e4m3 → 8x f32 via bit manipulation (AVX2).
334
176
  * E4M3 format: S EEEE MMM (bias=7). F32: sign<<31, (exp+120)<<23, mant<<20.
335
- * Subnormals (exp=0): value = mantissa × 2⁽¹⁻⁷⁾ × 2⁻³ = mantissa ÷ 512. */
177
+ * Subnormals (exp=0): looked up via vpermps from an 8-entry register LUT.
178
+ * NaN detection uses a single comparison on the 7-bit magnitude (0x7F). */
336
179
  NK_INTERNAL __m256 nk_e4m3x8_to_f32x8_haswell_(__m128i e4m3_i8x8) {
337
180
  __m256i e4m3_i32x8 = _mm256_cvtepu8_epi32(e4m3_i8x8);
338
181
 
@@ -348,21 +191,26 @@ NK_INTERNAL __m256 nk_e4m3x8_to_f32x8_haswell_(__m128i e4m3_i8x8) {
348
191
  __m256i f32_mant_i32x8 = _mm256_slli_epi32(mant_i32x8, 20);
349
192
  __m256i normal_bits_i32x8 = _mm256_or_si256(f32_sign_i32x8, _mm256_or_si256(f32_exp_i32x8, f32_mant_i32x8));
350
193
 
351
- // Subnormal path: value = mantissa / 512.0f, then apply sign
352
- __m256 subnorm_abs_f32x8 = _mm256_mul_ps(_mm256_cvtepi32_ps(mant_i32x8), _mm256_set1_ps(1.0f / 512.0f));
353
- __m256 subnorm_f32x8 = _mm256_or_ps(subnorm_abs_f32x8, _mm256_castsi256_ps(f32_sign_i32x8));
194
+ // Subnormal path: vpermps from 8-entry register LUT (3 cy latency, no memory access)
195
+ __m256 subnorm_lut_f32x8 = _mm256_setr_ps(0, 1.0f / 512, 2.0f / 512, 3.0f / 512, //
196
+ 4.0f / 512, 5.0f / 512, 6.0f / 512, 7.0f / 512);
197
+ __m256i subnorm_bits_i32x8 = _mm256_or_si256( //
198
+ _mm256_castps_si256(_mm256_permutevar8x32_ps(subnorm_lut_f32x8, mant_i32x8)), f32_sign_i32x8);
354
199
 
355
- // Blend: if exp==0, use subnormal result; otherwise use normal bits
200
+ // Bitwise select: if exp==0, use subnormal; otherwise use normal
356
201
  __m256i exp_zero_mask = _mm256_cmpeq_epi32(exp_i32x8, _mm256_setzero_si256());
357
- __m256 result = _mm256_blendv_ps(_mm256_castsi256_ps(normal_bits_i32x8), subnorm_f32x8,
358
- _mm256_castsi256_ps(exp_zero_mask));
202
+ __m256i result_i32x8 = _mm256_or_si256( //
203
+ _mm256_and_si256(exp_zero_mask, subnorm_bits_i32x8), //
204
+ _mm256_andnot_si256(exp_zero_mask, normal_bits_i32x8));
359
205
 
360
- // NaN path: E4M3FN has NaN only when exp=15 AND mant=7 (0x7F or 0xFF)
361
- __m256i is_nan_mask = _mm256_and_si256( //
362
- _mm256_cmpeq_epi32(exp_i32x8, _mm256_set1_epi32(15)), //
363
- _mm256_cmpeq_epi32(mant_i32x8, _mm256_set1_epi32(7))); //
364
- __m256i nan_bits = _mm256_or_si256(f32_sign_i32x8, _mm256_set1_epi32(0x7FC00000)); // F32 quiet NaN
365
- return _mm256_blendv_ps(result, _mm256_castsi256_ps(nan_bits), _mm256_castsi256_ps(is_nan_mask));
206
+ // NaN: E4M3FN has NaN only at magnitude 0x7F (exp=15, mant=7)
207
+ __m256i lower7_i32x8 = _mm256_and_si256(e4m3_i32x8, _mm256_set1_epi32(0x7F));
208
+ __m256i is_nan_mask = _mm256_cmpeq_epi32(lower7_i32x8, _mm256_set1_epi32(0x7F));
209
+ __m256i nan_i32x8 = _mm256_or_si256(f32_sign_i32x8, _mm256_set1_epi32(0x7FC00000));
210
+ result_i32x8 = _mm256_or_si256( //
211
+ _mm256_and_si256(is_nan_mask, nan_i32x8), //
212
+ _mm256_andnot_si256(is_nan_mask, result_i32x8));
213
+ return _mm256_castsi256_ps(result_i32x8);
366
214
  }
367
215
 
368
216
  /** @brief Convert 8x e5m2 → 8x f32 via bit manipulation (AVX2).
@@ -676,9 +524,9 @@ NK_INTERNAL __m128i nk_f32x8_to_e3m2x8_haswell_(__m256 f32x8) {
676
524
  return packed_i8x8;
677
525
  }
678
526
 
679
- #pragma endregion - Vectorized Conversions
527
+ #pragma endregion Vectorized Conversions
680
528
 
681
- #pragma region - Converting Loads and Stores
529
+ #pragma region Converting Loads and Stores
682
530
 
683
531
  /** @brief Full load for f16 elements (8) with conversion to f32 via F16C. */
684
532
  NK_INTERNAL void nk_load_f16x8_to_f32x8_haswell_(void const *src, nk_b256_vec_t *dst) {
@@ -794,9 +642,9 @@ NK_INTERNAL void nk_partial_load_u32x8_to_f32x8_haswell_(nk_u32_t const *src, nk
794
642
  dst->ymm_ps = nk_u32x8_to_f32x8_haswell_(vec.ymm);
795
643
  }
796
644
 
797
- #pragma endregion - Converting Loads and Stores
645
+ #pragma endregion Converting Loads and Stores
798
646
 
799
- #pragma region - Public API
647
+ #pragma region Public API
800
648
 
801
649
  NK_PUBLIC void nk_cast_haswell(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type) {
802
650
  // Same-type fast path
@@ -958,7 +806,7 @@ NK_PUBLIC void nk_cast_haswell(void const *from, nk_dtype_t from_type, nk_size_t
958
806
  }
959
807
  }
960
808
 
961
- #pragma endregion - Public API
809
+ #pragma endregion Public API
962
810
 
963
811
  #if defined(__clang__)
964
812
  #pragma clang attribute pop