numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -8,15 +8,15 @@
8
8
  *
9
9
  * @section spatial_svehalf_instructions ARM SVE+FP16 Instructions
10
10
  *
11
- * Intrinsic Instruction Latency Throughput
12
- * svld1_f16 LD1H (Z.H, P/Z, [Xn]) 4-6cy 2/cy
13
- * svsub_f16_x FSUB (Z.H, P/M, Z.H, Z.H) 3cy 2/cy
14
- * svmla_f16_x FMLA (Z.H, P/M, Z.H, Z.H) 4cy 2/cy
15
- * svaddv_f16 FADDV (H, P, Z.H) 6cy 1/cy
16
- * svdupq_n_f16 DUP (Z.H, #imm) 1cy 2/cy
17
- * svwhilelt_b16 WHILELT (P.H, Xn, Xm) 2cy 1/cy
18
- * svptrue_b16 PTRUE (P.H, pattern) 1cy 2/cy
19
- * svcnth CNTH (Xd) 1cy 2/cy
11
+ * Intrinsic Instruction V1
12
+ * svld1_f16 LD1H (Z.H, P/Z, [Xn]) 4-6cy @ 2p
13
+ * svsub_f16_x FSUB (Z.H, P/M, Z.H, Z.H) 3cy @ 2p
14
+ * svmla_f16_x FMLA (Z.H, P/M, Z.H, Z.H) 4cy @ 2p
15
+ * svaddv_f16 FADDV (H, P, Z.H) 6cy @ 1p
16
+ * svdupq_n_f16 DUP (Z.H, #imm) 1cy @ 2p
17
+ * svwhilelt_b16 WHILELT (P.H, Xn, Xm) 2cy @ 1p
18
+ * svptrue_b16 PTRUE (P.H, pattern) 1cy @ 2p
19
+ * svcnth CNTH (Xd) 1cy @ 2p
20
20
  *
21
21
  * SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
22
22
  * and Apple M4+ use 128-bit. Code using svcntb() adapts automatically, but wider vectors
@@ -52,14 +52,27 @@ NK_PUBLIC void nk_sqeuclidean_f16_svehalf(nk_f16_t const *a_enum, nk_f16_t const
52
52
  nk_f16_for_arm_simd_t const *a = (nk_f16_for_arm_simd_t const *)(a_enum);
53
53
  nk_f16_for_arm_simd_t const *b = (nk_f16_for_arm_simd_t const *)(b_enum);
54
54
  do {
55
- svbool_t predicate_f32x = svwhilelt_b32_u64(i, n);
56
- svfloat16_t a_f16x = svld1_f16(svwhilelt_b16_u64(i, n), a + i);
57
- svfloat16_t b_f16x = svld1_f16(svwhilelt_b16_u64(i, n), b + i);
58
- svfloat32_t a_f32x = svcvt_f32_f16_x(predicate_f32x, a_f16x);
59
- svfloat32_t b_f32x = svcvt_f32_f16_x(predicate_f32x, b_f16x);
60
- svfloat32_t diff_f32x = svsub_f32_x(predicate_f32x, a_f32x, b_f32x);
61
- d2_f32x = svmla_f32_x(predicate_f32x, d2_f32x, diff_f32x, diff_f32x);
62
- i += svcntw();
55
+ svbool_t predicate_b16x = svwhilelt_b16_u64(i, n);
56
+ svfloat16_t a_f16x = svld1_f16(predicate_b16x, a + i);
57
+ svfloat16_t b_f16x = svld1_f16(predicate_b16x, b + i);
58
+ nk_size_t remaining = n - i < svcnth() ? n - i : svcnth();
59
+
60
+ // SVE `svcvt_f32_f16_x` widens only even-indexed f16 elements (0, 2, 4, ...),
61
+ // so we need two passes: one on the original vector (even elements) and one on
62
+ // a vector shifted by one position via `svext` (odd elements become even).
63
+ svbool_t pred_even_b32x = svwhilelt_b32_u64(0u, (remaining + 1) / 2);
64
+ svfloat32_t a_even_f32x = svcvt_f32_f16_x(pred_even_b32x, a_f16x);
65
+ svfloat32_t b_even_f32x = svcvt_f32_f16_x(pred_even_b32x, b_f16x);
66
+ svfloat32_t diff_even_f32x = svsub_f32_x(pred_even_b32x, a_even_f32x, b_even_f32x);
67
+ d2_f32x = svmla_f32_m(pred_even_b32x, d2_f32x, diff_even_f32x, diff_even_f32x);
68
+
69
+ svbool_t pred_odd_b32x = svwhilelt_b32_u64(0u, remaining / 2);
70
+ svfloat32_t a_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(a_f16x, a_f16x, 1));
71
+ svfloat32_t b_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(b_f16x, b_f16x, 1));
72
+ svfloat32_t diff_odd_f32x = svsub_f32_x(pred_odd_b32x, a_odd_f32x, b_odd_f32x);
73
+ d2_f32x = svmla_f32_m(pred_odd_b32x, d2_f32x, diff_odd_f32x, diff_odd_f32x);
74
+
75
+ i += svcnth();
63
76
  } while (i < n);
64
77
  *result = svaddv_f32(svptrue_b32(), d2_f32x);
65
78
  }
