numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -145,7 +145,7 @@ NK_INTERNAL vfloat64m4_t nk_log2_f64m4_rvv_(vfloat64m4_t x, nk_size_t vector_len
145
145
  return __riscv_vfadd_vv_f64m4(exp_f, log2_m, vector_length);
146
146
  }
147
147
 
148
- #pragma region - Kullback-Leibler Divergence
148
+ #pragma region Kullback Leibler Divergence
149
149
 
150
150
  NK_PUBLIC void nk_kld_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
151
151
  nk_size_t vector_length_max = __riscv_vsetvlmax_e64m4();
@@ -172,8 +172,8 @@ NK_PUBLIC void nk_kld_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n,
172
172
  }
173
173
 
174
174
  NK_PUBLIC void nk_kld_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
175
- nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
176
- vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
175
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
176
+ vfloat64m4_t sum_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
177
177
  for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, b += vector_length) {
178
178
  vector_length = __riscv_vsetvl_e64m4(n);
179
179
  vfloat64m4_t a_f64m4 = __riscv_vle64_v_f64m4(a, vector_length);
@@ -192,13 +192,13 @@ NK_PUBLIC void nk_kld_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n,
192
192
  // Single horizontal reduction after loop
193
193
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
194
194
  // Convert from log2 to ln by multiplying by ln(2)
195
- *result = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1, vlmax)) *
195
+ *result = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(sum_f64m4, zero_f64m1, max_vector_length)) *
196
196
  0.6931471805599453;
197
197
  }
198
198
 
199
199
  NK_PUBLIC void nk_kld_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
200
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
201
- vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
200
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
201
+ vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
202
202
  for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, b += vector_length) {
203
203
  vector_length = __riscv_vsetvl_e16m1(n);
204
204
  // Load f16 as raw u16 bits
@@ -220,12 +220,13 @@ NK_PUBLIC void nk_kld_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
220
220
  }
221
221
  // Single horizontal reduction after loop
222
222
  vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
223
- *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax)) * 0.693147181f;
223
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, max_vector_length)) *
224
+ 0.693147181f;
224
225
  }
225
226
 
226
227
  NK_PUBLIC void nk_kld_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
227
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
228
- vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
228
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
229
+ vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
229
230
  for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, b += vector_length) {
230
231
  vector_length = __riscv_vsetvl_e16m1(n);
231
232
  // Load bf16 as raw u16 bits
@@ -247,12 +248,13 @@ NK_PUBLIC void nk_kld_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t
247
248
  }
248
249
  // Single horizontal reduction after loop
249
250
  vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
250
- *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax)) * 0.693147181f;
251
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, max_vector_length)) *
252
+ 0.693147181f;
251
253
  }
252
254
 
253
- #pragma endregion - Kullback - Leibler Divergence
255
+ #pragma endregion Kullback Leibler Divergence
254
256
 
255
- #pragma region - Jensen-Shannon Divergence
257
+ #pragma region Jensen Shannon Divergence
256
258
 
257
259
  NK_PUBLIC void nk_jsd_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
258
260
  nk_size_t vector_length_max = __riscv_vsetvlmax_e64m4();
@@ -288,9 +290,9 @@ NK_PUBLIC void nk_jsd_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n,
288
290
  }
289
291
 
290
292
  NK_PUBLIC void nk_jsd_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
291
- nk_size_t vlmax = __riscv_vsetvlmax_e64m4();
292
- vfloat64m4_t sum_a_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
293
- vfloat64m4_t sum_b_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, vlmax);
293
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
294
+ vfloat64m4_t sum_a_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
295
+ vfloat64m4_t sum_b_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
294
296
  for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, b += vector_length) {
295
297
  vector_length = __riscv_vsetvl_e64m4(n);
296
298
  vfloat64m4_t va = __riscv_vle64_v_f64m4(a, vector_length);
@@ -315,14 +317,15 @@ NK_PUBLIC void nk_jsd_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n,
315
317
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
316
318
  // JSD = sqrt((sum_a + sum_b) * ln(2) / 2)
317
319
  nk_f64_t sum = __riscv_vfmv_f_s_f64m1_f64(__riscv_vfredusum_vs_f64m4_f64m1(
318
- __riscv_vfadd_vv_f64m4(sum_a_f64m4, sum_b_f64m4, vlmax), zero_f64m1, vlmax)) *
320
+ __riscv_vfadd_vv_f64m4(sum_a_f64m4, sum_b_f64m4, max_vector_length), zero_f64m1,
321
+ max_vector_length)) *
319
322
  0.6931471805599453 / 2;
