numkong 7.0.0 → 7.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +239 -122
  2. package/binding.gyp +25 -491
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -150,6 +150,20 @@ NK_PUBLIC void nk_f32_to_f16_sapphire(nk_f32_t const *src, nk_f16_t *dest);
150
150
  NK_PUBLIC void nk_cast_rvv(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type);
151
151
  #endif // NK_TARGET_RVV
152
152
 
153
+ #if NK_TARGET_POWERVSX
154
+ /** @copydoc nk_cast */
155
+ NK_PUBLIC void nk_cast_powervsx(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type);
156
+ /** @copydoc nk_f16_to_f32 */
157
+ NK_PUBLIC void nk_f16_to_f32_powervsx(nk_f16_t const *src, nk_f32_t *dest);
158
+ /** @copydoc nk_f32_to_f16 */
159
+ NK_PUBLIC void nk_f32_to_f16_powervsx(nk_f32_t const *src, nk_f16_t *dest);
160
+ #endif // NK_TARGET_POWERVSX
161
+
162
+ #if NK_TARGET_V128RELAXED
163
+ /** @copydoc nk_cast */
164
+ NK_PUBLIC void nk_cast_v128relaxed(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type);
165
+ #endif // NK_TARGET_V128RELAXED
166
+
153
167
  #if defined(__cplusplus)
154
168
  } // extern "C"
155
169
  #endif
