numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -54,7 +54,7 @@
54
54
  extern "C" {
55
55
  #endif
56
56
 
57
- #pragma region - Register-to-Register Helpers
57
+ #pragma region Register to Register Helpers
58
58
 
59
59
  /**
60
60
  * @brief Convert bf16 (m1) to f32 (m2) register-to-register.
@@ -90,7 +90,11 @@ NK_INTERNAL vuint16m1_t nk_f32m2_to_bf16m1_rvv_(vfloat32m2_t f32_f32m2, nk_size_
90
90
  * F16 format: S EEEEE MMMMMMMMMM (1 sign, 5 exponent bits with bias=15, 10 mantissa bits)
91
91
  * F32 format: S EEEEEEEE MMMMMMMMMMMMMMMMMMMMMMM (1 sign, 8 exponent bits with bias=127, 23 mantissa bits)
92
92
  *
93
- * Handles all IEEE-754 edge cases: ±zero, denormals, normals, ±inf, NaN.
93
+ * Uses the Giesen magic-multiply trick: treat the magnitude bits as a denormal f32 and
94
+ * multiply by 2^112 to rebias the exponent. This correctly handles ±zero, denormals,
95
+ * and normals in a single FP multiply; only inf/NaN needs a fixup compare+merge.
96
+ *
97
+ * https://fgiesen.wordpress.com/2012/03/28/half-to-float-done-quic/
94
98
  */
95
99
  NK_INTERNAL vfloat32m2_t nk_f16m1_to_f32m2_rvv_(vuint16m1_t f16_u16m1, nk_size_t vector_length) {
96
100
  // Widen to 32-bit for manipulation
@@ -98,45 +102,31 @@ NK_INTERNAL vfloat32m2_t nk_f16m1_to_f32m2_rvv_(vuint16m1_t f16_u16m1, nk_size_t
98
102
  // Extract sign: (raw >> 15) << 31
99
103
  vuint32m2_t sign_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vsrl_vx_u32m2(bits_u32m2, 15, vector_length), 31,
100
104
  vector_length);
101
- // Extract exponent: (raw >> 10) & 0x1F
102
- vuint32m2_t exponent_u32m2 = __riscv_vand_vx_u32m2(__riscv_vsrl_vx_u32m2(bits_u32m2, 10, vector_length), 0x1F,
105
+ // Strip sign, shift magnitude into f32 mantissa position.
106
+ // For a normal f16 with exp E, this places E into the f32 exponent field,
107
+ // creating a tiny f32 whose value is proportional to the f16 magnitude.
108
+ vuint32m2_t nonsign_u32m2 = __riscv_vand_vx_u32m2(bits_u32m2, 0x7FFF, vector_length);
109
+ vuint32m2_t shifted_u32m2 = __riscv_vsll_vx_u32m2(nonsign_u32m2, 13, vector_length);
110
+ // Multiply by 2^112 (= magic 0x77800000 as f32) to rebias the exponent.
111
+ // This single multiply correctly handles zero, denormals, and normals:
112
+ // zero: 0.0 × 2^112 = 0.0
113
+ // denormal: (M × 2^-136) × 2^112 = M × 2^-24 (correct f16 denormal value)
114
+ // normal: (2^(E-127) × …) × 2^112 = 2^(E-15) × … (correct rebiased value)
115
+ vfloat32m2_t magic_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
116
+ __riscv_vmv_v_x_u32m2(((nk_u32_t)(254 - 15) << 23), vector_length));
117
+ vfloat32m2_t result_f32m2 = __riscv_vfmul_vv_f32m2(__riscv_vreinterpret_v_u32m2_f32m2(shifted_u32m2), magic_f32m2,
103
118
  vector_length);
104
- // Extract mantissa: raw & 0x3FF
105
- vuint32m2_t mantissa_u32m2 = __riscv_vand_vx_u32m2(bits_u32m2, 0x3FF, vector_length);
106
-
107
- // Normal path: rebias exponent (15 → 127): add 112, combine
108
- vuint32m2_t f32_exponent_u32m2 = __riscv_vadd_vx_u32m2(exponent_u32m2, 112, vector_length);
109
- vuint32m2_t normal_u32m2 = __riscv_vor_vv_u32m2(
110
- sign_u32m2,
111
- __riscv_vor_vv_u32m2(__riscv_vsll_vx_u32m2(f32_exponent_u32m2, 23, vector_length),
112
- __riscv_vsll_vx_u32m2(mantissa_u32m2, 13, vector_length), vector_length),
113
- vector_length);
114
-
115
- // Special case: exponent == 0 (zero or denormal)
116
- // Zero: sign | 0. Denormal: mantissa × 2^(-24), handled via FPU normalization trick.
117
- // For denormals, convert mantissa to float and subtract 0x0C000000 (24 from exponent),
118
- // matching the serial implementation. For zeros (mantissa==0), (float)0 - bias = 0.
119
- vbool16_t is_exp_zero = __riscv_vmseq_vx_u32m2_b16(exponent_u32m2, 0, vector_length);
120
- vfloat32m2_t mantissa_f32m2 = __riscv_vfcvt_f_xu_v_f32m2(mantissa_u32m2, vector_length);
121
- vuint32m2_t denorm_bits_u32m2 = __riscv_vsub_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(mantissa_f32m2),
122
- 0x0C000000, vector_length);
123
- vuint32m2_t zero_or_denorm_u32m2 = __riscv_vor_vv_u32m2(sign_u32m2, denorm_bits_u32m2, vector_length);
124
- // For true zeros (mantissa==0), the FPU converts 0 to 0x00000000, minus bias wraps,
125
- // so force to sign-only.
126
- vbool16_t is_true_zero = __riscv_vmand_mm_b16(
127
- is_exp_zero, __riscv_vmseq_vx_u32m2_b16(mantissa_u32m2, 0, vector_length), vector_length);
128
- zero_or_denorm_u32m2 = __riscv_vmerge_vvm_u32m2(zero_or_denorm_u32m2, sign_u32m2, is_true_zero, vector_length);
129
-
130
- // Special case: exponent == 31 (infinity or NaN)
131
- // sign | 0x7F800000 | (mantissa << 13)
132
- vbool16_t is_exp_max = __riscv_vmseq_vx_u32m2_b16(exponent_u32m2, 31, vector_length);
133
- vuint32m2_t inf_nan_u32m2 = __riscv_vor_vv_u32m2(__riscv_vor_vx_u32m2(sign_u32m2, 0x7F800000, vector_length),
134
- __riscv_vsll_vx_u32m2(mantissa_u32m2, 13, vector_length),
135
- vector_length);
136
-
137
- // Select: exp==0 → zero_or_denorm, exp==31 → inf_nan, else → normal
138
- vuint32m2_t result_u32m2 = __riscv_vmerge_vvm_u32m2(normal_u32m2, zero_or_denorm_u32m2, is_exp_zero, vector_length);
139
- result_u32m2 = __riscv_vmerge_vvm_u32m2(result_u32m2, inf_nan_u32m2, is_exp_max, vector_length);
119
+ // Inf/NaN fixup: the multiply maps f16 exp=31 to a large finite f32.
120
+ // Detect those lanes and force the f32 exponent to 255 (inf/NaN).
121
+ // Threshold 0x47800000 = 2^16; any f16 with exp=31 exceeds it after scaling.
122
+ vfloat32m2_t infnan_threshold_f32m2 = __riscv_vreinterpret_v_u32m2_f32m2(
123
+ __riscv_vmv_v_x_u32m2(((nk_u32_t)(127 + 16) << 23), vector_length));
124
+ vbool16_t is_infnan = __riscv_vmfge_vv_f32m2_b16(result_f32m2, infnan_threshold_f32m2, vector_length);
125
+ vuint32m2_t result_u32m2 = __riscv_vreinterpret_v_f32m2_u32m2(result_f32m2);
126
+ vuint32m2_t fixed_u32m2 = __riscv_vor_vx_u32m2(result_u32m2, 0x7F800000, vector_length);
127
+ result_u32m2 = __riscv_vmerge_vvm_u32m2(result_u32m2, fixed_u32m2, is_infnan, vector_length);
128
+ // Restore sign
129
+ result_u32m2 = __riscv_vor_vv_u32m2(result_u32m2, sign_u32m2, vector_length);
140
130
  return __riscv_vreinterpret_v_u32m2_f32m2(result_u32m2);
141
131
  }