320
323
  *result = sum > 0 ? nk_f64_sqrt_rvv(sum) : 0;
321
324
  }
322
325
 
323
326
  NK_PUBLIC void nk_jsd_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
324
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
325
- vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
327
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
328
+ vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
326
329
  for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, b += vector_length) {
327
330
  vector_length = __riscv_vsetvl_e16m1(n);
328
331
  // Load f16 as raw u16 bits
@@ -351,14 +354,15 @@ NK_PUBLIC void nk_jsd_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
351
354
  }
352
355
  // Single horizontal reduction after loop
353
356
  vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
354
- nk_f32_t sum = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax)) *
357
+ nk_f32_t sum = __riscv_vfmv_f_s_f32m1_f32(
358
+ __riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, max_vector_length)) *
355
359
  0.693147181f / 2;
356
360
  *result = sum > 0 ? nk_f32_sqrt_rvv(sum) : 0;
357
361
  }
358
362
 
359
363
  NK_PUBLIC void nk_jsd_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
360
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
361
- vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
364
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
365
+ vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
362
366
  for (nk_size_t vector_length; n > 0; n -= vector_length, a += vector_length, b += vector_length) {
363
367
  vector_length = __riscv_vsetvl_e16m1(n);
364
368
  // Load bf16 as raw u16 bits
@@ -387,12 +391,13 @@ NK_PUBLIC void nk_jsd_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t
387
391
  }
388
392
  // Single horizontal reduction after loop
389
393
  vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
390
- nk_f32_t sum = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax)) *
394
+ nk_f32_t sum = __riscv_vfmv_f_s_f32m1_f32(
395
+ __riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, max_vector_length)) *
391
396
  0.693147181f / 2;
392
397
  *result = sum > 0 ? nk_f32_sqrt_rvv(sum) : 0;
393
398
  }
394
399
 
395
- #pragma endregion - Jensen - Shannon Divergence
400
+ #pragma endregion Jensen Shannon Divergence
396
401
 
397
402
  #if defined(__cplusplus)
398
403
  } // extern "C"
@@ -17,32 +17,35 @@
17
17
  extern "C" {
18
18
  #endif
19
19
 
20
- #define nk_define_kld_(input_type, accumulator_type, output_type, load_and_convert, epsilon, compute_log) \
20
+ #define nk_define_kld_(input_type, unpacked_type, accumulator_type, output_type, load_and_convert, epsilon, \
21
+ compute_log) \
21
22
  NK_PUBLIC void nk_kld_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
22
23
  nk_size_t n, output_type *result) { \
23
- nk_##accumulator_type##_t d = 0, ai, bi; \
24
+ nk_##accumulator_type##_t sum = 0; \
25
+ nk_##unpacked_type##_t a_value, b_value; \
24
26
  for (nk_size_t i = 0; i != n; ++i) { \
25
- load_and_convert(a + i, &ai); \
26
- load_and_convert(b + i, &bi); \
27
- d += ai * compute_log((ai + epsilon) / (bi + epsilon)); \
27
+ load_and_convert(a + i, &a_value); \
28
+ load_and_convert(b + i, &b_value); \
29
+ sum += a_value * compute_log((a_value + epsilon) / (b_value + epsilon)); \
28
30
  } \
29
- *result = (output_type)d; \
31
+ *result = (output_type)sum; \
30
32
  }
31
33
 
32
- #define nk_define_jsd_(input_type, accumulator_type, output_type, load_and_convert, epsilon, compute_log, \
33
- compute_sqrt) \
34
+ #define nk_define_jsd_(input_type, unpacked_type, accumulator_type, output_type, load_and_convert, epsilon, \
35
+ compute_log, compute_sqrt) \
34
36
  NK_PUBLIC void nk_jsd_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
