numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -0,0 +1,483 @@
1
+ /**
2
+ * @brief SIMD-accelerated Spatial Similarity Measures for LoongArch LASX (256-bit).
3
+ * @file include/numkong/spatial/loongsonasx.h
4
+ * @author Ash Vardanian
5
+ * @date March 23, 2026
6
+ *
7
+ * @sa include/numkong/spatial.h
8
+ *
9
+ * @section spatial_loongsonasx_instructions Key LASX Spatial Instructions
10
+ *
11
+ * LASX provides 256-bit SIMD operations using __m256i as the universal vector type.
12
+ * All intrinsics are prefixed with __lasx_. Float operations reinterpret __m256i as
13
+ * f32x8 or f64x4. Integer widening multiply-accumulate chains handle i8/u8 distances.
14
+ *
15
+ * For F32 spatial distances, upcasting to F64 and downcasting back is faster than stable
16
+ * summation algorithms. For F64 angular we use the Dot2 algorithm (Ogita-Rump-Oishi, 2005)
17
+ * for the cross-product accumulation, while self-products use simple FMA since all terms
18
+ * are non-negative and don't suffer from cancellation.
19
+ */
20
+ #ifndef NK_SPATIAL_LOONGSONASX_H
21
+ #define NK_SPATIAL_LOONGSONASX_H
22
+
23
+ #if NK_TARGET_LOONGARCH_
24
+ #if NK_TARGET_LOONGSONASX
25
+
26
+ #include "numkong/types.h"
27
+ #include "numkong/spatial/serial.h"
28
+ #include "numkong/dot/loongsonasx.h" //
29
+ #include "numkong/cast/loongsonasx.h" // `nk_bf16x8_to_f32x8_loongsonasx_`
30
+ #include "numkong/scalar/loongsonasx.h" // `nk_f32_sqrt_loongsonasx`, `nk_f64_sqrt_loongsonasx`
31
+
32
+ #if defined(__cplusplus)
33
+ extern "C" {
34
+ #endif
35
+
36
+ #pragma region Angular Normalize Helpers
37
+
38
+ NK_INTERNAL nk_f64_t nk_angular_normalize_f64_loongsonasx_(nk_f64_t ab, nk_f64_t a2, nk_f64_t b2) {
39
+ if (a2 == 0 && b2 == 0) return 0;
40
+ else if (ab == 0) return 1;
41
+ nk_f64_t result = 1 - ab / (nk_f64_sqrt_loongsonasx(a2) * nk_f64_sqrt_loongsonasx(b2));
42
+ return result > 0 ? result : 0;
43
+ }
44
+
45
+ NK_INTERNAL nk_f32_t nk_angular_normalize_i32_loongsonasx_(nk_i32_t ab, nk_i32_t a2, nk_i32_t b2) {
46
+ if (a2 == 0 && b2 == 0) return 0;
47
+ else if (ab == 0) return 1;
48
+ nk_f32_t result = 1.0f -
49
+ (nk_f32_t)ab * nk_f32_rsqrt_loongsonasx((nk_f32_t)a2) * nk_f32_rsqrt_loongsonasx((nk_f32_t)b2);
50
+ return result > 0 ? result : 0;
51
+ }
52
+
53
+ #pragma endregion Angular Normalize Helpers
54
+
55
+ #pragma region I8 and U8 Integers
56
+
57
+ NK_PUBLIC void nk_sqeuclidean_i8_loongsonasx(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result) {
58
+ __m256i sum_i32x8 = __lasx_xvreplgr2vr_w(0);
59
+ nk_size_t i = 0;
60
+ for (; i + 32 <= n; i += 32) {
61
+ __m256i a_i8x32 = __lasx_xvld(a + i, 0);
62
+ __m256i b_i8x32 = __lasx_xvld(b + i, 0);
63
+ __m256i diff_i8x32 = __lasx_xvsub_b(a_i8x32, b_i8x32);
64
+ __m256i sq_i16x16 = __lasx_xvreplgr2vr_h(0);
65
+ sq_i16x16 = __lasx_xvmaddwev_h_b(sq_i16x16, diff_i8x32, diff_i8x32);
66
+ sq_i16x16 = __lasx_xvmaddwod_h_b(sq_i16x16, diff_i8x32, diff_i8x32);
67
+ sum_i32x8 = __lasx_xvadd_w(sum_i32x8, __lasx_xvhaddw_w_h(sq_i16x16, sq_i16x16));
68
+ }
69
+ nk_i32_t sum = nk_reduce_add_i32x8_loongsonasx_(sum_i32x8);
70
+ for (; i < n; ++i) {
71
+ nk_i32_t diff = (nk_i32_t)a[i] - b[i];
72
+ sum += diff * diff;
73
+ }
74
+ *result = (nk_u32_t)sum;
75
+ }
76
+
77
+ NK_PUBLIC void nk_euclidean_i8_loongsonasx(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
78
+ nk_u32_t distance_sq_u32;
79
+ nk_sqeuclidean_i8_loongsonasx(a, b, n, &distance_sq_u32);
80
+ *result = nk_f32_sqrt_loongsonasx((nk_f32_t)distance_sq_u32);
81
+ }
82
+
83
+ NK_PUBLIC void nk_angular_i8_loongsonasx(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
84
+ __m256i dot_i32x8 = __lasx_xvreplgr2vr_w(0);
85
+ __m256i a_sq_i32x8 = __lasx_xvreplgr2vr_w(0);
86
+ __m256i b_sq_i32x8 = __lasx_xvreplgr2vr_w(0);
87
+ nk_size_t i = 0;
88
+ for (; i + 32 <= n; i += 32) {
89
+ __m256i a_i8x32 = __lasx_xvld(a + i, 0);
90
+ __m256i b_i8x32 = __lasx_xvld(b + i, 0);
91
+ // dot(a, b)
92
+ __m256i ab_i16x16 = __lasx_xvreplgr2vr_h(0);
93
+ ab_i16x16 = __lasx_xvmaddwev_h_b(ab_i16x16, a_i8x32, b_i8x32);
94
+ ab_i16x16 = __lasx_xvmaddwod_h_b(ab_i16x16, a_i8x32, b_i8x32);
95
+ dot_i32x8 = __lasx_xvadd_w(dot_i32x8, __lasx_xvhaddw_w_h(ab_i16x16, ab_i16x16));
96
+ // norm_sq(a)
97
+ __m256i aa_i16x16 = __lasx_xvreplgr2vr_h(0);
98
+ aa_i16x16 = __lasx_xvmaddwev_h_b(aa_i16x16, a_i8x32, a_i8x32);
99
+ aa_i16x16 = __lasx_xvmaddwod_h_b(aa_i16x16, a_i8x32, a_i8x32);
100
+ a_sq_i32x8 = __lasx_xvadd_w(a_sq_i32x8, __lasx_xvhaddw_w_h(aa_i16x16, aa_i16x16));
101
+ // norm_sq(b)
102
+ __m256i bb_i16x16 = __lasx_xvreplgr2vr_h(0);
103
+ bb_i16x16 = __lasx_xvmaddwev_h_b(bb_i16x16, b_i8x32, b_i8x32);
104
+ bb_i16x16 = __lasx_xvmaddwod_h_b(bb_i16x16, b_i8x32, b_i8x32);
105
+ b_sq_i32x8 = __lasx_xvadd_w(b_sq_i32x8, __lasx_xvhaddw_w_h(bb_i16x16, bb_i16x16));
106
+ }
107
+ nk_i32_t dot = nk_reduce_add_i32x8_loongsonasx_(dot_i32x8);
108
+ nk_i32_t a_sq = nk_reduce_add_i32x8_loongsonasx_(a_sq_i32x8);
109
+ nk_i32_t b_sq = nk_reduce_add_i32x8_loongsonasx_(b_sq_i32x8);
110
+ for (; i < n; ++i) {
111
+ nk_i32_t a_val = a[i], b_val = b[i];
112
+ dot += a_val * b_val;
113
+ a_sq += a_val * a_val;
114
+ b_sq += b_val * b_val;
115
+ }
116
+ *result = nk_angular_normalize_i32_loongsonasx_(dot, a_sq, b_sq);
117
+ }
118
+
119
+ NK_PUBLIC void nk_sqeuclidean_u8_loongsonasx(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
120
+ __m256i sum_i32x8 = __lasx_xvreplgr2vr_w(0);
121
+ __m256i zeros_i8x32 = __lasx_xvreplgr2vr_b(0);
122
+ nk_size_t i = 0;
123
+ for (; i + 32 <= n; i += 32) {
124
+ __m256i a_u8x32 = __lasx_xvld(a + i, 0);
125
+ __m256i b_u8x32 = __lasx_xvld(b + i, 0);
126
+ __m256i a_low_u16x16 = __lasx_xvilvl_b(zeros_i8x32, a_u8x32);
127
+ __m256i a_high_u16x16 = __lasx_xvilvh_b(zeros_i8x32, a_u8x32);
128
+ __m256i b_low_u16x16 = __lasx_xvilvl_b(zeros_i8x32, b_u8x32);
129
+ __m256i b_high_u16x16 = __lasx_xvilvh_b(zeros_i8x32, b_u8x32);
130
+ __m256i diff_low_i16x16 = __lasx_xvsub_h(a_low_u16x16, b_low_u16x16);
131
+ __m256i diff_high_i16x16 = __lasx_xvsub_h(a_high_u16x16, b_high_u16x16);
132
+ __m256i sq_ev_low_i32x8 = __lasx_xvmulwev_w_h(diff_low_i16x16, diff_low_i16x16);
133
+ __m256i sq_od_low_i32x8 = __lasx_xvmulwod_w_h(diff_low_i16x16, diff_low_i16x16);
134
+ __m256i sq_ev_high_i32x8 = __lasx_xvmulwev_w_h(diff_high_i16x16, diff_high_i16x16);
135
+ __m256i sq_od_high_i32x8 = __lasx_xvmulwod_w_h(diff_high_i16x16, diff_high_i16x16);
136
+ sum_i32x8 = __lasx_xvadd_w(sum_i32x8, __lasx_xvadd_w(sq_ev_low_i32x8, sq_od_low_i32x8));
137
+ sum_i32x8 = __lasx_xvadd_w(sum_i32x8, __lasx_xvadd_w(sq_ev_high_i32x8, sq_od_high_i32x8));
138
+ }
139
+ nk_i32_t sum = nk_reduce_add_i32x8_loongsonasx_(sum_i32x8);
140
+ for (; i < n; ++i) {
141
+ nk_i32_t diff = (nk_i32_t)a[i] - b[i];
142
+ sum += diff * diff;
143
+ }
144
+ *result = (nk_u32_t)sum;
145
+ }
146
+
147
+ NK_PUBLIC void nk_euclidean_u8_loongsonasx(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
148
+ nk_u32_t distance_sq_u32;
149
+ nk_sqeuclidean_u8_loongsonasx(a, b, n, &distance_sq_u32);
150
+ *result = nk_f32_sqrt_loongsonasx((nk_f32_t)distance_sq_u32);
151
+ }
152
+
153
+ NK_PUBLIC void nk_angular_u8_loongsonasx(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
154
+ __m256i dot_i32x8 = __lasx_xvreplgr2vr_w(0);
155
+ __m256i a_sq_i32x8 = __lasx_xvreplgr2vr_w(0);
156
+ __m256i b_sq_i32x8 = __lasx_xvreplgr2vr_w(0);
157
+ __m256i zeros_i8x32 = __lasx_xvreplgr2vr_b(0);
158
+ nk_size_t i = 0;
159
+ for (; i + 32 <= n; i += 32) {
160
+ __m256i a_u8x32 = __lasx_xvld(a + i, 0);
161
+ __m256i b_u8x32 = __lasx_xvld(b + i, 0);
162
+ __m256i a_low_u16x16 = __lasx_xvilvl_b(zeros_i8x32, a_u8x32);
163
+ __m256i a_high_u16x16 = __lasx_xvilvh_b(zeros_i8x32, a_u8x32);
164
+ __m256i b_low_u16x16 = __lasx_xvilvl_b(zeros_i8x32, b_u8x32);
165
+ __m256i b_high_u16x16 = __lasx_xvilvh_b(zeros_i8x32, b_u8x32);
166
+ // dot(a, b)
167
+ __m256i ab_ev_low_i32x8 = __lasx_xvmulwev_w_h(a_low_u16x16, b_low_u16x16);
168
+ __m256i ab_od_low_i32x8 = __lasx_xvmulwod_w_h(a_low_u16x16, b_low_u16x16);
169
+ __m256i ab_ev_high_i32x8 = __lasx_xvmulwev_w_h(a_high_u16x16, b_high_u16x16);
170
+ __m256i ab_od_high_i32x8 = __lasx_xvmulwod_w_h(a_high_u16x16, b_high_u16x16);
171
+ dot_i32x8 = __lasx_xvadd_w(dot_i32x8, __lasx_xvadd_w(ab_ev_low_i32x8, ab_od_low_i32x8));
172
+ dot_i32x8 = __lasx_xvadd_w(dot_i32x8, __lasx_xvadd_w(ab_ev_high_i32x8, ab_od_high_i32x8));
173
+ // norm_sq(a)
174
+ __m256i aa_ev_low_i32x8 = __lasx_xvmulwev_w_h(a_low_u16x16, a_low_u16x16);
175
+ __m256i aa_od_low_i32x8 = __lasx_xvmulwod_w_h(a_low_u16x16, a_low_u16x16);
176
+ __m256i aa_ev_high_i32x8 = __lasx_xvmulwev_w_h(a_high_u16x16, a_high_u16x16);
177
+ __m256i aa_od_high_i32x8 = __lasx_xvmulwod_w_h(a_high_u16x16, a_high_u16x16);
178
+ a_sq_i32x8 = __lasx_xvadd_w(a_sq_i32x8, __lasx_xvadd_w(aa_ev_low_i32x8, aa_od_low_i32x8));
179
+ a_sq_i32x8 = __lasx_xvadd_w(a_sq_i32x8, __lasx_xvadd_w(aa_ev_high_i32x8, aa_od_high_i32x8));
180
+ // norm_sq(b)
181
+ __m256i bb_ev_low_i32x8 = __lasx_xvmulwev_w_h(b_low_u16x16, b_low_u16x16);
182
+ __m256i bb_od_low_i32x8 = __lasx_xvmulwod_w_h(b_low_u16x16, b_low_u16x16);
183
+ __m256i bb_ev_high_i32x8 = __lasx_xvmulwev_w_h(b_high_u16x16, b_high_u16x16);
184
+ __m256i bb_od_high_i32x8 = __lasx_xvmulwod_w_h(b_high_u16x16, b_high_u16x16);
185
+ b_sq_i32x8 = __lasx_xvadd_w(b_sq_i32x8, __lasx_xvadd_w(bb_ev_low_i32x8, bb_od_low_i32x8));
186
+ b_sq_i32x8 = __lasx_xvadd_w(b_sq_i32x8, __lasx_xvadd_w(bb_ev_high_i32x8, bb_od_high_i32x8));
187
+ }
188
+ nk_i32_t dot = nk_reduce_add_i32x8_loongsonasx_(dot_i32x8);
189
+ nk_i32_t a_sq = nk_reduce_add_i32x8_loongsonasx_(a_sq_i32x8);
190
+ nk_i32_t b_sq = nk_reduce_add_i32x8_loongsonasx_(b_sq_i32x8);
191
+ for (; i < n; ++i) {
192
+ nk_i32_t a_val = a[i], b_val = b[i];
193
+ dot += a_val * b_val;
194
+ a_sq += a_val * a_val;
195
+ b_sq += b_val * b_val;
196
+ }
197
+ *result = nk_angular_normalize_i32_loongsonasx_(dot, a_sq, b_sq);
198
+ }
199
+
200
+ #pragma endregion I8 and U8 Integers
201
+
202
+ #pragma region F32 and F64 Floats
203
+
204
+ NK_PUBLIC void nk_sqeuclidean_f32_loongsonasx(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
205
+ __m256d sum_f64x4_low = (__m256d)__lasx_xvreplgr2vr_d(0);
206
+ __m256d sum_f64x4_high = (__m256d)__lasx_xvreplgr2vr_d(0);
207
+ nk_size_t i = 0;
208
+ for (; i + 8 <= n; i += 8) {
209
+ __m256i a_f32x8 = __lasx_xvld(a + i, 0);
210
+ __m256i b_f32x8 = __lasx_xvld(b + i, 0);
211
+ __m256d a_low_f64x4 = __lasx_xvfcvtl_d_s((__m256)a_f32x8);
212
+ __m256d b_low_f64x4 = __lasx_xvfcvtl_d_s((__m256)b_f32x8);
213
+ __m256d a_high_f64x4 = __lasx_xvfcvth_d_s((__m256)a_f32x8);
214
+ __m256d b_high_f64x4 = __lasx_xvfcvth_d_s((__m256)b_f32x8);
215
+ __m256d diff_low_f64x4 = __lasx_xvfsub_d(a_low_f64x4, b_low_f64x4);
216
+ __m256d diff_high_f64x4 = __lasx_xvfsub_d(a_high_f64x4, b_high_f64x4);
217
+ sum_f64x4_low = __lasx_xvfmadd_d(diff_low_f64x4, diff_low_f64x4, sum_f64x4_low);
218
+ sum_f64x4_high = __lasx_xvfmadd_d(diff_high_f64x4, diff_high_f64x4, sum_f64x4_high);
219
+ }
220
+ __m256d combined_f64x4 = __lasx_xvfadd_d(sum_f64x4_low, sum_f64x4_high);
221
+ nk_f64_t sum = nk_reduce_add_f64x4_loongsonasx_(combined_f64x4);
222
+ for (; i < n; ++i) {
223
+ nk_f64_t diff = (nk_f64_t)a[i] - b[i];
224
+ sum += diff * diff;
225
+ }
226
+ *result = sum;
227
+ }
228
+
229
+ NK_PUBLIC void nk_euclidean_f32_loongsonasx(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
230
+ nk_sqeuclidean_f32_loongsonasx(a, b, n, result);
231
+ *result = nk_f64_sqrt_loongsonasx(*result);
232
+ }
233
+
234
+ NK_PUBLIC void nk_angular_f32_loongsonasx(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
235
+ __m256d dot_f64x4_low = (__m256d)__lasx_xvreplgr2vr_d(0);
236
+ __m256d dot_f64x4_high = (__m256d)__lasx_xvreplgr2vr_d(0);
237
+ __m256d a_sq_f64x4_low = (__m256d)__lasx_xvreplgr2vr_d(0);
238
+ __m256d a_sq_f64x4_high = (__m256d)__lasx_xvreplgr2vr_d(0);
239
+ __m256d b_sq_f64x4_low = (__m256d)__lasx_xvreplgr2vr_d(0);
240
+ __m256d b_sq_f64x4_high = (__m256d)__lasx_xvreplgr2vr_d(0);
241
+ nk_size_t i = 0;
242
+ for (; i + 8 <= n; i += 8) {
243
+ __m256i a_f32x8 = __lasx_xvld(a + i, 0);
244
+ __m256i b_f32x8 = __lasx_xvld(b + i, 0);
245
+ __m256d a_low_f64x4 = __lasx_xvfcvtl_d_s((__m256)a_f32x8);
246
+ __m256d b_low_f64x4 = __lasx_xvfcvtl_d_s((__m256)b_f32x8);
247
+ __m256d a_high_f64x4 = __lasx_xvfcvth_d_s((__m256)a_f32x8);
248
+ __m256d b_high_f64x4 = __lasx_xvfcvth_d_s((__m256)b_f32x8);
249
+ dot_f64x4_low = __lasx_xvfmadd_d(a_low_f64x4, b_low_f64x4, dot_f64x4_low);
250
+ dot_f64x4_high = __lasx_xvfmadd_d(a_high_f64x4, b_high_f64x4, dot_f64x4_high);
251
+ a_sq_f64x4_low = __lasx_xvfmadd_d(a_low_f64x4, a_low_f64x4, a_sq_f64x4_low);
252
+ a_sq_f64x4_high = __lasx_xvfmadd_d(a_high_f64x4, a_high_f64x4, a_sq_f64x4_high);
253
+ b_sq_f64x4_low = __lasx_xvfmadd_d(b_low_f64x4, b_low_f64x4, b_sq_f64x4_low);
254
+ b_sq_f64x4_high = __lasx_xvfmadd_d(b_high_f64x4, b_high_f64x4, b_sq_f64x4_high);
255
+ }
256
+ nk_f64_t dot = nk_reduce_add_f64x4_loongsonasx_(__lasx_xvfadd_d(dot_f64x4_low, dot_f64x4_high));
257
+ nk_f64_t a_sq = nk_reduce_add_f64x4_loongsonasx_(__lasx_xvfadd_d(a_sq_f64x4_low, a_sq_f64x4_high));
258
+ nk_f64_t b_sq = nk_reduce_add_f64x4_loongsonasx_(__lasx_xvfadd_d(b_sq_f64x4_low, b_sq_f64x4_high));
259
+ for (; i < n; ++i) {
260
+ nk_f64_t a_val = a[i], b_val = b[i];
261
+ dot += a_val * b_val;
262
+ a_sq += a_val * a_val;
263
+ b_sq += b_val * b_val;
264
+ }
265
+ *result = nk_angular_normalize_f64_loongsonasx_(dot, a_sq, b_sq);
266
+ }
267
+
268
+ NK_PUBLIC void nk_sqeuclidean_f64_loongsonasx(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
269
+ __m256d sum_f64x4 = (__m256d)__lasx_xvreplgr2vr_d(0);
270
+ nk_size_t i = 0;
271
+ for (; i + 4 <= n; i += 4) {
272
+ __m256d a_f64x4 = (__m256d)__lasx_xvld(a + i, 0);
273
+ __m256d b_f64x4 = (__m256d)__lasx_xvld(b + i, 0);
274
+ __m256d diff_f64x4 = __lasx_xvfsub_d(a_f64x4, b_f64x4);
275
+ sum_f64x4 = __lasx_xvfmadd_d(diff_f64x4, diff_f64x4, sum_f64x4);
276
+ }
277
+ nk_f64_t sum = nk_reduce_add_f64x4_loongsonasx_(sum_f64x4);
278
+ for (; i < n; ++i) {
279
+ nk_f64_t diff = a[i] - b[i];
280
+ sum += diff * diff;
281
+ }
282
+ *result = sum;
283
+ }
284
+
285
+ NK_PUBLIC void nk_euclidean_f64_loongsonasx(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
286
+ nk_sqeuclidean_f64_loongsonasx(a, b, n, result);
287
+ *result = nk_f64_sqrt_loongsonasx(*result);
288
+ }
289
+
290
+ NK_PUBLIC void nk_angular_f64_loongsonasx(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
291
+ __m256d dot_sum_f64x4 = (__m256d)__lasx_xvreplgr2vr_d(0);
292
+ __m256d dot_compensation_f64x4 = (__m256d)__lasx_xvreplgr2vr_d(0);
293
+ __m256d a_norm_sq_f64x4 = (__m256d)__lasx_xvreplgr2vr_d(0);
294
+ __m256d b_norm_sq_f64x4 = (__m256d)__lasx_xvreplgr2vr_d(0);
295
+ nk_size_t i = 0;
296
+ for (; i + 4 <= n; i += 4) {
297
+ __m256d a_f64x4 = (__m256d)__lasx_xvld(a + i, 0);
298
+ __m256d b_f64x4 = (__m256d)__lasx_xvld(b + i, 0);
299
+
300
+ __m256d product_f64x4 = __lasx_xvfmul_d(a_f64x4, b_f64x4);
301
+ __m256d product_error_f64x4 = __lasx_xvfmsub_d(a_f64x4, b_f64x4, product_f64x4);
302
+
303
+ __m256d tentative_sum_f64x4 = __lasx_xvfadd_d(dot_sum_f64x4, product_f64x4);
304
+ __m256d virtual_addend_f64x4 = __lasx_xvfsub_d(tentative_sum_f64x4, dot_sum_f64x4);
305
+ __m256d sum_error_f64x4 = __lasx_xvfadd_d(
306
+ __lasx_xvfsub_d(dot_sum_f64x4, __lasx_xvfsub_d(tentative_sum_f64x4, virtual_addend_f64x4)),
307
+ __lasx_xvfsub_d(product_f64x4, virtual_addend_f64x4));
308
+
309
+ dot_sum_f64x4 = tentative_sum_f64x4;
310
+ dot_compensation_f64x4 = __lasx_xvfadd_d(dot_compensation_f64x4,
311
+ __lasx_xvfadd_d(sum_error_f64x4, product_error_f64x4));
312
+
313
+ a_norm_sq_f64x4 = __lasx_xvfmadd_d(a_f64x4, a_f64x4, a_norm_sq_f64x4);
314
+ b_norm_sq_f64x4 = __lasx_xvfmadd_d(b_f64x4, b_f64x4, b_norm_sq_f64x4);
315
+ }
316
+
317
+ nk_f64_t dot = nk_dot_stable_sum_f64x4_loongsonasx_(dot_sum_f64x4, dot_compensation_f64x4);
318
+ nk_f64_t a_sq = nk_reduce_add_f64x4_loongsonasx_(a_norm_sq_f64x4);
319
+ nk_f64_t b_sq = nk_reduce_add_f64x4_loongsonasx_(b_norm_sq_f64x4);
320
+ for (; i < n; ++i) {
321
+ nk_f64_t a_val = a[i], b_val = b[i];
322
+ dot += a_val * b_val;
323
+ a_sq += a_val * a_val;
324
+ b_sq += b_val * b_val;
325
+ }
326
+ *result = nk_angular_normalize_f64_loongsonasx_(dot, a_sq, b_sq);
327
+ }
328
+
329
+ #pragma endregion F32 and F64 Floats
330
+
331
+ #pragma region F16 and BF16 Floats
332
+
333
+ NK_INTERNAL nk_f32_t nk_angular_normalize_f32_loongsonasx_(nk_f32_t ab, nk_f32_t a2, nk_f32_t b2) {
334
+ if (a2 == 0.0f && b2 == 0.0f) return 0.0f;
335
+ else if (ab == 0.0f) return 1.0f;
336
+ nk_f32_t result = 1.0f - ab * nk_f32_rsqrt_loongsonasx(a2) * nk_f32_rsqrt_loongsonasx(b2);
337
+ return result > 0.0f ? result : 0.0f;
338
+ }
339
+
340
+ /** @brief Horizontal sum of 8 × f32 lanes in a 256-bit LASX register. */
341
+ NK_INTERNAL nk_f32_t nk_reduce_add_f32x8_loongsonasx_(__m256 sum_f32x8) {
342
+ // Add high 128-bit lane to low 128-bit lane
343
+ __m256 high_f32x4 = (__m256)__lasx_xvpermi_q((__m256i)sum_f32x8, (__m256i)sum_f32x8, 0x11);
344
+ __m256 sum_f32x4 = __lasx_xvfadd_s(sum_f32x8, high_f32x4);
345
+ __m256 swapped_f32x4 = (__m256)__lasx_xvshuf4i_w((__m256i)sum_f32x4, 0b01001110);
346
+ __m256 reduced_f32x4 = __lasx_xvfadd_s(sum_f32x4, swapped_f32x4);
347
+ __m256 swapped_f32x2 = (__m256)__lasx_xvshuf4i_w((__m256i)reduced_f32x4, 0b10110001);
348
+ __m256 reduced_f32x2 = __lasx_xvfadd_s(reduced_f32x4, swapped_f32x2);
349
+ nk_fui32_t c;
350
+ c.u = (nk_u32_t)__lasx_xvpickve2gr_w((__m256i)reduced_f32x2, 0);
351
+ return c.f;
352
+ }
353
+
354
+ NK_PUBLIC void nk_sqeuclidean_bf16_loongsonasx(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
355
+ __m256 sum_f32x8 = (__m256)__lasx_xvreplgr2vr_w(0);
356
+ __m256i mask_high_u32x8 = __lasx_xvreplgr2vr_w((int)0xFFFF0000);
357
+ nk_size_t i = 0;
358
+ for (; i + 16 <= n; i += 16) {
359
+ __m256i a_bf16x16 = __lasx_xvld(a + i, 0);
360
+ __m256i b_bf16x16 = __lasx_xvld(b + i, 0);
361
+ __m256 a_even_f32x8 = (__m256)__lasx_xvslli_w(a_bf16x16, 16);
362
+ __m256 b_even_f32x8 = (__m256)__lasx_xvslli_w(b_bf16x16, 16);
363
+ __m256 diff_even_f32x8 = __lasx_xvfsub_s(a_even_f32x8, b_even_f32x8);
364
+ sum_f32x8 = __lasx_xvfmadd_s(diff_even_f32x8, diff_even_f32x8, sum_f32x8);
365
+ __m256 a_odd_f32x8 = (__m256)__lasx_xvand_v(a_bf16x16, mask_high_u32x8);
366
+ __m256 b_odd_f32x8 = (__m256)__lasx_xvand_v(b_bf16x16, mask_high_u32x8);
367
+ __m256 diff_odd_f32x8 = __lasx_xvfsub_s(a_odd_f32x8, b_odd_f32x8);
368
+ sum_f32x8 = __lasx_xvfmadd_s(diff_odd_f32x8, diff_odd_f32x8, sum_f32x8);
369
+ }
370
+ nk_f32_t sum = nk_reduce_add_f32x8_loongsonasx_(sum_f32x8);
371
+ for (; i < n; ++i) {
372
+ nk_f32_t a_val, b_val;
373
+ nk_bf16_to_f32_serial(&a[i], &a_val);
374
+ nk_bf16_to_f32_serial(&b[i], &b_val);
375
+ nk_f32_t diff = a_val - b_val;
376
+ sum += diff * diff;
377
+ }
378
+ *result = sum;
379
+ }
380
+
381
+ NK_PUBLIC void nk_euclidean_bf16_loongsonasx(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
382
+ nk_sqeuclidean_bf16_loongsonasx(a, b, n, result);
383
+ *result = nk_f32_sqrt_loongsonasx(*result);
384
+ }
385
+
386
+ NK_PUBLIC void nk_angular_bf16_loongsonasx(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
387
+ __m256 dot_f32x8 = (__m256)__lasx_xvreplgr2vr_w(0);
388
+ __m256 a_sq_f32x8 = (__m256)__lasx_xvreplgr2vr_w(0);
389
+ __m256 b_sq_f32x8 = (__m256)__lasx_xvreplgr2vr_w(0);
390
+ __m256i mask_high_u32x8 = __lasx_xvreplgr2vr_w((int)0xFFFF0000);
391
+ nk_size_t i = 0;
392
+ for (; i + 16 <= n; i += 16) {
393
+ __m256i a_bf16x16 = __lasx_xvld(a + i, 0);
394
+ __m256i b_bf16x16 = __lasx_xvld(b + i, 0);
395
+ __m256 a_even_f32x8 = (__m256)__lasx_xvslli_w(a_bf16x16, 16);
396
+ __m256 b_even_f32x8 = (__m256)__lasx_xvslli_w(b_bf16x16, 16);
397
+ dot_f32x8 = __lasx_xvfmadd_s(a_even_f32x8, b_even_f32x8, dot_f32x8);
398
+ a_sq_f32x8 = __lasx_xvfmadd_s(a_even_f32x8, a_even_f32x8, a_sq_f32x8);
399
+ b_sq_f32x8 = __lasx_xvfmadd_s(b_even_f32x8, b_even_f32x8, b_sq_f32x8);
400
+ __m256 a_odd_f32x8 = (__m256)__lasx_xvand_v(a_bf16x16, mask_high_u32x8);
401
+ __m256 b_odd_f32x8 = (__m256)__lasx_xvand_v(b_bf16x16, mask_high_u32x8);
402
+ dot_f32x8 = __lasx_xvfmadd_s(a_odd_f32x8, b_odd_f32x8, dot_f32x8);
403
+ a_sq_f32x8 = __lasx_xvfmadd_s(a_odd_f32x8, a_odd_f32x8, a_sq_f32x8);
404
+ b_sq_f32x8 = __lasx_xvfmadd_s(b_odd_f32x8, b_odd_f32x8, b_sq_f32x8);
405
+ }
406
+ nk_f32_t dot = nk_reduce_add_f32x8_loongsonasx_(dot_f32x8);
407
+ nk_f32_t a_sq = nk_reduce_add_f32x8_loongsonasx_(a_sq_f32x8);
408
+ nk_f32_t b_sq = nk_reduce_add_f32x8_loongsonasx_(b_sq_f32x8);
409
+ for (; i < n; ++i) {
410
+ nk_f32_t a_val, b_val;
411
+ nk_bf16_to_f32_serial(&a[i], &a_val);
412
+ nk_bf16_to_f32_serial(&b[i], &b_val);
413
+ dot += a_val * b_val;
414
+ a_sq += a_val * a_val;
415
+ b_sq += b_val * b_val;
416
+ }
417
+ *result = nk_angular_normalize_f32_loongsonasx_(dot, a_sq, b_sq);
418
+ }
419
+
420
+ NK_PUBLIC void nk_sqeuclidean_f16_loongsonasx(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
421
+ __m256 sum_f32x8 = (__m256)__lasx_xvreplgr2vr_w(0);
422
+ nk_size_t i = 0;
423
+ for (; i + 8 <= n; i += 8) {
424
+ __m128i a_f16x8 = __lsx_vld(a + i, 0);
425
+ __m128i b_f16x8 = __lsx_vld(b + i, 0);
426
+ __m256 a_f32x8 = (__m256)nk_f16x8_to_f32x8_loongsonasx_(a_f16x8);
427
+ __m256 b_f32x8 = (__m256)nk_f16x8_to_f32x8_loongsonasx_(b_f16x8);
428
+ __m256 diff_f32x8 = __lasx_xvfsub_s(a_f32x8, b_f32x8);
429
+ sum_f32x8 = __lasx_xvfmadd_s(diff_f32x8, diff_f32x8, sum_f32x8);
430
+ }
431
+ nk_f32_t sum = nk_reduce_add_f32x8_loongsonasx_(sum_f32x8);
432
+ for (; i < n; ++i) {
433
+ nk_f32_t a_val, b_val;
434
+ nk_f16_to_f32_serial(&a[i], &a_val);
435
+ nk_f16_to_f32_serial(&b[i], &b_val);
436
+ nk_f32_t diff = a_val - b_val;
437
+ sum += diff * diff;
438
+ }
439
+ *result = sum;
440
+ }
441
+
442
+ NK_PUBLIC void nk_euclidean_f16_loongsonasx(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
443
+ nk_sqeuclidean_f16_loongsonasx(a, b, n, result);
444
+ *result = nk_f32_sqrt_loongsonasx(*result);
445
+ }
446
+
447
+ NK_PUBLIC void nk_angular_f16_loongsonasx(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
448
+ __m256 dot_f32x8 = (__m256)__lasx_xvreplgr2vr_w(0);
449
+ __m256 a_sq_f32x8 = (__m256)__lasx_xvreplgr2vr_w(0);
450
+ __m256 b_sq_f32x8 = (__m256)__lasx_xvreplgr2vr_w(0);
451
+ nk_size_t i = 0;
452
+ for (; i + 8 <= n; i += 8) {
453
+ __m128i a_f16x8 = __lsx_vld(a + i, 0);
454
+ __m128i b_f16x8 = __lsx_vld(b + i, 0);
455
+ __m256 a_f32x8 = (__m256)nk_f16x8_to_f32x8_loongsonasx_(a_f16x8);
456
+ __m256 b_f32x8 = (__m256)nk_f16x8_to_f32x8_loongsonasx_(b_f16x8);
457
+ dot_f32x8 = __lasx_xvfmadd_s(a_f32x8, b_f32x8, dot_f32x8);
458
+ a_sq_f32x8 = __lasx_xvfmadd_s(a_f32x8, a_f32x8, a_sq_f32x8);
459
+ b_sq_f32x8 = __lasx_xvfmadd_s(b_f32x8, b_f32x8, b_sq_f32x8);
460
+ }
461
+ nk_f32_t dot = nk_reduce_add_f32x8_loongsonasx_(dot_f32x8);
462
+ nk_f32_t a_sq = nk_reduce_add_f32x8_loongsonasx_(a_sq_f32x8);
463
+ nk_f32_t b_sq = nk_reduce_add_f32x8_loongsonasx_(b_sq_f32x8);
464
+ for (; i < n; ++i) {
465
+ nk_f32_t a_val, b_val;
466
+ nk_f16_to_f32_serial(&a[i], &a_val);
467
+ nk_f16_to_f32_serial(&b[i], &b_val);
468
+ dot += a_val * b_val;
469
+ a_sq += a_val * a_val;
470
+ b_sq += b_val * b_val;
471
+ }
472
+ *result = nk_angular_normalize_f32_loongsonasx_(dot, a_sq, b_sq);
473
+ }
474
+
475
+ #pragma endregion F16 and BF16 Floats
476
+
477
+ #if defined(__cplusplus)
478
+ } // extern "C"
479
+ #endif
480
+
481
+ #endif // NK_TARGET_LOONGSONASX
482
+ #endif // NK_TARGET_LOONGARCH_
483
+ #endif // NK_SPATIAL_LOONGSONASX_H