142
132
 
@@ -162,13 +152,8 @@ NK_INTERNAL vuint16m1_t nk_f32m2_to_f16m1_rvv_(vfloat32m2_t f32_f32m2, nk_size_t
162
152
  exponent_i32m2 = __riscv_vmax_vx_i32m2(exponent_i32m2, 0, vector_length);
163
153
  vuint32m2_t f16_exponent_u32m2 = __riscv_vreinterpret_v_i32m2_u32m2(
164
154
  __riscv_vmin_vx_i32m2(exponent_i32m2, 31, vector_length));
165
- // Round mantissa: add 0x1000 (half of truncated bits) then shift.
166
- // If rounding overflows the mantissa (bit 23 set), carry into exponent.
155
+ // Round mantissa: add 0x1000 (half of truncated bits) then shift
167
156
  vuint32m2_t rounded_mantissa_u32m2 = __riscv_vadd_vx_u32m2(mantissa_u32m2, 0x1000, vector_length);
168
- vbool16_t mantissa_overflow_b16 = __riscv_vmsne_vx_u32m2_b16(
169
- __riscv_vand_vx_u32m2(rounded_mantissa_u32m2, 0x800000, vector_length), 0, vector_length);
170
- f16_exponent_u32m2 = __riscv_vadd_vx_u32m2_mu(mantissa_overflow_b16, f16_exponent_u32m2, f16_exponent_u32m2, 1,
171
- vector_length);
172
157
  vuint32m2_t f16_mantissa_u32m2 = __riscv_vsrl_vx_u32m2(rounded_mantissa_u32m2, 13, vector_length);
173
158
  f16_mantissa_u32m2 = __riscv_vand_vx_u32m2(f16_mantissa_u32m2, 0x3FF, vector_length);
174
159
  // Combine: sign | (exponent << 10) | mantissa
@@ -181,242 +166,206 @@ NK_INTERNAL vuint16m1_t nk_f32m2_to_f16m1_rvv_(vfloat32m2_t f32_f32m2, nk_size_t
181
166
  }
182
167
 
183
168
  /**
184
- * @brief Convert e4m3 (m1) to f32 (m4) via sign-symmetric magnitude LUT.
185
- * E4M3FN: sign = bit 7, magnitude = bits 6:0 (128 entries). Sign bit 7 f32 bit 31 (<<24).
169
+ * @brief Convert e4m3 (m1) to f32 (m4) via Giesen magic-multiply.
170
+ * Reinterprets magnitude bits as a tiny f32, then multiplies by 2^(127-bias) to rebias.
171
+ * Handles zero, subnormals, and normals in a single vfmul. NaN fixup for magnitude 0x7F.
172
+ * https://fgiesen.wordpress.com/2012/03/28/half-to-float-done-quic/
186
173
  */
187
174
  NK_INTERNAL vfloat32m4_t nk_e4m3m1_to_f32m4_rvv_(vuint8m1_t e4m3_u8m1, nk_size_t vector_length) {
188
- static nk_u32_t const nk_e4m3_mag_to_f32_lut_[128] = {
189
- 0x00000000u, 0x3B000000u, 0x3B800000u, 0x3BC00000u,
190
- 0x3C000000u, 0x3C200000u, 0x3C400000u, 0x3C600000u, /* [ 0.. 7] */
191
- 0x3C800000u, 0x3C900000u, 0x3CA00000u, 0x3CB00000u,
192
- 0x3CC00000u, 0x3CD00000u, 0x3CE00000u, 0x3CF00000u, /* [ 8.. 15] */
193
- 0x3D000000u, 0x3D100000u, 0x3D200000u, 0x3D300000u,
194
- 0x3D400000u, 0x3D500000u, 0x3D600000u, 0x3D700000u, /* [ 16.. 23] */
195
- 0x3D800000u, 0x3D900000u, 0x3DA00000u, 0x3DB00000u,
196
- 0x3DC00000u, 0x3DD00000u, 0x3DE00000u, 0x3DF00000u, /* [ 24.. 31] */
197
- 0x3E000000u, 0x3E100000u, 0x3E200000u, 0x3E300000u,
198
- 0x3E400000u, 0x3E500000u, 0x3E600000u, 0x3E700000u, /* [ 32.. 39] */
199
- 0x3E800000u, 0x3E900000u, 0x3EA00000u, 0x3EB00000u,
200
- 0x3EC00000u, 0x3ED00000u, 0x3EE00000u, 0x3EF00000u, /* [ 40.. 47] */
201
- 0x3F000000u, 0x3F100000u, 0x3F200000u, 0x3F300000u,
202
- 0x3F400000u, 0x3F500000u, 0x3F600000u, 0x3F700000u, /* [ 48.. 55] */
203
- 0x3F800000u, 0x3F900000u, 0x3FA00000u, 0x3FB00000u,
204
- 0x3FC00000u, 0x3FD00000u, 0x3FE00000u, 0x3FF00000u, /* [ 56.. 63] */
205
- 0x40000000u, 0x40100000u, 0x40200000u, 0x40300000u,
206
- 0x40400000u, 0x40500000u, 0x40600000u, 0x40700000u, /* [ 64.. 71] */
207
- 0x40800000u, 0x40900000u, 0x40A00000u, 0x40B00000u,
208
- 0x40C00000u, 0x40D00000u, 0x40E00000u, 0x40F00000u, /* [ 72.. 79] */
209
- 0x41000000u, 0x41100000u, 0x41200000u, 0x41300000u,
210
- 0x41400000u, 0x41500000u, 0x41600000u, 0x41700000u, /* [ 80.. 87] */
211
- 0x41800000u, 0x41900000u, 0x41A00000u, 0x41B00000u,
212
- 0x41C00000u, 0x41D00000u, 0x41E00000u, 0x41F00000u, /* [ 88.. 95] */
213
- 0x42000000u, 0x42100000u, 0x42200000u, 0x42300000u,
214
- 0x42400000u, 0x42500000u, 0x42600000u, 0x42700000u, /* [ 96..103] */
215
- 0x42800000u, 0x42900000u, 0x42A00000u, 0x42B00000u,
216
- 0x42C00000u, 0x42D00000u, 0x42E00000u, 0x42F00000u, /* [104..111] */
217
- 0x43000000u, 0x43100000u, 0x43200000u, 0x43300000u,
218
- 0x43400000u, 0x43500000u, 0x43600000u, 0x43700000u, /* [112..119] */
219
- 0x43800000u, 0x43900000u, 0x43A00000u, 0x43B00000u,
220
- 0x43C00000u, 0x43D00000u, 0x43E00000u, 0x7FC00000u /* [120..127] */
221
- };
222
- vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x80, vector_length);
223
- vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x7F, vector_length);
224
- vuint32m4_t offsets_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(mag_u8m1, vector_length), 2,
225
- vector_length);
226
- vuint32m4_t result_u32m4 = __riscv_vluxei32_v_u32m4(nk_e4m3_mag_to_f32_lut_, offsets_u32m4, vector_length);
227
- vuint32m4_t sign_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(sign_u8m1, vector_length), 24,
228
- vector_length);
229
- return __riscv_vreinterpret_v_u32m4_f32m4(__riscv_vor_vv_u32m4(result_u32m4, sign_u32m4, vector_length));
175
+ // Extract sign: (raw & 0x80) → bit 7, shift to bit 31
176
+ vuint32m4_t sign_u32m4 = __riscv_vsll_vx_u32m4(
177
+ __riscv_vzext_vf4_u32m4(__riscv_vand_vx_u8m1(e4m3_u8m1, 0x80, vector_length), vector_length), 24,
178
+ vector_length);
179
+ // Strip sign to get 7-bit magnitude, widen to u32, shift left by 20
180
+ vuint8m1_t nonsign_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x7F, vector_length);
181
+ vuint32m4_t nonsign_u32m4 = __riscv_vzext_vf4_u32m4(nonsign_u8m1, vector_length);
182
+ vuint32m4_t shifted_u32m4 = __riscv_vsll_vx_u32m4(nonsign_u32m4, 20, vector_length);
183
+
184
+ // Magic multiply: reinterpret as f32 × 2^120 rebiases from E4M3 (bias=7) to f32 (bias=127).
185
+ vfloat32m4_t magic_f32m4 = __riscv_vreinterpret_v_u32m4_f32m4(
186
+ __riscv_vmv_v_x_u32m4(0x7B800000, vector_length)); // 2^120 = (254-7)<<23
187
+ vfloat32m4_t result_f32m4 = __riscv_vfmul_vv_f32m4(__riscv_vreinterpret_v_u32m4_f32m4(shifted_u32m4), magic_f32m4,
188
+ vector_length);
189
+
190
+ // NaN fixup: masked OR writes sign|0x7FC00000 only into NaN lanes
191
+ vbool8_t is_nan = __riscv_vmseq_vx_u8m1_b8(nonsign_u8m1, 0x7F, vector_length);
192
+ vuint32m4_t result_u32m4 = __riscv_vor_vx_u32m4_mu(is_nan, __riscv_vreinterpret_v_f32m4_u32m4(result_f32m4),
193
+ sign_u32m4, 0x7FC00000, vector_length);
194
+
195
+ // Restore sign
196
+ result_u32m4 = __riscv_vor_vv_u32m4(result_u32m4, sign_u32m4, vector_length);
197
+ return __riscv_vreinterpret_v_u32m4_f32m4(result_u32m4);
230
198
  }