35
37
  nk_size_t n, output_type *result) { \
36
- nk_##accumulator_type##_t d = 0, ai, bi; \
38
+ nk_##accumulator_type##_t sum = 0; \
39
+ nk_##unpacked_type##_t a_value, b_value; \
37
40
  for (nk_size_t i = 0; i != n; ++i) { \
38
- load_and_convert(a + i, &ai); \
39
- load_and_convert(b + i, &bi); \
40
- nk_##accumulator_type##_t mi = (ai + bi) / 2; \
41
- d += ai * compute_log((ai + epsilon) / (mi + epsilon)); \
42
- d += bi * compute_log((bi + epsilon) / (mi + epsilon)); \
41
+ load_and_convert(a + i, &a_value); \
42
+ load_and_convert(b + i, &b_value); \
43
+ nk_##unpacked_type##_t midpoint_value = (a_value + b_value) / 2; \
44
+ sum += a_value * compute_log((a_value + epsilon) / (midpoint_value + epsilon)); \
45
+ sum += b_value * compute_log((b_value + epsilon) / (midpoint_value + epsilon)); \
43
46
  } \
44
- output_type d_half = ((output_type)d / 2); \
45
- *result = d_half > 0 ? compute_sqrt(d_half) : 0; \
47
+ output_type sum_half = ((output_type)sum / 2); \
48
+ *result = sum_half > 0 ? compute_sqrt(sum_half) : 0; \
46
49
  }
47
50
 