@@ -77,15 +90,28 @@ NK_PUBLIC void nk_angular_f16_svehalf(nk_f16_t const *a_enum, nk_f16_t const *b_
77
90
  nk_f16_for_arm_simd_t const *a = (nk_f16_for_arm_simd_t const *)(a_enum);
78
91
  nk_f16_for_arm_simd_t const *b = (nk_f16_for_arm_simd_t const *)(b_enum);
79
92
  do {
80
- svbool_t predicate_f32x = svwhilelt_b32_u64(i, n);
81
- svfloat16_t a_f16x = svld1_f16(svwhilelt_b16_u64(i, n), a + i);
82
- svfloat16_t b_f16x = svld1_f16(svwhilelt_b16_u64(i, n), b + i);
83
- svfloat32_t a_f32x = svcvt_f32_f16_x(predicate_f32x, a_f16x);
84
- svfloat32_t b_f32x = svcvt_f32_f16_x(predicate_f32x, b_f16x);
85
- ab_f32x = svmla_f32_x(predicate_f32x, ab_f32x, a_f32x, b_f32x);
86
- a2_f32x = svmla_f32_x(predicate_f32x, a2_f32x, a_f32x, a_f32x);
87
- b2_f32x = svmla_f32_x(predicate_f32x, b2_f32x, b_f32x, b_f32x);
88
- i += svcntw();
93
+ svbool_t predicate_b16x = svwhilelt_b16_u64(i, n);
94
+ svfloat16_t a_f16x = svld1_f16(predicate_b16x, a + i);
95
+ svfloat16_t b_f16x = svld1_f16(predicate_b16x, b + i);
96
+ nk_size_t remaining = n - i < svcnth() ? n - i : svcnth();
97
+
98
+ // Even-indexed f16 elements (0, 2, 4, ...)
99
+ svbool_t pred_even_b32x = svwhilelt_b32_u64(0u, (remaining + 1) / 2);
100
+ svfloat32_t a_even_f32x = svcvt_f32_f16_x(pred_even_b32x, a_f16x);
101
+ svfloat32_t b_even_f32x = svcvt_f32_f16_x(pred_even_b32x, b_f16x);
102
+ ab_f32x = svmla_f32_m(pred_even_b32x, ab_f32x, a_even_f32x, b_even_f32x);
103
+ a2_f32x = svmla_f32_m(pred_even_b32x, a2_f32x, a_even_f32x, a_even_f32x);
104
+ b2_f32x = svmla_f32_m(pred_even_b32x, b2_f32x, b_even_f32x, b_even_f32x);
105
+
106
+ // Odd-indexed f16 elements (1, 3, 5, ...) via svext shift-by-1
107
+ svbool_t pred_odd_b32x = svwhilelt_b32_u64(0u, remaining / 2);
108
+ svfloat32_t a_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(a_f16x, a_f16x, 1));
109
+ svfloat32_t b_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(b_f16x, b_f16x, 1));
110
+ ab_f32x = svmla_f32_m(pred_odd_b32x, ab_f32x, a_odd_f32x, b_odd_f32x);
111
+ a2_f32x = svmla_f32_m(pred_odd_b32x, a2_f32x, a_odd_f32x, a_odd_f32x);
112
+ b2_f32x = svmla_f32_m(pred_odd_b32x, b2_f32x, b_odd_f32x, b_odd_f32x);
113
+
114
+ i += svcnth();
89
115
  } while (i < n);
90
116
 
91
117
  nk_f32_t ab_f32 = svaddv_f32(svptrue_b32(), ab_f32x);