231
199
 
232
200
  /**
233
- * @brief Convert e5m2 (m1) to f32 (m4) via sign-symmetric magnitude LUT.
234
- * E5M2: sign = bit 7, magnitude = bits 6:0 (128 entries). Sign bit 7 f32 bit 31 (<<24).
201
+ * @brief Convert e5m2 (m1) to f32 (m4) via Giesen magic-multiply.
202
+ * Reinterprets magnitude bits as a tiny f32, then multiplies by 2^(127-bias) to rebias.
203
+ * Handles zero, subnormals, and normals in a single vfmul. Inf/NaN fixup for exp=31.
204
+ * https://fgiesen.wordpress.com/2012/03/28/half-to-float-done-quic/
235
205
  */
236
206
  NK_INTERNAL vfloat32m4_t nk_e5m2m1_to_f32m4_rvv_(vuint8m1_t e5m2_u8m1, nk_size_t vector_length) {
237
- static nk_u32_t const nk_e5m2_mag_to_f32_lut_[128] = {
238
- 0x00000000u, 0x37800000u, 0x38000000u, 0x38400000u,
239
- 0x38800000u, 0x38A00000u, 0x38C00000u, 0x38E00000u, /* [ 0.. 7] */
240
- 0x39000000u, 0x39200000u, 0x39400000u, 0x39600000u,
241
- 0x39800000u, 0x39A00000u, 0x39C00000u, 0x39E00000u, /* [ 8.. 15] */
242
- 0x3A000000u, 0x3A200000u, 0x3A400000u, 0x3A600000u,
243
- 0x3A800000u, 0x3AA00000u, 0x3AC00000u, 0x3AE00000u, /* [ 16.. 23] */
244
- 0x3B000000u, 0x3B200000u, 0x3B400000u, 0x3B600000u,
245
- 0x3B800000u, 0x3BA00000u, 0x3BC00000u, 0x3BE00000u, /* [ 24.. 31] */
246
- 0x3C000000u, 0x3C200000u, 0x3C400000u, 0x3C600000u,
247
- 0x3C800000u, 0x3CA00000u, 0x3CC00000u, 0x3CE00000u, /* [ 32.. 39] */
248
- 0x3D000000u, 0x3D200000u, 0x3D400000u, 0x3D600000u,
249
- 0x3D800000u, 0x3DA00000u, 0x3DC00000u, 0x3DE00000u, /* [ 40.. 47] */
250
- 0x3E000000u, 0x3E200000u, 0x3E400000u, 0x3E600000u,
251
- 0x3E800000u, 0x3EA00000u, 0x3EC00000u, 0x3EE00000u, /* [ 48.. 55] */
252
- 0x3F000000u, 0x3F200000u, 0x3F400000u, 0x3F600000u,
253
- 0x3F800000u, 0x3FA00000u, 0x3FC00000u, 0x3FE00000u, /* [ 56.. 63] */
254
- 0x40000000u, 0x40200000u, 0x40400000u, 0x40600000u,
255
- 0x40800000u, 0x40A00000u, 0x40C00000u, 0x40E00000u, /* [ 64.. 71] */
256
- 0x41000000u, 0x41200000u, 0x41400000u, 0x41600000u,
257
- 0x41800000u, 0x41A00000u, 0x41C00000u, 0x41E00000u, /* [ 72.. 79] */
258
- 0x42000000u, 0x42200000u, 0x42400000u, 0x42600000u,
259
- 0x42800000u, 0x42A00000u, 0x42C00000u, 0x42E00000u, /* [ 80.. 87] */
260
- 0x43000000u, 0x43200000u, 0x43400000u, 0x43600000u,
261
- 0x43800000u, 0x43A00000u, 0x43C00000u, 0x43E00000u, /* [ 88.. 95] */
262
- 0x44000000u, 0x44200000u, 0x44400000u, 0x44600000u,
263
- 0x44800000u, 0x44A00000u, 0x44C00000u, 0x44E00000u, /* [ 96..103] */
264
- 0x45000000u, 0x45200000u, 0x45400000u, 0x45600000u,
265
- 0x45800000u, 0x45A00000u, 0x45C00000u, 0x45E00000u, /* [104..111] */
266
- 0x46000000u, 0x46200000u, 0x46400000u, 0x46600000u,
267
- 0x46800000u, 0x46A00000u, 0x46C00000u, 0x46E00000u, /* [112..119] */
268
- 0x47000000u, 0x47200000u, 0x47400000u, 0x47600000u,
269
- 0x7F800000u, 0x7FC00000u, 0x7FC00000u, 0x7FC00000u /* [120..127] */
270
- };
271
- vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e5m2_u8m1, 0x80, vector_length);
272
- vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e5m2_u8m1, 0x7F, vector_length);
273
- vuint32m4_t offsets_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(mag_u8m1, vector_length), 2,
274
- vector_length);
275
- vuint32m4_t result_u32m4 = __riscv_vluxei32_v_u32m4(nk_e5m2_mag_to_f32_lut_, offsets_u32m4, vector_length);
276
- vuint32m4_t sign_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(sign_u8m1, vector_length), 24,
277
- vector_length);
278
- return __riscv_vreinterpret_v_u32m4_f32m4(__riscv_vor_vv_u32m4(result_u32m4, sign_u32m4, vector_length));
207
+ // Extract sign: (raw & 0x80) → bit 7, shift to bit 31
208
+ vuint32m4_t sign_u32m4 = __riscv_vsll_vx_u32m4(
209
+ __riscv_vzext_vf4_u32m4(__riscv_vand_vx_u8m1(e5m2_u8m1, 0x80, vector_length), vector_length), 24,
210
+ vector_length);
211
+ // Strip sign to get 7-bit magnitude, widen to u32, shift left by 21
212
+ vuint32m4_t nonsign_u32m4 = __riscv_vzext_vf4_u32m4(__riscv_vand_vx_u8m1(e5m2_u8m1, 0x7F, vector_length),
213
+ vector_length);
214
+ vuint32m4_t shifted_u32m4 = __riscv_vsll_vx_u32m4(nonsign_u32m4, 21, vector_length);
215
+
216
+ // Magic multiply: reinterpret as f32 × 2^112 rebiases from E5M2 (bias=15) to f32 (bias=127).
217
+ vfloat32m4_t magic_f32m4 = __riscv_vreinterpret_v_u32m4_f32m4(
218
+ __riscv_vmv_v_x_u32m4(0x77800000, vector_length)); // 2^112 = (254-15)<<23
219
+ vfloat32m4_t result_f32m4 = __riscv_vfmul_vv_f32m4(__riscv_vreinterpret_v_u32m4_f32m4(shifted_u32m4), magic_f32m4,
220
+ vector_length);
221
+
222
+ // Inf/NaN fixup: masked OR writes 0x7F800000 only into inf/NaN lanes (nonsign > 123)
223
+ vbool8_t is_infnan = __riscv_vmsgtu_vx_u32m4_b8(nonsign_u32m4, 123, vector_length);
224
+ vuint32m4_t result_u32m4 = __riscv_vor_vx_u32m4_mu(is_infnan, __riscv_vreinterpret_v_f32m4_u32m4(result_f32m4),
225
+ __riscv_vreinterpret_v_f32m4_u32m4(result_f32m4), 0x7F800000,
226
+ vector_length);
227
+
228
+ // Restore sign
229
+ result_u32m4 = __riscv_vor_vv_u32m4(result_u32m4, sign_u32m4, vector_length);
230
+ return __riscv_vreinterpret_v_u32m4_f32m4(result_u32m4);
279
231
  }
