numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,1021 @@
1
+ /**
2
+ * @brief SIMD-accelerated Type Conversions for RISC-V.
3
+ * @file include/numkong/cast/rvv.h
4
+ * @author Ash Vardanian
5
+ * @date January 13, 2026
6
+ *
7
+ * @sa include/numkong/cast.h
8
+ *
9
+ * SpacemiT K1 and similar chips implement RVA22 profile with base RVV 1.0.
10
+ * This file provides vectorized type conversions for:
11
+ * - BF16 ↔ F32 (bit manipulation, no hardware support)
12
+ * - F16 ↔ F32 (bit manipulation, no hardware support)
13
+ * - E4M3 ↔ F32 (FP8 format for ML inference)
14
+ * - E5M2 ↔ F32 (FP8 format for ML training)
15
+ * - i4/u4 unpacking to i8/u8
16
+ *
17
+ * Mini-float conversions use sign-symmetric magnitude LUTs: every mini-float
18
+ * format is sign|magnitude, so we store only the positive-half (magnitude)
19
+ * entries and extract the sign bit separately. This cuts LUT memory by 50-87%
20
+ * and fixes the E2M3FN NaN bug (E2M3FN has NO NaN; index 31 is +7.5, not NaN).
21
+ *
22
+ * 8-bit formats (e4m3, e5m2): sign = bit 7, magnitude = bits 6:0 (128 entries)
23
+ * 6-bit formats (e2m3, e3m2): sign = bit 5, magnitude = bits 4:0 (32 entries)
24
+ *
25
+ * @section rvv_cast_instructions Key RVV Cast Instructions
26
+ *
27
+ * Intrinsic Purpose
28
+ * vzext_vf4_u32m4 Zero-extend u8 → u32 (4x widening)
29
+ * vsext_vf4_i32m4 Sign-extend i8 → i32 (4x widening)
30
+ * vsll_vx / vsrl_vx Bit shifts for field extraction
31
+ * vand_vx Bit masking
32
+ * vor_vv Combining bit fields
33
+ * vfcvt_f_xu_v Unsigned int → float
34
+ * vmseq_vx Compare for conditional selection
35
+ * vmerge_vvm Conditional select (blend)
36
+ */
37
+ #ifndef NK_CAST_RVV_H
38
+ #define NK_CAST_RVV_H
39
+
40
+ #if NK_TARGET_RISCV_
41
+ #if NK_TARGET_RVV
42
+
43
+ #include "numkong/types.h"
44
+ #include "numkong/cast/serial.h" // `nk_cast_serial`
45
+
46
+ #if defined(__clang__)
47
+ #pragma clang attribute push(__attribute__((target("arch=+v"))), apply_to = function)
48
+ #elif defined(__GNUC__)
49
+ #pragma GCC push_options
50
+ #pragma GCC target("arch=+v")
51
+ #endif
52
+
53
+ #if defined(__cplusplus)
54
+ extern "C" {
55
+ #endif
56
+
57
+ #pragma region - Register-to-Register Helpers
58
+
59
+ /**
60
+ * @brief Convert bf16 (m1) to f32 (m2) register-to-register.
61
+ *
62
+ * BF16 is the upper 16 bits of F32 (same sign + exponent + top 7 mantissa bits).
63
+ * Conversion is simply: f32_bits = bf16_bits << 16.
64
+ */
65
+ NK_INTERNAL vfloat32m2_t nk_bf16m1_to_f32m2_rvv_(vuint16m1_t bf16_u16m1, nk_size_t vector_length) {
66
+ vuint32m2_t bits_u32m2 = __riscv_vzext_vf2_u32m2(bf16_u16m1, vector_length);
67
+ bits_u32m2 = __riscv_vsll_vx_u32m2(bits_u32m2, 16, vector_length);
68
+ return __riscv_vreinterpret_v_u32m2_f32m2(bits_u32m2);
69
+ }
70
+
71
+ /**
72
+ * @brief Convert f32 (m2) to bf16 (m1) register-to-register.
73
+ *
74
+ * Conversion with round-to-nearest-even (RNE): add (0x7FFF + lsb) to match hardware BF16 behavior.
75
+ */
76
+ NK_INTERNAL vuint16m1_t nk_f32m2_to_bf16m1_rvv_(vfloat32m2_t f32_f32m2, nk_size_t vector_length) {
77
+ vuint32m2_t bits_u32m2 = __riscv_vreinterpret_v_f32m2_u32m2(f32_f32m2);
78
+ // Extract LSB of result (bit 16) for round-to-nearest-even
79
+ vuint32m2_t lsb_u32m2 = __riscv_vand_vx_u32m2(__riscv_vsrl_vx_u32m2(bits_u32m2, 16, vector_length), 1,
80
+ vector_length);
81
+ vuint32m2_t rounding_u32m2 = __riscv_vadd_vx_u32m2(lsb_u32m2, 0x7FFF, vector_length);
82
+ vuint32m2_t rounded_u32m2 = __riscv_vadd_vv_u32m2(bits_u32m2, rounding_u32m2, vector_length);
83
+ vuint32m2_t shifted_u32m2 = __riscv_vsrl_vx_u32m2(rounded_u32m2, 16, vector_length);
84
+ return __riscv_vncvt_x_x_w_u16m1(shifted_u32m2, vector_length);
85
+ }
86
+
87
+ /**
88
+ * @brief Convert f16 (m1) to f32 (m2) register-to-register.
89
+ *
90
+ * F16 format: S EEEEE MMMMMMMMMM (1 sign, 5 exponent bits with bias=15, 10 mantissa bits)
91
+ * F32 format: S EEEEEEEE MMMMMMMMMMMMMMMMMMMMMMM (1 sign, 8 exponent bits with bias=127, 23 mantissa bits)
92
+ *
93
+ * Handles all IEEE-754 edge cases: ±zero, denormals, normals, ±inf, NaN.
94
+ */
95
+ NK_INTERNAL vfloat32m2_t nk_f16m1_to_f32m2_rvv_(vuint16m1_t f16_u16m1, nk_size_t vector_length) {
96
+ // Widen to 32-bit for manipulation
97
+ vuint32m2_t bits_u32m2 = __riscv_vzext_vf2_u32m2(f16_u16m1, vector_length);
98
+ // Extract sign: (raw >> 15) << 31
99
+ vuint32m2_t sign_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vsrl_vx_u32m2(bits_u32m2, 15, vector_length), 31,
100
+ vector_length);
101
+ // Extract exponent: (raw >> 10) & 0x1F
102
+ vuint32m2_t exponent_u32m2 = __riscv_vand_vx_u32m2(__riscv_vsrl_vx_u32m2(bits_u32m2, 10, vector_length), 0x1F,
103
+ vector_length);
104
+ // Extract mantissa: raw & 0x3FF
105
+ vuint32m2_t mantissa_u32m2 = __riscv_vand_vx_u32m2(bits_u32m2, 0x3FF, vector_length);
106
+
107
+ // Normal path: rebias exponent (15 → 127): add 112, combine
108
+ vuint32m2_t f32_exponent_u32m2 = __riscv_vadd_vx_u32m2(exponent_u32m2, 112, vector_length);
109
+ vuint32m2_t normal_u32m2 = __riscv_vor_vv_u32m2(
110
+ sign_u32m2,
111
+ __riscv_vor_vv_u32m2(__riscv_vsll_vx_u32m2(f32_exponent_u32m2, 23, vector_length),
112
+ __riscv_vsll_vx_u32m2(mantissa_u32m2, 13, vector_length), vector_length),
113
+ vector_length);
114
+
115
+ // Special case: exponent == 0 (zero or denormal)
116
+ // Zero: sign | 0. Denormal: mantissa × 2^(-24), handled via FPU normalization trick.
117
+ // For denormals, convert mantissa to float and subtract 0x0C000000 (24 from exponent),
118
+ // matching the serial implementation. For zeros (mantissa==0), (float)0 - bias = 0.
119
+ vbool16_t is_exp_zero = __riscv_vmseq_vx_u32m2_b16(exponent_u32m2, 0, vector_length);
120
+ vfloat32m2_t mantissa_f32m2 = __riscv_vfcvt_f_xu_v_f32m2(mantissa_u32m2, vector_length);
121
+ vuint32m2_t denorm_bits_u32m2 = __riscv_vsub_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(mantissa_f32m2),
122
+ 0x0C000000, vector_length);
123
+ vuint32m2_t zero_or_denorm_u32m2 = __riscv_vor_vv_u32m2(sign_u32m2, denorm_bits_u32m2, vector_length);
124
+ // For true zeros (mantissa==0), the FPU converts 0 to 0x00000000, minus bias wraps,
125
+ // so force to sign-only.
126
+ vbool16_t is_true_zero = __riscv_vmand_mm_b16(
127
+ is_exp_zero, __riscv_vmseq_vx_u32m2_b16(mantissa_u32m2, 0, vector_length), vector_length);
128
+ zero_or_denorm_u32m2 = __riscv_vmerge_vvm_u32m2(zero_or_denorm_u32m2, sign_u32m2, is_true_zero, vector_length);
129
+
130
+ // Special case: exponent == 31 (infinity or NaN)
131
+ // sign | 0x7F800000 | (mantissa << 13)
132
+ vbool16_t is_exp_max = __riscv_vmseq_vx_u32m2_b16(exponent_u32m2, 31, vector_length);
133
+ vuint32m2_t inf_nan_u32m2 = __riscv_vor_vv_u32m2(__riscv_vor_vx_u32m2(sign_u32m2, 0x7F800000, vector_length),
134
+ __riscv_vsll_vx_u32m2(mantissa_u32m2, 13, vector_length),
135
+ vector_length);
136
+
137
+ // Select: exp==0 → zero_or_denorm, exp==31 → inf_nan, else → normal
138
+ vuint32m2_t result_u32m2 = __riscv_vmerge_vvm_u32m2(normal_u32m2, zero_or_denorm_u32m2, is_exp_zero, vector_length);
139
+ result_u32m2 = __riscv_vmerge_vvm_u32m2(result_u32m2, inf_nan_u32m2, is_exp_max, vector_length);
140
+ return __riscv_vreinterpret_v_u32m2_f32m2(result_u32m2);
141
+ }
142
+
143
+ /**
144
+ * @brief Convert f32 (m2) to f16 (m1) register-to-register.
145
+ *
146
+ * Conversion: Rebias exponent from 127 to 15, truncate mantissa from 23 to 10 bits with rounding.
147
+ */
148
+ NK_INTERNAL vuint16m1_t nk_f32m2_to_f16m1_rvv_(vfloat32m2_t f32_f32m2, nk_size_t vector_length) {
149
+ vuint32m2_t bits_u32m2 = __riscv_vreinterpret_v_f32m2_u32m2(f32_f32m2);
150
+ // Extract sign: (raw >> 31) << 15
151
+ vuint32m2_t sign_u32m2 = __riscv_vsll_vx_u32m2(__riscv_vsrl_vx_u32m2(bits_u32m2, 31, vector_length), 15,
152
+ vector_length);
153
+ // Extract exponent: (raw >> 23) & 0xFF
154
+ vuint32m2_t exponent_u32m2 = __riscv_vand_vx_u32m2(__riscv_vsrl_vx_u32m2(bits_u32m2, 23, vector_length), 0xFF,
155
+ vector_length);
156
+ // Extract mantissa: raw & 0x7FFFFF
157
+ vuint32m2_t mantissa_u32m2 = __riscv_vand_vx_u32m2(bits_u32m2, 0x7FFFFF, vector_length);
158
+ // Rebias exponent (127 → 15): subtract 112, clamp to [0, 31]
159
+ // Note: This is a simplified conversion that doesn't handle subnormals or overflow properly
160
+ vint32m2_t exponent_i32m2 = __riscv_vsub_vx_i32m2(__riscv_vreinterpret_v_u32m2_i32m2(exponent_u32m2), 112,
161
+ vector_length);
162
+ exponent_i32m2 = __riscv_vmax_vx_i32m2(exponent_i32m2, 0, vector_length);
163
+ vuint32m2_t f16_exponent_u32m2 = __riscv_vreinterpret_v_i32m2_u32m2(
164
+ __riscv_vmin_vx_i32m2(exponent_i32m2, 31, vector_length));
165
+ // Round mantissa: add 0x1000 (half of truncated bits) then shift.
166
+ // If rounding overflows the mantissa (bit 23 set), carry into exponent.
167
+ vuint32m2_t rounded_mantissa_u32m2 = __riscv_vadd_vx_u32m2(mantissa_u32m2, 0x1000, vector_length);
168
+ vbool16_t mantissa_overflow_b16 = __riscv_vmsne_vx_u32m2_b16(
169
+ __riscv_vand_vx_u32m2(rounded_mantissa_u32m2, 0x800000, vector_length), 0, vector_length);
170
+ f16_exponent_u32m2 = __riscv_vadd_vx_u32m2_mu(mantissa_overflow_b16, f16_exponent_u32m2, f16_exponent_u32m2, 1,
171
+ vector_length);
172
+ vuint32m2_t f16_mantissa_u32m2 = __riscv_vsrl_vx_u32m2(rounded_mantissa_u32m2, 13, vector_length);
173
+ f16_mantissa_u32m2 = __riscv_vand_vx_u32m2(f16_mantissa_u32m2, 0x3FF, vector_length);
174
+ // Combine: sign | (exponent << 10) | mantissa
175
+ vuint32m2_t result_u32m2 = __riscv_vor_vv_u32m2(
176
+ sign_u32m2,
177
+ __riscv_vor_vv_u32m2(__riscv_vsll_vx_u32m2(f16_exponent_u32m2, 10, vector_length), f16_mantissa_u32m2,
178
+ vector_length),
179
+ vector_length);
180
+ return __riscv_vncvt_x_x_w_u16m1(result_u32m2, vector_length);
181
+ }
182
+
183
+ /**
184
+ * @brief Convert e4m3 (m1) to f32 (m4) via sign-symmetric magnitude LUT.
185
+ * E4M3FN: sign = bit 7, magnitude = bits 6:0 (128 entries). Sign bit 7 → f32 bit 31 (<<24).
186
+ */
187
+ NK_INTERNAL vfloat32m4_t nk_e4m3m1_to_f32m4_rvv_(vuint8m1_t e4m3_u8m1, nk_size_t vector_length) {
188
+ static nk_u32_t const nk_e4m3_mag_to_f32_lut_[128] = {
189
+ 0x00000000u, 0x3B000000u, 0x3B800000u, 0x3BC00000u,
190
+ 0x3C000000u, 0x3C200000u, 0x3C400000u, 0x3C600000u, /* [ 0.. 7] */
191
+ 0x3C800000u, 0x3C900000u, 0x3CA00000u, 0x3CB00000u,
192
+ 0x3CC00000u, 0x3CD00000u, 0x3CE00000u, 0x3CF00000u, /* [ 8.. 15] */
193
+ 0x3D000000u, 0x3D100000u, 0x3D200000u, 0x3D300000u,
194
+ 0x3D400000u, 0x3D500000u, 0x3D600000u, 0x3D700000u, /* [ 16.. 23] */
195
+ 0x3D800000u, 0x3D900000u, 0x3DA00000u, 0x3DB00000u,
196
+ 0x3DC00000u, 0x3DD00000u, 0x3DE00000u, 0x3DF00000u, /* [ 24.. 31] */
197
+ 0x3E000000u, 0x3E100000u, 0x3E200000u, 0x3E300000u,
198
+ 0x3E400000u, 0x3E500000u, 0x3E600000u, 0x3E700000u, /* [ 32.. 39] */
199
+ 0x3E800000u, 0x3E900000u, 0x3EA00000u, 0x3EB00000u,
200
+ 0x3EC00000u, 0x3ED00000u, 0x3EE00000u, 0x3EF00000u, /* [ 40.. 47] */
201
+ 0x3F000000u, 0x3F100000u, 0x3F200000u, 0x3F300000u,
202
+ 0x3F400000u, 0x3F500000u, 0x3F600000u, 0x3F700000u, /* [ 48.. 55] */
203
+ 0x3F800000u, 0x3F900000u, 0x3FA00000u, 0x3FB00000u,
204
+ 0x3FC00000u, 0x3FD00000u, 0x3FE00000u, 0x3FF00000u, /* [ 56.. 63] */
205
+ 0x40000000u, 0x40100000u, 0x40200000u, 0x40300000u,
206
+ 0x40400000u, 0x40500000u, 0x40600000u, 0x40700000u, /* [ 64.. 71] */
207
+ 0x40800000u, 0x40900000u, 0x40A00000u, 0x40B00000u,
208
+ 0x40C00000u, 0x40D00000u, 0x40E00000u, 0x40F00000u, /* [ 72.. 79] */
209
+ 0x41000000u, 0x41100000u, 0x41200000u, 0x41300000u,
210
+ 0x41400000u, 0x41500000u, 0x41600000u, 0x41700000u, /* [ 80.. 87] */
211
+ 0x41800000u, 0x41900000u, 0x41A00000u, 0x41B00000u,
212
+ 0x41C00000u, 0x41D00000u, 0x41E00000u, 0x41F00000u, /* [ 88.. 95] */
213
+ 0x42000000u, 0x42100000u, 0x42200000u, 0x42300000u,
214
+ 0x42400000u, 0x42500000u, 0x42600000u, 0x42700000u, /* [ 96..103] */
215
+ 0x42800000u, 0x42900000u, 0x42A00000u, 0x42B00000u,
216
+ 0x42C00000u, 0x42D00000u, 0x42E00000u, 0x42F00000u, /* [104..111] */
217
+ 0x43000000u, 0x43100000u, 0x43200000u, 0x43300000u,
218
+ 0x43400000u, 0x43500000u, 0x43600000u, 0x43700000u, /* [112..119] */
219
+ 0x43800000u, 0x43900000u, 0x43A00000u, 0x43B00000u,
220
+ 0x43C00000u, 0x43D00000u, 0x43E00000u, 0x7FC00000u /* [120..127] */
221
+ };
222
+ vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x80, vector_length);
223
+ vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x7F, vector_length);
224
+ vuint32m4_t offsets_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(mag_u8m1, vector_length), 2,
225
+ vector_length);
226
+ vuint32m4_t result_u32m4 = __riscv_vluxei32_v_u32m4(nk_e4m3_mag_to_f32_lut_, offsets_u32m4, vector_length);
227
+ vuint32m4_t sign_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(sign_u8m1, vector_length), 24,
228
+ vector_length);
229
+ return __riscv_vreinterpret_v_u32m4_f32m4(__riscv_vor_vv_u32m4(result_u32m4, sign_u32m4, vector_length));
230
+ }
231
+
232
+ /**
233
+ * @brief Convert e5m2 (m1) to f32 (m4) via sign-symmetric magnitude LUT.
234
+ * E5M2: sign = bit 7, magnitude = bits 6:0 (128 entries). Sign bit 7 → f32 bit 31 (<<24).
235
+ */
236
+ NK_INTERNAL vfloat32m4_t nk_e5m2m1_to_f32m4_rvv_(vuint8m1_t e5m2_u8m1, nk_size_t vector_length) {
237
+ static nk_u32_t const nk_e5m2_mag_to_f32_lut_[128] = {
238
+ 0x00000000u, 0x37800000u, 0x38000000u, 0x38400000u,
239
+ 0x38800000u, 0x38A00000u, 0x38C00000u, 0x38E00000u, /* [ 0.. 7] */
240
+ 0x39000000u, 0x39200000u, 0x39400000u, 0x39600000u,
241
+ 0x39800000u, 0x39A00000u, 0x39C00000u, 0x39E00000u, /* [ 8.. 15] */
242
+ 0x3A000000u, 0x3A200000u, 0x3A400000u, 0x3A600000u,
243
+ 0x3A800000u, 0x3AA00000u, 0x3AC00000u, 0x3AE00000u, /* [ 16.. 23] */
244
+ 0x3B000000u, 0x3B200000u, 0x3B400000u, 0x3B600000u,
245
+ 0x3B800000u, 0x3BA00000u, 0x3BC00000u, 0x3BE00000u, /* [ 24.. 31] */
246
+ 0x3C000000u, 0x3C200000u, 0x3C400000u, 0x3C600000u,
247
+ 0x3C800000u, 0x3CA00000u, 0x3CC00000u, 0x3CE00000u, /* [ 32.. 39] */
248
+ 0x3D000000u, 0x3D200000u, 0x3D400000u, 0x3D600000u,
249
+ 0x3D800000u, 0x3DA00000u, 0x3DC00000u, 0x3DE00000u, /* [ 40.. 47] */
250
+ 0x3E000000u, 0x3E200000u, 0x3E400000u, 0x3E600000u,
251
+ 0x3E800000u, 0x3EA00000u, 0x3EC00000u, 0x3EE00000u, /* [ 48.. 55] */
252
+ 0x3F000000u, 0x3F200000u, 0x3F400000u, 0x3F600000u,
253
+ 0x3F800000u, 0x3FA00000u, 0x3FC00000u, 0x3FE00000u, /* [ 56.. 63] */
254
+ 0x40000000u, 0x40200000u, 0x40400000u, 0x40600000u,
255
+ 0x40800000u, 0x40A00000u, 0x40C00000u, 0x40E00000u, /* [ 64.. 71] */
256
+ 0x41000000u, 0x41200000u, 0x41400000u, 0x41600000u,
257
+ 0x41800000u, 0x41A00000u, 0x41C00000u, 0x41E00000u, /* [ 72.. 79] */
258
+ 0x42000000u, 0x42200000u, 0x42400000u, 0x42600000u,
259
+ 0x42800000u, 0x42A00000u, 0x42C00000u, 0x42E00000u, /* [ 80.. 87] */
260
+ 0x43000000u, 0x43200000u, 0x43400000u, 0x43600000u,
261
+ 0x43800000u, 0x43A00000u, 0x43C00000u, 0x43E00000u, /* [ 88.. 95] */
262
+ 0x44000000u, 0x44200000u, 0x44400000u, 0x44600000u,
263
+ 0x44800000u, 0x44A00000u, 0x44C00000u, 0x44E00000u, /* [ 96..103] */
264
+ 0x45000000u, 0x45200000u, 0x45400000u, 0x45600000u,
265
+ 0x45800000u, 0x45A00000u, 0x45C00000u, 0x45E00000u, /* [104..111] */
266
+ 0x46000000u, 0x46200000u, 0x46400000u, 0x46600000u,
267
+ 0x46800000u, 0x46A00000u, 0x46C00000u, 0x46E00000u, /* [112..119] */
268
+ 0x47000000u, 0x47200000u, 0x47400000u, 0x47600000u,
269
+ 0x7F800000u, 0x7FC00000u, 0x7FC00000u, 0x7FC00000u /* [120..127] */
270
+ };
271
+ vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e5m2_u8m1, 0x80, vector_length);
272
+ vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e5m2_u8m1, 0x7F, vector_length);
273
+ vuint32m4_t offsets_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(mag_u8m1, vector_length), 2,
274
+ vector_length);
275
+ vuint32m4_t result_u32m4 = __riscv_vluxei32_v_u32m4(nk_e5m2_mag_to_f32_lut_, offsets_u32m4, vector_length);
276
+ vuint32m4_t sign_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(sign_u8m1, vector_length), 24,
277
+ vector_length);
278
+ return __riscv_vreinterpret_v_u32m4_f32m4(__riscv_vor_vv_u32m4(result_u32m4, sign_u32m4, vector_length));
279
+ }
280
+
281
+ /**
282
+ * @brief Convert e2m3 (m1) to f32 (m4) via sign-symmetric magnitude LUT.
283
+ * E2M3FN: sign = bit 5, magnitude = bits 4:0 (32 entries). Sign bit 5 → f32 bit 31 (<<26).
284
+ */
285
+ NK_INTERNAL vfloat32m4_t nk_e2m3m1_to_f32m4_rvv_(vuint8m1_t e2m3_u8m1, nk_size_t vector_length) {
286
+ static nk_u32_t const nk_e2m3_mag_to_f32_lut_[32] = {
287
+ 0x00000000u, 0x3E000000u, 0x3E800000u, 0x3EC00000u,
288
+ 0x3F000000u, 0x3F200000u, 0x3F400000u, 0x3F600000u, /* [ 0.. 7] */
289
+ 0x3F800000u, 0x3F900000u, 0x3FA00000u, 0x3FB00000u,
290
+ 0x3FC00000u, 0x3FD00000u, 0x3FE00000u, 0x3FF00000u, /* [ 8.. 15] */
291
+ 0x40000000u, 0x40100000u, 0x40200000u, 0x40300000u,
292
+ 0x40400000u, 0x40500000u, 0x40600000u, 0x40700000u, /* [ 16.. 23] */
293
+ 0x40800000u, 0x40900000u, 0x40A00000u, 0x40B00000u,
294
+ 0x40C00000u, 0x40D00000u, 0x40E00000u, 0x40F00000u /* [ 24.. 31] */
295
+ };
296
+ vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x20, vector_length);
297
+ vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x1F, vector_length);
298
+ vuint32m4_t offsets_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(mag_u8m1, vector_length), 2,
299
+ vector_length);
300
+ vuint32m4_t result_u32m4 = __riscv_vluxei32_v_u32m4(nk_e2m3_mag_to_f32_lut_, offsets_u32m4, vector_length);
301
+ vuint32m4_t sign_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(sign_u8m1, vector_length), 26,
302
+ vector_length);
303
+ return __riscv_vreinterpret_v_u32m4_f32m4(__riscv_vor_vv_u32m4(result_u32m4, sign_u32m4, vector_length));
304
+ }
305
+
306
+ /**
307
+ * @brief Convert e3m2 (m1) to f32 (m4) via sign-symmetric magnitude LUT.
308
+ * E3M2FN: sign = bit 5, magnitude = bits 4:0 (32 entries). Sign bit 5 → f32 bit 31 (<<26).
309
+ */
310
+ NK_INTERNAL vfloat32m4_t nk_e3m2m1_to_f32m4_rvv_(vuint8m1_t e3m2_u8m1, nk_size_t vector_length) {
311
+ static nk_u32_t const nk_e3m2_mag_to_f32_lut_[32] = {
312
+ 0x00000000u, 0x3D800000u, 0x3E000000u, 0x3E400000u,
313
+ 0x3E800000u, 0x3EA00000u, 0x3EC00000u, 0x3EE00000u, /* [ 0.. 7] */
314
+ 0x3F000000u, 0x3F200000u, 0x3F400000u, 0x3F600000u,
315
+ 0x3F800000u, 0x3FA00000u, 0x3FC00000u, 0x3FE00000u, /* [ 8.. 15] */
316
+ 0x40000000u, 0x40200000u, 0x40400000u, 0x40600000u,
317
+ 0x40800000u, 0x40A00000u, 0x40C00000u, 0x40E00000u, /* [ 16.. 23] */
318
+ 0x41000000u, 0x41200000u, 0x41400000u, 0x41600000u,
319
+ 0x41800000u, 0x41A00000u, 0x41C00000u, 0x41E00000u /* [ 24.. 31] */
320
+ };
321
+ vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x20, vector_length);
322
+ vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x1F, vector_length);
323
+ vuint32m4_t offsets_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(mag_u8m1, vector_length), 2,
324
+ vector_length);
325
+ vuint32m4_t result_u32m4 = __riscv_vluxei32_v_u32m4(nk_e3m2_mag_to_f32_lut_, offsets_u32m4, vector_length);
326
+ vuint32m4_t sign_u32m4 = __riscv_vsll_vx_u32m4(__riscv_vzext_vf4_u32m4(sign_u8m1, vector_length), 26,
327
+ vector_length);
328
+ return __riscv_vreinterpret_v_u32m4_f32m4(__riscv_vor_vv_u32m4(result_u32m4, sign_u32m4, vector_length));
329
+ }
330
+
331
+ /** @brief Convert e4m3 (m1) to bf16 (m2) via sign-symmetric magnitude LUT. Sign bit 7 → bf16 bit 15 (<<8). */
332
+ NK_INTERNAL vuint16m2_t nk_e4m3m1_to_bf16m2_rvv_(vuint8m1_t e4m3_u8m1, nk_size_t vector_length) {
333
+ static nk_u16_t const nk_e4m3_mag_to_bf16_lut_[128] = {
334
+ 0x0000u, 0x3B00u, 0x3B80u, 0x3BC0u, 0x3C00u, 0x3C20u, 0x3C40u, 0x3C60u, /* [ 0.. 7] */
335
+ 0x3C80u, 0x3C90u, 0x3CA0u, 0x3CB0u, 0x3CC0u, 0x3CD0u, 0x3CE0u, 0x3CF0u, /* [ 8.. 15] */
336
+ 0x3D00u, 0x3D10u, 0x3D20u, 0x3D30u, 0x3D40u, 0x3D50u, 0x3D60u, 0x3D70u, /* [ 16.. 23] */
337
+ 0x3D80u, 0x3D90u, 0x3DA0u, 0x3DB0u, 0x3DC0u, 0x3DD0u, 0x3DE0u, 0x3DF0u, /* [ 24.. 31] */
338
+ 0x3E00u, 0x3E10u, 0x3E20u, 0x3E30u, 0x3E40u, 0x3E50u, 0x3E60u, 0x3E70u, /* [ 32.. 39] */
339
+ 0x3E80u, 0x3E90u, 0x3EA0u, 0x3EB0u, 0x3EC0u, 0x3ED0u, 0x3EE0u, 0x3EF0u, /* [ 40.. 47] */
340
+ 0x3F00u, 0x3F10u, 0x3F20u, 0x3F30u, 0x3F40u, 0x3F50u, 0x3F60u, 0x3F70u, /* [ 48.. 55] */
341
+ 0x3F80u, 0x3F90u, 0x3FA0u, 0x3FB0u, 0x3FC0u, 0x3FD0u, 0x3FE0u, 0x3FF0u, /* [ 56.. 63] */
342
+ 0x4000u, 0x4010u, 0x4020u, 0x4030u, 0x4040u, 0x4050u, 0x4060u, 0x4070u, /* [ 64.. 71] */
343
+ 0x4080u, 0x4090u, 0x40A0u, 0x40B0u, 0x40C0u, 0x40D0u, 0x40E0u, 0x40F0u, /* [ 72.. 79] */
344
+ 0x4100u, 0x4110u, 0x4120u, 0x4130u, 0x4140u, 0x4150u, 0x4160u, 0x4170u, /* [ 80.. 87] */
345
+ 0x4180u, 0x4190u, 0x41A0u, 0x41B0u, 0x41C0u, 0x41D0u, 0x41E0u, 0x41F0u, /* [ 88.. 95] */
346
+ 0x4200u, 0x4210u, 0x4220u, 0x4230u, 0x4240u, 0x4250u, 0x4260u, 0x4270u, /* [ 96..103] */
347
+ 0x4280u, 0x4290u, 0x42A0u, 0x42B0u, 0x42C0u, 0x42D0u, 0x42E0u, 0x42F0u, /* [104..111] */
348
+ 0x4300u, 0x4310u, 0x4320u, 0x4330u, 0x4340u, 0x4350u, 0x4360u, 0x4370u, /* [112..119] */
349
+ 0x4380u, 0x4390u, 0x43A0u, 0x43B0u, 0x43C0u, 0x43D0u, 0x43E0u, 0x7FC0u /* [120..127] */
350
+ };
351
+ vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x80, vector_length);
352
+ vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x7F, vector_length);
353
+ vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(mag_u8m1, vector_length), 1,
354
+ vector_length);
355
+ vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e4m3_mag_to_bf16_lut_, offsets_u16m2, vector_length);
356
+ vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 8, vector_length);
357
+ return __riscv_vor_vv_u16m2(result_u16m2, sign_u16m2, vector_length);
358
+ }
359
+
360
+ /** @brief Convert e5m2 (m1) to bf16 (m2) via sign-symmetric magnitude LUT. Sign bit 7 → bf16 bit 15 (<<8). */
361
+ NK_INTERNAL vuint16m2_t nk_e5m2m1_to_bf16m2_rvv_(vuint8m1_t e5m2_u8m1, nk_size_t vector_length) {
362
+ static nk_u16_t const nk_e5m2_mag_to_bf16_lut_[128] = {
363
+ 0x0000u, 0x3780u, 0x3800u, 0x3840u, 0x3880u, 0x38A0u, 0x38C0u, 0x38E0u, /* [ 0.. 7] */
364
+ 0x3900u, 0x3920u, 0x3940u, 0x3960u, 0x3980u, 0x39A0u, 0x39C0u, 0x39E0u, /* [ 8.. 15] */
365
+ 0x3A00u, 0x3A20u, 0x3A40u, 0x3A60u, 0x3A80u, 0x3AA0u, 0x3AC0u, 0x3AE0u, /* [ 16.. 23] */
366
+ 0x3B00u, 0x3B20u, 0x3B40u, 0x3B60u, 0x3B80u, 0x3BA0u, 0x3BC0u, 0x3BE0u, /* [ 24.. 31] */
367
+ 0x3C00u, 0x3C20u, 0x3C40u, 0x3C60u, 0x3C80u, 0x3CA0u, 0x3CC0u, 0x3CE0u, /* [ 32.. 39] */
368
+ 0x3D00u, 0x3D20u, 0x3D40u, 0x3D60u, 0x3D80u, 0x3DA0u, 0x3DC0u, 0x3DE0u, /* [ 40.. 47] */
369
+ 0x3E00u, 0x3E20u, 0x3E40u, 0x3E60u, 0x3E80u, 0x3EA0u, 0x3EC0u, 0x3EE0u, /* [ 48.. 55] */
370
+ 0x3F00u, 0x3F20u, 0x3F40u, 0x3F60u, 0x3F80u, 0x3FA0u, 0x3FC0u, 0x3FE0u, /* [ 56.. 63] */
371
+ 0x4000u, 0x4020u, 0x4040u, 0x4060u, 0x4080u, 0x40A0u, 0x40C0u, 0x40E0u, /* [ 64.. 71] */
372
+ 0x4100u, 0x4120u, 0x4140u, 0x4160u, 0x4180u, 0x41A0u, 0x41C0u, 0x41E0u, /* [ 72.. 79] */
373
+ 0x4200u, 0x4220u, 0x4240u, 0x4260u, 0x4280u, 0x42A0u, 0x42C0u, 0x42E0u, /* [ 80.. 87] */
374
+ 0x4300u, 0x4320u, 0x4340u, 0x4360u, 0x4380u, 0x43A0u, 0x43C0u, 0x43E0u, /* [ 88.. 95] */
375
+ 0x4400u, 0x4420u, 0x4440u, 0x4460u, 0x4480u, 0x44A0u, 0x44C0u, 0x44E0u, /* [ 96..103] */
376
+ 0x4500u, 0x4520u, 0x4540u, 0x4560u, 0x4580u, 0x45A0u, 0x45C0u, 0x45E0u, /* [104..111] */
377
+ 0x4600u, 0x4620u, 0x4640u, 0x4660u, 0x4680u, 0x46A0u, 0x46C0u, 0x46E0u, /* [112..119] */
378
+ 0x4700u, 0x4720u, 0x4740u, 0x4760u, 0x7F80u, 0x7FC0u, 0x7FC0u, 0x7FC0u /* [120..127] */
379
+ };
380
+ vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e5m2_u8m1, 0x80, vector_length);
381
+ vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e5m2_u8m1, 0x7F, vector_length);
382
+ vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(mag_u8m1, vector_length), 1,
383
+ vector_length);
384
+ vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e5m2_mag_to_bf16_lut_, offsets_u16m2, vector_length);
385
+ vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 8, vector_length);
386
+ return __riscv_vor_vv_u16m2(result_u16m2, sign_u16m2, vector_length);
387
+ }
388
+
389
+ /** @brief Convert e2m3 (m1) to bf16 (m2) via sign-symmetric magnitude LUT. Sign bit 5 → bf16 bit 15 (<<10). */
390
+ NK_INTERNAL vuint16m2_t nk_e2m3m1_to_bf16m2_rvv_(vuint8m1_t e2m3_u8m1, nk_size_t vector_length) {
391
+ static nk_u16_t const nk_e2m3_mag_to_bf16_lut_[32] = {
392
+ 0x0000u, 0x3E00u, 0x3E80u, 0x3EC0u, 0x3F00u, 0x3F20u, 0x3F40u, 0x3F60u, /* [ 0.. 7] */
393
+ 0x3F80u, 0x3F90u, 0x3FA0u, 0x3FB0u, 0x3FC0u, 0x3FD0u, 0x3FE0u, 0x3FF0u, /* [ 8.. 15] */
394
+ 0x4000u, 0x4010u, 0x4020u, 0x4030u, 0x4040u, 0x4050u, 0x4060u, 0x4070u, /* [ 16.. 23] */
395
+ 0x4080u, 0x4090u, 0x40A0u, 0x40B0u, 0x40C0u, 0x40D0u, 0x40E0u, 0x40F0u /* [ 24.. 31] */
396
+ };
397
+ vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x20, vector_length);
398
+ vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x1F, vector_length);
399
+ vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(mag_u8m1, vector_length), 1,
400
+ vector_length);
401
+ vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e2m3_mag_to_bf16_lut_, offsets_u16m2, vector_length);
402
+ vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 10,
403
+ vector_length);
404
+ return __riscv_vor_vv_u16m2(result_u16m2, sign_u16m2, vector_length);
405
+ }
406
+
407
+ /** @brief Convert e3m2 (m1) to bf16 (m2) via sign-symmetric magnitude LUT. Sign bit 5 → bf16 bit 15 (<<10). */
408
+ NK_INTERNAL vuint16m2_t nk_e3m2m1_to_bf16m2_rvv_(vuint8m1_t e3m2_u8m1, nk_size_t vector_length) {
409
+ static nk_u16_t const nk_e3m2_mag_to_bf16_lut_[32] = {
410
+ 0x0000u, 0x3D80u, 0x3E00u, 0x3E40u, 0x3E80u, 0x3EA0u, 0x3EC0u, 0x3EE0u, /* [ 0.. 7] */
411
+ 0x3F00u, 0x3F20u, 0x3F40u, 0x3F60u, 0x3F80u, 0x3FA0u, 0x3FC0u, 0x3FE0u, /* [ 8.. 15] */
412
+ 0x4000u, 0x4020u, 0x4040u, 0x4060u, 0x4080u, 0x40A0u, 0x40C0u, 0x40E0u, /* [ 16.. 23] */
413
+ 0x4100u, 0x4120u, 0x4140u, 0x4160u, 0x4180u, 0x41A0u, 0x41C0u, 0x41E0u /* [ 24.. 31] */
414
+ };
415
+ vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x20, vector_length);
416
+ vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x1F, vector_length);
417
+ vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(mag_u8m1, vector_length), 1,
418
+ vector_length);
419
+ vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e3m2_mag_to_bf16_lut_, offsets_u16m2, vector_length);
420
+ vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 10,
421
+ vector_length);
422
+ return __riscv_vor_vv_u16m2(result_u16m2, sign_u16m2, vector_length);
423
+ }
424
+
425
+ /** @brief Convert e4m3 (m1) to f16 (m2) via sign-symmetric magnitude LUT. Sign bit 7 → f16 bit 15 (<<8). */
426
+ NK_INTERNAL vuint16m2_t nk_e4m3m1_to_f16m2_rvv_(vuint8m1_t e4m3_u8m1, nk_size_t vector_length) {
427
+ static nk_u16_t const nk_e4m3_mag_to_f16_lut_[128] = {
428
+ 0x0000u, 0x1800u, 0x1C00u, 0x1E00u, 0x2000u, 0x2100u, 0x2200u, 0x2300u, /* [ 0.. 7] */
429
+ 0x2400u, 0x2480u, 0x2500u, 0x2580u, 0x2600u, 0x2680u, 0x2700u, 0x2780u, /* [ 8.. 15] */
430
+ 0x2800u, 0x2880u, 0x2900u, 0x2980u, 0x2A00u, 0x2A80u, 0x2B00u, 0x2B80u, /* [ 16.. 23] */
431
+ 0x2C00u, 0x2C80u, 0x2D00u, 0x2D80u, 0x2E00u, 0x2E80u, 0x2F00u, 0x2F80u, /* [ 24.. 31] */
432
+ 0x3000u, 0x3080u, 0x3100u, 0x3180u, 0x3200u, 0x3280u, 0x3300u, 0x3380u, /* [ 32.. 39] */
433
+ 0x3400u, 0x3480u, 0x3500u, 0x3580u, 0x3600u, 0x3680u, 0x3700u, 0x3780u, /* [ 40.. 47] */
434
+ 0x3800u, 0x3880u, 0x3900u, 0x3980u, 0x3A00u, 0x3A80u, 0x3B00u, 0x3B80u, /* [ 48.. 55] */
435
+ 0x3C00u, 0x3C80u, 0x3D00u, 0x3D80u, 0x3E00u, 0x3E80u, 0x3F00u, 0x3F80u, /* [ 56.. 63] */
436
+ 0x4000u, 0x4080u, 0x4100u, 0x4180u, 0x4200u, 0x4280u, 0x4300u, 0x4380u, /* [ 64.. 71] */
437
+ 0x4400u, 0x4480u, 0x4500u, 0x4580u, 0x4600u, 0x4680u, 0x4700u, 0x4780u, /* [ 72.. 79] */
438
+ 0x4800u, 0x4880u, 0x4900u, 0x4980u, 0x4A00u, 0x4A80u, 0x4B00u, 0x4B80u, /* [ 80.. 87] */
439
+ 0x4C00u, 0x4C80u, 0x4D00u, 0x4D80u, 0x4E00u, 0x4E80u, 0x4F00u, 0x4F80u, /* [ 88.. 95] */
440
+ 0x5000u, 0x5080u, 0x5100u, 0x5180u, 0x5200u, 0x5280u, 0x5300u, 0x5380u, /* [ 96..103] */
441
+ 0x5400u, 0x5480u, 0x5500u, 0x5580u, 0x5600u, 0x5680u, 0x5700u, 0x5780u, /* [104..111] */
442
+ 0x5800u, 0x5880u, 0x5900u, 0x5980u, 0x5A00u, 0x5A80u, 0x5B00u, 0x5B80u, /* [112..119] */
443
+ 0x5C00u, 0x5C80u, 0x5D00u, 0x5D80u, 0x5E00u, 0x5E80u, 0x5F00u, 0x7E00u /* [120..127] */
444
+ };
445
+ vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x80, vector_length);
446
+ vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e4m3_u8m1, 0x7F, vector_length);
447
+ vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(mag_u8m1, vector_length), 1,
448
+ vector_length);
449
+ vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e4m3_mag_to_f16_lut_, offsets_u16m2, vector_length);
450
+ vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 8, vector_length);
451
+ return __riscv_vor_vv_u16m2(result_u16m2, sign_u16m2, vector_length);
452
+ }
453
+
454
+ /** @brief Convert e2m3 (m1) to f16 (m2) via sign-symmetric magnitude LUT. Sign bit 5 → f16 bit 15 (<<10). */
455
+ NK_INTERNAL vuint16m2_t nk_e2m3m1_to_f16m2_rvv_(vuint8m1_t e2m3_u8m1, nk_size_t vector_length) {
456
+ static nk_u16_t const nk_e2m3_mag_to_f16_lut_[32] = {
457
+ 0x0000u, 0x3000u, 0x3400u, 0x3600u, 0x3800u, 0x3900u, 0x3A00u, 0x3B00u, /* [ 0.. 7] */
458
+ 0x3C00u, 0x3C80u, 0x3D00u, 0x3D80u, 0x3E00u, 0x3E80u, 0x3F00u, 0x3F80u, /* [ 8.. 15] */
459
+ 0x4000u, 0x4080u, 0x4100u, 0x4180u, 0x4200u, 0x4280u, 0x4300u, 0x4380u, /* [ 16.. 23] */
460
+ 0x4400u, 0x4480u, 0x4500u, 0x4580u, 0x4600u, 0x4680u, 0x4700u, 0x4780u /* [ 24.. 31] */
461
+ };
462
+ vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x20, vector_length);
463
+ vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e2m3_u8m1, 0x1F, vector_length);
464
+ vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(mag_u8m1, vector_length), 1,
465
+ vector_length);
466
+ vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e2m3_mag_to_f16_lut_, offsets_u16m2, vector_length);
467
+ vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 10,
468
+ vector_length);
469
+ return __riscv_vor_vv_u16m2(result_u16m2, sign_u16m2, vector_length);
470
+ }
471
+
472
+ /** @brief Convert e3m2 (m1) to f16 (m2) via sign-symmetric magnitude LUT. Sign bit 5 → f16 bit 15 (<<10). */
473
+ NK_INTERNAL vuint16m2_t nk_e3m2m1_to_f16m2_rvv_(vuint8m1_t e3m2_u8m1, nk_size_t vector_length) {
474
+ static nk_u16_t const nk_e3m2_mag_to_f16_lut_[32] = {
475
+ 0x0000u, 0x2C00u, 0x3000u, 0x3200u, 0x3400u, 0x3500u, 0x3600u, 0x3700u, /* [ 0.. 7] */
476
+ 0x3800u, 0x3900u, 0x3A00u, 0x3B00u, 0x3C00u, 0x3D00u, 0x3E00u, 0x3F00u, /* [ 8.. 15] */
477
+ 0x4000u, 0x4100u, 0x4200u, 0x4300u, 0x4400u, 0x4500u, 0x4600u, 0x4700u, /* [ 16.. 23] */
478
+ 0x4800u, 0x4900u, 0x4A00u, 0x4B00u, 0x4C00u, 0x4D00u, 0x4E00u, 0x4F00u /* [ 24.. 31] */
479
+ };
480
+ vuint8m1_t sign_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x20, vector_length);
481
+ vuint8m1_t mag_u8m1 = __riscv_vand_vx_u8m1(e3m2_u8m1, 0x1F, vector_length);
482
+ vuint16m2_t offsets_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(mag_u8m1, vector_length), 1,
483
+ vector_length);
484
+ vuint16m2_t result_u16m2 = __riscv_vluxei16_v_u16m2(nk_e3m2_mag_to_f16_lut_, offsets_u16m2, vector_length);
485
+ vuint16m2_t sign_u16m2 = __riscv_vsll_vx_u16m2(__riscv_vzext_vf2_u16m2(sign_u8m1, vector_length), 10,
486
+ vector_length);
487
+ return __riscv_vor_vv_u16m2(result_u16m2, sign_u16m2, vector_length);
488
+ }
489
+
490
+ /**
491
+ * @brief Unpack i4 (m1) nibbles to i8 (m2) register-to-register.
492
+ *
493
+ * Packed format: byte[i] contains two nibbles:
494
+ * - High nibble (bits [7:4]) → output[i*2]
495
+ * - Low nibble (bits [3:0]) → output[i*2+1]
496
+ *
497
+ * Sign extension: 4-bit signed value [-8,7] extended to 8-bit.
498
+ * Trick: (x ^ 8) - 8 sign-extends a 4-bit value to larger type.
499
+ *
500
+ * Returns a tuple of two m1 vectors (high nibbles, low nibbles) for segment store.
501
+ */
502
+ NK_INTERNAL vint8m1x2_t nk_i4m1_to_i8m2_rvv_(vuint8m1_t packed_u8m1, nk_size_t vector_length) {
503
+ // Extract high nibble (even indices in output)
504
+ vuint8m1_t hi_u8m1 = __riscv_vsrl_vx_u8m1(packed_u8m1, 4, vector_length);
505
+ // Sign extend: (x ^ 8) - 8
506
+ vint8m1_t hi_i8m1 = __riscv_vsub_vx_i8m1(
507
+ __riscv_vxor_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(hi_u8m1), 8, vector_length), 8, vector_length);
508
+
509
+ // Extract low nibble (odd indices in output)
510
+ vuint8m1_t lo_u8m1 = __riscv_vand_vx_u8m1(packed_u8m1, 0x0F, vector_length);
511
+ // Sign extend: (x ^ 8) - 8
512
+ vint8m1_t lo_i8m1 = __riscv_vsub_vx_i8m1(
513
+ __riscv_vxor_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(lo_u8m1), 8, vector_length), 8, vector_length);
514
+
515
+ return __riscv_vcreate_v_i8m1x2(hi_i8m1, lo_i8m1);
516
+ }
517
+
518
+ /**
519
+ * @brief Unpack u4 (m1) nibbles to u8 (m2) register-to-register.
520
+ *
521
+ * Returns a tuple of two m1 vectors (high nibbles, low nibbles) for segment store.
522
+ */
523
+ NK_INTERNAL vuint8m1x2_t nk_u4m1_to_u8m2_rvv_(vuint8m1_t packed_u8m1, nk_size_t vector_length) {
524
+ // Extract high nibble (even indices in output)
525
+ vuint8m1_t hi_u8m1 = __riscv_vsrl_vx_u8m1(packed_u8m1, 4, vector_length);
526
+
527
+ // Extract low nibble (odd indices in output)
528
+ vuint8m1_t lo_u8m1 = __riscv_vand_vx_u8m1(packed_u8m1, 0x0F, vector_length);
529
+
530
+ return __riscv_vcreate_v_u8m1x2(hi_u8m1, lo_u8m1);
531
+ }
532
+
533
+ /**
534
+ * @brief Pack i8 (m2) to i4 (m1) nibbles register-to-register.
535
+ *
536
+ * Takes a tuple of two m1 vectors (high nibbles, low nibbles from segment load).
537
+ * Values are clamped to [-8, 7] before packing.
538
+ */
539
+ NK_INTERNAL vuint8m1_t nk_i8m2_to_i4m1_rvv_(vint8m1_t hi_i8m1, vint8m1_t lo_i8m1, nk_size_t vector_length) {
540
+ // Clamp to [-8, 7]
541
+ hi_i8m1 = __riscv_vmax_vx_i8m1(__riscv_vmin_vx_i8m1(hi_i8m1, 7, vector_length), -8, vector_length);
542
+ lo_i8m1 = __riscv_vmax_vx_i8m1(__riscv_vmin_vx_i8m1(lo_i8m1, 7, vector_length), -8, vector_length);
543
+
544
+ // Convert to unsigned nibbles: value & 0x0F
545
+ vuint8m1_t hi_u4m1 = __riscv_vand_vx_u8m1(__riscv_vreinterpret_v_i8m1_u8m1(hi_i8m1), 0x0F, vector_length);
546
+ vuint8m1_t lo_u4m1 = __riscv_vand_vx_u8m1(__riscv_vreinterpret_v_i8m1_u8m1(lo_i8m1), 0x0F, vector_length);
547
+
548
+ // Pack: (hi << 4) | lo
549
+ return __riscv_vor_vv_u8m1(__riscv_vsll_vx_u8m1(hi_u4m1, 4, vector_length), lo_u4m1, vector_length);
550
+ }
551
+
552
+ /**
553
+ * @brief Pack u8 (m2) to u4 (m1) nibbles register-to-register.
554
+ *
555
+ * Takes a tuple of two m1 vectors (high nibbles, low nibbles from segment load).
556
+ * Values are clamped to [0, 15] before packing.
557
+ */
558
+ NK_INTERNAL vuint8m1_t nk_u8m2_to_u4m1_rvv_(vuint8m1_t hi_u8m1, vuint8m1_t lo_u8m1, nk_size_t vector_length) {
559
+ // Clamp to [0, 15]
560
+ hi_u8m1 = __riscv_vminu_vx_u8m1(hi_u8m1, 15, vector_length);
561
+ lo_u8m1 = __riscv_vminu_vx_u8m1(lo_u8m1, 15, vector_length);
562
+
563
+ // Pack: (hi << 4) | lo
564
+ return __riscv_vor_vv_u8m1(__riscv_vsll_vx_u8m1(hi_u8m1, 4, vector_length), lo_u8m1, vector_length);
565
+ }
566
+
567
+ /**
568
+ * @brief Convert f32 (m4) to e4m3 (m1) register-to-register.
569
+ *
570
+ * E4M3FN format: S EEEE MMM (1 sign, 4 exponent bits with bias=7, 3 mantissa bits)
571
+ * Handles normal, subnormal, overflow, and NaN. Uses RNE mantissa rounding.
572
+ * E4M3FN quirk: exp=15 with mant=7 is NaN (0x7F), so max finite is 0x7E (exp=15, mant=6).
573
+ */
574
+ NK_INTERNAL vuint8m1_t nk_f32m4_to_e4m3m1_rvv_(vfloat32m4_t f32_f32m4, nk_size_t vector_length) {
575
+ vuint32m4_t bits_u32m4 = __riscv_vreinterpret_v_f32m4_u32m4(f32_f32m4);
576
+ vuint32m4_t sign_u32m4 = __riscv_vsrl_vx_u32m4(bits_u32m4, 31, vector_length);
577
+ vuint32m4_t abs_bits_u32m4 = __riscv_vand_vx_u32m4(bits_u32m4, 0x7FFFFFFF, vector_length);
578
+ vuint32m4_t f32_exp_u32m4 = __riscv_vand_vx_u32m4(__riscv_vsrl_vx_u32m4(bits_u32m4, 23, vector_length), 0xFF,
579
+ vector_length);
580
+
581
+ // Round mantissa from 23 to 3 bits using RNE (round to nearest, ties to even)
582
+ vuint32m4_t significand_u32m4 = __riscv_vor_vx_u32m4(__riscv_vand_vx_u32m4(bits_u32m4, 0x007FFFFF, vector_length),
583
+ 0x00800000, vector_length);
584
+ vuint32m4_t lsb_u32m4 = __riscv_vand_vx_u32m4(__riscv_vsrl_vx_u32m4(significand_u32m4, 20, vector_length), 1,
585
+ vector_length);
586
+ vuint32m4_t rounding_bias_u32m4 = __riscv_vadd_vx_u32m4(lsb_u32m4, 0x0007FFFF, vector_length);
587
+ vuint32m4_t rounded_sig_u32m4 = __riscv_vadd_vv_u32m4(significand_u32m4, rounding_bias_u32m4, vector_length);
588
+ vuint32m4_t carry_u32m4 = __riscv_vsrl_vx_u32m4(rounded_sig_u32m4, 24, vector_length);
589
+ vuint32m4_t f32_mantissa_u32m4 = __riscv_vand_vx_u32m4(__riscv_vsrl_vx_u32m4(rounded_sig_u32m4, 20, vector_length),
590
+ 0x07, vector_length);
591
+ // If carry, mantissa becomes 0 (rounded up to next power of 2)
592
+ vbool8_t has_carry_b8 = __riscv_vmsne_vx_u32m4_b8(carry_u32m4, 0, vector_length);
593
+ f32_mantissa_u32m4 = __riscv_vmerge_vxm_u32m4(f32_mantissa_u32m4, 0, has_carry_b8, vector_length);
594
+
595
+ // e4m3_exp = f32_exp + carry - 120
596
+ vint32m4_t e4m3_exp_i32m4 = __riscv_vsub_vx_i32m4(
597
+ __riscv_vreinterpret_v_u32m4_i32m4(__riscv_vadd_vv_u32m4(f32_exp_u32m4, carry_u32m4, vector_length)), 120,
598
+ vector_length);
599
+
600
+ // Detect subnormal (exp <= 0) and overflow (exp > 15)
601
+ vbool8_t is_subnormal_b8 = __riscv_vmsle_vx_i32m4_b8(e4m3_exp_i32m4, 0, vector_length);
602
+ vbool8_t is_overflow_b8 = __riscv_vmsgt_vx_i32m4_b8(e4m3_exp_i32m4, 15, vector_length);
603
+
604
+ // Normal path: clamp exp to [1,15]
605
+ vint32m4_t clamped_exp_i32m4 = __riscv_vmax_vx_i32m4(e4m3_exp_i32m4, 1, vector_length);
606
+ clamped_exp_i32m4 = __riscv_vmin_vx_i32m4(clamped_exp_i32m4, 15, vector_length);
607
+ // E4M3FN quirk: exp=15 with mant=7 is NaN, so cap mantissa to 6 when exp=15
608
+ vbool8_t is_max_exp_b8 = __riscv_vmseq_vx_i32m4_b8(clamped_exp_i32m4, 15, vector_length);
609
+ vuint32m4_t max_mant_u32m4 = __riscv_vmerge_vxm_u32m4(__riscv_vmv_v_x_u32m4(7, vector_length), 6, is_max_exp_b8,
610
+ vector_length);
611
+ vuint32m4_t normal_mant_u32m4 = __riscv_vminu_vv_u32m4(f32_mantissa_u32m4, max_mant_u32m4, vector_length);
612
+ // On overflow, saturate to max finite (exp=15, mant=6 = 0x7E with sign)
613
+ normal_mant_u32m4 = __riscv_vmerge_vxm_u32m4(normal_mant_u32m4, 0x06, is_overflow_b8, vector_length);
614
+ vuint32m4_t normal_u32m4 = __riscv_vor_vv_u32m4(
615
+ __riscv_vsll_vx_u32m4(sign_u32m4, 7, vector_length),
616
+ __riscv_vor_vv_u32m4(
617
+ __riscv_vsll_vx_u32m4(__riscv_vreinterpret_v_i32m4_u32m4(clamped_exp_i32m4), 3, vector_length),
618
+ normal_mant_u32m4, vector_length),
619
+ vector_length);
620
+
621
+ // Subnormal path: mantissa = round(|f32| * 512)
622
+ vfloat32m4_t abs_f32m4 = __riscv_vreinterpret_v_u32m4_f32m4(abs_bits_u32m4);
623
+ vfloat32m4_t scaled_f32m4 = __riscv_vfmul_vf_f32m4(abs_f32m4, 512.0f, vector_length);
624
+ vint32m4_t subnorm_mant_i32m4 = __riscv_vfcvt_x_f_v_i32m4(scaled_f32m4, vector_length); // RNE rounding
625
+ // If rounds to 8+, promote to first normal (exp=1, mant=0 = 0x08)
626
+ vbool8_t promotes_b8 = __riscv_vmsgt_vx_i32m4_b8(subnorm_mant_i32m4, 7, vector_length);
627
+ subnorm_mant_i32m4 = __riscv_vmin_vx_i32m4(subnorm_mant_i32m4, 7, vector_length);
628
+ subnorm_mant_i32m4 = __riscv_vmax_vx_i32m4(subnorm_mant_i32m4, 0, vector_length);
629
+ vuint32m4_t subnorm_u32m4 = __riscv_vor_vv_u32m4(__riscv_vsll_vx_u32m4(sign_u32m4, 7, vector_length),
630
+ __riscv_vreinterpret_v_i32m4_u32m4(subnorm_mant_i32m4),
631
+ vector_length);
632
+ vuint32m4_t first_normal_u32m4 = __riscv_vor_vx_u32m4(__riscv_vsll_vx_u32m4(sign_u32m4, 7, vector_length), 0x08,
633
+ vector_length);
634
+ subnorm_u32m4 = __riscv_vmerge_vvm_u32m4(subnorm_u32m4, first_normal_u32m4, promotes_b8, vector_length);
635
+
636
+ // Select: subnormal when exp <= 0, else normal
637
+ vuint32m4_t result_u32m4 = __riscv_vmerge_vvm_u32m4(normal_u32m4, subnorm_u32m4, is_subnormal_b8, vector_length);
638
+
639
+ // Handle NaN: f32 NaN (abs_bits > 0x7F800000) → e4m3 NaN (sign | 0x7F)
640
+ vbool8_t is_nan_b8 = __riscv_vmsgtu_vx_u32m4_b8(abs_bits_u32m4, 0x7F800000, vector_length);
641
+ vuint32m4_t nan_u32m4 = __riscv_vor_vx_u32m4(__riscv_vsll_vx_u32m4(sign_u32m4, 7, vector_length), 0x7F,
642
+ vector_length);
643
+ result_u32m4 = __riscv_vmerge_vvm_u32m4(result_u32m4, nan_u32m4, is_nan_b8, vector_length);
644
+
645
+ // Narrow u32m4 → u16m2 → u8m1
646
+ vuint16m2_t result_u16m2 = __riscv_vncvt_x_x_w_u16m2(result_u32m4, vector_length);
647
+ return __riscv_vncvt_x_x_w_u8m1(result_u16m2, vector_length);
648
+ }
649
+
650
+ /**
651
+ * @brief Convert f32 (m4) to e5m2 (m1) register-to-register.
652
+ *
653
+ * E5M2 format: S EEEEE MM (1 sign, 5 exponent bits with bias=15, 2 mantissa bits)
654
+ * Handles normal, subnormal, overflow (→ infinity), and NaN. Uses RNE mantissa rounding.
655
+ */
656
+ NK_INTERNAL vuint8m1_t nk_f32m4_to_e5m2m1_rvv_(vfloat32m4_t f32_f32m4, nk_size_t vector_length) {
657
+ vuint32m4_t bits_u32m4 = __riscv_vreinterpret_v_f32m4_u32m4(f32_f32m4);
658
+ vuint32m4_t sign_u32m4 = __riscv_vsrl_vx_u32m4(bits_u32m4, 31, vector_length);
659
+ vuint32m4_t abs_bits_u32m4 = __riscv_vand_vx_u32m4(bits_u32m4, 0x7FFFFFFF, vector_length);
660
+ vuint32m4_t f32_exp_u32m4 = __riscv_vand_vx_u32m4(__riscv_vsrl_vx_u32m4(bits_u32m4, 23, vector_length), 0xFF,
661
+ vector_length);
662
+
663
+ // Round mantissa from 23 to 2 bits using RNE
664
+ vuint32m4_t significand_u32m4 = __riscv_vor_vx_u32m4(__riscv_vand_vx_u32m4(bits_u32m4, 0x007FFFFF, vector_length),
665
+ 0x00800000, vector_length);
666
+ vuint32m4_t lsb_u32m4 = __riscv_vand_vx_u32m4(__riscv_vsrl_vx_u32m4(significand_u32m4, 21, vector_length), 1,
667
+ vector_length);
668
+ vuint32m4_t rounding_bias_u32m4 = __riscv_vadd_vx_u32m4(lsb_u32m4, 0x000FFFFF, vector_length);
669
+ vuint32m4_t rounded_sig_u32m4 = __riscv_vadd_vv_u32m4(significand_u32m4, rounding_bias_u32m4, vector_length);
670
+ vuint32m4_t carry_u32m4 = __riscv_vsrl_vx_u32m4(rounded_sig_u32m4, 24, vector_length);
671
+ vuint32m4_t f32_mantissa_u32m4 = __riscv_vand_vx_u32m4(__riscv_vsrl_vx_u32m4(rounded_sig_u32m4, 21, vector_length),
672
+ 0x03, vector_length);
673
+ vbool8_t has_carry_b8 = __riscv_vmsne_vx_u32m4_b8(carry_u32m4, 0, vector_length);
674
+ f32_mantissa_u32m4 = __riscv_vmerge_vxm_u32m4(f32_mantissa_u32m4, 0, has_carry_b8, vector_length);
675
+
676
+ // e5m2_exp = f32_exp + carry - 112
677
+ vint32m4_t e5m2_exp_i32m4 = __riscv_vsub_vx_i32m4(
678
+ __riscv_vreinterpret_v_u32m4_i32m4(__riscv_vadd_vv_u32m4(f32_exp_u32m4, carry_u32m4, vector_length)), 112,
679
+ vector_length);
680
+
681
+ // Detect subnormal (exp <= 0) and overflow (exp > 31)
682
+ vbool8_t is_subnormal_b8 = __riscv_vmsle_vx_i32m4_b8(e5m2_exp_i32m4, 0, vector_length);
683
+ vbool8_t is_overflow_b8 = __riscv_vmsgt_vx_i32m4_b8(e5m2_exp_i32m4, 31, vector_length);
684
+
685
+ // Normal path: clamp exp to [1,31], on overflow return infinity (exp=31, mant=0)
686
+ vint32m4_t clamped_exp_i32m4 = __riscv_vmax_vx_i32m4(e5m2_exp_i32m4, 1, vector_length);
687
+ clamped_exp_i32m4 = __riscv_vmin_vx_i32m4(clamped_exp_i32m4, 31, vector_length);
688
+ vuint32m4_t normal_mant_u32m4 = __riscv_vmerge_vxm_u32m4(f32_mantissa_u32m4, 0, is_overflow_b8, vector_length);
689
+ vuint32m4_t normal_u32m4 = __riscv_vor_vv_u32m4(
690
+ __riscv_vsll_vx_u32m4(sign_u32m4, 7, vector_length),
691
+ __riscv_vor_vv_u32m4(
692
+ __riscv_vsll_vx_u32m4(__riscv_vreinterpret_v_i32m4_u32m4(clamped_exp_i32m4), 2, vector_length),
693
+ normal_mant_u32m4, vector_length),
694
+ vector_length);
695
+
696
+ // Subnormal path: mantissa = round(|f32| * 65536)
697
+ vfloat32m4_t abs_f32m4 = __riscv_vreinterpret_v_u32m4_f32m4(abs_bits_u32m4);
698
+ vfloat32m4_t scaled_f32m4 = __riscv_vfmul_vf_f32m4(abs_f32m4, 65536.0f, vector_length);
699
+ vint32m4_t subnorm_mant_i32m4 = __riscv_vfcvt_x_f_v_i32m4(scaled_f32m4, vector_length);
700
+ vbool8_t promotes_b8 = __riscv_vmsgt_vx_i32m4_b8(subnorm_mant_i32m4, 3, vector_length);
701
+ subnorm_mant_i32m4 = __riscv_vmin_vx_i32m4(subnorm_mant_i32m4, 3, vector_length);
702
+ subnorm_mant_i32m4 = __riscv_vmax_vx_i32m4(subnorm_mant_i32m4, 0, vector_length);
703
+ vuint32m4_t subnorm_u32m4 = __riscv_vor_vv_u32m4(__riscv_vsll_vx_u32m4(sign_u32m4, 7, vector_length),
704
+ __riscv_vreinterpret_v_i32m4_u32m4(subnorm_mant_i32m4),
705
+ vector_length);
706
+ vuint32m4_t first_normal_u32m4 = __riscv_vor_vx_u32m4(__riscv_vsll_vx_u32m4(sign_u32m4, 7, vector_length), 0x04,
707
+ vector_length);
708
+ subnorm_u32m4 = __riscv_vmerge_vvm_u32m4(subnorm_u32m4, first_normal_u32m4, promotes_b8, vector_length);
709
+
710
+ // Select: subnormal when exp <= 0, else normal
711
+ vuint32m4_t result_u32m4 = __riscv_vmerge_vvm_u32m4(normal_u32m4, subnorm_u32m4, is_subnormal_b8, vector_length);
712
+
713
+ // Handle NaN: f32 NaN (abs_bits > 0x7F800000) → e5m2 NaN (sign | 0x7D)
714
+ vbool8_t is_nan_b8 = __riscv_vmsgtu_vx_u32m4_b8(abs_bits_u32m4, 0x7F800000, vector_length);
715
+ vuint32m4_t nan_u32m4 = __riscv_vor_vx_u32m4(__riscv_vsll_vx_u32m4(sign_u32m4, 7, vector_length), 0x7D,
716
+ vector_length);
717
+ result_u32m4 = __riscv_vmerge_vvm_u32m4(result_u32m4, nan_u32m4, is_nan_b8, vector_length);
718
+
719
+ // Narrow u32m4 → u16m2 → u8m1
720
+ vuint16m2_t result_u16m2 = __riscv_vncvt_x_x_w_u16m2(result_u32m4, vector_length);
721
+ return __riscv_vncvt_x_x_w_u8m1(result_u16m2, vector_length);
722
+ }
723
+
724
+ #pragma endregion - Register - to - Register Helpers
725
+
726
+ #pragma region - Unified Cast Dispatcher
727
+
728
+ NK_PUBLIC void nk_cast_rvv(void const *from, nk_dtype_t from_type, nk_size_t count, void *to, nk_dtype_t to_type) {
729
+ // bf16 → f32
730
+ if (from_type == nk_bf16_k && to_type == nk_f32_k) {
731
+ nk_bf16_t const *source = (nk_bf16_t const *)from;
732
+ nk_f32_t *destination = (nk_f32_t *)to;
733
+ for (nk_size_t vector_length; count > 0;
734
+ count -= vector_length, source += vector_length, destination += vector_length) {
735
+ vector_length = __riscv_vsetvl_e16m1(count);
736
+ vuint16m1_t bf16_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)source, vector_length);
737
+ vfloat32m2_t f32_f32m2 = nk_bf16m1_to_f32m2_rvv_(bf16_u16m1, vector_length);
738
+ __riscv_vse32_v_f32m2(destination, f32_f32m2, vector_length);
739
+ }
740
+ return;
741
+ }
742
+
743
+ // f32 → bf16
744
+ if (from_type == nk_f32_k && to_type == nk_bf16_k) {
745
+ nk_f32_t const *source = (nk_f32_t const *)from;
746
+ nk_bf16_t *destination = (nk_bf16_t *)to;
747
+ for (nk_size_t vector_length; count > 0;
748
+ count -= vector_length, source += vector_length, destination += vector_length) {
749
+ vector_length = __riscv_vsetvl_e32m2(count);
750
+ vfloat32m2_t f32_f32m2 = __riscv_vle32_v_f32m2(source, vector_length);
751
+ vuint16m1_t bf16_u16m1 = nk_f32m2_to_bf16m1_rvv_(f32_f32m2, vector_length);
752
+ __riscv_vse16_v_u16m1((nk_u16_t *)destination, bf16_u16m1, vector_length);
753
+ }
754
+ return;
755
+ }
756
+
757
+ // f16 → f32
758
+ if (from_type == nk_f16_k && to_type == nk_f32_k) {
759
+ nk_f16_t const *source = (nk_f16_t const *)from;
760
+ nk_f32_t *destination = (nk_f32_t *)to;
761
+ for (nk_size_t vector_length; count > 0;
762
+ count -= vector_length, source += vector_length, destination += vector_length) {
763
+ vector_length = __riscv_vsetvl_e16m1(count);
764
+ vuint16m1_t f16_u16m1 = __riscv_vle16_v_u16m1((nk_u16_t const *)source, vector_length);
765
+ vfloat32m2_t f32_f32m2 = nk_f16m1_to_f32m2_rvv_(f16_u16m1, vector_length);
766
+ __riscv_vse32_v_f32m2(destination, f32_f32m2, vector_length);
767
+ }
768
+ return;
769
+ }
770
+
771
+ // f32 → f16
772
+ if (from_type == nk_f32_k && to_type == nk_f16_k) {
773
+ nk_f32_t const *source = (nk_f32_t const *)from;
774
+ nk_f16_t *destination = (nk_f16_t *)to;
775
+ for (nk_size_t vector_length; count > 0;
776
+ count -= vector_length, source += vector_length, destination += vector_length) {
777
+ vector_length = __riscv_vsetvl_e32m2(count);
778
+ vfloat32m2_t f32_f32m2 = __riscv_vle32_v_f32m2(source, vector_length);
779
+ vuint16m1_t f16_u16m1 = nk_f32m2_to_f16m1_rvv_(f32_f32m2, vector_length);
780
+ __riscv_vse16_v_u16m1((nk_u16_t *)destination, f16_u16m1, vector_length);
781
+ }
782
+ return;
783
+ }
784
+
785
+ // e4m3 → f32
786
+ if (from_type == nk_e4m3_k && to_type == nk_f32_k) {
787
+ nk_e4m3_t const *source = (nk_e4m3_t const *)from;
788
+ nk_f32_t *destination = (nk_f32_t *)to;
789
+ for (nk_size_t vector_length; count > 0;
790
+ count -= vector_length, source += vector_length, destination += vector_length) {
791
+ vector_length = __riscv_vsetvl_e8m1(count);
792
+ vuint8m1_t e4m3_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
793
+ vfloat32m4_t f32_f32m4 = nk_e4m3m1_to_f32m4_rvv_(e4m3_u8m1, vector_length);
794
+ __riscv_vse32_v_f32m4(destination, f32_f32m4, vector_length);
795
+ }
796
+ return;
797
+ }
798
+
799
+ // e5m2 → f32
800
+ if (from_type == nk_e5m2_k && to_type == nk_f32_k) {
801
+ nk_e5m2_t const *source = (nk_e5m2_t const *)from;
802
+ nk_f32_t *destination = (nk_f32_t *)to;
803
+ for (nk_size_t vector_length; count > 0;
804
+ count -= vector_length, source += vector_length, destination += vector_length) {
805
+ vector_length = __riscv_vsetvl_e8m1(count);
806
+ vuint8m1_t e5m2_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
807
+ vfloat32m4_t f32_f32m4 = nk_e5m2m1_to_f32m4_rvv_(e5m2_u8m1, vector_length);
808
+ __riscv_vse32_v_f32m4(destination, f32_f32m4, vector_length);
809
+ }
810
+ return;
811
+ }
812
+
813
+ // e2m3 → f32
814
+ if (from_type == nk_e2m3_k && to_type == nk_f32_k) {
815
+ nk_e2m3_t const *source = (nk_e2m3_t const *)from;
816
+ nk_f32_t *destination = (nk_f32_t *)to;
817
+ for (nk_size_t vector_length; count > 0;
818
+ count -= vector_length, source += vector_length, destination += vector_length) {
819
+ vector_length = __riscv_vsetvl_e8m1(count);
820
+ vuint8m1_t e2m3_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
821
+ vfloat32m4_t f32_f32m4 = nk_e2m3m1_to_f32m4_rvv_(e2m3_u8m1, vector_length);
822
+ __riscv_vse32_v_f32m4(destination, f32_f32m4, vector_length);
823
+ }
824
+ return;
825
+ }
826
+
827
+ // e3m2 → f32
828
+ if (from_type == nk_e3m2_k && to_type == nk_f32_k) {
829
+ nk_e3m2_t const *source = (nk_e3m2_t const *)from;
830
+ nk_f32_t *destination = (nk_f32_t *)to;
831
+ for (nk_size_t vector_length; count > 0;
832
+ count -= vector_length, source += vector_length, destination += vector_length) {
833
+ vector_length = __riscv_vsetvl_e8m1(count);
834
+ vuint8m1_t e3m2_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
835
+ vfloat32m4_t f32_f32m4 = nk_e3m2m1_to_f32m4_rvv_(e3m2_u8m1, vector_length);
836
+ __riscv_vse32_v_f32m4(destination, f32_f32m4, vector_length);
837
+ }
838
+ return;
839
+ }
840
+
841
+ // e4m3 → bf16
842
+ if (from_type == nk_e4m3_k && to_type == nk_bf16_k) {
843
+ nk_e4m3_t const *source = (nk_e4m3_t const *)from;
844
+ nk_bf16_t *destination = (nk_bf16_t *)to;
845
+ for (nk_size_t vector_length; count > 0;
846
+ count -= vector_length, source += vector_length, destination += vector_length) {
847
+ vector_length = __riscv_vsetvl_e8m1(count);
848
+ vuint8m1_t e4m3_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
849
+ vuint16m2_t bf16_u16m2 = nk_e4m3m1_to_bf16m2_rvv_(e4m3_u8m1, vector_length);
850
+ __riscv_vse16_v_u16m2((nk_u16_t *)destination, bf16_u16m2, vector_length);
851
+ }
852
+ return;
853
+ }
854
+
855
+ // e5m2 → bf16
856
+ if (from_type == nk_e5m2_k && to_type == nk_bf16_k) {
857
+ nk_e5m2_t const *source = (nk_e5m2_t const *)from;
858
+ nk_bf16_t *destination = (nk_bf16_t *)to;
859
+ for (nk_size_t vector_length; count > 0;
860
+ count -= vector_length, source += vector_length, destination += vector_length) {
861
+ vector_length = __riscv_vsetvl_e8m1(count);
862
+ vuint8m1_t e5m2_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
863
+ vuint16m2_t bf16_u16m2 = nk_e5m2m1_to_bf16m2_rvv_(e5m2_u8m1, vector_length);
864
+ __riscv_vse16_v_u16m2((nk_u16_t *)destination, bf16_u16m2, vector_length);
865
+ }
866
+ return;
867
+ }
868
+
869
+ // e2m3 → bf16
870
+ if (from_type == nk_e2m3_k && to_type == nk_bf16_k) {
871
+ nk_e2m3_t const *source = (nk_e2m3_t const *)from;
872
+ nk_bf16_t *destination = (nk_bf16_t *)to;
873
+ for (nk_size_t vector_length; count > 0;
874
+ count -= vector_length, source += vector_length, destination += vector_length) {
875
+ vector_length = __riscv_vsetvl_e8m1(count);
876
+ vuint8m1_t e2m3_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
877
+ vuint16m2_t bf16_u16m2 = nk_e2m3m1_to_bf16m2_rvv_(e2m3_u8m1, vector_length);
878
+ __riscv_vse16_v_u16m2((nk_u16_t *)destination, bf16_u16m2, vector_length);
879
+ }
880
+ return;
881
+ }
882
+
883
+ // e3m2 → bf16
884
+ if (from_type == nk_e3m2_k && to_type == nk_bf16_k) {
885
+ nk_e3m2_t const *source = (nk_e3m2_t const *)from;
886
+ nk_bf16_t *destination = (nk_bf16_t *)to;
887
+ for (nk_size_t vector_length; count > 0;
888
+ count -= vector_length, source += vector_length, destination += vector_length) {
889
+ vector_length = __riscv_vsetvl_e8m1(count);
890
+ vuint8m1_t e3m2_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
891
+ vuint16m2_t bf16_u16m2 = nk_e3m2m1_to_bf16m2_rvv_(e3m2_u8m1, vector_length);
892
+ __riscv_vse16_v_u16m2((nk_u16_t *)destination, bf16_u16m2, vector_length);
893
+ }
894
+ return;
895
+ }
896
+
897
+ // e4m3 → f16
898
+ if (from_type == nk_e4m3_k && to_type == nk_f16_k) {
899
+ nk_e4m3_t const *source = (nk_e4m3_t const *)from;
900
+ nk_f16_t *destination = (nk_f16_t *)to;
901
+ for (nk_size_t vector_length; count > 0;
902
+ count -= vector_length, source += vector_length, destination += vector_length) {
903
+ vector_length = __riscv_vsetvl_e8m1(count);
904
+ vuint8m1_t e4m3_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
905
+ vuint16m2_t f16_u16m2 = nk_e4m3m1_to_f16m2_rvv_(e4m3_u8m1, vector_length);
906
+ __riscv_vse16_v_u16m2((nk_u16_t *)destination, f16_u16m2, vector_length);
907
+ }
908
+ return;
909
+ }
910
+
911
+ // e2m3 → f16
912
+ if (from_type == nk_e2m3_k && to_type == nk_f16_k) {
913
+ nk_e2m3_t const *source = (nk_e2m3_t const *)from;
914
+ nk_f16_t *destination = (nk_f16_t *)to;
915
+ for (nk_size_t vector_length; count > 0;
916
+ count -= vector_length, source += vector_length, destination += vector_length) {
917
+ vector_length = __riscv_vsetvl_e8m1(count);
918
+ vuint8m1_t e2m3_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
919
+ vuint16m2_t f16_u16m2 = nk_e2m3m1_to_f16m2_rvv_(e2m3_u8m1, vector_length);
920
+ __riscv_vse16_v_u16m2((nk_u16_t *)destination, f16_u16m2, vector_length);
921
+ }
922
+ return;
923
+ }
924
+
925
+ // e3m2 → f16
926
+ if (from_type == nk_e3m2_k && to_type == nk_f16_k) {
927
+ nk_e3m2_t const *source = (nk_e3m2_t const *)from;
928
+ nk_f16_t *destination = (nk_f16_t *)to;
929
+ for (nk_size_t vector_length; count > 0;
930
+ count -= vector_length, source += vector_length, destination += vector_length) {
931
+ vector_length = __riscv_vsetvl_e8m1(count);
932
+ vuint8m1_t e3m2_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
933
+ vuint16m2_t f16_u16m2 = nk_e3m2m1_to_f16m2_rvv_(e3m2_u8m1, vector_length);
934
+ __riscv_vse16_v_u16m2((nk_u16_t *)destination, f16_u16m2, vector_length);
935
+ }
936
+ return;
937
+ }
938
+
939
+ // i4 → i8
940
+ if (from_type == nk_i4_k && to_type == nk_i8_k) {
941
+ nk_i4x2_t const *source = (nk_i4x2_t const *)from;
942
+ nk_i8_t *destination = (nk_i8_t *)to;
943
+ nk_size_t n_bytes = count / 2;
944
+ for (nk_size_t vector_length; n_bytes > 0;
945
+ n_bytes -= vector_length, source += vector_length, destination += vector_length * 2) {
946
+ vector_length = __riscv_vsetvl_e8m1(n_bytes);
947
+ vuint8m1_t packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
948
+ vint8m1x2_t unpacked_i8m1x2 = nk_i4m1_to_i8m2_rvv_(packed_u8m1, vector_length);
949
+ __riscv_vsseg2e8_v_i8m1x2(destination, unpacked_i8m1x2, vector_length);
950
+ }
951
+ return;
952
+ }
953
+
954
+ // u4 → u8
955
+ if (from_type == nk_u4_k && to_type == nk_u8_k) {
956
+ nk_u4x2_t const *source = (nk_u4x2_t const *)from;
957
+ nk_u8_t *destination = (nk_u8_t *)to;
958
+ nk_size_t n_bytes = count / 2;
959
+ for (nk_size_t vector_length; n_bytes > 0;
960
+ n_bytes -= vector_length, source += vector_length, destination += vector_length * 2) {
961
+ vector_length = __riscv_vsetvl_e8m1(n_bytes);
962
+ vuint8m1_t packed_u8m1 = __riscv_vle8_v_u8m1((nk_u8_t const *)source, vector_length);
963
+ vuint8m1x2_t unpacked_u8m1x2 = nk_u4m1_to_u8m2_rvv_(packed_u8m1, vector_length);
964
+ __riscv_vsseg2e8_v_u8m1x2(destination, unpacked_u8m1x2, vector_length);
965
+ }
966
+ return;
967
+ }
968
+
969
+ // i8 → i4
970
+ if (from_type == nk_i8_k && to_type == nk_i4_k) {
971
+ nk_i8_t const *source = (nk_i8_t const *)from;
972
+ nk_i4x2_t *destination = (nk_i4x2_t *)to;
973
+ nk_size_t n_bytes = count / 2;
974
+ for (nk_size_t vector_length; n_bytes > 0;
975
+ n_bytes -= vector_length, source += vector_length * 2, destination += vector_length) {
976
+ vector_length = __riscv_vsetvl_e8m1(n_bytes);
977
+ vint8m1x2_t loaded_i8m1x2 = __riscv_vlseg2e8_v_i8m1x2(source, vector_length);
978
+ vint8m1_t hi_i8m1 = __riscv_vget_v_i8m1x2_i8m1(loaded_i8m1x2, 0);
979
+ vint8m1_t lo_i8m1 = __riscv_vget_v_i8m1x2_i8m1(loaded_i8m1x2, 1);
980
+ vuint8m1_t packed_u8m1 = nk_i8m2_to_i4m1_rvv_(hi_i8m1, lo_i8m1, vector_length);
981
+ __riscv_vse8_v_u8m1((nk_u8_t *)destination, packed_u8m1, vector_length);
982
+ }
983
+ return;
984
+ }
985
+
986
+ // u8 → u4
987
+ if (from_type == nk_u8_k && to_type == nk_u4_k) {
988
+ nk_u8_t const *source = (nk_u8_t const *)from;
989
+ nk_u4x2_t *destination = (nk_u4x2_t *)to;
990
+ nk_size_t n_bytes = count / 2;
991
+ for (nk_size_t vector_length; n_bytes > 0;
992
+ n_bytes -= vector_length, source += vector_length * 2, destination += vector_length) {
993
+ vector_length = __riscv_vsetvl_e8m1(n_bytes);
994
+ vuint8m1x2_t loaded_u8m1x2 = __riscv_vlseg2e8_v_u8m1x2(source, vector_length);
995
+ vuint8m1_t hi_u8m1 = __riscv_vget_v_u8m1x2_u8m1(loaded_u8m1x2, 0);
996
+ vuint8m1_t lo_u8m1 = __riscv_vget_v_u8m1x2_u8m1(loaded_u8m1x2, 1);
997
+ vuint8m1_t packed_u8m1 = nk_u8m2_to_u4m1_rvv_(hi_u8m1, lo_u8m1, vector_length);
998
+ __riscv_vse8_v_u8m1((nk_u8_t *)destination, packed_u8m1, vector_length);
999
+ }
1000
+ return;
1001
+ }
1002
+
1003
+ // Fallback to serial for unimplemented conversions
1004
+ nk_cast_serial(from, from_type, count, to, to_type);
1005
+ }
1006
+
1007
+ #pragma endregion - Unified Cast Dispatcher
1008
+
1009
+ #if defined(__cplusplus)
1010
+ } // extern "C"
1011
+ #endif
1012
+
1013
+ #if defined(__clang__)
1014
+ #pragma clang attribute pop
1015
+ #elif defined(__GNUC__)
1016
+ #pragma GCC pop_options
1017
+ #endif
1018
+
1019
+ #endif // NK_TARGET_RVV
1020
+ #endif // NK_TARGET_RISCV_
1021
+ #endif // NK_CAST_RVV_H