@@ -0,0 +1,142 @@
1
+ /**
2
+ * @brief SIMD-accelerated Spatial Similarity Measures for SVE SDOT.
3
+ * @file include/numkong/spatial/svesdot.h
4
+ * @author Ash Vardanian
5
+ * @date April 3, 2026
6
+ *
7
+ * @sa include/numkong/spatial.h
8
+ *
9
+ * @section spatial_svesdot_instructions ARM SVE+DotProd Instructions
10
+ *
11
+ * Intrinsic Instruction V1
12
+ * svld1_s8 LD1B (Z.B, P/Z, [Xn]) 4-6cy @ 2p
13
+ * svld1_u8 LD1B (Z.B, P/Z, [Xn]) 4-6cy @ 2p
14
+ * svdot_s32 SDOT (Z.S, Z.B, Z.B) 3cy @ 2p
15
+ * svdot_u32 UDOT (Z.S, Z.B, Z.B) 3cy @ 2p
16
+ * svabd_s8_x SABD (Z.B, P/M, Z.B) 3cy @ 2p
17
+ * svabd_u8_x UABD (Z.B, P/M, Z.B) 3cy @ 2p
18
+ * svaddv_s32 SADDV (D, P, Z.S) 6cy @ 1p
19
+ * svaddv_u32 UADDV (D, P, Z.S) 6cy @ 1p
20
+ * svwhilelt_b8 WHILELT (P.B, Xn, Xm) 2cy @ 1p
21
+ * svcntb CNTB (Xd) 1cy @ 2p
22
+ *
23
+ * SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
24
+ * and Apple M4+ use 128-bit. Code using svcntb() adapts automatically, but wider vectors
25
+ * process more elements per iteration with identical latencies.
26
+ *
27
+ * For L2 distance, SABD/UABD computes |a-b| per byte, then UDOT squares and accumulates.
28
+ * Angular distance uses SDOT/UDOT directly for dot product and norm computations.
29
+ */
30
+ #ifndef NK_SPATIAL_SVESDOT_H
31
+ #define NK_SPATIAL_SVESDOT_H
32
+
33
+ #if NK_TARGET_ARM_
34
+ #if NK_TARGET_SVESDOT
35
+
36
+ #include "numkong/types.h"
37
+ #include "numkong/spatial/neon.h" // `nk_angular_normalize_f32_neon_`, `nk_f32_sqrt_neon`
38
+
39
+ #if defined(__cplusplus)
40
+ extern "C" {
41
+ #endif
42
+
43
+ #if defined(__clang__)
44
+ #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve+dotprod"))), apply_to = function)
45
+ #elif defined(__GNUC__)
46
+ #pragma GCC push_options
47
+ #pragma GCC target("arch=armv8.2-a+sve+dotprod")
48
+ #endif
49
+
50
+ NK_PUBLIC void nk_sqeuclidean_i8_svesdot(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result) {
51
+ nk_size_t i = 0;
52
+ svuint32_t distance_sq_u32x = svdup_u32(0);
53
+ do {
54
+ svbool_t predicate_b8x = svwhilelt_b8_u64(i, n);
55
+ svint8_t a_i8x = svld1_s8(predicate_b8x, a + i);
56
+ svint8_t b_i8x = svld1_s8(predicate_b8x, b + i);
57
+ svuint8_t diff_u8x = svreinterpret_u8_s8(svabd_s8_x(predicate_b8x, a_i8x, b_i8x));
58
+ distance_sq_u32x = svdot_u32(distance_sq_u32x, diff_u8x, diff_u8x);
59
+ i += svcntb();
60
+ } while (i < n);
61
+ *result = (nk_u32_t)svaddv_u32(svptrue_b32(), distance_sq_u32x);
62
+ }
63
+ NK_PUBLIC void nk_euclidean_i8_svesdot(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
64
+ nk_u32_t distance_sq_u32;
65
+ nk_sqeuclidean_i8_svesdot(a, b, n, &distance_sq_u32);
66
+ *result = nk_f32_sqrt_neon((nk_f32_t)distance_sq_u32);
67
+ }
68
+
69
+ NK_PUBLIC void nk_angular_i8_svesdot(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
70
+ nk_size_t i = 0;
71
+ svint32_t ab_i32x = svdup_s32(0);
72
+ svint32_t a2_i32x = svdup_s32(0);
73
+ svint32_t b2_i32x = svdup_s32(0);
74
+ do {
75
+ svbool_t predicate_b8x = svwhilelt_b8_u64(i, n);
76
+ svint8_t a_i8x = svld1_s8(predicate_b8x, a + i);
77
+ svint8_t b_i8x = svld1_s8(predicate_b8x, b + i);
78
+ ab_i32x = svdot_s32(ab_i32x, a_i8x, b_i8x);
79
+ a2_i32x = svdot_s32(a2_i32x, a_i8x, a_i8x);
80
+ b2_i32x = svdot_s32(b2_i32x, b_i8x, b_i8x);
81
+ i += svcntb();
82
+ } while (i < n);
83
+
84
+ nk_i32_t ab = (nk_i32_t)svaddv_s32(svptrue_b32(), ab_i32x);
85
+ nk_i32_t a2 = (nk_i32_t)svaddv_s32(svptrue_b32(), a2_i32x);
86
+ nk_i32_t b2 = (nk_i32_t)svaddv_s32(svptrue_b32(), b2_i32x);
87
+ *result = nk_angular_normalize_f32_neon_((nk_f32_t)ab, (nk_f32_t)a2, (nk_f32_t)b2);
88
+ }
89
+
90
+ NK_PUBLIC void nk_sqeuclidean_u8_svesdot(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
91
+ nk_size_t i = 0;
92
+ svuint32_t distance_sq_u32x = svdup_u32(0);
93
+ do {
94
+ svbool_t predicate_b8x = svwhilelt_b8_u64(i, n);
95
+ svuint8_t a_u8x = svld1_u8(predicate_b8x, a + i);
96
+ svuint8_t b_u8x = svld1_u8(predicate_b8x, b + i);
97
+ svuint8_t diff_u8x = svabd_u8_x(predicate_b8x, a_u8x, b_u8x);
98
+ distance_sq_u32x = svdot_u32(distance_sq_u32x, diff_u8x, diff_u8x);
99
+ i += svcntb();
100
+ } while (i < n);
101
+ *result = (nk_u32_t)svaddv_u32(svptrue_b32(), distance_sq_u32x);
102
+ }
103
+ NK_PUBLIC void nk_euclidean_u8_svesdot(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
104
+ nk_u32_t distance_sq_u32;
105
+ nk_sqeuclidean_u8_svesdot(a, b, n, &distance_sq_u32);
106
+ *result = nk_f32_sqrt_neon((nk_f32_t)distance_sq_u32);
107
+ }
108
+
109
+ NK_PUBLIC void nk_angular_u8_svesdot(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
110
+ nk_size_t i = 0;
111
+ svuint32_t ab_u32x = svdup_u32(0);
112
+ svuint32_t a2_u32x = svdup_u32(0);
113
+ svuint32_t b2_u32x = svdup_u32(0);
114
+ do {
115
+ svbool_t predicate_b8x = svwhilelt_b8_u64(i, n);
116
+ svuint8_t a_u8x = svld1_u8(predicate_b8x, a + i);
117
+ svuint8_t b_u8x = svld1_u8(predicate_b8x, b + i);
118
+ ab_u32x = svdot_u32(ab_u32x, a_u8x, b_u8x);
119
+ a2_u32x = svdot_u32(a2_u32x, a_u8x, a_u8x);
120
+ b2_u32x = svdot_u32(b2_u32x, b_u8x, b_u8x);
121
+ i += svcntb();
122
+ } while (i < n);
123
+
124
+ nk_u32_t ab = (nk_u32_t)svaddv_u32(svptrue_b32(), ab_u32x);
125
+ nk_u32_t a2 = (nk_u32_t)svaddv_u32(svptrue_b32(), a2_u32x);
126
+ nk_u32_t b2 = (nk_u32_t)svaddv_u32(svptrue_b32(), b2_u32x);
127
+ *result = nk_angular_normalize_f32_neon_((nk_f32_t)ab, (nk_f32_t)a2, (nk_f32_t)b2);
128
+ }
129
+
130
+ #if defined(__clang__)
131
+ #pragma clang attribute pop
132
+ #elif defined(__GNUC__)
133
+ #pragma GCC pop_options
134
+ #endif
135
+
136
+ #if defined(__cplusplus)
137
+ } // extern "C"
138
+ #endif
139
+
140
+ #endif // NK_TARGET_SVESDOT
141
+ #endif // NK_TARGET_ARM_
142
+ #endif // NK_SPATIAL_SVESDOT_H
@@ -64,7 +64,7 @@ NK_INTERNAL nk_f64_t nk_angular_normalize_f64_v128relaxed_(nk_f64_t ab, nk_f64_t
64
64
  return result > 0.0 ? result : 0.0;
65
65
  }
66
66
 
67
- #pragma region - Traditional Floats
67
+ #pragma region F32 and F64 Floats
68
68
 
69
69
  NK_PUBLIC void nk_sqeuclidean_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
70
70
  v128_t sum_f64x2 = wasm_f64x2_splat(0.0);
@@ -83,8 +83,8 @@ nk_sqeuclidean_f32_v128relaxed_cycle:
83
83
  nk_load_b64_serial_(b_scalars, &b_f32_vec);