280
232
 
281
233
  /**
282
- * @brief Convert e2m3 (m1) to f32 (m4) via sign-symmetric magnitude LUT.
283
- * E2M3FN: sign = bit 5, magnitude = bits 4:0 (32 entries). Sign bit 5 f32 bit 31 (<<26).
234
+ * @brief Convert e2m3 (m1) to f32 (m4) via Giesen magic-multiply.
235
+ * Reinterprets magnitude bits as a tiny f32, then multiplies by 2^(127-bias) to rebias.
236
+ * Handles zero, subnormals, and normals in a single vfmul. No inf/NaN in E2M3FN.
237
+ * https://fgiesen.wordpress.com/2012/03/28/half-to-float-done-quic/
284
238
  */
285
239
  NK_INTERNAL vfloat32m4_t nk_e2m3m1_to_f32m4_rvv_(vuint8m1_t e2m3_u8m1, nk_size_t vector_length) {
286
- static nk_u32_t const nk_e2m3_mag_to_f32_lut_[32] = {
287
- 0x00000000u, 0x3E000000u, 0x3E800000u, 0x3EC00000u,
288
- 0x3F000000u, 0x3F200000u, 0x3F400000u, 0x3F600000u, /* [ 0.. 7] */
289
- 0x3F800000u, 0x3F900000u, 0x3FA00000u, 0x3FB00000u,
290
- 0x3FC00000u, 0x3FD00000u, 0x3FE00000u, 0x3FF00000u, /* [ 8.. 15] */
291
- 0x40000000u, 0x40100000u, 0x40200000u, 0x40300000u,
292
- 0x40400000u, 0x40500000u, 0x40600000u, 0x40700000u, /* [ 16.. 23] */
293
- 0x40800000u, 0x40900000u, 0x40A00000u, 0x40B00000u,
294
- 0x40C00000u, 0x40D00000u, 0x40E00000u, 0x40F00000u /* [ 24.. 31] */
295
- };
296
- vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x20, vector_length);
297
- vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x1F, vector_length);
298
- vuint32m4_t offsets_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(mag_u8m1, vector_length), 2,
299
- vector_length);
300
- vuint32m4_t result_u32m4 = __riscv_vluxei32_v_u32m4(nk_e2m3_mag_to_f32_lut_, offsets_u32m4, vector_length);
301
- vuint32m4_t sign_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(sign_u8m1, vector_length), 26,
302
- vector_length);
303
- return __riscv_vreinterpret_v_u32m4_f32m4(__riscv_vor_vv_u32m4(result_u32m4, sign_u32m4, vector_length));
240
+ // Extract sign: bit 5 → bit 31
241
+ vuint32m4_t sign_u32m4 = __riscv_vsll_vx_u32m4(
242
+ __riscv_vzext_vf4_u32m4(__riscv_vand_vx_u8m1(e2m3_u8m1, 0x20, vector_length), vector_length), 26,
243
+ vector_length);
244
+ // Strip sign to get 5-bit magnitude, widen to u32, shift left by 20
245
+ vuint32m4_t nonsign_u32m4 = __riscv_vzext_vf4_u32m4(__riscv_vand_vx_u8m1(e2m3_u8m1, 0x1F, vector_length),
246
+ vector_length);
247
+ vuint32m4_t shifted_u32m4 = __riscv_vsll_vx_u32m4(nonsign_u32m4, 20, vector_length);
248
+
249
+ // Magic multiply: reinterpret as f32 × 2^126 rebiases from E2M3 (bias=1) to f32 (bias=127).
250
+ vfloat32m4_t magic_f32m4 = __riscv_vreinterpret_v_u32m4_f32m4(
251
+ __riscv_vmv_v_x_u32m4(0x7E800000, vector_length)); // 2^126 = (254-1)<<23
252
+ vfloat32m4_t result_f32m4 = __riscv_vfmul_vv_f32m4(__riscv_vreinterpret_v_u32m4_f32m4(shifted_u32m4), magic_f32m4,
253
+ vector_length);
254
+
255
+ // Restore sign (no inf/NaN fixup needed for E2M3FN)
256
+ vuint32m4_t result_u32m4 = __riscv_vor_vv_u32m4(__riscv_vreinterpret_v_f32m4_u32m4(result_f32m4), sign_u32m4,
257
+ vector_length);
258
+ return __riscv_vreinterpret_v_u32m4_f32m4(result_u32m4);
304
259
  }
305
260
 
306
261
  /**
307
- * @brief Convert e3m2 (m1) to f32 (m4) via sign-symmetric magnitude LUT.
308
- * E3M2FN: sign = bit 5, magnitude = bits 4:0 (32 entries). Sign bit 5 f32 bit 31 (<<26).
262
+ * @brief Convert e3m2 (m1) to f32 (m4) via Giesen magic-multiply.
263
+ * Reinterprets magnitude bits as a tiny f32, then multiplies by 2^(127-bias) to rebias.
264
+ * Handles zero, subnormals, and normals in a single vfmul. No inf/NaN in E3M2FN.
265
+ * https://fgiesen.wordpress.com/2012/03/28/half-to-float-done-quic/
309
266
  */
310
267
  NK_INTERNAL vfloat32m4_t nk_e3m2m1_to_f32m4_rvv_(vuint8m1_t e3m2_u8m1, nk_size_t vector_length) {
311
- static nk_u32_t const nk_e3m2_mag_to_f32_lut_[32] = {
312
- 0x00000000u, 0x3D800000u, 0x3E000000u, 0x3E400000u,
313
- 0x3E800000u, 0x3EA00000u, 0x3EC00000u, 0x3EE00000u, /* [ 0.. 7] */
314
- 0x3F000000u, 0x3F200000u, 0x3F400000u, 0x3F600000u,
315
- 0x3F800000u, 0x3FA00000u, 0x3FC00000u, 0x3FE00000u, /* [ 8.. 15] */
316
- 0x40000000u, 0x40200000u, 0x40400000u, 0x40600000u,
317
- 0x40800000u, 0x40A00000u, 0x40C00000u, 0x40E00000u, /* [ 16.. 23] */
318
- 0x41000000u, 0x41200000u, 0x41400000u, 0x41600000u,
319
- 0x41800000u, 0x41A00000u, 0x41C00000u, 0x41E00000u /* [ 24.. 31] */
320
- };
321
- vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x20, vector_length);
322
- vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x1F, vector_length);
323
- vuint32m4_t offsets_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(mag_u8m1, vector_length), 2,
324
- vector_length);
325
- vuint32m4_t result_u32m4 = __riscv_vluxei32_v_u32m4(nk_e3m2_mag_to_f32_lut_, offsets_u32m4, vector_length);
326
- vuint32m4_t sign_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(sign_u8m1, vector_length), 26,
327
- vector_length);
328
- return __riscv_vreinterpret_v_u32m4_f32m4(__riscv_vor_vv_u32m4(result_u32m4, sign_u32m4, vector_length));
268
+ // Extract sign: bit 5 → bit 31
269
+ vuint32m4_t sign_u32m4 = __riscv_vsll_vx_u32m4(
270
+ __riscv_vzext_vf4_u32m4(__riscv_vand_vx_u8m1(e3m2_u8m1, 0x20, vector_length), vector_length), 26,
271
+ vector_length);
272
+ // Strip sign to get 5-bit magnitude, widen to u32, shift left by 21
273
+ vuint32m4_t nonsign_u32m4 = __riscv_vzext_vf4_u32m4(__riscv_vand_vx_u8m1(e3m2_u8m1, 0x1F, vector_length),
274
+ vector_length);
275
+ vuint32m4_t shifted_u32m4 = __riscv_vsll_vx_u32m4(nonsign_u32m4, 21, vector_length);
276
+
277
+ // Magic multiply: reinterpret as f32 × 2^124 rebiases from E3M2 (bias=3) to f32 (bias=127).
278
+ vfloat32m4_t magic_f32m4 = __riscv_vreinterpret_v_u32m4_f32m4(
279
+ __riscv_vmv_v_x_u32m4(0x7D800000, vector_length)); // 2^124 = (254-3)<<23
280
+ vfloat32m4_t result_f32m4 = __riscv_vfmul_vv_f32m4(__riscv_vreinterpret_v_u32m4_f32m4(shifted_u32m4), magic_f32m4,
281
+ vector_length);
282
+
283
+ // Restore sign (no inf/NaN fixup needed for E3M2FN)
284
+ vuint32m4_t result_u32m4 = __riscv_vor_vv_u32m4(__riscv_vreinterpret_v_f32m4_u32m4(result_f32m4), sign_u32m4,
285
+ vector_length);
286
+ return __riscv_vreinterpret_v_u32m4_f32m4(result_u32m4);
329
287
  }