@@ -161,6 +175,8 @@ NK_PUBLIC void nk_cast_rvv(void const *from, nk_dtype_t from_type, nk_size_t n,
161
175
  #include "numkong/cast/icelake.h"
162
176
  #include "numkong/cast/sapphire.h"
163
177
  #include "numkong/cast/rvv.h"
178
+ #include "numkong/cast/powervsx.h"
179
+ #include "numkong/cast/loongsonasx.h"
164
180
 
165
181
  #if defined(__cplusplus)
166
182
  extern "C" {
@@ -177,10 +193,14 @@ NK_PUBLIC void nk_cast(void const *from, nk_dtype_t from_type, nk_size_t n, void
177
193
  nk_cast_skylake(from, from_type, n, to, to_type);
178
194
  #elif NK_TARGET_HASWELL
179
195
  nk_cast_haswell(from, from_type, n, to, to_type);
196
+ #elif NK_TARGET_POWERVSX
197
+ nk_cast_powervsx(from, from_type, n, to, to_type);
180
198
  #elif NK_TARGET_RVV
181
199
  nk_cast_rvv(from, from_type, n, to, to_type);
182
200
  #elif NK_TARGET_NEON
183
201
  nk_cast_neon(from, from_type, n, to, to_type);
202
+ #elif NK_TARGET_V128RELAXED
203
+ nk_cast_v128relaxed(from, from_type, n, to, to_type);
184
204
  #else
185
205
  nk_cast_serial(from, from_type, n, to, to_type);
186
206
  #endif
@@ -191,6 +211,8 @@ NK_PUBLIC void nk_f16_to_f32(nk_f16_t const *src, nk_f32_t *dest) {
191
211
  nk_f16_to_f32_sapphire(src, dest);
192
212
  #elif NK_TARGET_HASWELL
193
213
  nk_f16_to_f32_haswell(src, dest);
214
+ #elif NK_TARGET_POWERVSX
215
+ nk_f16_to_f32_powervsx(src, dest);
194
216
  #elif NK_TARGET_NEON
195
217
  nk_f16_to_f32_neon(src, dest);
196
218
  #else
@@ -203,6 +225,8 @@ NK_PUBLIC void nk_f32_to_f16(nk_f32_t const *src, nk_f16_t *dest) {
203
225
  nk_f32_to_f16_sapphire(src, dest);
204
226
  #elif NK_TARGET_HASWELL
205
227
  nk_f32_to_f16_haswell(src, dest);
228
+ #elif NK_TARGET_POWERVSX
229
+ nk_f32_to_f16_powervsx(src, dest);
206
230
  #elif NK_TARGET_NEON
207
231
  nk_f32_to_f16_neon(src, dest);
208
232
  #else
@@ -0,0 +1,44 @@
1
+ /**
2
+ * @brief C++ wrappers for SIMD-accelerated type casting.
3
+ * @file include/numkong/cast.hpp
4
+ * @author Ash Vardanian
5
+ * @date March 20, 2026
6
+ */
7
+ #ifndef NK_CAST_HPP
8
+ #define NK_CAST_HPP
9
+
10
+ #include <cstddef> // `std::size_t`
11
+
12
+ #include "numkong/cast.h"
13
+
14
+ #include "numkong/types.hpp"
15
+ #include "numkong/vector.hpp"
16
+
17
+ namespace ashvardanian::numkong {
18
+
19
+ /**
20
+ * @brief Elementwise type-cast from one numeric type to another.
21
+ * @param[in] from Input array of `n` elements.
22
+ * @param[in] n Number of elements.
23
+ * @param[out] to Output array of `n` elements.
24
+ *
25
+ * @tparam from_type_ Source element type.
26
+ * @tparam to_type_ Destination element type.
27
+ * @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`.
28
+ */
29
+ template <numeric_dtype from_type_, numeric_dtype to_type_, allow_simd_t allow_simd_ = prefer_simd_k>
30
+ void cast(from_type_ const *from, std::size_t n, to_type_ *to) noexcept {
31
+ if constexpr (allow_simd_ == prefer_simd_k) nk_cast(from, from_type_::dtype(), n, to, to_type_::dtype());
32
+ else nk_cast_serial(from, from_type_::dtype(), n, to, to_type_::dtype());
33
+ }
34
+
35
+ /** @brief Elementwise type-cast between vector views. Sizes must match. */
36
+ template <numeric_dtype from_type_, numeric_dtype to_type_, allow_simd_t allow_simd_ = prefer_simd_k>
37
+ void cast(vector_view<from_type_> from, vector_span<to_type_> to) noexcept {
38
+ std::size_t n = from.size() < to.size() ? from.size() : to.size();
39
+ cast<from_type_, to_type_, allow_simd_>(from.data(), n, to.data());
40
+ }
41
+
42
+ } // namespace ashvardanian::numkong
43
+
44
+ #endif // NK_CAST_HPP
@@ -6,21 +6,21 @@ These operations are central to Gaussian process inference, metric learning, and
6
6
 
7
7
  The bilinear form for real vectors is:
8
8
 
9
- ```math
9
+ $$
10
10
  \text{bilinear}(a, b, C) = a^T C b = \sum_{i=0}^{n-1} \sum_{j=0}^{n-1} a_i \cdot c_{ij} \cdot b_j
11
- ```
11
+ $$
12
12
 
13
13
  The Mahalanobis distance is:
14
14
 
15
- ```math
15
+ $$
16
16
  \text{mahalanobis}(a, b, C) = \sqrt{(a - b)^T C (a - b)}
17
- ```
17
+ $$
18
18
 
19
19
  For complex vectors, the bilinear form uses the conjugate transpose:
20
20
 
21
- ```math
21
+ $$
22
22
  \text{bilinear}(a, b, C) = a^H C b = \sum_{i=0}^{n-1} \sum_{j=0}^{n-1} \bar{a_i} \cdot c_{ij} \cdot b_j
23
- ```
23
+ $$
24
24
 
25
25
  Reformulating as Python pseudocode:
26
26
 
@@ -72,8 +72,8 @@ This nested structure gives $O(n)$ cache-friendly sequential access to the $n \t
72
72
 
73
73
  `nk_bilinear_f32_smef64`, `nk_bilinear_f64_smef64`, `nk_bilinear_f32c_smef64`, `nk_bilinear_f64c_smef64`, `nk_mahalanobis_f32_smef64`, `nk_mahalanobis_f64_smef64` use the Scalable Matrix Extension to compute the bilinear form as an outer-product accumulation.
74
74
  Each `FMOPA` instruction performs a rank-1 update $a_i \cdot b^T$ into the SME ZA tile array, and the matrix $C$ is streamed row-by-row and multiplied into the accumulator.
75
- This is fundamentally different from the row-major dot approach — it reformulates $a^T C b$ as a matrix-multiply problem where SME's 2D tile registers can exploit the matrix engine's throughput.
76
- For dimensions that align to the tile size, this approach achieves near-peak throughput; dimensions that do not align fall back to NEON for cleanup of the residual elements.
75
+ This differs from the row-major dot approach — it reformulates $a^T C b$ as a matrix-multiply problem where SME's 2D tile registers use the matrix engine's throughput.
76
+ For dimensions that align to the tile size, this approach has high throughput; dimensions that do not align fall back to NEON for cleanup of the residual elements.
77
77
 
78
78
  ### Complex Bilinear Decomposition
79
79
 
@@ -201,23 +201,23 @@ Measured with Wasmtime v42 (Cranelift backend).
201
201
 
202
202
  #### WASM
203
203
 
204
- Measured with Wasmtime v42 (Cranelift backend).
204
+ Measured with Wasmtime v43 (Cranelift backend).
205
205
 
206
206
  | Kernel | 256² | 1024² | 4096² |
207
207
  | :------------------------- | -----------------------: | -----------------------: | -----------------------: |
208
208
  | __f64c__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
209
- | `nk_bilinear_f64c_serial` | ? gso/s, ? ulp | ? gso/s, ? ulp | ? gso/s, ? ulp |
209
+ | `nk_bilinear_f64c_serial` | 0.445 gso/s, ? ulp | 0.445 gso/s, ? ulp | 0.445 gso/s, ? ulp |
210
210
  | __f32c__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
211
- | `nk_bilinear_f32c_serial` | ? gso/s, ? ulp | ? gso/s, ? ulp | ? gso/s, ? ulp |
211
+ | `nk_bilinear_f32c_serial` | 2.83 gso/s, ? ulp | 2.83 gso/s, ? ulp | 2.84 gso/s, ? ulp |
212
212
  | __bf16c__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
213
- | `nk_bilinear_bf16c_serial` | ? gso/s, ? ulp | ? gso/s, ? ulp | ? gso/s, ? ulp |
213
+ | `nk_bilinear_bf16c_serial` | 3.05 gso/s, ? ulp | 3.02 gso/s, ? ulp | 3.03 gso/s, ? ulp |
214
214
  | __f16c__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
215
- | `nk_bilinear_f16c_serial` | ? gso/s, ? ulp | ? gso/s, ? ulp | ? gso/s, ? ulp |
215
+ | `nk_bilinear_f16c_serial` | 0.984 gso/s, ? ulp | 0.992 gso/s, ? ulp | 0.995 gso/s, ? ulp |
216
216
  | __f64__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
217
- | `nk_bilinear_f64_serial` | ? gso/s, ? ulp | ? gso/s, ? ulp | ? gso/s, ? ulp |
217
+ | `nk_bilinear_f64_serial` | 0.998 gso/s, ? ulp | 0.999 gso/s, ? ulp | 0.999 gso/s, ? ulp |
218
218
  | __f32__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
219
- | `nk_bilinear_f32_serial` | ? gso/s, ? ulp | ? gso/s, ? ulp | ? gso/s, ? ulp |
219
+ | `nk_bilinear_f32_serial` | 5.00 gso/s, ? ulp | 3.73 gso/s, ? ulp | 3.49 gso/s, ? ulp |
220
220
  | __bf16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
221
- | `nk_bilinear_bf16_serial` | ? gso/s, ? ulp | ? gso/s, ? ulp | ? gso/s, ? ulp |
221
+ | `nk_bilinear_bf16_serial` | 4.84 gso/s, ? ulp | 3.83 gso/s, ? ulp | 3.60 gso/s, ? ulp |
222
222
  | __f16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
223
- | `nk_bilinear_f16_serial` | ? gso/s, ? ulp | ? gso/s, ? ulp | ? gso/s, ? ulp |
223
+ | `nk_bilinear_f16_serial` | 1.90 gso/s, ? ulp | 1.75 gso/s, ? ulp | 1.93 gso/s, ? ulp |
@@ -11,13 +11,12 @@
11
11
  *
12
12
  * @section neon_curved_instructions Key NEON Instructions
13
13
  *
14
- * Intrinsic Instruction Latency Throughput
15
- * A76 M4+/V1+/Oryon
16
- * vfmaq_f64 FMLA (V.2D, V.2D, V.2D) 4cy 2/cy 4/cy
17
- * vcvt_f64_f32 FCVTL (V.2D, V.2S) 3cy 2/cy 2/cy
18
- * vaddvq_f64 FADDP (V.2D to scalar) 3cy 1/cy 1/cy
19
- * vld1_f32 LD1 ({Vt.2S}, [Xn]) 4cy 2/cy 2/cy
20
- * vld2_f32 LD2 ({Vt.2S, Vt2.2S}, [Xn]) 4cy 1/cy 1/cy
14
+ * Intrinsic Instruction A76 M5
15
+ * vfmaq_f64 FMLA (V.2D, V.2D, V.2D) 4cy @ 2p 3cy @ 4p
16
+ * vcvt_f64_f32 FCVTL (V.2D, V.2S) 3cy @ 2p 3cy @ 4p
17
+ * vaddvq_f64 FADDP (V.2D to scalar) 3cy @ 1p 3cy @ 2p
18
+ * vld1_f32 LD1 ({Vt.2S}, [Xn]) 4cy @ 2p 4cy @ 3p
19
+ * vld2_f32 LD2 ({Vt.2S, Vt2.2S}, [Xn]) 4cy @ 1p 4cy @ 1p
21
20
  *
22
21
  * For f32 bilinear and Mahalanobis, we upcast to f64 for accumulation to preserve
23
22
  * precision and avoid catastrophic cancellation in large-magnitude sums.
@@ -190,6 +189,131 @@ NK_PUBLIC void nk_bilinear_f32c_neon(nk_f32c_t const *a_pairs, nk_f32c_t const *
190
189
  results->imag = outer_sum_imag_f64;
191
190
  }
192
191
 
192
+ NK_PUBLIC void nk_bilinear_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
193
+ nk_f32_t *result) {
194
+ nk_f32_t outer_sum = 0;
195
+ for (nk_size_t row = 0; row != n; ++row) {
196
+ nk_f16_t const *c_row = c + row * n;
197
+ nk_f32_t a_row;
198
+ nk_f16_to_f32_serial(a + row, &a_row);
199
+ float32x4_t inner_sum_f32x4 = vdupq_n_f32(0);
200
+ nk_size_t column = 0;
201
+ for (; column + 8 <= n; column += 8) {
202
+ float16x8_t b_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)(b + column)));
203
+ float16x8_t c_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)(c_row + column)));
204
+ float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
205
+ float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
206
+ float32x4_t c_low_f32x4 = vcvt_f32_f16(vget_low_f16(c_f16x8));
207
+ float32x4_t c_high_f32x4 = vcvt_high_f32_f16(c_f16x8);
208
+ inner_sum_f32x4 = vfmaq_f32(inner_sum_f32x4, c_low_f32x4, b_low_f32x4);
209
+ inner_sum_f32x4 = vfmaq_f32(inner_sum_f32x4, c_high_f32x4, b_high_f32x4);
210
+ }
211
+ nk_f32_t inner_sum = vaddvq_f32(inner_sum_f32x4);
212
+ for (; column < n; ++column) {
213
+ nk_f32_t b_val, c_val;
214
+ nk_f16_to_f32_serial(b + column, &b_val);
215
+ nk_f16_to_f32_serial(c_row + column, &c_val);
216
+ inner_sum += c_val * b_val;
217
+ }
218
+ outer_sum += a_row * inner_sum;
219
+ }
220
+ *result = outer_sum;
221
+ }
222
+
223
+ NK_PUBLIC void nk_mahalanobis_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
224
+ nk_f32_t *result) {
225
+ nk_f32_t outer_sum = 0;
226
+ for (nk_size_t row = 0; row != n; ++row) {
227
+ nk_f16_t const *c_row = c + row * n;
228
+ nk_f32_t a_row, b_row;
229
+ nk_f16_to_f32_serial(a + row, &a_row);
230
+ nk_f16_to_f32_serial(b + row, &b_row);
231
+ nk_f32_t diff_row = a_row - b_row;
232
+ float32x4_t inner_sum_f32x4 = vdupq_n_f32(0);
233
+ nk_size_t column = 0;
234
+ for (; column + 8 <= n; column += 8) {
235
+ float16x8_t a_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)(a + column)));
236
+ float16x8_t b_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)(b + column)));
237
+ float16x8_t c_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)(c_row + column)));
238
+ float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
239
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
240
+ float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
241
+ float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
242
+ float32x4_t c_low_f32x4 = vcvt_f32_f16(vget_low_f16(c_f16x8));
243
+ float32x4_t c_high_f32x4 = vcvt_high_f32_f16(c_f16x8);
244
+ float32x4_t diff_low_f32x4 = vsubq_f32(a_low_f32x4, b_low_f32x4);
245
+ float32x4_t diff_high_f32x4 = vsubq_f32(a_high_f32x4, b_high_f32x4);
246
+ inner_sum_f32x4 = vfmaq_f32(inner_sum_f32x4, c_low_f32x4, diff_low_f32x4);
247
+ inner_sum_f32x4 = vfmaq_f32(inner_sum_f32x4, c_high_f32x4, diff_high_f32x4);
248
+ }
249
+ nk_f32_t inner_sum = vaddvq_f32(inner_sum_f32x4);
250
+ for (; column < n; ++column) {
251
+ nk_f32_t a_val, b_val, c_val;
252
+ nk_f16_to_f32_serial(a + column, &a_val);
253
+ nk_f16_to_f32_serial(b + column, &b_val);
254
+ nk_f16_to_f32_serial(c_row + column, &c_val);
255
+ inner_sum += c_val * (a_val - b_val);
256
+ }
257
+ outer_sum += diff_row * inner_sum;
258
+ }
259
+ nk_f32_t quadratic = outer_sum;
260
+ *result = nk_f32_sqrt_neon(quadratic > 0 ? quadratic : 0);
261
+ }
262
+
263
+ NK_PUBLIC void nk_bilinear_f16c_neon(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_f16c_t const *c_pairs,
264
+ nk_size_t n, nk_f32c_t *results) {
265
+ nk_f32_t outer_sum_real = 0;
266
+ nk_f32_t outer_sum_imag = 0;
267
+ for (nk_size_t row = 0; row != n; ++row) {
268
+ nk_f16c_t const *c_row = c_pairs + row * n;
269
+ nk_f32_t a_real, a_imag;
270
+ nk_f16_to_f32_serial(&(a_pairs + row)->real, &a_real);
271
+ nk_f16_to_f32_serial(&(a_pairs + row)->imag, &a_imag);
272
+ float32x4_t inner_sum_real_f32x4 = vdupq_n_f32(0);
273
+ float32x4_t inner_sum_imag_f32x4 = vdupq_n_f32(0);
274
+ nk_size_t column = 0;
275
+ for (; column + 8 <= n; column += 8) {
276
+ int16x8x2_t b_i16x8x2 = vld2q_s16((short const *)(b_pairs + column));
277
+ int16x8x2_t c_i16x8x2 = vld2q_s16((short const *)(c_row + column));
278
+ float16x8_t b_real_f16x8 = vreinterpretq_f16_s16(b_i16x8x2.val[0]);
279
+ float16x8_t b_imag_f16x8 = vreinterpretq_f16_s16(b_i16x8x2.val[1]);
280
+ float16x8_t c_real_f16x8 = vreinterpretq_f16_s16(c_i16x8x2.val[0]);
281
+ float16x8_t c_imag_f16x8 = vreinterpretq_f16_s16(c_i16x8x2.val[1]);
282
+ float32x4_t b_real_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_real_f16x8));
283
+ float32x4_t b_real_high_f32x4 = vcvt_high_f32_f16(b_real_f16x8);
284
+ float32x4_t b_imag_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_imag_f16x8));
285
+ float32x4_t b_imag_high_f32x4 = vcvt_high_f32_f16(b_imag_f16x8);
286
+ float32x4_t c_real_low_f32x4 = vcvt_f32_f16(vget_low_f16(c_real_f16x8));
287
+ float32x4_t c_real_high_f32x4 = vcvt_high_f32_f16(c_real_f16x8);
288
+ float32x4_t c_imag_low_f32x4 = vcvt_f32_f16(vget_low_f16(c_imag_f16x8));
289
+ float32x4_t c_imag_high_f32x4 = vcvt_high_f32_f16(c_imag_f16x8);
290
+ inner_sum_real_f32x4 = vfmaq_f32(inner_sum_real_f32x4, c_real_low_f32x4, b_real_low_f32x4);
291
+ inner_sum_real_f32x4 = vfmsq_f32(inner_sum_real_f32x4, c_imag_low_f32x4, b_imag_low_f32x4);
292
+ inner_sum_real_f32x4 = vfmaq_f32(inner_sum_real_f32x4, c_real_high_f32x4, b_real_high_f32x4);
293
+ inner_sum_real_f32x4 = vfmsq_f32(inner_sum_real_f32x4, c_imag_high_f32x4, b_imag_high_f32x4);
294
+ inner_sum_imag_f32x4 = vfmaq_f32(inner_sum_imag_f32x4, c_real_low_f32x4, b_imag_low_f32x4);
295
+ inner_sum_imag_f32x4 = vfmaq_f32(inner_sum_imag_f32x4, c_imag_low_f32x4, b_real_low_f32x4);
296
+ inner_sum_imag_f32x4 = vfmaq_f32(inner_sum_imag_f32x4, c_real_high_f32x4, b_imag_high_f32x4);
297
+ inner_sum_imag_f32x4 = vfmaq_f32(inner_sum_imag_f32x4, c_imag_high_f32x4, b_real_high_f32x4);
298
+ }
299
+ nk_f32_t inner_sum_real = vaddvq_f32(inner_sum_real_f32x4);
300
+ nk_f32_t inner_sum_imag = vaddvq_f32(inner_sum_imag_f32x4);
301
+ for (; column < n; ++column) {
302
+ nk_f32_t b_real, b_imag, c_real, c_imag;
303
+ nk_f16_to_f32_serial(&(b_pairs + column)->real, &b_real);
304
+ nk_f16_to_f32_serial(&(b_pairs + column)->imag, &b_imag);
305
+ nk_f16_to_f32_serial(&(c_row + column)->real, &c_real);
306
+ nk_f16_to_f32_serial(&(c_row + column)->imag, &c_imag);
307
+ inner_sum_real += c_real * b_real - c_imag * b_imag;
308
+ inner_sum_imag += c_real * b_imag + c_imag * b_real;
309
+ }
310
+ outer_sum_real += a_real * inner_sum_real - a_imag * inner_sum_imag;
311
+ outer_sum_imag += a_real * inner_sum_imag + a_imag * inner_sum_real;
312
+ }
313
+ results->real = outer_sum_real;
314
+ results->imag = outer_sum_imag;
315
+ }
316
+
193
317
  #if defined(__clang__)