84
84
  a_scalars += 2, b_scalars += 2, count_scalars -= 2;
85
85
  }
86
- v128_t a_f32x2 = wasm_v128_load64_zero(&a_f32_vec.u64);
87
- v128_t b_f32x2 = wasm_v128_load64_zero(&b_f32_vec.u64);
86
+ v128_t a_f32x2 = wasm_i64x2_splat(a_f32_vec.u64);
87
+ v128_t b_f32x2 = wasm_i64x2_splat(b_f32_vec.u64);
88
88
  v128_t a_f64x2 = wasm_f64x2_promote_low_f32x4(a_f32x2);
89
89
  v128_t b_f64x2 = wasm_f64x2_promote_low_f32x4(b_f32x2);
90
90
  v128_t diff_f64x2 = wasm_f64x2_sub(a_f64x2, b_f64x2);
@@ -152,8 +152,8 @@ nk_angular_f32_v128relaxed_cycle:
152
152
  }
153
153
 
154
154
  // Upcast F32x2 → F64x2 for high-precision accumulation
155
- v128_t a_f32x2 = wasm_v128_load64_zero(&a_f32_vec.u64);
156
- v128_t b_f32x2 = wasm_v128_load64_zero(&b_f32_vec.u64);
155
+ v128_t a_f32x2 = wasm_i64x2_splat(a_f32_vec.u64);
156
+ v128_t b_f32x2 = wasm_i64x2_splat(b_f32_vec.u64);
157
157
  v128_t a_f64x2 = wasm_f64x2_promote_low_f32x4(a_f32x2);
158
158
  v128_t b_f64x2 = wasm_f64x2_promote_low_f32x4(b_f32x2);
159
159
 
@@ -203,8 +203,8 @@ nk_angular_f64_v128relaxed_cycle:
203
203
  *result = nk_angular_normalize_f64_v128relaxed_(ab, a2, b2);
204
204
  }
205
205
 
206
- #pragma endregion - Traditional Floats
207
- #pragma region - Smaller Floats
206
+ #pragma endregion F32 and F64 Floats
207
+ #pragma region F16 and BF16 Floats
208
208
 