330
288
 
331
- /** @brief Convert e4m3 (m1) to bf16 (m2) via sign-symmetric magnitude LUT. Sign bit 7 → bf16 bit 15 (<<8). */
289
+ /** @brief Convert e4m3 (m1) to bf16 (m2) via Giesen magic-multiply.
290
+ * Magic-multiply to f32, truncate upper 16 bits to bf16. NaN fixup for magnitude 0x7F. */
332
291
  NK_INTERNAL vuint16m2_t nk_e4m3m1_to_bf16m2_rvv_(vuint8m1_t e4m3_u8m1, nk_size_t vector_length) {
333
- static nk_u16_t const nk_e4m3_mag_to_bf16_lut_[128] = {
334
- 0x0000u, 0x3B00u, 0x3B80u, 0x3BC0u, 0x3C00u, 0x3C20u, 0x3C40u, 0x3C60u, /* [ 0.. 7] */
335
- 0x3C80u, 0x3C90u, 0x3CA0u, 0x3CB0u, 0x3CC0u, 0x3CD0u, 0x3CE0u, 0x3CF0u, /* [ 8.. 15] */
336
- 0x3D00u, 0x3D10u, 0x3D20u, 0x3D30u, 0x3D40u, 0x3D50u, 0x3D60u, 0x3D70u, /* [ 16.. 23] */
337
- 0x3D80u, 0x3D90u, 0x3DA0u, 0x3DB0u, 0x3DC0u, 0x3DD0u, 0x3DE0u, 0x3DF0u, /* [ 24.. 31] */
338
- 0x3E00u, 0x3E10u, 0x3E20u, 0x3E30u, 0x3E40u, 0x3E50u, 0x3E60u, 0x3E70u, /* [ 32.. 39] */
339
- 0x3E80u, 0x3E90u, 0x3EA0u, 0x3EB0u, 0x3EC0u, 0x3ED0u, 0x3EE0u, 0x3EF0u, /* [ 40.. 47] */
340
- 0x3F00u, 0x3F10u, 0x3F20u, 0x3F30u, 0x3F40u, 0x3F50u, 0x3F60u, 0x3F70u, /* [ 48.. 55] */
341
- 0x3F80u, 0x3F90u, 0x3FA0u, 0x3FB0u, 0x3FC0u, 0x3FD0u, 0x3FE0u, 0x3FF0u, /* [ 56.. 63] */
342
- 0x4000u, 0x4010u, 0x4020u, 0x4030u, 0x4040u, 0x4050u, 0x4060u, 0x4070u, /* [ 64.. 71] */
343
- 0x4080u, 0x4090u, 0x40A0u, 0x40B0u, 0x40C0u, 0x40D0u, 0x40E0u, 0x40F0u, /* [ 72.. 79] */
344
- 0x4100u, 0x4110u, 0x4120u, 0x4130u, 0x4140u, 0x4150u, 0x4160u, 0x4170u, /* [ 80.. 87] */
345
- 0x4180u, 0x4190u, 0x41A0u, 0x41B0u, 0x41C0u, 0x41D0u, 0x41E0u, 0x41F0u, /* [ 88.. 95] */
346
- 0x4200u, 0x4210u, 0x4220u, 0x4230u, 0x4240u, 0x4250u, 0x4260u, 0x4270u, /* [ 96..103] */
347
- 0x4280u, 0x4290u, 0x42A0u, 0x42B0u, 0x42C0u, 0x42D0u, 0x42E0u, 0x42F0u, /* [104..111] */
348
- 0x4300u, 0x4310u, 0x4320u, 0x4330u, 0x4340u, 0x4350u, 0x4360u, 0x4370u, /* [112..119] */
349
- 0x4380u, 0x4390u, 0x43A0u, 0x43B0u, 0x43C0u, 0x43D0u, 0x43E0u, 0x7FC0u /* [120..127] */
350
- };
351
292
  vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x80, vector_length);
352
- vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x7F, vector_length);
353
- vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(mag_u8m1, vector_length), 1,
293
+ vuint8m1_t nonsign_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x7F, vector_length);
294
+ vuint32m4_t nonsign_u32m4 = __riscv_vzext_vf4_u32m4(nonsign_u8m1, vector_length);
295
+ vuint32m4_t shifted_u32m4 = __riscv_vsll_vx_u32m4(nonsign_u32m4, 20, vector_length);
296
+ // Magic multiply: reinterpret as f32 × 2^120
297
+ vfloat32m4_t magic_f32m4 = __riscv_vreinterpret_v_u32m4_f32m4(__riscv_vmv_v_x_u32m4(0x7B800000, vector_length));
298
+ vfloat32m4_t result_f32m4 = __riscv_vfmul_vv_f32m4(__riscv_vreinterpret_v_u32m4_f32m4(shifted_u32m4), magic_f32m4,
299
+ vector_length);
300
+ // Truncate f32 → bf16 (right shift 16, exact for all e4m3 values)
301
+ vuint16m2_t result_u16m2 = __riscv_vnsrl_wx_u16m2(__riscv_vreinterpret_v_f32m4_u32m4(result_f32m4), 16,
354
302
  vector_length);
355
- vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e4m3_mag_to_bf16_lut_, offsets_u16m2, vector_length);
303
+ // NaN fixup: magnitude 0x7F → bf16 quiet NaN 0x7FC0
304
+ vbool8_t is_nan = __riscv_vmseq_vx_u8m1_b8(nonsign_u8m1, 0x7F, vector_length);
305
+ result_u16m2 = __riscv_vmerge_vxm_u16m2(result_u16m2, 0x7FC0, is_nan, vector_length);
306
+ // Restore sign: bit 7 → bf16 bit 15 (<<8)
356
307
  vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 8, vector_length);
357
308
  return __riscv_vor_vv_u16m2(result_u16m2, sign_u16m2, vector_length);
358
309
  }
359
310
 
360
- /** @brief Convert e5m2 (m1) to bf16 (m2) via sign-symmetric magnitude LUT. Sign bit 7 → bf16 bit 15 (<<8). */
311
+ /** @brief Convert e5m2 (m1) to bf16 (m2) via Giesen magic-multiply.
312
+ * Magic-multiply to f32, inf/NaN fixup, truncate upper 16 bits to bf16. */
361
313
  NK_INTERNAL vuint16m2_t nk_e5m2m1_to_bf16m2_rvv_(vuint8m1_t e5m2_u8m1, nk_size_t vector_length) {
362
- static nk_u16_t const nk_e5m2_mag_to_bf16_lut_[128] = {
363
- 0x0000u, 0x3780u, 0x3800u, 0x3840u, 0x3880u, 0x38A0u, 0x38C0u, 0x38E0u, /* [ 0.. 7] */
364
- 0x3900u, 0x3920u, 0x3940u, 0x3960u, 0x3980u, 0x39A0u, 0x39C0u, 0x39E0u, /* [ 8.. 15] */
365
- 0x3A00u, 0x3A20u, 0x3A40u, 0x3A60u, 0x3A80u, 0x3AA0u, 0x3AC0u, 0x3AE0u, /* [ 16.. 23] */
366
- 0x3B00u, 0x3B20u, 0x3B40u, 0x3B60u, 0x3B80u, 0x3BA0u, 0x3BC0u, 0x3BE0u, /* [ 24.. 31] */
367
- 0x3C00u, 0x3C20u, 0x3C40u, 0x3C60u, 0x3C80u, 0x3CA0u, 0x3CC0u, 0x3CE0u, /* [ 32.. 39] */
368
- 0x3D00u, 0x3D20u, 0x3D40u, 0x3D60u, 0x3D80u, 0x3DA0u, 0x3DC0u, 0x3DE0u, /* [ 40.. 47] */
369
- 0x3E00u, 0x3E20u, 0x3E40u, 0x3E60u, 0x3E80u, 0x3EA0u, 0x3EC0u, 0x3EE0u, /* [ 48.. 55] */
370
- 0x3F00u, 0x3F20u, 0x3F40u, 0x3F60u, 0x3F80u, 0x3FA0u, 0x3FC0u, 0x3FE0u, /* [ 56.. 63] */
371
- 0x4000u, 0x4020u, 0x4040u, 0x4060u, 0x4080u, 0x40A0u, 0x40C0u, 0x40E0u, /* [ 64.. 71] */
372
- 0x4100u, 0x4120u, 0x4140u, 0x4160u, 0x4180u, 0x41A0u, 0x41C0u, 0x41E0u, /* [ 72.. 79] */
373
- 0x4200u, 0x4220u, 0x4240u, 0x4260u, 0x4280u, 0x42A0u, 0x42C0u, 0x42E0u, /* [ 80.. 87] */
374
- 0x4300u, 0x4320u, 0x4340u, 0x4360u, 0x4380u, 0x43A0u, 0x43C0u, 0x43E0u, /* [ 88.. 95] */
375
- 0x4400u, 0x4420u, 0x4440u, 0x4460u, 0x4480u, 0x44A0u, 0x44C0u, 0x44E0u, /* [ 96..103] */
376
- 0x4500u, 0x4520u, 0x4540u, 0x4560u, 0x4580u, 0x45A0u, 0x45C0u, 0x45E0u, /* [104..111] */
377
- 0x4600u, 0x4620u, 0x4640u, 0x4660u, 0x4680u, 0x46A0u, 0x46C0u, 0x46E0u, /* [112..119] */
378
- 0x4700u, 0x4720u, 0x4740u, 0x4760u, 0x7F80u, 0x7FC0u, 0x7FC0u, 0x7FC0u /* [120..127] */
379
- };
380
314
  vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e5m2_u8m1, 0x80, vector_length);