194
318
  #pragma clang attribute pop
195
319
  #elif defined(__GNUC__)
@@ -10,13 +10,12 @@
10
10
  *
11
11
  * @section curved_neonbfdot_instructions ARM NEON BF16 Instructions (ARMv8.6-BF16)
12
12
  *
13
- * Intrinsic Instruction Latency Throughput
14
- * A76 M4+/V1+/Oryon
15
- * vbfdotq_f32 BFDOT (V.4S, V.8H, V.8H) 3cy 2/cy 4/cy
16
- * vcvt_f32_bf16 BFCVTN (V.4H, V.4S) 3cy 2/cy 4/cy
17
- * vld1q_bf16 LD1 (V.8H) 4cy 2/cy 3/cy
18
- * vaddvq_f32 FADDP+FADDP (V.4S) 4cy 1/cy 2/cy
19
- * vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy 2/cy 4/cy
13
+ * Intrinsic Instruction A76 M5
14
+ * vbfdotq_f32 BFDOT (V.4S, V.8H, V.8H) 3cy @ 2p 2cy @ 1p
15
+ * vcvt_f32_bf16 BFCVTN (V.4H, V.4S) 3cy @ 2p 3cy @ 4p
16
+ * vld1q_bf16 LD1 (V.8H) 4cy @ 2p 4cy @ 3p
17
+ * vaddvq_f32 FADDP+FADDP (V.4S) 5cy @ 1p 8cy @ 1p
18
+ * vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy @ 2p 3cy @ 4p
20
19
  *
