numkong 7.0.0 → 7.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +239 -122
  2. package/binding.gyp +25 -491
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -6,11 +6,11 @@
6
6
  *
7
7
  * @section ice_cast_instructions AVX-512 VBMI2 Instructions
8
8
  *
9
- * Intrinsic Instruction Ice Genoa
10
- * _mm512_permutex2var_epi16 VPERMI2W (ZMM, ZMM, ZMM) 3cy @ p5 2cy @ p12
11
- * _mm512_test_epi16_mask VPTESTMW (k, ZMM, ZMM) 3cy @ p5 2cy @ p01
12
- * _mm512_mask_mov_epi16 VMOVDQU16 (ZMM{k}, ZMM) 1cy @ p05 1cy @ p05
13
- * _mm512_cvtepi16_epi8 VPMOVWB (YMM, ZMM) 3cy @ p5 2cy @ p12
9
+ * Intrinsic Instruction Icelake Genoa
10
+ * _mm512_permutex2var_epi16 VPERMI2W (ZMM, ZMM, ZMM) 3cy @ p5 2cy @ p12
11
+ * _mm512_test_epi16_mask VPTESTMW (k, ZMM, ZMM) 3cy @ p5 2cy @ p01
12
+ * _mm512_mask_mov_epi16 VMOVDQU16 (ZMM{k}, ZMM) 1cy @ p05 1cy @ p05
13
+ * _mm512_cvtepi16_epi8 VPMOVWB (YMM, ZMM) 3cy @ p5 2cy @ p12
14
14
  *
15
15
  * Ice Lake's AVX-512 VBMI2 enables efficient 128-entry LUT lookups via dual VPERMI2W operations.
16
16
  * FP8-to-BF16/F16 conversions use 4 ZMM LUT registers with VPTESTMW for range selection, achieving