381
- vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e5m2_u8m1, 0x7F, vector_length);
382
- vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(mag_u8m1, vector_length), 1,
383
- vector_length);
384
- vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e5m2_mag_to_bf16_lut_, offsets_u16m2, vector_length);
315
+ vuint8m1_t nonsign_u8m1 = __riscv_vand_vx_u8m1(e5m2_u8m1, 0x7F, vector_length);
316
+ vuint32m4_t nonsign_u32m4 = __riscv_vzext_vf4_u32m4(nonsign_u8m1, vector_length);
317
+ vuint32m4_t shifted_u32m4 = __riscv_vsll_vx_u32m4(nonsign_u32m4, 21, vector_length);
318
+ // Magic multiply: reinterpret as f32 × 2^112
319
+ vfloat32m4_t magic_f32m4 = __riscv_vreinterpret_v_u32m4_f32m4(__riscv_vmv_v_x_u32m4(0x77800000, vector_length));
320
+ vfloat32m4_t result_f32m4 = __riscv_vfmul_vv_f32m4(__riscv_vreinterpret_v_u32m4_f32m4(shifted_u32m4), magic_f32m4,
321
+ vector_length);
322
+ // Inf/NaN fixup: masked OR writes 0x7F800000 only into inf/NaN lanes (nonsign > 123)
323
+ vbool8_t is_infnan = __riscv_vmsgtu_vx_u32m4_b8(nonsign_u32m4, 123, vector_length);
324
+ vuint32m4_t f32_bits = __riscv_vor_vx_u32m4_mu(is_infnan, __riscv_vreinterpret_v_f32m4_u32m4(result_f32m4),
325
+ __riscv_vreinterpret_v_f32m4_u32m4(result_f32m4), 0x7F800000,
326
+ vector_length);
327
+ // Truncate f32 → bf16 (right shift 16, exact for all e5m2 values)
328
+ vuint16m2_t result_u16m2 = __riscv_vnsrl_wx_u16m2(f32_bits, 16, vector_length);
329
+ // Restore sign: bit 7 → bf16 bit 15 (<<8)
385
330
  vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 8, vector_length);
386
331
  return __riscv_vor_vv_u16m2(result_u16m2, sign_u16m2, vector_length);
387
332
  }
388
333
 
389
- /** @brief Convert e2m3 (m1) to bf16 (m2) via sign-symmetric magnitude LUT. Sign bit 5 → bf16 bit 15 (<<10). */
334
+ /** @brief Convert e2m3 (m1) to bf16 (m2) via Giesen magic-multiply.
335
+ * Magic-multiply to f32, truncate upper 16 bits to bf16. No inf/NaN in E2M3FN. */
390
336
  NK_INTERNAL vuint16m2_t nk_e2m3m1_to_bf16m2_rvv_(vuint8m1_t e2m3_u8m1, nk_size_t vector_length) {
391
- static nk_u16_t const nk_e2m3_mag_to_bf16_lut_[32] = {
392
- 0x0000u, 0x3E00u, 0x3E80u, 0x3EC0u, 0x3F00u, 0x3F20u, 0x3F40u, 0x3F60u, /* [ 0.. 7] */
393
- 0x3F80u, 0x3F90u, 0x3FA0u, 0x3FB0u, 0x3FC0u, 0x3FD0u, 0x3FE0u, 0x3FF0u, /* [ 8.. 15] */
394
- 0x4000u, 0x4010u, 0x4020u, 0x4030u, 0x4040u, 0x4050u, 0x4060u, 0x4070u, /* [ 16.. 23] */
395
- 0x4080u, 0x4090u, 0x40A0u, 0x40B0u, 0x40C0u, 0x40D0u, 0x40E0u, 0x40F0u /* [ 24.. 31] */
396
- };
397
337
  vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x20, vector_length);
398
- vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x1F, vector_length);
399
- vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(mag_u8m1, vector_length), 1,
338
+ vuint8m1_t nonsign_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x1F, vector_length);
339
+ vuint32m4_t nonsign_u32m4 = __riscv_vzext_vf4_u32m4(nonsign_u8m1, vector_length);
340
+ vuint32m4_t shifted_u32m4 = __riscv_vsll_vx_u32m4(nonsign_u32m4, 20, vector_length);
341
+ // Magic multiply: reinterpret as f32 × 2^126
342
+ vfloat32m4_t magic_f32m4 = __riscv_vreinterpret_v_u32m4_f32m4(__riscv_vmv_v_x_u32m4(0x7E800000, vector_length));
343
+ vfloat32m4_t result_f32m4 = __riscv_vfmul_vv_f32m4(__riscv_vreinterpret_v_u32m4_f32m4(shifted_u32m4), magic_f32m4,
344
+ vector_length);
345
+ // Truncate f32 → bf16 (right shift 16, exact for all e2m3 values)
346
+ vuint16m2_t result_u16m2 = __riscv_vnsrl_wx_u16m2(__riscv_vreinterpret_v_f32m4_u32m4(result_f32m4), 16,
400
347
  vector_length);
401
- vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e2m3_mag_to_bf16_lut_, offsets_u16m2, vector_length);
348
+ // Restore sign: bit 5 → bf16 bit 15 (<<10)
402
349
  vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 10,
403
350
  vector_length);
404
351
  return __riscv_vor_vv_u16m2(result_u16m2, sign_u16m2, vector_length);
405
352
  }
406
353
 
407
- /** @brief Convert e3m2 (m1) to bf16 (m2) via sign-symmetric magnitude LUT. Sign bit 5 → bf16 bit 15 (<<10). */
354
+ /** @brief Convert e3m2 (m1) to bf16 (m2) via Giesen magic-multiply.
355
+ * Magic-multiply to f32, truncate upper 16 bits to bf16. No inf/NaN in E3M2FN. */
408
356
  NK_INTERNAL vuint16m2_t nk_e3m2m1_to_bf16m2_rvv_(vuint8m1_t e3m2_u8m1, nk_size_t vector_length) {
409
- static nk_u16_t const nk_e3m2_mag_to_bf16_lut_[32] = {
410
- 0x0000u, 0x3D80u, 0x3E00u, 0x3E40u, 0x3E80u, 0x3EA0u, 0x3EC0u, 0x3EE0u, /* [ 0.. 7] */
411
- 0x3F00u, 0x3F20u, 0x3F40u, 0x3F60u, 0x3F80u, 0x3FA0u, 0x3FC0u, 0x3FE0u, /* [ 8.. 15] */
412
- 0x4000u, 0x4020u, 0x4040u, 0x4060u, 0x4080u, 0x40A0u, 0x40C0u, 0x40E0u, /* [ 16.. 23] */
413
- 0x4100u, 0x4120u, 0x4140u, 0x4160u, 0x4180u, 0x41A0u, 0x41C0u, 0x41E0u /* [ 24.. 31] */
414
- };
415
357
  vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x20, vector_length);