21
20
  * For bilinear forms, BFDOT enables efficient inner-product computation by processing 8 bf16
22
21
  * pairs into 4 f32 results per instruction. For Mahalanobis distance, bf16 inputs are converted
@@ -36,10 +36,10 @@ extern "C" {
36
36
 
37
37
  NK_PUBLIC void nk_bilinear_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
38
38
  nk_f64_t *result) {
39
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
39
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
40
40
  nk_f64_t outer_sum = 0;
41
41
  for (nk_size_t i = 0; i < n; ++i) {
42
- vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
42
+ vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
43
43
  nk_f32_t const *c_row = c + i * n;
44
44
  nk_size_t remaining = n;
45
45
  for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
@@ -50,7 +50,7 @@ NK_PUBLIC void nk_bilinear_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_f32_
50
50
  }
51
51
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
52
52
  nk_f64_t inner_val = __riscv_vfmv_f_s_f64m1_f64(
53
- __riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1, vlmax));
53
+ __riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1, max_vector_length));
54
54
  outer_sum += (nk_f64_t)a[i] * inner_val;
55
55
  }
56
56
  *result = outer_sum;
@@ -58,12 +58,12 @@ NK_PUBLIC void nk_bilinear_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_f32_
58
58
 
59
59
  NK_PUBLIC void nk_bilinear_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