@@ -37,7 +37,7 @@ extern "C" {
37
37
  #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "f16c", "fma", "bmi", "bmi2")
38
38
  #endif
39
39
 
40
- #pragma region - Vectorized Conversions
40
+ #pragma region Vectorized Conversions
41
41
 
42
42
  /** @brief Convert 32x e4m3 → 32x bf16 via arithmetic + 8-entry subnormal LUT (AVX-512BW).
43
43
  * E4M3 format: S EEEE MMM (bias=7). BF16: S EEEEEEEE MMMMMMM (bias=127).
@@ -72,7 +72,12 @@ NK_INTERNAL __m512i nk_e4m3x32_to_bf16x32_icelake_(__m256i e4m3x32) {
72
72
 
73
73
  // Apply sign: shift E4M3 bit 7 to BF16 bit 15
74
74
  sign_i16x32 = _mm512_slli_epi16(sign_i16x32, 8);
75
- return _mm512_or_si512(result_abs_i16x32, sign_i16x32);
75
+ __m512i result_i16x32 = _mm512_or_si512(result_abs_i16x32, sign_i16x32);
76
+
77
+ // NaN: E4M3FN has NaN only at magnitude 0x7F → BF16 quiet NaN (0x7FC0)
78
+ __mmask32 is_nan = _mm512_cmpeq_epi16_mask(lower7_i16x32, _mm512_set1_epi16(0x7F));
79
+ __m512i nan_i16x32 = _mm512_or_si512(sign_i16x32, _mm512_set1_epi16(0x7FC0));
80
+ return _mm512_mask_blend_epi16(is_nan, result_i16x32, nan_i16x32);
76
81
  }
77
82
 
78
83
  /** @brief Convert 32x e5m2 → 32x bf16 via arithmetic + 4-entry subnormal LUT (AVX-512BW).
@@ -268,14 +273,14 @@ NK_INTERNAL __m256i nk_bf16x32_to_e4m3x32_icelake_(__m512i bf16x32) {
268
273
  // bf16 to f32 is just left shift by 16
269
274
  __m512i bf16_low_i32x16 = _mm512_cvtepu16_epi32(_mm512_castsi512_si256(bf16x32));
270
275
  __m512i bf16_high_i32x16 = _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(bf16x32, 1));
271
- __m512 f32_low = _mm512_castsi512_ps(_mm512_slli_epi32(bf16_low_i32x16, 16));
272
- __m512 f32_high = _mm512_castsi512_ps(_mm512_slli_epi32(bf16_high_i32x16, 16));
273
- __m512 abs_f32_low = _mm512_and_ps(f32_low, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
274
- __m512 abs_f32_high = _mm512_and_ps(f32_high, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
275
- __m512 scaled_low = _mm512_mul_ps(abs_f32_low, _mm512_set1_ps(512.0f));
276
- __m512 scaled_high = _mm512_mul_ps(abs_f32_high, _mm512_set1_ps(512.0f));
277
- __m512i subnorm_mant_low_i32x16 = _mm512_cvtps_epi32(scaled_low);
278
- __m512i subnorm_mant_high_i32x16 = _mm512_cvtps_epi32(scaled_high);
276
+ __m512 f32_low_f32x16 = _mm512_castsi512_ps(_mm512_slli_epi32(bf16_low_i32x16, 16));
277
+ __m512 f32_high_f32x16 = _mm512_castsi512_ps(_mm512_slli_epi32(bf16_high_i32x16, 16));
278
+ __m512 abs_f32_low_f32x16 = _mm512_and_ps(f32_low_f32x16, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
279
+ __m512 abs_f32_high_f32x16 = _mm512_and_ps(f32_high_f32x16, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
280
+ __m512 scaled_low_f32x16 = _mm512_mul_ps(abs_f32_low_f32x16, _mm512_set1_ps(512.0f));
281
+ __m512 scaled_high_f32x16 = _mm512_mul_ps(abs_f32_high_f32x16, _mm512_set1_ps(512.0f));
282
+ __m512i subnorm_mant_low_i32x16 = _mm512_cvtps_epi32(scaled_low_f32x16);
283
+ __m512i subnorm_mant_high_i32x16 = _mm512_cvtps_epi32(scaled_high_f32x16);
279
284
  __m256i subnorm_mant_low_i16x16 = _mm512_cvtepi32_epi16(subnorm_mant_low_i32x16);
280
285
  __m256i subnorm_mant_high_i16x16 = _mm512_cvtepi32_epi16(subnorm_mant_high_i32x16);
281
286
  __m512i subnorm_mantissa_i16x32 = _mm512_inserti64x4(_mm512_castsi256_si512(subnorm_mant_low_i16x16),
@@ -328,14 +333,14 @@ NK_INTERNAL __m256i nk_bf16x32_to_e5m2x32_icelake_(__m512i bf16x32) {
328
333
  // Subnormal path: compute via f32 to get correct rounding
329
334
  __m512i bf16_low_i32x16 = _mm512_cvtepu16_epi32(_mm512_castsi512_si256(bf16x32));
330
335
  __m512i bf16_high_i32x16 = _mm512_cvtepu16_epi32(_mm512_extracti32x8_epi32(bf16x32, 1));
331
- __m512 f32_low = _mm512_castsi512_ps(_mm512_slli_epi32(bf16_low_i32x16, 16));
332
- __m512 f32_high = _mm512_castsi512_ps(_mm512_slli_epi32(bf16_high_i32x16, 16));
333
- __m512 abs_f32_low = _mm512_and_ps(f32_low, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
334
- __m512 abs_f32_high = _mm512_and_ps(f32_high, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
335
- __m512 scaled_low = _mm512_mul_ps(abs_f32_low, _mm512_set1_ps(65536.0f));
336
- __m512 scaled_high = _mm512_mul_ps(abs_f32_high, _mm512_set1_ps(65536.0f));
337
- __m512i subnorm_mant_low_i32x16 = _mm512_cvtps_epi32(scaled_low);
338
- __m512i subnorm_mant_high_i32x16 = _mm512_cvtps_epi32(scaled_high);
336
+ __m512 f32_low_f32x16 = _mm512_castsi512_ps(_mm512_slli_epi32(bf16_low_i32x16, 16));
337
+ __m512 f32_high_f32x16 = _mm512_castsi512_ps(_mm512_slli_epi32(bf16_high_i32x16, 16));
338
+ __m512 abs_f32_low_f32x16 = _mm512_and_ps(f32_low_f32x16, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
339
+ __m512 abs_f32_high_f32x16 = _mm512_and_ps(f32_high_f32x16, _mm512_castsi512_ps(_mm512_set1_epi32(0x7FFFFFFF)));
340
+ __m512 scaled_low_f32x16 = _mm512_mul_ps(abs_f32_low_f32x16, _mm512_set1_ps(65536.0f));
341
+ __m512 scaled_high_f32x16 = _mm512_mul_ps(abs_f32_high_f32x16, _mm512_set1_ps(65536.0f));
342
+ __m512i subnorm_mant_low_i32x16 = _mm512_cvtps_epi32(scaled_low_f32x16);
343
+ __m512i subnorm_mant_high_i32x16 = _mm512_cvtps_epi32(scaled_high_f32x16);
339
344
  __m256i subnorm_mant_low_i16x16 = _mm512_cvtepi32_epi16(subnorm_mant_low_i32x16);
340
345
  __m256i subnorm_mant_high_i16x16 = _mm512_cvtepi32_epi16(subnorm_mant_high_i32x16);
341
346
  __m512i subnorm_mantissa_i16x32 = _mm512_inserti64x4(_mm512_castsi256_si512(subnorm_mant_low_i16x16),
@@ -362,8 +367,8 @@ NK_INTERNAL void nk_load_e4m3x32_to_bf16x32_icelake_(void const *src, nk_b512_ve
362
367
  /** @brief Partial load n e4m3 elements from memory and convert to bf16 (Ice Lake AVX-512BW). */
363
368
  NK_INTERNAL void nk_partial_load_e4m3x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
364
369
  __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
365
- __m256i e4m3_partial = _mm256_maskz_loadu_epi8(mask, src);
366
- dst->zmm = nk_e4m3x32_to_bf16x32_icelake_(e4m3_partial);
370
+ __m256i e4m3_partial_i8x32 = _mm256_maskz_loadu_epi8(mask, src);
371
+ dst->zmm = nk_e4m3x32_to_bf16x32_icelake_(e4m3_partial_i8x32);
367
372
  }