209
209
  NK_PUBLIC void nk_sqeuclidean_f16_v128relaxed(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
210
210
  v128_t sum_f32x4 = wasm_f32x4_splat(0.0f);
@@ -286,31 +286,30 @@ nk_angular_f16_v128relaxed_cycle:
286
286
 
287
287
  NK_PUBLIC void nk_sqeuclidean_bf16_v128relaxed(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
288
288
  v128_t sum_f32x4 = wasm_f32x4_splat(0.0f);
289
+ v128_t mask_high_u32x4 = wasm_i32x4_splat((int)0xFFFF0000);
289
290
  nk_bf16_t const *a_scalars = a, *b_scalars = b;
290
291
  nk_size_t count_scalars = n;
291
- nk_b64_vec_t a_bf16_vec, b_bf16_vec;
292
+ nk_b128_vec_t a_bf16_vec, b_bf16_vec;
292
293
 
293
294
  nk_sqeuclidean_bf16_v128relaxed_cycle:
294
- // Tail or full load
295
- if (count_scalars < 4) {
296
- nk_partial_load_b16x4_serial_(a_scalars, &a_bf16_vec, count_scalars);
297
- nk_partial_load_b16x4_serial_(b_scalars, &b_bf16_vec, count_scalars);
295
+ if (count_scalars < 8) {
296
+ nk_partial_load_b16x8_serial_(a_scalars, &a_bf16_vec, count_scalars);
297
+ nk_partial_load_b16x8_serial_(b_scalars, &b_bf16_vec, count_scalars);
298
298
  count_scalars = 0;
299
299
  }
300
300
  else {
301
- nk_load_b64_serial_(a_scalars, &a_bf16_vec);
302
- nk_load_b64_serial_(b_scalars, &b_bf16_vec);
303
- a_scalars += 4, b_scalars += 4, count_scalars -= 4;
301
+ nk_load_b128_v128relaxed_(a_scalars, &a_bf16_vec);
302
+ nk_load_b128_v128relaxed_(b_scalars, &b_bf16_vec);
303
+ a_scalars += 8, b_scalars += 8, count_scalars -= 8;
304
304
  }
305
-
306
- // Convert bf16 → f32 (4 elements)
307
- nk_b128_vec_t a_f32_vec = nk_bf16x4_to_f32x4_v128relaxed_(a_bf16_vec);
308
- nk_b128_vec_t b_f32_vec = nk_bf16x4_to_f32x4_v128relaxed_(b_bf16_vec);
309
-
310
- // Accumulate (a - b)²
311
- v128_t diff_f32x4 = wasm_f32x4_sub(a_f32_vec.v128, b_f32_vec.v128);
312
- sum_f32x4 = wasm_f32x4_relaxed_madd(diff_f32x4, diff_f32x4, sum_f32x4);
313
-
305
+ v128_t a_even_f32x4 = wasm_i32x4_shl(a_bf16_vec.v128, 16);
306
+ v128_t b_even_f32x4 = wasm_i32x4_shl(b_bf16_vec.v128, 16);
307
+ v128_t diff_even_f32x4 = wasm_f32x4_sub(a_even_f32x4, b_even_f32x4);
308
+ sum_f32x4 = wasm_f32x4_relaxed_madd(diff_even_f32x4, diff_even_f32x4, sum_f32x4);
309
+ v128_t a_odd_f32x4 = wasm_v128_and(a_bf16_vec.v128, mask_high_u32x4);
310
+ v128_t b_odd_f32x4 = wasm_v128_and(b_bf16_vec.v128, mask_high_u32x4);
311
+ v128_t diff_odd_f32x4 = wasm_f32x4_sub(a_odd_f32x4, b_odd_f32x4);
312
+ sum_f32x4 = wasm_f32x4_relaxed_madd(diff_odd_f32x4, diff_odd_f32x4, sum_f32x4);
314
313
  if (count_scalars) goto nk_sqeuclidean_bf16_v128relaxed_cycle;
315
314
 
316
315
  *result = nk_reduce_add_f32x4_v128relaxed_(sum_f32x4);
@@ -326,44 +325,297 @@ NK_PUBLIC void nk_angular_bf16_v128relaxed(nk_bf16_t const *a, nk_bf16_t const *
326
325
  v128_t ab_f32x4 = wasm_f32x4_splat(0.0f);
327
326
  v128_t a2_f32x4 = wasm_f32x4_splat(0.0f);
328
327
  v128_t b2_f32x4 = wasm_f32x4_splat(0.0f);
328
+ v128_t mask_high_u32x4 = wasm_i32x4_splat((int)0xFFFF0000);
329
329
  nk_bf16_t const *a_scalars = a, *b_scalars = b;
330
330
  nk_size_t count_scalars = n;
331
- nk_b64_vec_t a_bf16_vec, b_bf16_vec;
331
+ nk_b128_vec_t a_bf16_vec, b_bf16_vec;
332
332
 
333
333
  nk_angular_bf16_v128relaxed_cycle:
334
+ if (count_scalars < 8) {
335
+ nk_partial_load_b16x8_serial_(a_scalars, &a_bf16_vec, count_scalars);
336
+ nk_partial_load_b16x8_serial_(b_scalars, &b_bf16_vec, count_scalars);
337
+ count_scalars = 0;
338
+ }
339
+ else {
340
+ nk_load_b128_v128relaxed_(a_scalars, &a_bf16_vec);
341
+ nk_load_b128_v128relaxed_(b_scalars, &b_bf16_vec);
342
+ a_scalars += 8, b_scalars += 8, count_scalars -= 8;
343
+ }
344
+ v128_t a_even_f32x4 = wasm_i32x4_shl(a_bf16_vec.v128, 16);
345
+ v128_t b_even_f32x4 = wasm_i32x4_shl(b_bf16_vec.v128, 16);
346
+ ab_f32x4 = wasm_f32x4_relaxed_madd(a_even_f32x4, b_even_f32x4, ab_f32x4);
347
+ a2_f32x4 = wasm_f32x4_relaxed_madd(a_even_f32x4, a_even_f32x4, a2_f32x4);
348
+ b2_f32x4 = wasm_f32x4_relaxed_madd(b_even_f32x4, b_even_f32x4, b2_f32x4);
349
+ v128_t a_odd_f32x4 = wasm_v128_and(a_bf16_vec.v128, mask_high_u32x4);
350
+ v128_t b_odd_f32x4 = wasm_v128_and(b_bf16_vec.v128, mask_high_u32x4);
351
+ ab_f32x4 = wasm_f32x4_relaxed_madd(a_odd_f32x4, b_odd_f32x4, ab_f32x4);
352
+ a2_f32x4 = wasm_f32x4_relaxed_madd(a_odd_f32x4, a_odd_f32x4, a2_f32x4);
353
+ b2_f32x4 = wasm_f32x4_relaxed_madd(b_odd_f32x4, b_odd_f32x4, b2_f32x4);
354
+ if (count_scalars) goto nk_angular_bf16_v128relaxed_cycle;
355
+
356
+ nk_f32_t ab = nk_reduce_add_f32x4_v128relaxed_(ab_f32x4);
357
+ nk_f32_t a2 = nk_reduce_add_f32x4_v128relaxed_(a2_f32x4);
358
+ nk_f32_t b2 = nk_reduce_add_f32x4_v128relaxed_(b2_f32x4);
359
+ *result = (nk_f32_t)nk_angular_normalize_f64_v128relaxed_((nk_f64_t)ab, (nk_f64_t)a2, (nk_f64_t)b2);
360
+ }
361
+
362
+ #pragma endregion F16 and BF16 Floats
363
+ #pragma region FP8 Floats
364
+
365
+ NK_PUBLIC void nk_sqeuclidean_e4m3_v128relaxed(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
366
+ v128_t sum_f32x4 = wasm_f32x4_splat(0.0f);
367
+ nk_e4m3_t const *a_scalars = a, *b_scalars = b;
368
+ nk_size_t count_scalars = n;
369
+ nk_b32_vec_t a_raw, b_raw;
370
+
371
+ nk_sqeuclidean_e4m3_v128relaxed_cycle:
334
372
  if (count_scalars < 4) {
335
- nk_partial_load_b16x4_serial_(a_scalars, &a_bf16_vec, count_scalars);
336
- nk_partial_load_b16x4_serial_(b_scalars, &b_bf16_vec, count_scalars);
373
+ a_raw = nk_partial_load_b8x4_serial_(a_scalars, count_scalars);
374
+ b_raw = nk_partial_load_b8x4_serial_(b_scalars, count_scalars);
337
375
  count_scalars = 0;
338
376
  }
339
377
  else {
340
- nk_load_b64_serial_(a_scalars, &a_bf16_vec);
341
- nk_load_b64_serial_(b_scalars, &b_bf16_vec);
378
+ nk_load_b32_serial_(a_scalars, &a_raw);
379
+ nk_load_b32_serial_(b_scalars, &b_raw);
342
380
  a_scalars += 4, b_scalars += 4, count_scalars -= 4;
343
381
  }
382
+ nk_b128_vec_t a_f32_vec = nk_e4m3x4_to_f32x4_v128relaxed_(a_raw);
383
+ nk_b128_vec_t b_f32_vec = nk_e4m3x4_to_f32x4_v128relaxed_(b_raw);
384
+ v128_t diff_f32x4 = wasm_f32x4_sub(a_f32_vec.v128, b_f32_vec.v128);
385
+ sum_f32x4 = wasm_f32x4_relaxed_madd(diff_f32x4, diff_f32x4, sum_f32x4);
386
+ if (count_scalars) goto nk_sqeuclidean_e4m3_v128relaxed_cycle;
344
387
 
345
- // Convert bf16 → f32
346
- nk_b128_vec_t a_f32_vec = nk_bf16x4_to_f32x4_v128relaxed_(a_bf16_vec);
347
- nk_b128_vec_t b_f32_vec = nk_bf16x4_to_f32x4_v128relaxed_(b_bf16_vec);
388
+ *result = nk_reduce_add_f32x4_v128relaxed_(sum_f32x4);
389
+ }
348
390
 
349
- // Triple accumulation: ab, a², b²
391
+ NK_PUBLIC void nk_euclidean_e4m3_v128relaxed(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
392
+ nk_sqeuclidean_e4m3_v128relaxed(a, b, n, result);
393
+ *result = nk_f32_sqrt_v128relaxed(*result);
394
+ }
395
+
396
+ NK_PUBLIC void nk_angular_e4m3_v128relaxed(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
397
+ v128_t ab_f32x4 = wasm_f32x4_splat(0.0f);
398
+ v128_t a2_f32x4 = wasm_f32x4_splat(0.0f);
399
+ v128_t b2_f32x4 = wasm_f32x4_splat(0.0f);
400
+ nk_e4m3_t const *a_scalars = a, *b_scalars = b;
401
+ nk_size_t count_scalars = n;
402
+ nk_b32_vec_t a_raw, b_raw;
403
+
404
+ nk_angular_e4m3_v128relaxed_cycle:
405
+ if (count_scalars < 4) {
406
+ a_raw = nk_partial_load_b8x4_serial_(a_scalars, count_scalars);
407
+ b_raw = nk_partial_load_b8x4_serial_(b_scalars, count_scalars);
408
+ count_scalars = 0;
409
+ }
410
+ else {
411
+ nk_load_b32_serial_(a_scalars, &a_raw);
412
+ nk_load_b32_serial_(b_scalars, &b_raw);
413
+ a_scalars += 4, b_scalars += 4, count_scalars -= 4;
414
+ }
415
+ nk_b128_vec_t a_f32_vec = nk_e4m3x4_to_f32x4_v128relaxed_(a_raw);
416
+ nk_b128_vec_t b_f32_vec = nk_e4m3x4_to_f32x4_v128relaxed_(b_raw);
350
417
  ab_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, b_f32_vec.v128, ab_f32x4);
351
418
  a2_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, a_f32_vec.v128, a2_f32x4);
352
419
  b2_f32x4 = wasm_f32x4_relaxed_madd(b_f32_vec.v128, b_f32_vec.v128, b2_f32x4);
420
+ if (count_scalars) goto nk_angular_e4m3_v128relaxed_cycle;
353
421
 
354
- if (count_scalars) goto nk_angular_bf16_v128relaxed_cycle;
422
+ nk_f32_t ab = nk_reduce_add_f32x4_v128relaxed_(ab_f32x4);
423
+ nk_f32_t a2 = nk_reduce_add_f32x4_v128relaxed_(a2_f32x4);
424
+ nk_f32_t b2 = nk_reduce_add_f32x4_v128relaxed_(b2_f32x4);
425
+ *result = (nk_f32_t)nk_angular_normalize_f64_v128relaxed_((nk_f64_t)ab, (nk_f64_t)a2, (nk_f64_t)b2);
426
+ }
427
+
428
+ NK_PUBLIC void nk_sqeuclidean_e5m2_v128relaxed(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
429
+ v128_t sum_f32x4 = wasm_f32x4_splat(0.0f);
430
+ nk_e5m2_t const *a_scalars = a, *b_scalars = b;
431
+ nk_size_t count_scalars = n;
432
+ nk_b32_vec_t a_raw, b_raw;
433
+
434
+ nk_sqeuclidean_e5m2_v128relaxed_cycle:
435
+ if (count_scalars < 4) {
436
+ a_raw = nk_partial_load_b8x4_serial_(a_scalars, count_scalars);
437
+ b_raw = nk_partial_load_b8x4_serial_(b_scalars, count_scalars);
438
+ count_scalars = 0;
439
+ }
440
+ else {
441
+ nk_load_b32_serial_(a_scalars, &a_raw);
442
+ nk_load_b32_serial_(b_scalars, &b_raw);
443
+ a_scalars += 4, b_scalars += 4, count_scalars -= 4;
444
+ }
445
+ nk_b128_vec_t a_f32_vec = nk_e5m2x4_to_f32x4_v128relaxed_(a_raw);
446
+ nk_b128_vec_t b_f32_vec = nk_e5m2x4_to_f32x4_v128relaxed_(b_raw);
447
+ v128_t diff_f32x4 = wasm_f32x4_sub(a_f32_vec.v128, b_f32_vec.v128);
448
+ sum_f32x4 = wasm_f32x4_relaxed_madd(diff_f32x4, diff_f32x4, sum_f32x4);
449
+ if (count_scalars) goto nk_sqeuclidean_e5m2_v128relaxed_cycle;
450
+
451
+ *result = nk_reduce_add_f32x4_v128relaxed_(sum_f32x4);
452
+ }
453
+
454
+ NK_PUBLIC void nk_euclidean_e5m2_v128relaxed(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
455
+ nk_sqeuclidean_e5m2_v128relaxed(a, b, n, result);
456
+ *result = nk_f32_sqrt_v128relaxed(*result);
457
+ }
458
+
459
+ NK_PUBLIC void nk_angular_e5m2_v128relaxed(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
460
+ v128_t ab_f32x4 = wasm_f32x4_splat(0.0f);
461
+ v128_t a2_f32x4 = wasm_f32x4_splat(0.0f);
462
+ v128_t b2_f32x4 = wasm_f32x4_splat(0.0f);
463
+ nk_e5m2_t const *a_scalars = a, *b_scalars = b;
464
+ nk_size_t count_scalars = n;
465
+ nk_b32_vec_t a_raw, b_raw;
466
+
467
+ nk_angular_e5m2_v128relaxed_cycle:
468
+ if (count_scalars < 4) {
469
+ a_raw = nk_partial_load_b8x4_serial_(a_scalars, count_scalars);
470
+ b_raw = nk_partial_load_b8x4_serial_(b_scalars, count_scalars);
471
+ count_scalars = 0;
472
+ }
473
+ else {
474
+ nk_load_b32_serial_(a_scalars, &a_raw);
475
+ nk_load_b32_serial_(b_scalars, &b_raw);
476
+ a_scalars += 4, b_scalars += 4, count_scalars -= 4;
477
+ }
478
+ nk_b128_vec_t a_f32_vec = nk_e5m2x4_to_f32x4_v128relaxed_(a_raw);
479
+ nk_b128_vec_t b_f32_vec = nk_e5m2x4_to_f32x4_v128relaxed_(b_raw);
480
+ ab_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, b_f32_vec.v128, ab_f32x4);
481
+ a2_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, a_f32_vec.v128, a2_f32x4);
482
+ b2_f32x4 = wasm_f32x4_relaxed_madd(b_f32_vec.v128, b_f32_vec.v128, b2_f32x4);
483
+ if (count_scalars) goto nk_angular_e5m2_v128relaxed_cycle;
355
484
 