60
60
  nk_f64_t *result) {
61
- nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
61
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
62
62
  vfloat64m1_t sum_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
63
63
  nk_f64_t outer_compensation = 0;
64
64
  for (nk_size_t i = 0; i < n; ++i) {
65
- vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
66
- vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
65
+ vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
66
+ vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
67
67
  nk_f64_t const *c_row = c + i * n;
68
68
  nk_size_t remaining = n;
69
69
  for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
@@ -82,7 +82,7 @@ NK_PUBLIC void nk_bilinear_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_f64_
82
82
  }
83
83
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
84
84
  nk_f64_t inner_val = __riscv_vfmv_f_s_f64m1_f64(
85
- __riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1, vlmax));
85
+ __riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1, max_vector_length));
86
86
  nk_f64_t product_outer = a[i] * inner_val;
87
87
  nk_f64_t old_sum = __riscv_vfmv_f_s_f64m1_f64(sum_f64m1);
88
88
  nk_f64_t new_sum = old_sum + product_outer;
@@ -96,14 +96,14 @@ NK_PUBLIC void nk_bilinear_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_f64_
96
96
 
97
97
  NK_PUBLIC void nk_bilinear_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
98
98
  nk_f32_t *result) {
99
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
99
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
100
100
  vfloat32m1_t sum_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
101
101
  for (nk_size_t i = 0; i < n; ++i) {
102
102
  // Convert a[i] from f16 to f32
103
103
  nk_f32_t a_i;
104
104
  nk_f16_to_f32_serial(a + i, &a_i);
105
105
 
106
- vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
106
+ vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
107
107
  nk_f16_t const *c_row = c + i * n;
108
108
  nk_size_t remaining = n;
109
109
  for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
@@ -117,7 +117,7 @@ NK_PUBLIC void nk_bilinear_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f16_
117
117
  }
118
118
  vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
119
119
  nk_f32_t inner_val = __riscv_vfmv_f_s_f32m1_f32(
120
- __riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1, vlmax));
120
+ __riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1, max_vector_length));
121
121
  sum_f32m1 = __riscv_vfmv_v_f_f32m1(__riscv_vfmv_f_s_f32m1_f32(sum_f32m1) + a_i * inner_val, 1);