416
- vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x1F, vector_length);
417
- vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(mag_u8m1, vector_length), 1,
358
+ vuint8m1_t nonsign_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x1F, vector_length);
359
+ vuint32m4_t nonsign_u32m4 = __riscv_vzext_vf4_u32m4(nonsign_u8m1, vector_length);
360
+ vuint32m4_t shifted_u32m4 = __riscv_vsll_vx_u32m4(nonsign_u32m4, 21, vector_length);
361
+ // Magic multiply: reinterpret as f32 × 2^124
362
+ vfloat32m4_t magic_f32m4 = __riscv_vreinterpret_v_u32m4_f32m4(__riscv_vmv_v_x_u32m4(0x7D800000, vector_length));
363
+ vfloat32m4_t result_f32m4 = __riscv_vfmul_vv_f32m4(__riscv_vreinterpret_v_u32m4_f32m4(shifted_u32m4), magic_f32m4,
364
+ vector_length);
365
+ // Truncate f32 → bf16 (right shift 16, exact for all e3m2 values)
366
+ vuint16m2_t result_u16m2 = __riscv_vnsrl_wx_u16m2(__riscv_vreinterpret_v_f32m4_u32m4(result_f32m4), 16,
418
367
  vector_length);
419
- vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e3m2_mag_to_bf16_lut_, offsets_u16m2, vector_length);
368
+ // Restore sign: bit 5 → bf16 bit 15 (<<10)
420
369
  vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 10,
421
370
  vector_length);
422
371
  return __riscv_vor_vv_u16m2(result_u16m2, sign_u16m2, vector_length);
@@ -443,8 +392,8 @@ NK_INTERNAL vuint16m2_t nk_e4m3m1_to_f16m2_rvv_(vuint8m1_t e4m3_u8m1, nk_size_t
443
392
  0x5C00u, 0x5C80u, 0x5D00u, 0x5D80u, 0x5E00u, 0x5E80u, 0x5F00u, 0x7E00u /* [120..127] */
444
393
  };
445
394
  vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x80, vector_length);
446
- vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x7F, vector_length);
447
- vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(mag_u8m1, vector_length), 1,
395
+ vuint8m1_t nonsign_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x7F, vector_length);
396
+ vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(nonsign_u8m1, vector_length), 1,
448
397
  vector_length);
449
398
  vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e4m3_mag_to_f16_lut_, offsets_u16m2, vector_length);
450
399
  vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 8, vector_length);
@@ -460,8 +409,8 @@ NK_INTERNAL vuint16m2_t nk_e2m3m1_to_f16m2_rvv_(vuint8m1_t e2m3_u8m1, nk_size_t
460
409
  0x4400u, 0x4480u, 0x4500u, 0x4580u, 0x4600u, 0x4680u, 0x4700u, 0x4780u /* [ 24.. 31] */
461
410
  };
462
411
  vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x20, vector_length);
463
- vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x1F, vector_length);
464
- vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(mag_u8m1, vector_length), 1,
412
+ vuint8m1_t nonsign_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x1F, vector_length);
413
+ vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(nonsign_u8m1, vector_length), 1,
465
414
  vector_length);
466
415
  vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e2m3_mag_to_f16_lut_, offsets_u16m2, vector_length);
467
416
  vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 10,
@@ -478,8 +427,8 @@ NK_INTERNAL vuint16m2_t nk_e3m2m1_to_f16m2_rvv_(vuint8m1_t e3m2_u8m1, nk_size_t
478
427
  0x4800u, 0x4900u, 0x4A00u, 0x4B00u, 0x4C00u, 0x4D00u, 0x4E00u, 0x4F00u /* [ 24.. 31] */
479
428
  };
480
429
  vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x20, vector_length);
481
- vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x1F, vector_length);
482
- vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(mag_u8m1, vector_length), 1,
430
+ vuint8m1_t nonsign_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x1F, vector_length);
431
+ vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(nonsign_u8m1, vector_length), 1,
483
432
  vector_length);
484
433
  vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e3m2_mag_to_f16_lut_, offsets_u16m2, vector_length);
485
434
  vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 10,
@@ -501,18 +450,18 @@ NK_INTERNAL vuint16m2_t nk_e3m2m1_to_f16m2_rvv_(vuint8m1_t e3m2_u8m1, nk_size_t
501
450
  */
502
451
  NK_INTERNAL vint8m1x2_t nk_i4m1_to_i8m2_rvv_(vuint8m1_t packed_u8m1, nk_size_t vector_length) {
503
452
  // Extract high nibble (even indices in output)
504
- vuint8m1_t hi_u8m1 = __riscv_vsrl_vx_u8m1(packed_u8m1, 4, vector_length);
453
+ vuint8m1_t high_u8m1 = __riscv_vsrl_vx_u8m1(packed_u8m1, 4, vector_length);
505
454
  // Sign extend: (x ^ 8) - 8
506
- vint8m1_t hi_i8m1 = __riscv_vsub_vx_i8m1(
507
- __riscv_vxor_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(hi_u8m1), 8, vector_length), 8, vector_length);
455
+ vint8m1_t high_i8m1 = __riscv_vsub_vx_i8m1(
456
+ __riscv_vxor_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(high_u8m1), 8, vector_length), 8, vector_length);
508
457
 
509
458
  // Extract low nibble (odd indices in output)
510
- vuint8m1_t lo_u8m1 = __riscv_vand_vx_u8m1(packed_u8m1, 0x0F, vector_length);
459
+ vuint8m1_t low_u8m1 = __riscv_vand_vx_u8m1(packed_u8m1, 0x0F, vector_length);
511
460
  // Sign extend: (x ^ 8) - 8
512
- vint8m1_t lo_i8m1 = __riscv_vsub_vx_i8m1(
513
- __riscv_vxor_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(lo_u8m1), 8, vector_length), 8, vector_length);
461
+ vint8m1_t low_i8m1 = __riscv_vsub_vx_i8m1(
462
+ __riscv_vxor_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(low_u8m1), 8, vector_length), 8, vector_length);
514
463
 
515
- return __riscv_vcreate_v_i8m1x2(hi_i8m1, lo_i8m1);
464
+ return __riscv_vcreate_v_i8m1x2(high_i8m1, low_i8m1);
516
465
  }
517
466
 
518
467
  /**
@@ -522,12 +471,12 @@ NK_INTERNAL vint8m1x2_t nk_i4m1_to_i8m2_rvv_(vuint8m1_t packed_u8m1, nk_size_t v
522
471
  */
523
472
  NK_INTERNAL vuint8m1x2_t nk_u4m1_to_u8m2_rvv_(vuint8m1_t packed_u8m1, nk_size_t vector_length) {
524
473
  // Extract high nibble (even indices in output)
525
- vuint8m1_t hi_u8m1 = __riscv_vsrl_vx_u8m1(packed_u8m1, 4, vector_length);
474
+ vuint8m1_t high_u8m1 = __riscv_vsrl_vx_u8m1(packed_u8m1, 4, vector_length);
526
475
 
527
476
  // Extract low nibble (odd indices in output)
528
- vuint8m1_t lo_u8m1 = __riscv_vand_vx_u8m1(packed_u8m1, 0x0F, vector_length);
477
+ vuint8m1_t low_u8m1 = __riscv_vand_vx_u8m1(packed_u8m1, 0x0F, vector_length);
529
478
 
530
- return __riscv_vcreate_v_u8m1x2(hi_u8m1, lo_u8m1);
479
+ return __riscv_vcreate_v_u8m1x2(high_u8m1, low_u8m1);
531
480
  }
532
481
 
533
482
  /**
@@ -536,17 +485,17 @@ NK_INTERNAL vuint8m1x2_t nk_u4m1_to_u8m2_rvv_(vuint8m1_t packed_u8m1, nk_size_t
536
485
  * Takes a tuple of two m1 vectors (high nibbles, low nibbles from segment load).
537
486
  * Values are clamped to [-8, 7] before packing.
538
487
  */