356
- // Reduce accumulators
357
485
  nk_f32_t ab = nk_reduce_add_f32x4_v128relaxed_(ab_f32x4);
358
486
  nk_f32_t a2 = nk_reduce_add_f32x4_v128relaxed_(a2_f32x4);
359
487
  nk_f32_t b2 = nk_reduce_add_f32x4_v128relaxed_(b2_f32x4);
488
+ *result = (nk_f32_t)nk_angular_normalize_f64_v128relaxed_((nk_f64_t)ab, (nk_f64_t)a2, (nk_f64_t)b2);
489
+ }
360
490
 
361
- // Normalize using f64 helper (handles edge cases: zero vectors, perpendicular, clamping)
491
+ NK_PUBLIC void nk_sqeuclidean_e2m3_v128relaxed(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
492
+ v128_t sum_f32x4 = wasm_f32x4_splat(0.0f);
493
+ nk_e2m3_t const *a_scalars = a, *b_scalars = b;
494
+ nk_size_t count_scalars = n;
495
+ nk_b32_vec_t a_raw, b_raw;
496
+
497
+ nk_sqeuclidean_e2m3_v128relaxed_cycle:
498
+ if (count_scalars < 4) {
499
+ a_raw = nk_partial_load_b8x4_serial_(a_scalars, count_scalars);
500
+ b_raw = nk_partial_load_b8x4_serial_(b_scalars, count_scalars);
501
+ count_scalars = 0;
502
+ }
503
+ else {
504
+ nk_load_b32_serial_(a_scalars, &a_raw);
505
+ nk_load_b32_serial_(b_scalars, &b_raw);
506
+ a_scalars += 4, b_scalars += 4, count_scalars -= 4;
507
+ }
508
+ nk_b128_vec_t a_f32_vec = nk_e2m3x4_to_f32x4_v128relaxed_(a_raw);
509
+ nk_b128_vec_t b_f32_vec = nk_e2m3x4_to_f32x4_v128relaxed_(b_raw);
510
+ v128_t diff_f32x4 = wasm_f32x4_sub(a_f32_vec.v128, b_f32_vec.v128);
511
+ sum_f32x4 = wasm_f32x4_relaxed_madd(diff_f32x4, diff_f32x4, sum_f32x4);
512
+ if (count_scalars) goto nk_sqeuclidean_e2m3_v128relaxed_cycle;
513
+
514
+ *result = nk_reduce_add_f32x4_v128relaxed_(sum_f32x4);
515
+ }
516
+
517
+ NK_PUBLIC void nk_euclidean_e2m3_v128relaxed(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
518
+ nk_sqeuclidean_e2m3_v128relaxed(a, b, n, result);
519
+ *result = nk_f32_sqrt_v128relaxed(*result);
520
+ }
521
+
522
+ NK_PUBLIC void nk_angular_e2m3_v128relaxed(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
523
+ v128_t ab_f32x4 = wasm_f32x4_splat(0.0f);
524
+ v128_t a2_f32x4 = wasm_f32x4_splat(0.0f);
525
+ v128_t b2_f32x4 = wasm_f32x4_splat(0.0f);
526
+ nk_e2m3_t const *a_scalars = a, *b_scalars = b;
527
+ nk_size_t count_scalars = n;
528
+ nk_b32_vec_t a_raw, b_raw;
529
+
530
+ nk_angular_e2m3_v128relaxed_cycle:
531
+ if (count_scalars < 4) {
532
+ a_raw = nk_partial_load_b8x4_serial_(a_scalars, count_scalars);
533
+ b_raw = nk_partial_load_b8x4_serial_(b_scalars, count_scalars);
534
+ count_scalars = 0;
535
+ }
536
+ else {
537
+ nk_load_b32_serial_(a_scalars, &a_raw);
538
+ nk_load_b32_serial_(b_scalars, &b_raw);
539
+ a_scalars += 4, b_scalars += 4, count_scalars -= 4;
540
+ }
541
+ nk_b128_vec_t a_f32_vec = nk_e2m3x4_to_f32x4_v128relaxed_(a_raw);
542
+ nk_b128_vec_t b_f32_vec = nk_e2m3x4_to_f32x4_v128relaxed_(b_raw);
543
+ ab_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, b_f32_vec.v128, ab_f32x4);
544
+ a2_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, a_f32_vec.v128, a2_f32x4);
545
+ b2_f32x4 = wasm_f32x4_relaxed_madd(b_f32_vec.v128, b_f32_vec.v128, b2_f32x4);
546
+ if (count_scalars) goto nk_angular_e2m3_v128relaxed_cycle;
547
+
548
+ nk_f32_t ab = nk_reduce_add_f32x4_v128relaxed_(ab_f32x4);
549
+ nk_f32_t a2 = nk_reduce_add_f32x4_v128relaxed_(a2_f32x4);
550
+ nk_f32_t b2 = nk_reduce_add_f32x4_v128relaxed_(b2_f32x4);
551
+ *result = (nk_f32_t)nk_angular_normalize_f64_v128relaxed_((nk_f64_t)ab, (nk_f64_t)a2, (nk_f64_t)b2);
552
+ }
553
+
554
+ NK_PUBLIC void nk_sqeuclidean_e3m2_v128relaxed(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
555
+ v128_t sum_f32x4 = wasm_f32x4_splat(0.0f);
556
+ nk_e3m2_t const *a_scalars = a, *b_scalars = b;
557
+ nk_size_t count_scalars = n;
558
+ nk_b32_vec_t a_raw, b_raw;
559
+
560
+ nk_sqeuclidean_e3m2_v128relaxed_cycle:
561
+ if (count_scalars < 4) {
562
+ a_raw = nk_partial_load_b8x4_serial_(a_scalars, count_scalars);
563
+ b_raw = nk_partial_load_b8x4_serial_(b_scalars, count_scalars);
564
+ count_scalars = 0;
565
+ }
566
+ else {
567
+ nk_load_b32_serial_(a_scalars, &a_raw);
568
+ nk_load_b32_serial_(b_scalars, &b_raw);
569
+ a_scalars += 4, b_scalars += 4, count_scalars -= 4;
570
+ }
571
+ nk_b128_vec_t a_f32_vec = nk_e3m2x4_to_f32x4_v128relaxed_(a_raw);
572
+ nk_b128_vec_t b_f32_vec = nk_e3m2x4_to_f32x4_v128relaxed_(b_raw);
573
+ v128_t diff_f32x4 = wasm_f32x4_sub(a_f32_vec.v128, b_f32_vec.v128);
574
+ sum_f32x4 = wasm_f32x4_relaxed_madd(diff_f32x4, diff_f32x4, sum_f32x4);
575
+ if (count_scalars) goto nk_sqeuclidean_e3m2_v128relaxed_cycle;
576
+
577
+ *result = nk_reduce_add_f32x4_v128relaxed_(sum_f32x4);
578
+ }
579
+
580
+ NK_PUBLIC void nk_euclidean_e3m2_v128relaxed(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
581
+ nk_sqeuclidean_e3m2_v128relaxed(a, b, n, result);
582
+ *result = nk_f32_sqrt_v128relaxed(*result);
583
+ }
584
+
585
+ NK_PUBLIC void nk_angular_e3m2_v128relaxed(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
586
+ v128_t ab_f32x4 = wasm_f32x4_splat(0.0f);
587
+ v128_t a2_f32x4 = wasm_f32x4_splat(0.0f);
588
+ v128_t b2_f32x4 = wasm_f32x4_splat(0.0f);
589
+ nk_e3m2_t const *a_scalars = a, *b_scalars = b;
590
+ nk_size_t count_scalars = n;
591
+ nk_b32_vec_t a_raw, b_raw;
592
+
593
+ nk_angular_e3m2_v128relaxed_cycle:
594
+ if (count_scalars < 4) {
595
+ a_raw = nk_partial_load_b8x4_serial_(a_scalars, count_scalars);
596
+ b_raw = nk_partial_load_b8x4_serial_(b_scalars, count_scalars);
597
+ count_scalars = 0;
598
+ }
599
+ else {
600
+ nk_load_b32_serial_(a_scalars, &a_raw);
601
+ nk_load_b32_serial_(b_scalars, &b_raw);
602
+ a_scalars += 4, b_scalars += 4, count_scalars -= 4;
603
+ }
604
+ nk_b128_vec_t a_f32_vec = nk_e3m2x4_to_f32x4_v128relaxed_(a_raw);
605
+ nk_b128_vec_t b_f32_vec = nk_e3m2x4_to_f32x4_v128relaxed_(b_raw);
606
+ ab_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, b_f32_vec.v128, ab_f32x4);
607
+ a2_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, a_f32_vec.v128, a2_f32x4);
608
+ b2_f32x4 = wasm_f32x4_relaxed_madd(b_f32_vec.v128, b_f32_vec.v128, b2_f32x4);
609
+ if (count_scalars) goto nk_angular_e3m2_v128relaxed_cycle;
610
+
611
+ nk_f32_t ab = nk_reduce_add_f32x4_v128relaxed_(ab_f32x4);
612
+ nk_f32_t a2 = nk_reduce_add_f32x4_v128relaxed_(a2_f32x4);
613
+ nk_f32_t b2 = nk_reduce_add_f32x4_v128relaxed_(b2_f32x4);
362
614
  *result = (nk_f32_t)nk_angular_normalize_f64_v128relaxed_((nk_f64_t)ab, (nk_f64_t)a2, (nk_f64_t)b2);