368
373
 
369
374
  /** @brief Load 32x e5m2 from memory and convert to 32x bf16 (Ice Lake AVX-512BW). */
@@ -374,8 +379,8 @@ NK_INTERNAL void nk_load_e5m2x32_to_bf16x32_icelake_(void const *src, nk_b512_ve
374
379
  /** @brief Partial load n e5m2 elements from memory and convert to bf16 (Ice Lake AVX-512BW). */
375
380
  NK_INTERNAL void nk_partial_load_e5m2x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
376
381
  __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
377
- __m256i e5m2_partial = _mm256_maskz_loadu_epi8(mask, src);
378
- dst->zmm = nk_e5m2x32_to_bf16x32_icelake_(e5m2_partial);
382
+ __m256i e5m2_partial_i8x32 = _mm256_maskz_loadu_epi8(mask, src);
383
+ dst->zmm = nk_e5m2x32_to_bf16x32_icelake_(e5m2_partial_i8x32);
379
384
  }
380
385
 
381
386
  /** @brief Load 32x e2m3 from memory and convert to 32x bf16 (Ice Lake AVX-512BW). */
@@ -386,8 +391,8 @@ NK_INTERNAL void nk_load_e2m3x32_to_bf16x32_icelake_(void const *src, nk_b512_ve
386
391
  /** @brief Partial load n e2m3 elements from memory and convert to bf16 (Ice Lake AVX-512BW). */
387
392
  NK_INTERNAL void nk_partial_load_e2m3x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
388
393
  __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
389
- __m256i e2m3_partial = _mm256_maskz_loadu_epi8(mask, src);
390
- dst->zmm = nk_e2m3x32_to_bf16x32_icelake_(e2m3_partial);
394
+ __m256i e2m3_partial_i8x32 = _mm256_maskz_loadu_epi8(mask, src);
395
+ dst->zmm = nk_e2m3x32_to_bf16x32_icelake_(e2m3_partial_i8x32);
391
396
  }
392
397
 
393
398
  /** @brief Load 32x e3m2 from memory and convert to 32x bf16 (Ice Lake AVX-512BW). */