539
- NK_INTERNAL vuint8m1_t nk_i8m2_to_i4m1_rvv_(vint8m1_t hi_i8m1, vint8m1_t lo_i8m1, nk_size_t vector_length) {
488
+ NK_INTERNAL vuint8m1_t nk_i8m2_to_i4m1_rvv_(vint8m1_t high_i8m1, vint8m1_t low_i8m1, nk_size_t vector_length) {
540
489
  // Clamp to [-8, 7]
541
- hi_i8m1 = __riscv_vmax_vx_i8m1(__riscv_vmin_vx_i8m1(hi_i8m1, 7, vector_length), -8, vector_length);
542
- lo_i8m1 = __riscv_vmax_vx_i8m1(__riscv_vmin_vx_i8m1(lo_i8m1, 7, vector_length), -8, vector_length);
490
+ high_i8m1 = __riscv_vmax_vx_i8m1(__riscv_vmin_vx_i8m1(high_i8m1, 7, vector_length), -8, vector_length);
491
+ low_i8m1 = __riscv_vmax_vx_i8m1(__riscv_vmin_vx_i8m1(low_i8m1, 7, vector_length), -8, vector_length);
543
492
 
544
493
  // Convert to unsigned nibbles: value & 0x0F
545
- vuint8m1_t hi_u4m1 = __riscv_vand_vx_u8m1(__riscv_vreinterpret_v_i8m1_u8m1(hi_i8m1), 0x0F, vector_length);
546
- vuint8m1_t lo_u4m1 = __riscv_vand_vx_u8m1(__riscv_vreinterpret_v_i8m1_u8m1(lo_i8m1), 0x0F, vector_length);
494
+ vuint8m1_t high_u4m1 = __riscv_vand_vx_u8m1(__riscv_vreinterpret_v_i8m1_u8m1(high_i8m1), 0x0F, vector_length);
495
+ vuint8m1_t low_u4m1 = __riscv_vand_vx_u8m1(__riscv_vreinterpret_v_i8m1_u8m1(low_i8m1), 0x0F, vector_length);
547
496
 
548
497
  // Pack: (hi << 4) | lo
549
- return __riscv_vor_vv_u8m1(__riscv_vsll_vx_u8m1(hi_u4m1, 4, vector_length), lo_u4m1, vector_length);
498
+ return __riscv_vor_vv_u8m1(__riscv_vsll_vx_u8m1(high_u4m1, 4, vector_length), low_u4m1, vector_length);
550
499
  }
551
500
 
552
501
  /**
@@ -555,13 +504,13 @@ NK_INTERNAL vuint8m1_t nk_i8m2_to_i4m1_rvv_(vint8m1_t hi_i8m1, vint8m1_t lo_i8m1
555
504
  * Takes a tuple of two m1 vectors (high nibbles, low nibbles from segment load).
556
505
  * Values are clamped to [0, 15] before packing.
557
506
  */
558
- NK_INTERNAL vuint8m1_t nk_u8m2_to_u4m1_rvv_(vuint8m1_t hi_u8m1, vuint8m1_t lo_u8m1, nk_size_t vector_length) {
507
+ NK_INTERNAL vuint8m1_t nk_u8m2_to_u4m1_rvv_(vuint8m1_t high_u8m1, vuint8m1_t low_u8m1, nk_size_t vector_length) {
559
508
  // Clamp to [0, 15]
560
- hi_u8m1 = __riscv_vminu_vx_u8m1(hi_u8m1, 15, vector_length);
561
- lo_u8m1 = __riscv_vminu_vx_u8m1(lo_u8m1, 15, vector_length);
509
+ high_u8m1 = __riscv_vminu_vx_u8m1(high_u8m1, 15, vector_length);
510
+ low_u8m1 = __riscv_vminu_vx_u8m1(low_u8m1, 15, vector_length);
562
511
 
563
512
  // Pack: (hi << 4) | lo
564
- return __riscv_vor_vv_u8m1(__riscv_vsll_vx_u8m1(hi_u8m1, 4, vector_length), lo_u8m1, vector_length);
513
+ return __riscv_vor_vv_u8m1(__riscv_vsll_vx_u8m1(high_u8m1, 4, vector_length), low_u8m1, vector_length);
565
514
  }
566
515
 
567
516
  /**
@@ -721,9 +670,9 @@ NK_INTERNAL vuint8m1_t nk_f32m4_to_e5m2m1_rvv_(vfloat32m4_t f32_f32m4, nk_size_t
721
670
  return __riscv_vncvt_x_x_w_u8m1(result_u16m2, vector_length);
722
671
  }
723
672
 
724
- #pragma endregion - Register - to - Register Helpers
673
+ #pragma endregion Register - to - Register Helpers
725
674
 
726
- #pragma region - Unified Cast Dispatcher
675
+ #pragma region Unified Cast Dispatcher
727
676
 
728
677
  NK_PUBLIC void nk_cast_rvv(void const *from, nk_dtype_t from_type, nk_size_t count, void *to, nk_dtype_t to_type) {
729
678
  // bf16 → f32
@@ -975,9 +924,9 @@ NK_PUBLIC void nk_cast_rvv(void const *from, nk_dtype_t from_type, nk_size_t cou
975
924
  n_bytes -= vector_length, source += vector_length * 2, destination += vector_length) {
976
925
  vector_length = __riscv_vsetvl_e8m1(n_bytes);
977
926
  vint8m1x2_t loaded_i8m1x2 = __riscv_vlseg2e8_v_i8m1x2(source, vector_length);
978
- vint8m1_t hi_i8m1 = __riscv_vget_v_i8m1x2_i8m1(loaded_i8m1x2, 0);
979
- vint8m1_t lo_i8m1 = __riscv_vget_v_i8m1x2_i8m1(loaded_i8m1x2, 1);
980
- vuint8m1_t packed_u8m1 = nk_i8m2_to_i4m1_rvv_(hi_i8m1, lo_i8m1, vector_length);
927
+ vint8m1_t high_i8m1 = __riscv_vget_v_i8m1x2_i8m1(loaded_i8m1x2, 0);
928
+ vint8m1_t low_i8m1 = __riscv_vget_v_i8m1x2_i8m1(loaded_i8m1x2, 1);
929
+ vuint8m1_t packed_u8m1 = nk_i8m2_to_i4m1_rvv_(high_i8m1, low_i8m1, vector_length);
981
930
  __riscv_vse8_v_u8m1((nk_u8_t *)destination, packed_u8m1, vector_length);
982
931
  }
983
932
  return;
@@ -992,9 +941,9 @@ NK_PUBLIC void nk_cast_rvv(void const *from, nk_dtype_t from_type, nk_size_t cou
992
941
  n_bytes -= vector_length, source += vector_length * 2, destination += vector_length) {
993
942
  vector_length = __riscv_vsetvl_e8m1(n_bytes);
994
943
  vuint8m1x2_t loaded_u8m1x2 = __riscv_vlseg2e8_v_u8m1x2(source, vector_length);
995
- vuint8m1_t hi_u8m1 = __riscv_vget_v_u8m1x2_u8m1(loaded_u8m1x2, 0);
996
- vuint8m1_t lo_u8m1 = __riscv_vget_v_u8m1x2_u8m1(loaded_u8m1x2, 1);
997
- vuint8m1_t packed_u8m1 = nk_u8m2_to_u4m1_rvv_(hi_u8m1, lo_u8m1, vector_length);
944
+ vuint8m1_t high_u8m1 = __riscv_vget_v_u8m1x2_u8m1(loaded_u8m1x2, 0);
945
+ vuint8m1_t low_u8m1 = __riscv_vget_v_u8m1x2_u8m1(loaded_u8m1x2, 1);
946
+ vuint8m1_t packed_u8m1 = nk_u8m2_to_u4m1_rvv_(high_u8m1, low_u8m1, vector_length);
998
947
  __riscv_vse8_v_u8m1((nk_u8_t *)destination, packed_u8m1, vector_length);
999
948
  }
1000
949
  return;
@@ -1004,7 +953,7 @@ NK_PUBLIC void nk_cast_rvv(void const *from, nk_dtype_t from_type, nk_size_t cou
1004
953
  nk_cast_serial(from, from_type, count, to, to_type);
1005
954
  }
1006
955
 
1007
- #pragma endregion - Unified Cast Dispatcher
956
+ #pragma endregion Unified Cast Dispatcher
1008
957
 
1009
958
  #if defined(__cplusplus)
1010
959
  } // extern "C"