122
122
  }
123
123
  *result = __riscv_vfmv_f_s_f32m1_f32(sum_f32m1);
@@ -125,14 +125,14 @@ NK_PUBLIC void nk_bilinear_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f16_
125
125
 
126
126
  NK_PUBLIC void nk_bilinear_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
127
127
  nk_f32_t *result) {
128
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
128
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
129
129
  vfloat32m1_t sum_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
130
130
  for (nk_size_t i = 0; i < n; ++i) {
131
131
  // Convert a[i] from bf16 to f32
132
132
  nk_f32_t a_i;
133
133
  nk_bf16_to_f32_serial(a + i, &a_i);
134
134
 
135
- vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
135
+ vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
136
136
  nk_bf16_t const *c_row = c + i * n;
137
137
  nk_size_t remaining = n;
138
138
  for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
@@ -146,7 +146,7 @@ NK_PUBLIC void nk_bilinear_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_b
146
146
  }
147
147
  vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
148
148
  nk_f32_t inner_val = __riscv_vfmv_f_s_f32m1_f32(
149
- __riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1, vlmax));
149
+ __riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1, max_vector_length));
150
150
  sum_f32m1 = __riscv_vfmv_v_f_f32m1(__riscv_vfmv_f_s_f32m1_f32(sum_f32m1) + a_i * inner_val, 1);