48
51
  /**
@@ -121,45 +124,54 @@ NK_INTERNAL nk_f64_t nk_f64_log_serial_(nk_f64_t x) {
121
124
  return (nk_f64_t)exp * 0.6931471805599453 + 2.0 * u * poly;
122
125
  }
123
126
 
124
- nk_define_kld_(f32, f64, nk_f64_t, nk_assign_from_to_, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_)
125
- nk_define_jsd_(f32, f64, nk_f64_t, nk_assign_from_to_, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_, nk_f64_sqrt_serial)
127
+ nk_define_kld_(f32, f32, f64, nk_f64_t, nk_assign_from_to_, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_)
128
+ nk_define_jsd_(f32, f32, f64, nk_f64_t, nk_assign_from_to_, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_,
129
+ nk_f64_sqrt_serial)
126
130
 
127
- nk_define_kld_(f16, f32, nk_f32_t, nk_f16_to_f32_serial, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_)
128
- nk_define_jsd_(f16, f32, nk_f32_t, nk_f16_to_f32_serial, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_,
131
+ nk_define_kld_(f16, f32, f32, nk_f32_t, nk_f16_to_f32_serial, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_)
132
+ nk_define_jsd_(f16, f32, f32, nk_f32_t, nk_f16_to_f32_serial, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_,
129
133
  nk_f32_sqrt_serial)
130
134
 
131
- nk_define_kld_(bf16, f32, nk_f32_t, nk_bf16_to_f32_serial, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_)
132
- nk_define_jsd_(bf16, f32, nk_f32_t, nk_bf16_to_f32_serial, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_,
135
+ nk_define_kld_(bf16, f32, f32, nk_f32_t, nk_bf16_to_f32_serial, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_)
136
+ nk_define_jsd_(bf16, f32, f32, nk_f32_t, nk_bf16_to_f32_serial, NK_F32_DIVISION_EPSILON, nk_f32_log_serial_,
133
137
  nk_f32_sqrt_serial)
134
138
 
135
139
  NK_PUBLIC void nk_kld_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
140
+ // Use Kahan summation for higher numerical stability in long distributions
136
141
  nk_f64_t sum = 0, compensation = 0;
137
142
  for (nk_size_t i = 0; i != n; ++i) {
138
- nk_f64_t ai = a[i], bi = b[i];
139
- nk_f64_t term = ai * nk_f64_log_serial_((ai + NK_F64_DIVISION_EPSILON) / (bi + NK_F64_DIVISION_EPSILON));
140
- nk_f64_t t = sum + term;
141
- compensation += (nk_f64_abs_(sum) >= nk_f64_abs_(term)) ? ((sum - t) + term) : ((term - t) + sum);
142
- sum = t;
143
+ nk_f64_t a_value = a[i], b_value = b[i];
144
+ nk_f64_t term = a_value *
145
+ nk_f64_log_serial_((a_value + NK_F64_DIVISION_EPSILON) / (b_value + NK_F64_DIVISION_EPSILON));
146
+ nk_f64_t provisional_sum = sum + term;
147
+ compensation += (nk_f64_abs_(sum) >= nk_f64_abs_(term)) ? ((sum - provisional_sum) + term)
148
+ : ((term - provisional_sum) + sum);
149
+ sum = provisional_sum;
143
150
  }
144
151
  *result = sum + compensation;
145
152
  }
146
153
 
147
154
  NK_PUBLIC void nk_jsd_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
155
+ // Use Kahan summation for higher numerical stability in long distributions
148
156
  nk_f64_t sum = 0, compensation = 0;
149
157
  for (nk_size_t i = 0; i != n; ++i) {
150
- nk_f64_t ai = a[i], bi = b[i];
151
- nk_f64_t mi = (ai + bi) / 2;
152
- nk_f64_t term_a = ai * nk_f64_log_serial_((ai + NK_F64_DIVISION_EPSILON) / (mi + NK_F64_DIVISION_EPSILON));
153
- nk_f64_t t = sum + term_a;
154
- compensation += (nk_f64_abs_(sum) >= nk_f64_abs_(term_a)) ? ((sum - t) + term_a) : ((term_a - t) + sum);
155
- sum = t;
156
- nk_f64_t term_b = bi * nk_f64_log_serial_((bi + NK_F64_DIVISION_EPSILON) / (mi + NK_F64_DIVISION_EPSILON));
157
- t = sum + term_b;
158
- compensation += (nk_f64_abs_(sum) >= nk_f64_abs_(term_b)) ? ((sum - t) + term_b) : ((term_b - t) + sum);
159
- sum = t;
158
+ nk_f64_t a_value = a[i], b_value = b[i];
159
+ nk_f64_t mi = (a_value + b_value) / 2;
160
+ nk_f64_t term_a = a_value *
161
+ nk_f64_log_serial_((a_value + NK_F64_DIVISION_EPSILON) / (mi + NK_F64_DIVISION_EPSILON));
162
+ nk_f64_t provisional_sum = sum + term_a;
163
+ compensation += (nk_f64_abs_(sum) >= nk_f64_abs_(term_a)) ? ((sum - provisional_sum) + term_a)
164
+ : ((term_a - provisional_sum) + sum);
165
+ sum = provisional_sum;
166
+ nk_f64_t term_b = b_value *
167
+ nk_f64_log_serial_((b_value + NK_F64_DIVISION_EPSILON) / (mi + NK_F64_DIVISION_EPSILON));
168
+ provisional_sum = sum + term_b;
169
+ compensation += (nk_f64_abs_(sum) >= nk_f64_abs_(term_b)) ? ((sum - provisional_sum) + term_b)
170
+ : ((term_b - provisional_sum) + sum);
171
+ sum = provisional_sum;
160
172
  }
161
- nk_f64_t d_half = (sum + compensation) / 2;
162
- *result = d_half > 0 ? nk_f64_sqrt_serial(d_half) : 0;
173
+ nk_f64_t sum_half = (sum + compensation) / 2;
174
+ *result = sum_half > 0 ? nk_f64_sqrt_serial(sum_half) : 0;
163
175
  }
164
176
 
165
177
  #if defined(__cplusplus)
@@ -38,14 +38,14 @@
38
38
  * calls. Division (for p/q ratio) uses either VDIVPS directly or VRCP14PS with Newton-Raphson
39
39
  * refinement when ~14-bit precision suffices. Genoa's VGETEXP/VGETMANT are 25% faster than Ice.
40
40
  *
41
- * Intrinsic Instruction Ice Genoa
42
- * _mm512_getexp_ps VGETEXPPS (ZMM, ZMM) 4c @ p0 3c @ p23
43
- * _mm512_getexp_pd VGETEXPPD (ZMM, ZMM) 4c @ p0 3c @ p23
44
- * _mm512_getmant_ps VGETMANTPS (ZMM, ZMM, I8) 4c @ p0 3c @ p23
45
- * _mm512_getmant_pd VGETMANTPD (ZMM, ZMM, I8) 4c @ p0 3c @ p23
46
- * _mm512_rcp14_ps VRCP14PS (ZMM, ZMM) 7c @ p05 5c @ p01
47
- * _mm512_div_ps VDIVPS (ZMM, ZMM, ZMM) 17c @ p05 11c @ p01
48
- * _mm512_fmadd_ps VFMADD231PS (ZMM, ZMM, ZMM) 4c @ p0 4c @ p01
41
+ * Intrinsic Instruction Icelake Genoa
42
+ * _mm512_getexp_ps VGETEXPPS (ZMM, ZMM) 4cy @ p0 3cy @ p23
43
+ * _mm512_getexp_pd VGETEXPPD (ZMM, ZMM) 4cy @ p0 3cy @ p23
44
+ * _mm512_getmant_ps VGETMANTPS (ZMM, ZMM, I8) 4cy @ p0 3cy @ p23
45
+ * _mm512_getmant_pd VGETMANTPD (ZMM, ZMM, I8) 4cy @ p0 3cy @ p23
46
+ * _mm512_rcp14_ps VRCP14PS (ZMM, ZMM) 7cy @ p0+p0+p05 5cy @ p01
47
+ * _mm512_div_ps VDIVPS (ZMM, ZMM, ZMM) 17cy @ p0+p0+p05 11cy @ p01
48
+ * _mm512_fmadd_ps VFMADD231PS (ZMM, ZMM, ZMM) 4cy @ p0 4cy @ p01
49
49
  *
50
50
  * @section arm_instructions Relevant ARM NEON/SVE Instructions
51
51
  *
@@ -53,14 +53,14 @@
53
53
  * float bits followed by polynomial refinement. FRECPE provides ~8-bit reciprocal approximation
54
54
  * for division, refined with FRECPS Newton-Raphson steps to ~22-bit precision.
55
55
  *
56
- * Intrinsic Instruction M1 Firestorm Graviton 3 Graviton 4
57
- * vfmaq_f32 FMLA.S (vec) 4c @ V0123 4c @ V0123 4c @ V0123
58
- * vrecpeq_f32 FRECPE.S 3c @ V02 3c @ V02 3c @ V02
59
- * vrecpsq_f32 FRECPS.S 4c @ V0123 4c @ V0123 4c @ V0123
56
+ * Intrinsic Instruction M1 Firestorm Graviton 3 Graviton 4
57
+ * vfmaq_f32 FMLA.S (vec) 4cy @ V0123 4cy @ V0123 4cy @ V0123
58
+ * vrecpeq_f32 FRECPE.S 3cy @ V02 3cy @ V02 3cy @ V02
59
+ * vrecpsq_f32 FRECPS.S 4cy @ V0123 4cy @ V0123 4cy @ V0123
60
60
  *
61
61
  * @section references References
62
62
  *
63
- * - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
63
+ * - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html
64
64
  * - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
65
65
  *
66
66
  */