@@ -398,13 +403,13 @@ NK_INTERNAL void nk_load_e3m2x32_to_bf16x32_icelake_(void const *src, nk_b512_ve
398
403
  /** @brief Partial load n e3m2 elements from memory and convert to bf16 (Ice Lake AVX-512BW). */
399
404
  NK_INTERNAL void nk_partial_load_e3m2x32_to_bf16x32_icelake_(void const *src, nk_b512_vec_t *dst, nk_size_t n) {
400
405
  __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
401
- __m256i e3m2_partial = _mm256_maskz_loadu_epi8(mask, src);
402
- dst->zmm = nk_e3m2x32_to_bf16x32_icelake_(e3m2_partial);
406
+ __m256i e3m2_partial_i8x32 = _mm256_maskz_loadu_epi8(mask, src);
407
+ dst->zmm = nk_e3m2x32_to_bf16x32_icelake_(e3m2_partial_i8x32);
403
408
  }
404
409
 
405
- #pragma endregion - Vectorized Conversions
410
+ #pragma endregion Vectorized Conversions
406
411
 
407
- #pragma region - Public API
412
+ #pragma region Public API
408
413
 
409
414
  NK_PUBLIC void nk_cast_icelake(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type) {
410
415
  // Group 1: Conversions to bf16 (e4m3 → bf16, e5m2 → bf16)
@@ -428,9 +433,9 @@ NK_PUBLIC void nk_cast_icelake(void const *from, nk_dtype_t from_type, nk_size_t
428
433
  for (nk_size_t idx = 0; idx < n; idx += 32) {
429
434
  nk_size_t remaining = n - idx;
430
435
  __mmask32 mask = (remaining >= 32) ? 0xFFFFFFFF : _bzhi_u32(0xFFFFFFFF, (unsigned)remaining);
431
- __m512i in_bf16x32 = _mm512_maskz_loadu_epi16(mask, from_ptr + idx);
432
- __m256i out_f8x32 = (to_type == nk_e4m3_k) ? nk_bf16x32_to_e4m3x32_icelake_(in_bf16x32)
433
- : nk_bf16x32_to_e5m2x32_icelake_(in_bf16x32);
436
+ __m512i in_bf16x32_i16x32 = _mm512_maskz_loadu_epi16(mask, from_ptr + idx);
437
+ __m256i out_f8x32 = (to_type == nk_e4m3_k) ? nk_bf16x32_to_e4m3x32_icelake_(in_bf16x32_i16x32)
438
+ : nk_bf16x32_to_e5m2x32_icelake_(in_bf16x32_i16x32);
434
439
  _mm256_mask_storeu_epi8(to_ptr + idx, mask, out_f8x32);
435
440
  }
436
441
  }
@@ -453,7 +458,7 @@ NK_PUBLIC void nk_cast_icelake(void const *from, nk_dtype_t from_type, nk_size_t
453
458
  else nk_cast_skylake(from, from_type, n, to, to_type);
454
459
  }
455
460
 
456
- #pragma endregion - Public API
461
+ #pragma endregion Public API
457
462
 
458
463
  #if defined(__clang__)
459
464
  #pragma clang attribute pop
@@ -0,0 +1,252 @@
1
+ /**
2
+ * @brief SIMD-accelerated Type Conversions and Load/Store Helpers for LoongArch LASX (256-bit).
3
+ * @file include/numkong/cast/loongsonasx.h
4
+ * @author Ash Vardanian
5
+ * @date March 23, 2026
6
+ *
7
+ * @sa include/numkong/cast.h
8
+ *
9
+ * @section loongsonasx_cast_instructions Key LASX Load/Store Instructions
10
+ *
11
+ * Intrinsic Instruction Description
12
+ * __lasx_xvld(ptr, 0) XVLD 256-bit aligned/unaligned load
13
+ * __lasx_xvst(v, ptr, 0) XVST 256-bit aligned/unaligned store
14
+ * __lasx_xvreplgr2vr_w(bits) XVREPLGR2VR.W Broadcast i32 to 8 lanes
15
+ * __lasx_xvreplgr2vr_d(bits) XVREPLGR2VR.D Broadcast i64 to 4 lanes
16
+ * __lasx_xvffint_s_w(v) XVFFINT.S.W 4x i32 -> f32 (per 128-bit lane)
17
+ * __lasx_xvfrsqrt_s(v) XVFRSQRT.S f32 full-precision reciprocal sqrt
18
+ * __lasx_xvfsqrt_s(v) XVFSQRT.S f32 full-precision sqrt
19
+ * __lasx_xvfsqrt_d(v) XVFSQRT.D f64 full-precision sqrt
20
+ *
21
+ * LASX is a 256-bit extension; all vector registers are 256-bit `__m256i`. For 128-bit
22
+ * `nk_b128_vec_t` operations, `__lasx_xvld` safely loads into the low 128 bits (the high
23
+ * 128 bits are zeroed or undefined depending on context). For 128-bit stores we use `memcpy`
24
+ * to avoid writing beyond the intended 16 bytes. Partial loads/stores delegate to serial
25
+ * helpers since LASX lacks masked load/store instructions.
26
+ */
27
+ #ifndef NK_CAST_LOONGSONASX_H
28
+ #define NK_CAST_LOONGSONASX_H
29
+
30
+ #if NK_TARGET_LOONGARCH_
31
+ #if NK_TARGET_LOONGSONASX
32
+
33
+ #include "numkong/types.h"
34
+ #include "numkong/cast/serial.h" // `nk_partial_load_b32x4_serial_`, `nk_partial_load_b64x4_serial_`
35
+ #include "numkong/scalar/loongsonasx.h" // `nk_xvreplgr2vr_s_128_`, `nk_xvfreplgr2vr_s_`
36
+
37
+ #if defined(__cplusplus)
38
+ extern "C" {
39
+ #endif
40
+
41
+ #pragma region Type Punned Loads and Stores
42
+
43
+ /**
44
+ * LSX and LASX share the same physical register file, so widening __m128i → __m256i and
45
+ * extracting __m256i → __m128i are no-ops on hardware. Empty inline asm with "f" constraints
46
+ * avoids the stack round-trip that union punning causes on GCC 14.
47
+ * Named after x86 `_mm256_castsi128_si256` / `_mm256_castsi256_si128` / `_mm256_castps256_ps128`.
48
+ */
49
+ NK_INTERNAL __m256i nk_lasx_castsi128_si256_(__m128i low_i64x2) {
50
+ __m256i wide_i64x4;
51
+ __asm__("" : "=f"(wide_i64x4) : "f"(low_i64x2));
52
+ return wide_i64x4;
53
+ }
54
+ NK_INTERNAL __m128i nk_lasx_castsi256_si128_(__m256i wide_i64x4) {
55
+ __m128i low_i64x2;
56
+ __asm__("" : "=f"(low_i64x2) : "f"(wide_i64x4));
57
+ return low_i64x2;
58
+ }
59
+ NK_INTERNAL __m128 nk_lasx_castps256_ps128_(__m256 wide_f32x8) {
60
+ __m128 low_f32x4;
61
+ __asm__("" : "=f"(low_f32x4) : "f"(wide_f32x8));
62
+ return low_f32x4;
63
+ }
64
+
65
+ /** @brief Type-agnostic 256-bit full load (LASX). */
66
+ NK_INTERNAL void nk_load_b256_loongsonasx_(void const *src, nk_b256_vec_t *dst) { dst->ymm = __lasx_xvld(src, 0); }
67
+
68
+ /** @brief Type-agnostic 256-bit full store (LASX). */
69
+ NK_INTERNAL void nk_store_b256_loongsonasx_(nk_b256_vec_t const *src, void *dst) { __lasx_xvst(src->ymm, dst, 0); }
70
+
71
+ /** @brief Type-agnostic 128-bit full load (LSX subset of LASX). */
72
+ NK_INTERNAL void nk_load_b128_loongsonasx_(void const *src, nk_b128_vec_t *dst) { dst->xmm = __lsx_vld(src, 0); }
73
+
74
+ /** @brief Type-agnostic 128-bit full store (LSX subset of LASX). */
75
+ NK_INTERNAL void nk_store_b128_loongsonasx_(nk_b128_vec_t const *src, void *dst) { __lsx_vst(src->xmm, dst, 0); }
76
+
77
+ /** @brief Convert 8 × bf16 → 8 × f32 by interleaving with zero so bf16 lands in upper 16 bits (LASX).
78
+ *
79
+ * Duplicates the 128-bit input into both lanes, then `xvilvl_h(bf16, zero)` places each bf16
80
+ * value in the high 16 bits of a 32-bit slot — which is valid f32 with no shift needed.
81
+ * `xvpermi_q` combines the low-element and high-element halves into a single register.
82
+ */
83
+ NK_INTERNAL __m256i nk_bf16x8_to_f32x8_loongsonasx_(__m128i bf16_i16x8) {
84
+ __m256i duped_bf16x16 = __lasx_xvpermi_q(nk_lasx_castsi128_si256_(bf16_i16x8), nk_lasx_castsi128_si256_(bf16_i16x8),
85
+ 0x00);
86
+ __m256i zero_i16x16 = __lasx_xvreplgr2vr_h(0);
87
+ __m256i low_f32x8 = __lasx_xvilvl_h(duped_bf16x16, zero_i16x16);
88
+ __m256i high_f32x8 = __lasx_xvilvh_h(duped_bf16x16, zero_i16x16);
89
+ return __lasx_xvpermi_q(high_f32x8, low_f32x8, 0x20);
90
+ }
91
+
92
+ /** @brief Load 8 × bf16 from memory, convert to 8 × f32, store in 256-bit vector (LASX). */
93
+ NK_INTERNAL void nk_load_bf16x8_to_f32x8_loongsonasx_(void const *src, nk_b256_vec_t *dst) {
94
+ dst->ymm = nk_bf16x8_to_f32x8_loongsonasx_(__lsx_vld(src, 0));
95
+ }
96
+
97
+ /** @brief Partial load for bf16 elements (up to 8) with conversion to f32 (LASX). */
98
+ NK_INTERNAL void nk_partial_load_bf16x8_to_f32x8_loongsonasx_(nk_bf16_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
99
+ nk_b128_vec_t vec;
100
+ nk_partial_load_b16x8_serial_(src, &vec, n);
101
+ dst->ymm = nk_bf16x8_to_f32x8_loongsonasx_(vec.xmm);
102
+ }
103
+
104
+ /** @brief Convert 8 × f16 → 8 × f32 via native LASX hardware conversion. */
105
+ NK_INTERNAL __m256i nk_f16x8_to_f32x8_loongsonasx_(__m128i f16_i16x8) {
106
+ __m256i duped_f16x16 = __lasx_xvpermi_q(nk_lasx_castsi128_si256_(f16_i16x8), nk_lasx_castsi128_si256_(f16_i16x8),
107
+ 0x00);
108
+ __m256i low_f32x8 = (__m256i)__lasx_xvfcvtl_s_h(duped_f16x16);
109
+ __m256i high_f32x8 = (__m256i)__lasx_xvfcvth_s_h(duped_f16x16);
110
+ return __lasx_xvpermi_q(high_f32x8, low_f32x8, 0x20);
111
+ }
112
+
113
+ /** @brief Load 8 × f16 from memory, convert to 8 × f32 via native LASX conversion. */
114
+ NK_INTERNAL void nk_load_f16x8_to_f32x8_loongsonasx_(void const *src, nk_b256_vec_t *dst) {
115
+ dst->ymm = nk_f16x8_to_f32x8_loongsonasx_(__lsx_vld(src, 0));
116
+ }
117
+
118
+ /** @brief Partial load for f16 elements (up to 8) with conversion to f32 (LASX). */
119
+ NK_INTERNAL void nk_partial_load_f16x8_to_f32x8_loongsonasx_(nk_f16_t const *src, nk_b256_vec_t *dst, nk_size_t n) {
120
+ nk_b128_vec_t vec;
121
+ nk_partial_load_b16x8_serial_(src, &vec, n);
122
+ dst->ymm = nk_f16x8_to_f32x8_loongsonasx_(vec.xmm);
123
+ }
124
+
125
+ #pragma endregion Type Punned Loads and Stores
126
+
127
+ #pragma region Vectorized From Dot Helpers
128
+
129
+ /** @brief Safe square root of 8 floats with zero-clamping for numerical stability (LASX 256-bit). */
130
+ NK_INTERNAL __m256 nk_sqrt_f32x8_loongsonasx_(__m256 x_f32x8) {
131
+ __m256 zero_f32x8 = (__m256)__lasx_xvreplgr2vr_w(0);
132
+ return __lasx_xvfsqrt_s(__lasx_xvfmax_s(x_f32x8, zero_f32x8));
133
+ }
134
+
135
+ /** @brief Safe square root of 4 floats with zero-clamping for numerical stability (LSX 128-bit). */
136
+ NK_INTERNAL __m128 nk_sqrt_f32x4_loongsonasx_(__m128 x_f32x4) {
137
+ __m128 zero_f32x4 = (__m128)__lsx_vreplgr2vr_w(0);
138
+ return __lsx_vfsqrt_s(__lsx_vfmax_s(x_f32x4, zero_f32x4));
139
+ }
140
+
141
+ /** @brief Angular from_dot: computes 1 − dot × rsqrt(query_sumsq × target_sumsq) for 4 pairs (LSX 128-bit f32). */
142
+ NK_INTERNAL void nk_angular_through_f32_from_dot_loongsonasx_(nk_b128_vec_t dots, nk_f32_t query_sumsq,
143
+ nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
144
+ __m128 dots_f32x4 = dots.xmm_ps;
145
+ __m128 query_sumsq_f32x4 = nk_xvreplgr2vr_s_128_(query_sumsq);
146
+ __m128 products_f32x4 = __lsx_vfmul_s(query_sumsq_f32x4, target_sumsqs.xmm_ps);
147
+ __m128 rsqrt_f32x4 = __lsx_vfrsqrt_s(products_f32x4);
148
+ __m128 normalized_f32x4 = __lsx_vfmul_s(dots_f32x4, rsqrt_f32x4);
149
+ __m128 one_f32x4 = nk_xvreplgr2vr_s_128_(1.0f);
150
+ __m128 angular_f32x4 = __lsx_vfsub_s(one_f32x4, normalized_f32x4);
151
+ __m128 zero_f32x4 = (__m128)__lsx_vreplgr2vr_w(0);
152
+ results->xmm_ps = __lsx_vfmax_s(angular_f32x4, zero_f32x4);
153
+ }
154
+
155
+ /** @brief Euclidean from_dot: computes √(query_sumsq + target_sumsq − 2 × dot) for 4 pairs (LSX 128-bit f32). */
156
+ NK_INTERNAL void nk_euclidean_through_f32_from_dot_loongsonasx_(nk_b128_vec_t dots, nk_f32_t query_sumsq,
157
+ nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
158
+ __m128 dots_f32x4 = dots.xmm_ps;
159
+ __m128 query_sumsq_f32x4 = nk_xvreplgr2vr_s_128_(query_sumsq);
160
+ __m128 sum_sq_f32x4 = __lsx_vfadd_s(query_sumsq_f32x4, target_sumsqs.xmm_ps);
161
+ __m128 two_f32x4 = nk_xvreplgr2vr_s_128_(2.0f);
162
+ // dist_sq = sum_sq − 2 × dots = -(2 × dots − sum_sq)
163
+ __m128 dist_sq_f32x4 = __lsx_vfnmsub_s(two_f32x4, dots_f32x4, sum_sq_f32x4);
164
+ results->xmm_ps = nk_sqrt_f32x4_loongsonasx_(dist_sq_f32x4);
165
+ }
166
+
167
+ /** @brief Angular from_dot for native f64: 1 − dot / √(query_sumsq × target_sumsq) for 4 pairs (LASX 256-bit). */
168
+ NK_INTERNAL void nk_angular_through_f64_from_dot_loongsonasx_(nk_b256_vec_t dots, nk_f64_t query_sumsq,
169
+ nk_b256_vec_t target_sumsqs, nk_b256_vec_t *results) {
170
+ __m256d dots_f64x4 = dots.ymm_pd;
171
+ __m256d query_sumsq_f64x4 = nk_xvfreplgr2vr_d_(query_sumsq);
172
+ __m256d products_f64x4 = __lasx_xvfmul_d(query_sumsq_f64x4, target_sumsqs.ymm_pd);
173
+ __m256d sqrt_products_f64x4 = __lasx_xvfsqrt_d(products_f64x4);
174
+ __m256d normalized_f64x4 = __lasx_xvfdiv_d(dots_f64x4, sqrt_products_f64x4);
175
+ __m256d one_f64x4 = nk_xvfreplgr2vr_d_(1.0);
176
+ __m256d angular_f64x4 = __lasx_xvfsub_d(one_f64x4, normalized_f64x4);
177
+ __m256d zero_f64x4 = (__m256d)__lasx_xvreplgr2vr_d(0);
178
+ results->ymm_pd = __lasx_xvfmax_d(angular_f64x4, zero_f64x4);
179
+ }
180
+
181
+ /** @brief Euclidean from_dot for native f64: √(query_sumsq + target_sumsq − 2 × dot) for 4 pairs (LASX 256-bit). */
182
+ NK_INTERNAL void nk_euclidean_through_f64_from_dot_loongsonasx_(nk_b256_vec_t dots, nk_f64_t query_sumsq,
183
+ nk_b256_vec_t target_sumsqs, nk_b256_vec_t *results) {
184
+ __m256d dots_f64x4 = dots.ymm_pd;
185
+ __m256d query_sumsq_f64x4 = nk_xvfreplgr2vr_d_(query_sumsq);
186
+ __m256d sum_sq_f64x4 = __lasx_xvfadd_d(query_sumsq_f64x4, target_sumsqs.ymm_pd);
187
+ __m256d two_f64x4 = nk_xvfreplgr2vr_d_(2.0);
188
+ // dist_sq = sum_sq − 2 × dots = -(2 × dots − sum_sq)
189
+ __m256d dist_sq_f64x4 = __lasx_xvfnmsub_d(two_f64x4, dots_f64x4, sum_sq_f64x4);
190
+ __m256d zero_f64x4 = (__m256d)__lasx_xvreplgr2vr_d(0);
191
+ results->ymm_pd = __lasx_xvfsqrt_d(__lasx_xvfmax_d(dist_sq_f64x4, zero_f64x4));
192
+ }
193
+
194
+ /** @brief Angular from_dot for i32 accumulators: cast i32 → f32, rsqrt+NR, clamp. 4 pairs (LSX 128-bit). */
195
+ NK_INTERNAL void nk_angular_through_i32_from_dot_loongsonasx_(nk_b128_vec_t dots, nk_i32_t query_sumsq,
196
+ nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
197
+ __m128 dots_f32x4 = __lsx_vffint_s_w(dots.xmm);
198
+ __m128 query_sumsq_f32x4 = nk_xvreplgr2vr_s_128_((nk_f32_t)query_sumsq);
199
+ __m128 products_f32x4 = __lsx_vfmul_s(query_sumsq_f32x4, __lsx_vffint_s_w(target_sumsqs.xmm));
200
+ __m128 rsqrt_f32x4 = __lsx_vfrsqrt_s(products_f32x4);
201
+ __m128 normalized_f32x4 = __lsx_vfmul_s(dots_f32x4, rsqrt_f32x4);
202
+ __m128 one_f32x4 = nk_xvreplgr2vr_s_128_(1.0f);
203
+ __m128 angular_f32x4 = __lsx_vfsub_s(one_f32x4, normalized_f32x4);
204
+ __m128 zero_f32x4 = (__m128)__lsx_vreplgr2vr_w(0);
205
+ results->xmm_ps = __lsx_vfmax_s(angular_f32x4, zero_f32x4);
206
+ }
207
+
208
+ /** @brief Euclidean from_dot for i32 accumulators: cast i32 → f32, then √(a² + b² − 2ab). 4 pairs (LSX 128-bit). */
209
+ NK_INTERNAL void nk_euclidean_through_i32_from_dot_loongsonasx_(nk_b128_vec_t dots, nk_i32_t query_sumsq,
210
+ nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
211
+ __m128 dots_f32x4 = __lsx_vffint_s_w(dots.xmm);
212
+ __m128 query_sumsq_f32x4 = nk_xvreplgr2vr_s_128_((nk_f32_t)query_sumsq);
213
+ __m128 sum_sq_f32x4 = __lsx_vfadd_s(query_sumsq_f32x4, __lsx_vffint_s_w(target_sumsqs.xmm));
214
+ __m128 two_f32x4 = nk_xvreplgr2vr_s_128_(2.0f);
215
+ __m128 dist_sq_f32x4 = __lsx_vfnmsub_s(two_f32x4, dots_f32x4, sum_sq_f32x4);
216
+ results->xmm_ps = nk_sqrt_f32x4_loongsonasx_(dist_sq_f32x4);
217
+ }
218
+
219
+ /** @brief Angular from_dot for u32 accumulators: cast u32 → f32, rsqrt+NR, clamp. 4 pairs (LSX 128-bit). */
220
+ NK_INTERNAL void nk_angular_through_u32_from_dot_loongsonasx_(nk_b128_vec_t dots, nk_u32_t query_sumsq,
221
+ nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
222
+ __m128 dots_f32x4 = __lsx_vffint_s_w(dots.xmm);
223
+ __m128 query_sumsq_f32x4 = nk_xvreplgr2vr_s_128_((nk_f32_t)query_sumsq);
224
+ __m128 products_f32x4 = __lsx_vfmul_s(query_sumsq_f32x4, __lsx_vffint_s_w(target_sumsqs.xmm));
225
+ __m128 rsqrt_f32x4 = __lsx_vfrsqrt_s(products_f32x4);
226
+ __m128 normalized_f32x4 = __lsx_vfmul_s(dots_f32x4, rsqrt_f32x4);
227
+ __m128 one_f32x4 = nk_xvreplgr2vr_s_128_(1.0f);
228
+ __m128 angular_f32x4 = __lsx_vfsub_s(one_f32x4, normalized_f32x4);
229
+ __m128 zero_f32x4 = (__m128)__lsx_vreplgr2vr_w(0);
230
+ results->xmm_ps = __lsx_vfmax_s(angular_f32x4, zero_f32x4);
231
+ }
232
+
233
+ /** @brief Euclidean from_dot for u32 accumulators: cast u32 → f32, then √(a² + b² − 2ab). 4 pairs (LSX 128-bit). */
234
+ NK_INTERNAL void nk_euclidean_through_u32_from_dot_loongsonasx_(nk_b128_vec_t dots, nk_u32_t query_sumsq,
235
+ nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
236
+ __m128 dots_f32x4 = __lsx_vffint_s_w(dots.xmm);
237
+ __m128 query_sumsq_f32x4 = nk_xvreplgr2vr_s_128_((nk_f32_t)query_sumsq);
238
+ __m128 sum_sq_f32x4 = __lsx_vfadd_s(query_sumsq_f32x4, __lsx_vffint_s_w(target_sumsqs.xmm));
239
+ __m128 two_f32x4 = nk_xvreplgr2vr_s_128_(2.0f);
240
+ __m128 dist_sq_f32x4 = __lsx_vfnmsub_s(two_f32x4, dots_f32x4, sum_sq_f32x4);
241
+ results->xmm_ps = nk_sqrt_f32x4_loongsonasx_(dist_sq_f32x4);
242
+ }
243
+
244
+ #pragma endregion Vectorized From Dot Helpers
245
+
246
+ #if defined(__cplusplus)
247
+ } // extern "C"
248
+ #endif
249
+
250
+ #endif // NK_TARGET_LOONGSONASX
251
+ #endif // NK_TARGET_LOONGARCH_
252
+ #endif // NK_CAST_LOONGSONASX_H