151
151
  }
152
152
  *result = __riscv_vfmv_f_s_f32m1_f32(sum_f32m1);
@@ -154,11 +154,11 @@ NK_PUBLIC void nk_bilinear_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_b
154
154
 
155
155
  NK_PUBLIC void nk_mahalanobis_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
156
156
  nk_f64_t *result) {
157
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
157
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
158
158
  nk_f64_t outer_sum = 0;
159
159
  for (nk_size_t i = 0; i < n; ++i) {
160
160
  nk_f64_t diff_i = (nk_f64_t)a[i] - (nk_f64_t)b[i];
161
- vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
161
+ vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
162
162
  nk_f32_t const *c_row = c + i * n;
163
163
  nk_size_t remaining = n;
164
164
  for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
@@ -173,7 +173,7 @@ NK_PUBLIC void nk_mahalanobis_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_f
173
173
  }
174
174
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
175
175
  nk_f64_t inner_val = __riscv_vfmv_f_s_f64m1_f64(
176
- __riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1, vlmax));
176
+ __riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1, max_vector_length));
177
177
  outer_sum += diff_i * inner_val;
178
178
  }
179
179
  *result = nk_f64_sqrt_rvv(outer_sum > 0 ? outer_sum : 0);
@@ -181,13 +181,13 @@ NK_PUBLIC void nk_mahalanobis_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_f
181
181
 