@@ -201,14 +201,11 @@ NK_PUBLIC void nk_jsd_bf16_serial(nk_bf16_t const *a, nk_bf16_t const *b, nk_siz
201
201
  NK_PUBLIC void nk_kld_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
202
202
  /** @copydoc nk_jsd_f32 */
203
203
  NK_PUBLIC void nk_jsd_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
204
- #endif // NK_TARGET_NEON
205
-
206
- #if NK_TARGET_NEONHALF
207
204
  /** @copydoc nk_kld_f16 */
208
- NK_PUBLIC void nk_kld_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
205
+ NK_PUBLIC void nk_kld_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
209
206
  /** @copydoc nk_jsd_f16 */
210
- NK_PUBLIC void nk_jsd_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
211
- #endif // NK_TARGET_NEONHALF
207
+ NK_PUBLIC void nk_jsd_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
208
+ #endif // NK_TARGET_NEON
212
209
 
213
210
  #if NK_TARGET_HASWELL
214
211
  /** @copydoc nk_kld_f64 */
@@ -283,8 +280,8 @@ extern "C" {
283
280
  #if !NK_DYNAMIC_DISPATCH
284
281
 
285
282
  NK_PUBLIC void nk_kld_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
286
- #if NK_TARGET_NEONHALF
287
- nk_kld_f16_neonhalf(a, b, n, result);
283
+ #if NK_TARGET_NEON
284
+ nk_kld_f16_neon(a, b, n, result);
288
285
  #elif NK_TARGET_SKYLAKE
289
286
  nk_kld_f16_skylake(a, b, n, result);
290
287
  #elif NK_TARGET_HASWELL
@@ -329,8 +326,8 @@ NK_PUBLIC void nk_kld_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_
329
326
  }
330
327
 
331
328
  NK_PUBLIC void nk_jsd_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
332
- #if NK_TARGET_NEONHALF
333
- nk_jsd_f16_neonhalf(a, b, n, result);
329
+ #if NK_TARGET_NEON
330
+ nk_jsd_f16_neon(a, b, n, result);
334
331
  #elif NK_TARGET_SKYLAKE
335
332
  nk_jsd_f16_skylake(a, b, n, result);
336
333
  #elif NK_TARGET_HASWELL
@@ -29,7 +29,7 @@
29
29
  *
30
30
  * @section references References
31
31
  *
32
- * - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
32
+ * - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html
33
33
  * - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
34
34
  *
35
35
  */