363
615
  }
364
616
 
365
- #pragma endregion - Smaller Floats
366
- #pragma region - Spatial From-Dot Helpers
617
+ #pragma endregion FP8 Floats
618
+ #pragma region Spatial From Dot Helpers
367
619
 
368
620
  /** @brief Angular from_dot: computes 1 − dot / √(query_sumsq × target_sumsq) for 4 pairs in f32. */
369
621
  NK_INTERNAL void nk_angular_through_f32_from_dot_v128relaxed_(nk_b128_vec_t dots, nk_f32_t query_sumsq,
@@ -437,8 +689,8 @@ NK_INTERNAL void nk_euclidean_through_u32_from_dot_v128relaxed_(nk_b128_vec_t do
437
689
  results->v128 = wasm_f32x4_sqrt(dist_sq_f32x4);
438
690
  }
439
691
 
440
- #pragma endregion - Spatial From - Dot Helpers
441
- #pragma region - Integer Spatial
692
+ #pragma endregion Spatial From Dot Helpers
693
+ #pragma region I8 and U8 Integers
442
694
 
443
695
  NK_PUBLIC void nk_sqeuclidean_u8_v128relaxed(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
444
696
  v128_t sum_u32x4 = wasm_u32x4_splat(0);
@@ -703,7 +955,7 @@ NK_PUBLIC void nk_angular_i8_v128relaxed(nk_i8_t const *a, nk_i8_t const *b, nk_
703
955
  *result = (nk_f32_t)nk_angular_normalize_f64_v128relaxed_(dot_ab, norm_aa, norm_bb);
704
956
  }
705
957
 
706
- #pragma endregion - Integer Spatial
958
+ #pragma endregion I8 and U8 Integers
707
959
 
708
960
  #if defined(__clang__)
709
961
  #pragma clang attribute pop