182
182
  NK_PUBLIC void nk_mahalanobis_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
183
183
  nk_f64_t *result) {
184
- nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
184
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
185
185
  vfloat64m1_t sum_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
186
186
  nk_f64_t outer_compensation = 0;
187
187
  for (nk_size_t i = 0; i < n; ++i) {
188
188
  nk_f64_t diff_i = a[i] - b[i];
189
- vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
190
- vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
189
+ vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
190
+ vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
191
191
  nk_f64_t const *c_row = c + i * n;
192
192
  nk_size_t remaining = n;
193
193
  for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
@@ -209,7 +209,7 @@ NK_PUBLIC void nk_mahalanobis_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_f
209
209
  }
210
210
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
211
211
  nk_f64_t inner_val = __riscv_vfmv_f_s_f64m1_f64(
212
- __riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1, vlmax));
212
+ __riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1, max_vector_length));
213
213
  nk_f64_t product_outer = diff_i * inner_val;
214
214
  nk_f64_t old_sum = __riscv_vfmv_f_s_f64m1_f64(sum_f64m1);
215
215
  nk_f64_t new_sum = old_sum + product_outer;
@@ -224,7 +224,7 @@ NK_PUBLIC void nk_mahalanobis_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_f
224
224
 
225
225
  NK_PUBLIC void nk_mahalanobis_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
226
226
  nk_f32_t *result) {
227
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
227
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
228
228
  vfloat32m1_t sum_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
229
229
  for (nk_size_t i = 0; i < n; ++i) {
230
230
  nk_f32_t a_i, b_i;
@@ -232,7 +232,7 @@ NK_PUBLIC void nk_mahalanobis_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f
232
232
  nk_f16_to_f32_serial(b + i, &b_i);
233
233
  nk_f32_t diff_i = a_i - b_i;
234
234
 
235
- vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
235
+ vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
236
236
  nk_f16_t const *c_row = c + i * n;
237
237
  nk_size_t remaining = n;
238
238
  for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
@@ -249,7 +249,7 @@ NK_PUBLIC void nk_mahalanobis_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f
249
249
  }
250
250
  vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
251
251
  nk_f32_t inner_val = __riscv_vfmv_f_s_f32m1_f32(
252
- __riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1, vlmax));
252
+ __riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1, max_vector_length));
253
253
  sum_f32m1 = __riscv_vfmv_v_f_f32m1(__riscv_vfmv_f_s_f32m1_f32(sum_f32m1) + diff_i * inner_val, 1);
254
254
  }
255
255
  nk_f32_t quadratic_f16 = __riscv_vfmv_f_s_f32m1_f32(sum_f32m1);
@@ -258,7 +258,7 @@ NK_PUBLIC void nk_mahalanobis_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f
258
258
 
259
259
  NK_PUBLIC void nk_mahalanobis_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
260
260
  nk_f32_t *result) {
261
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
261
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
262
262
  vfloat32m1_t sum_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
263
263
  for (nk_size_t i = 0; i < n; ++i) {
264
264
  nk_f32_t a_i, b_i;
@@ -266,7 +266,7 @@ NK_PUBLIC void nk_mahalanobis_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, n
266
266
  nk_bf16_to_f32_serial(b + i, &b_i);
267
267
  nk_f32_t diff_i = a_i - b_i;
268
268
 
269
- vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
269
+ vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
270
270
  nk_bf16_t const *c_row = c + i * n;
271
271
  nk_size_t remaining = n;
272
272
  for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
@@ -283,7 +283,7 @@ NK_PUBLIC void nk_mahalanobis_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, n
283
283
  }
284
284
  vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
285
285
  nk_f32_t inner_val = __riscv_vfmv_f_s_f32m1_f32(
286
- __riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1, vlmax));
286
+ __riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1, max_vector_length));
287
287
  sum_f32m1 = __riscv_vfmv_v_f_f32m1(__riscv_vfmv_f_s_f32m1_f32(sum_f32m1) + diff_i * inner_val, 1);
288
288
  }
289
289
  nk_f32_t quadratic_bf16 = __riscv_vfmv_f_s_f32m1_f32(sum_f32m1);