numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,123 @@
1
+ /**
2
+ * @brief SIMD-accelerated Spatial Similarity Measures for RISC-V BF16.
3
+ * @file include/numkong/spatial/rvvbf16.h
4
+ * @author Ash Vardanian
5
+ * @date January 5, 2026
6
+ *
7
+ * @sa include/numkong/spatial.h
8
+ *
9
+ * Zvfbfwma provides widening bf16 fused multiply-accumulate to f32:
10
+ * vfwmaccbf16: f32 ← bf16 × bf16
11
+ *
12
+ * For L2 distance, we use the identity: (a−b)² = a² + b² − 2 × a × b
13
+ * This allows us to use vfwmaccbf16 for all computations.
14
+ *
15
+ * Requires: RVV 1.0 + Zvfbfwma extension (GCC 14+ or Clang 18+)
16
+ */
17
+ #ifndef NK_SPATIAL_RVVBF16_H
18
+ #define NK_SPATIAL_RVVBF16_H
19
+
20
+ #if NK_TARGET_RISCV_
21
+ #if NK_TARGET_RVVBF16
22
+
23
+ #include "numkong/types.h"
24
+ #include "numkong/spatial/rvv.h" // `nk_f32_sqrt_rvv`
25
+
26
+ #if defined(__clang__)
27
+ #pragma clang attribute push(__attribute__((target("arch=+v,+zvfbfwma"))), apply_to = function)
28
+ #elif defined(__GNUC__)
29
+ #pragma GCC push_options
30
+ #pragma GCC target("arch=+v,+zvfbfwma")
31
+ #endif
32
+
33
+ #if defined(__cplusplus)
34
+ extern "C" {
35
+ #endif
36
+
37
+ NK_PUBLIC void nk_sqeuclidean_bf16_rvvbf16(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars,
38
+ nk_size_t count_scalars, nk_f32_t *result) {
39
+ // Per-lane accumulators — deferred horizontal reduction
40
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
41
+ vfloat32m2_t sq_sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax); // a² + b²
42
+ vfloat32m2_t ab_sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax); // a × b
43
+
44
+ for (nk_size_t vector_length; count_scalars > 0;
45
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
46
+ vector_length = __riscv_vsetvl_e16m1(count_scalars);
47
+ vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)a_scalars, vector_length);
48
+ vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)b_scalars, vector_length);
49
+ vbfloat16m1_t a_bf16m1 = __riscv_vreinterpret_v_u16m1_bf16m1(a_u16m1);
50
+ vbfloat16m1_t b_bf16m1 = __riscv_vreinterpret_v_u16m1_bf16m1(b_u16m1);
51
+
52
+ // Accumulate a², b², and a×b per-lane (no per-iteration reduction)
53
+ sq_sum_f32m2 = __riscv_vfwmaccbf16_vv_f32m2_tu(sq_sum_f32m2, a_bf16m1, a_bf16m1, vector_length);
54
+ sq_sum_f32m2 = __riscv_vfwmaccbf16_vv_f32m2_tu(sq_sum_f32m2, b_bf16m1, b_bf16m1, vector_length);
55
+ ab_sum_f32m2 = __riscv_vfwmaccbf16_vv_f32m2_tu(ab_sum_f32m2, a_bf16m1, b_bf16m1, vector_length);
56
+ }
57
+
58
+ // Single horizontal reduction after the loop
59
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
60
+ nk_f32_t sq_sum = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sq_sum_f32m2, zero_f32m1, vlmax));
61
+ nk_f32_t ab_sum = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(ab_sum_f32m2, zero_f32m1, vlmax));
62
+ *result = sq_sum - 2.0f * ab_sum;
63
+ }
64
+
65
+ NK_PUBLIC void nk_euclidean_bf16_rvvbf16(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars,
66
+ nk_size_t count_scalars, nk_f32_t *result) {
67
+ nk_sqeuclidean_bf16_rvvbf16(a_scalars, b_scalars, count_scalars, result);
68
+ // Handle potential negative values from floating point errors
69
+ *result = *result > 0.0f ? nk_f32_sqrt_rvv(*result) : 0.0f;
70
+ }
71
+
72
+ NK_PUBLIC void nk_angular_bf16_rvvbf16(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
73
+ nk_f32_t *result) {
74
+ // Per-lane accumulators — deferred horizontal reduction
75
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
76
+ vfloat32m2_t dot_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
77
+ vfloat32m2_t a_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
78
+ vfloat32m2_t b_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
79
+
80
+ for (nk_size_t vector_length; count_scalars > 0;
81
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
82
+ vector_length = __riscv_vsetvl_e16m1(count_scalars);
83
+ vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)a_scalars, vector_length);
84
+ vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)b_scalars, vector_length);
85
+ vbfloat16m1_t a_bf16m1 = __riscv_vreinterpret_v_u16m1_bf16m1(a_u16m1);
86
+ vbfloat16m1_t b_bf16m1 = __riscv_vreinterpret_v_u16m1_bf16m1(b_u16m1);
87
+
88
+ // dot += a × b
89
+ dot_f32m2 = __riscv_vfwmaccbf16_vv_f32m2_tu(dot_f32m2, a_bf16m1, b_bf16m1, vector_length);
90
+ // a_sq += a × a
91
+ a_sq_f32m2 = __riscv_vfwmaccbf16_vv_f32m2_tu(a_sq_f32m2, a_bf16m1, a_bf16m1, vector_length);
92
+ // b_sq += b × b
93
+ b_sq_f32m2 = __riscv_vfwmaccbf16_vv_f32m2_tu(b_sq_f32m2, b_bf16m1, b_bf16m1, vector_length);
94
+ }
95
+
96
+ // Single horizontal reduction after the loop
97
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
98
+ nk_f32_t dot = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(dot_f32m2, zero_f32m1, vlmax));
99
+ nk_f32_t a_sq = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(a_sq_f32m2, zero_f32m1, vlmax));
100
+ nk_f32_t b_sq = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(b_sq_f32m2, zero_f32m1, vlmax));
101
+
102
+ // Normalize: 1 − dot / sqrt(‖a‖² × ‖b‖²)
103
+ if (a_sq == 0.0f && b_sq == 0.0f) { *result = 0.0f; }
104
+ else if (dot == 0.0f) { *result = 1.0f; }
105
+ else {
106
+ nk_f32_t unclipped = 1.0f - dot * nk_f32_rsqrt_rvv(a_sq) * nk_f32_rsqrt_rvv(b_sq);
107
+ *result = unclipped > 0.0f ? unclipped : 0.0f;
108
+ }
109
+ }
110
+
111
+ #if defined(__cplusplus)
112
+ } // extern "C"
113
+ #endif
114
+
115
+ #if defined(__clang__)
116
+ #pragma clang attribute pop
117
+ #elif defined(__GNUC__)
118
+ #pragma GCC pop_options
119
+ #endif
120
+
121
+ #endif // NK_TARGET_RVVBF16
122
+ #endif // NK_TARGET_RISCV_
123
+ #endif // NK_SPATIAL_RVVBF16_H
@@ -0,0 +1,117 @@
1
+ /**
2
+ * @brief SIMD-accelerated Spatial Similarity Measures for RISC-V FP16.
3
+ * @file include/numkong/spatial/rvvhalf.h
4
+ * @author Ash Vardanian
5
+ * @date January 5, 2026
6
+ *
7
+ * @sa include/numkong/spatial.h
8
+ *
9
+ * Zvfh provides native half-precision (f16) vector operations.
10
+ * Uses widening operations (f16 → f32) for precision accumulation.
11
+ *
12
+ * Requires: RVV 1.0 + Zvfh extension (GCC 14+ or Clang 18+)
13
+ */
14
+ #ifndef NK_SPATIAL_RVVHALF_H
15
+ #define NK_SPATIAL_RVVHALF_H
16
+
17
+ #if NK_TARGET_RISCV_
18
+ #if NK_TARGET_RVVHALF
19
+
20
+ #include "numkong/types.h"
21
+ #include "numkong/spatial/rvv.h" // `nk_f32_sqrt_rvv`
22
+
23
+ #if defined(__clang__)
24
+ #pragma clang attribute push(__attribute__((target("arch=+v,+zvfh"))), apply_to = function)
25
+ #elif defined(__GNUC__)
26
+ #pragma GCC push_options
27
+ #pragma GCC target("arch=+v,+zvfh")
28
+ #endif
29
+
30
+ #if defined(__cplusplus)
31
+ extern "C" {
32
+ #endif
33
+
34
+ NK_PUBLIC void nk_sqeuclidean_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
35
+ nk_f32_t *result) {
36
+ // Per-lane accumulator — deferred horizontal reduction
37
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
38
+ vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
39
+
40
+ for (nk_size_t vector_length; count_scalars > 0;
41
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
42
+ vector_length = __riscv_vsetvl_e16m1(count_scalars);
43
+ vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)a_scalars, vector_length);
44
+ vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)b_scalars, vector_length);
45
+ vfloat16m1_t a_f16m1 = __riscv_vreinterpret_v_u16m1_f16m1(a_u16m1);
46
+ vfloat16m1_t b_f16m1 = __riscv_vreinterpret_v_u16m1_f16m1(b_u16m1);
47
+ // Upcast to f32 before subtraction to avoid catastrophic cancellation in f16
48
+ vfloat32m2_t a_f32m2 = __riscv_vfwcvt_f_f_v_f32m2(a_f16m1, vector_length);
49
+ vfloat32m2_t b_f32m2 = __riscv_vfwcvt_f_f_v_f32m2(b_f16m1, vector_length);
50
+ vfloat32m2_t diff_f32m2 = __riscv_vfsub_vv_f32m2(a_f32m2, b_f32m2, vector_length);
51
+ // Accumulate diff² in f32
52
+ sum_f32m2 = __riscv_vfmacc_vv_f32m2_tu(sum_f32m2, diff_f32m2, diff_f32m2, vector_length);
53
+ }
54
+
55
+ // Single horizontal reduction after the loop
56
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
57
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax));
58
+ }
59
+
60
+ NK_PUBLIC void nk_euclidean_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
61
+ nk_f32_t *result) {
62
+ nk_sqeuclidean_f16_rvvhalf(a_scalars, b_scalars, count_scalars, result);
63
+ *result = nk_f32_sqrt_rvv(*result);
64
+ }
65
+
66
+ NK_PUBLIC void nk_angular_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
67
+ nk_f32_t *result) {
68
+ // Per-lane accumulators — deferred horizontal reduction
69
+ nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
70
+ vfloat32m2_t dot_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
71
+ vfloat32m2_t a_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
72
+ vfloat32m2_t b_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
73
+
74
+ for (nk_size_t vector_length; count_scalars > 0;
75
+ count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
76
+ vector_length = __riscv_vsetvl_e16m1(count_scalars);
77
+ vuint16m1_t a_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)a_scalars, vector_length);
78
+ vuint16m1_t b_u16m1 = __riscv_vle16_v_u16m1((unsigned short const *)b_scalars, vector_length);
79
+ vfloat16m1_t a_f16m1 = __riscv_vreinterpret_v_u16m1_f16m1(a_u16m1);
80
+ vfloat16m1_t b_f16m1 = __riscv_vreinterpret_v_u16m1_f16m1(b_u16m1);
81
+
82
+ // dot += a × b (widened to f32)
83
+ dot_f32m2 = __riscv_vfwmacc_vv_f32m2_tu(dot_f32m2, a_f16m1, b_f16m1, vector_length);
84
+ // a_sq += a × a
85
+ a_sq_f32m2 = __riscv_vfwmacc_vv_f32m2_tu(a_sq_f32m2, a_f16m1, a_f16m1, vector_length);
86
+ // b_sq += b × b
87
+ b_sq_f32m2 = __riscv_vfwmacc_vv_f32m2_tu(b_sq_f32m2, b_f16m1, b_f16m1, vector_length);
88
+ }
89
+
90
+ // Single horizontal reduction after the loop
91
+ vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
92
+ nk_f32_t dot = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(dot_f32m2, zero_f32m1, vlmax));
93
+ nk_f32_t a_sq = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(a_sq_f32m2, zero_f32m1, vlmax));
94
+ nk_f32_t b_sq = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(b_sq_f32m2, zero_f32m1, vlmax));
95
+
96
+ // Normalize: 1 − dot / sqrt(‖a‖² × ‖b‖²)
97
+ if (a_sq == 0.0f && b_sq == 0.0f) { *result = 0.0f; }
98
+ else if (dot == 0.0f) { *result = 1.0f; }
99
+ else {
100
+ nk_f32_t unclipped = 1.0f - dot * nk_f32_rsqrt_rvv(a_sq) * nk_f32_rsqrt_rvv(b_sq);
101
+ *result = unclipped > 0.0f ? unclipped : 0.0f;
102
+ }
103
+ }
104
+
105
+ #if defined(__cplusplus)
106
+ } // extern "C"
107
+ #endif
108
+
109
+ #if defined(__clang__)
110
+ #pragma clang attribute pop
111
+ #elif defined(__GNUC__)
112
+ #pragma GCC pop_options
113
+ #endif
114
+
115
+ #endif // NK_TARGET_RVVHALF
116
+ #endif // NK_TARGET_RISCV_
117
+ #endif // NK_SPATIAL_RVVHALF_H
@@ -0,0 +1,343 @@
1
+ /**
2
+ * @brief SIMD-accelerated Spatial Similarity Measures for Sapphire Rapids.
3
+ * @file include/numkong/spatial/sapphire.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/spatial.h
8
+ *
9
+ * Sapphire Rapids adds native FP16 support via AVX-512 FP16 extension.
10
+ * For e4m3 L2 distance, we can leverage F16 for the subtraction step:
11
+ * - e4m3 differences fit in F16 (max |a−b| = 896 < 65504)
12
+ * - But squared differences overflow F16 (896² = 802816 > 65504)
13
+ * - So: subtract in F16, convert to F32, then square and accumulate
14
+ *
15
+ * For e2m3/e3m2 L2 distance, squared differences fit in FP16:
16
+ * - E2M3: max |a−b| = 15, max (a−b)² = 225 < 65504, flush cadence = 4 (conservative for uniformity)
17
+ * - E3M2: max |a−b| = 56, max (a−b)² = 3136 < 65504, flush cadence = 4
18
+ * So the entire sub+square+accumulate stays in FP16 with periodic F32 flush.
19
+ *
20
+ * @section spatial_sapphire_instructions Relevant Instructions
21
+ *
22
+ * Intrinsic Instruction Sapphire Genoa
23
+ * _mm256_sub_ph VSUBPH (YMM, YMM, YMM) 4cy @ p05 3cy @ p01
24
+ * _mm512_cvtph_ps VCVTPH2PS (ZMM, YMM) 5cy @ p05 5cy @ p01
25
+ * _mm512_fmadd_ps VFMADD (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p01
26
+ * _mm512_reduce_add_ps (pseudo: VHADDPS chain) ~8cy ~8cy
27
+ * _mm_maskz_loadu_epi8 VMOVDQU8 (XMM {K}, M128) 7cy @ p23 7cy @ p23
28
+ */
29
+ #ifndef NK_SPATIAL_SAPPHIRE_H
30
+ #define NK_SPATIAL_SAPPHIRE_H
31
+
32
+ #if NK_TARGET_X86_
33
+ #if NK_TARGET_SAPPHIRE
34
+
35
+ #include "numkong/types.h"
36
+ #include "numkong/cast/sapphire.h" // `nk_e4m3x16_to_f16x16_sapphire_`
37
+ #include "numkong/dot/sapphire.h" // `nk_e2m3x32_to_f16x32_sapphire_`, `nk_flush_f16_to_f32_sapphire_`
38
+ #include "numkong/spatial/haswell.h" // `nk_angular_normalize_f32_haswell_`, `nk_f32_sqrt_haswell`
39
+
40
+ #if defined(__cplusplus)
41
+ extern "C" {
42
+ #endif
43
+
44
+ #if defined(__clang__)
45
+ #pragma clang attribute push( \
46
+ __attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,f16c,fma,bmi,bmi2"))), \
47
+ apply_to = function)
48
+ #elif defined(__GNUC__)
49
+ #pragma GCC push_options
50
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "f16c", "fma", "bmi", "bmi2")
51
+ #endif
52
+
53
+ NK_PUBLIC void nk_sqeuclidean_e4m3_sapphire(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars,
54
+ nk_size_t count_scalars, nk_f32_t *result) {
55
+ __m512 sum_f32x16 = _mm512_setzero_ps();
56
+
57
+ while (count_scalars > 0) {
58
+ nk_size_t const n = count_scalars < 16 ? count_scalars : 16;
59
+ __mmask16 const mask = (__mmask16)_bzhi_u32(0xFFFF, n);
60
+ __m128i a_e4m3x16 = _mm_maskz_loadu_epi8(mask, a_scalars);
61
+ __m128i b_e4m3x16 = _mm_maskz_loadu_epi8(mask, b_scalars);
62
+
63
+ // Convert e4m3 → f16
64
+ __m256h a_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(a_e4m3x16);
65
+ __m256h b_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(b_e4m3x16);
66
+
67
+ // Subtract in F16 − differences fit (max 896 < 65504)
68
+ __m256h diff_f16x16 = _mm256_sub_ph(a_f16x16, b_f16x16);
69
+
70
+ // Convert to F32 before squaring (896² = 802816 overflows F16!)
71
+ __m512 diff_f32x16 = _mm512_cvtph_ps(_mm256_castph_si256(diff_f16x16));
72
+
73
+ // Square and accumulate in F32
74
+ sum_f32x16 = _mm512_fmadd_ps(diff_f32x16, diff_f32x16, sum_f32x16);
75
+ a_scalars += n, b_scalars += n, count_scalars -= n;
76
+ }
77
+
78
+ *result = _mm512_reduce_add_ps(sum_f32x16);
79
+ }
80
+
81
+ NK_PUBLIC void nk_euclidean_e4m3_sapphire(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars,
82
+ nk_size_t count_scalars, nk_f32_t *result) {
83
+ nk_sqeuclidean_e4m3_sapphire(a_scalars, b_scalars, count_scalars, result);
84
+ *result = nk_f32_sqrt_haswell(*result);
85
+ }
86
+
87
+ NK_PUBLIC void nk_sqeuclidean_e2m3_sapphire(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars,
88
+ nk_size_t count_scalars, nk_f32_t *result) {
89
+ __m512 sum_f32x16 = _mm512_setzero_ps();
90
+
91
+ // Main loop: 4-way unrolled, 128 elements per flush
92
+ while (count_scalars >= 128) {
93
+ __m512h acc_f16x32 = _mm512_setzero_ph();
94
+ __m512h a_f16x32, b_f16x32, diff_f16x32;
95
+ // Iteration 1
96
+ a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars));
97
+ b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars));
98
+ diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
99
+ acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
100
+ // Iteration 2
101
+ a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 32));
102
+ b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 32));
103
+ diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
104
+ acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
105
+ // Iteration 3
106
+ a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 64));
107
+ b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 64));
108
+ diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
109
+ acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
110
+ // Iteration 4
111
+ a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 96));
112
+ b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 96));
113
+ diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
114
+ acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
115
+ // Flush to F32
116
+ sum_f32x16 = nk_flush_f16_to_f32_sapphire_(acc_f16x32, sum_f32x16);
117
+ a_scalars += 128, b_scalars += 128, count_scalars -= 128;
118
+ }
119
+
120
+ // Tail: remaining 0–127 elements, 32 at a time via masked loads
121
+ __m512h acc_f16x32 = _mm512_setzero_ph();
122
+ while (count_scalars > 0) {
123
+ nk_size_t const n = count_scalars < 32 ? count_scalars : 32;
124
+ __mmask32 const mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
125
+ __m512h a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, a_scalars));
126
+ __m512h b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, b_scalars));
127
+ __m512h diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
128
+ acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
129
+ a_scalars += n, b_scalars += n, count_scalars -= n;
130
+ }
131
+ sum_f32x16 = nk_flush_f16_to_f32_sapphire_(acc_f16x32, sum_f32x16);
132
+
133
+ *result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
134
+ }
135
+
136
+ NK_PUBLIC void nk_sqeuclidean_e3m2_sapphire(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars,
137
+ nk_size_t count_scalars, nk_f32_t *result) {
138
+ __m512 sum_f32x16 = _mm512_setzero_ps();
139
+
140
+ // Main loop: 4-way unrolled, 128 elements per flush
141
+ while (count_scalars >= 128) {
142
+ __m512h acc_f16x32 = _mm512_setzero_ph();
143
+ __m512h a_f16x32, b_f16x32, diff_f16x32;
144
+ // Iteration 1
145
+ a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars));
146
+ b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars));
147
+ diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
148
+ acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
149
+ // Iteration 2
150
+ a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 32));
151
+ b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 32));
152
+ diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
153
+ acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
154
+ // Iteration 3
155
+ a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 64));
156
+ b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 64));
157
+ diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
158
+ acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
159
+ // Iteration 4
160
+ a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 96));
161
+ b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 96));
162
+ diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
163
+ acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
164
+ // Flush to F32
165
+ sum_f32x16 = nk_flush_f16_to_f32_sapphire_(acc_f16x32, sum_f32x16);
166
+ a_scalars += 128, b_scalars += 128, count_scalars -= 128;
167
+ }
168
+
169
+ // Tail: remaining 0–127 elements, 32 at a time via masked loads
170
+ __m512h acc_f16x32 = _mm512_setzero_ph();
171
+ while (count_scalars > 0) {
172
+ nk_size_t const n = count_scalars < 32 ? count_scalars : 32;
173
+ __mmask32 const mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
174
+ __m512h a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, a_scalars));
175
+ __m512h b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, b_scalars));
176
+ __m512h diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
177
+ acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
178
+ a_scalars += n, b_scalars += n, count_scalars -= n;
179
+ }
180
+ sum_f32x16 = nk_flush_f16_to_f32_sapphire_(acc_f16x32, sum_f32x16);
181
+
182
+ *result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
183
+ }
184
+
185
+ NK_PUBLIC void nk_euclidean_e2m3_sapphire(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars,
186
+ nk_size_t count_scalars, nk_f32_t *result) {
187
+ nk_sqeuclidean_e2m3_sapphire(a_scalars, b_scalars, count_scalars, result);
188
+ *result = nk_f32_sqrt_haswell(*result);
189
+ }
190
+
191
+ NK_PUBLIC void nk_euclidean_e3m2_sapphire(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars,
192
+ nk_size_t count_scalars, nk_f32_t *result) {
193
+ nk_sqeuclidean_e3m2_sapphire(a_scalars, b_scalars, count_scalars, result);
194
+ *result = nk_f32_sqrt_haswell(*result);
195
+ }
196
+
197
+ NK_PUBLIC void nk_angular_e2m3_sapphire(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
198
+ nk_f32_t *result) {
199
+ __m512 sum_dot_f32x16 = _mm512_setzero_ps();
200
+ __m512 sum_a_f32x16 = _mm512_setzero_ps();
201
+ __m512 sum_b_f32x16 = _mm512_setzero_ps();
202
+
203
+ // Main loop: 4-way unrolled, 128 elements per flush
204
+ while (count_scalars >= 128) {
205
+ __m512h dot_acc = _mm512_setzero_ph();
206
+ __m512h a_norm_acc = _mm512_setzero_ph();
207
+ __m512h b_norm_acc = _mm512_setzero_ph();
208
+ __m512h a_f16x32, b_f16x32;
209
+ // Iteration 1
210
+ a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars));
211
+ b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars));
212
+ dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
213
+ a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
214
+ b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
215
+ // Iteration 2
216
+ a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 32));
217
+ b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 32));
218
+ dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
219
+ a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
220
+ b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
221
+ // Iteration 3
222
+ a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 64));
223
+ b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 64));
224
+ dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
225
+ a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
226
+ b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
227
+ // Iteration 4
228
+ a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 96));
229
+ b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 96));
230
+ dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
231
+ a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
232
+ b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
233
+ // Flush to F32
234
+ sum_dot_f32x16 = nk_flush_f16_to_f32_sapphire_(dot_acc, sum_dot_f32x16);
235
+ sum_a_f32x16 = nk_flush_f16_to_f32_sapphire_(a_norm_acc, sum_a_f32x16);
236
+ sum_b_f32x16 = nk_flush_f16_to_f32_sapphire_(b_norm_acc, sum_b_f32x16);
237
+ a_scalars += 128, b_scalars += 128, count_scalars -= 128;
238
+ }
239
+
240
+ // Tail: remaining 0–127 elements, 32 at a time via masked loads
241
+ __m512h dot_acc = _mm512_setzero_ph();
242
+ __m512h a_norm_acc = _mm512_setzero_ph();
243
+ __m512h b_norm_acc = _mm512_setzero_ph();
244
+ while (count_scalars > 0) {
245
+ nk_size_t const n = count_scalars < 32 ? count_scalars : 32;
246
+ __mmask32 const mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
247
+ __m512h a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, a_scalars));
248
+ __m512h b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, b_scalars));
249
+ dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
250
+ a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
251
+ b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
252
+ a_scalars += n, b_scalars += n, count_scalars -= n;
253
+ }
254
+ sum_dot_f32x16 = nk_flush_f16_to_f32_sapphire_(dot_acc, sum_dot_f32x16);
255
+ sum_a_f32x16 = nk_flush_f16_to_f32_sapphire_(a_norm_acc, sum_a_f32x16);
256
+ sum_b_f32x16 = nk_flush_f16_to_f32_sapphire_(b_norm_acc, sum_b_f32x16);
257
+
258
+ nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(sum_dot_f32x16);
259
+ nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(sum_a_f32x16);
260
+ nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(sum_b_f32x16);
261
+ *result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
262
+ }
263
+
264
+ NK_PUBLIC void nk_angular_e3m2_sapphire(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars, nk_size_t count_scalars,
265
+ nk_f32_t *result) {
266
+ __m512 sum_dot_f32x16 = _mm512_setzero_ps();
267
+ __m512 sum_a_f32x16 = _mm512_setzero_ps();
268
+ __m512 sum_b_f32x16 = _mm512_setzero_ps();
269
+
270
+ // Main loop: 4-way unrolled, 128 elements per flush
271
+ while (count_scalars >= 128) {
272
+ __m512h dot_acc = _mm512_setzero_ph();
273
+ __m512h a_norm_acc = _mm512_setzero_ph();
274
+ __m512h b_norm_acc = _mm512_setzero_ph();
275
+ __m512h a_f16x32, b_f16x32;
276
+ // Iteration 1
277
+ a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars));
278
+ b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars));
279
+ dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
280
+ a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
281
+ b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
282
+ // Iteration 2
283
+ a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 32));
284
+ b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 32));
285
+ dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
286
+ a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
287
+ b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
288
+ // Iteration 3
289
+ a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 64));
290
+ b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 64));
291
+ dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
292
+ a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
293
+ b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
294
+ // Iteration 4
295
+ a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 96));
296
+ b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 96));
297
+ dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
298
+ a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
299
+ b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
300
+ // Flush to F32
301
+ sum_dot_f32x16 = nk_flush_f16_to_f32_sapphire_(dot_acc, sum_dot_f32x16);
302
+ sum_a_f32x16 = nk_flush_f16_to_f32_sapphire_(a_norm_acc, sum_a_f32x16);
303
+ sum_b_f32x16 = nk_flush_f16_to_f32_sapphire_(b_norm_acc, sum_b_f32x16);
304
+ a_scalars += 128, b_scalars += 128, count_scalars -= 128;
305
+ }
306
+
307
+ // Tail: remaining 0–127 elements, 32 at a time via masked loads
308
+ __m512h dot_acc = _mm512_setzero_ph();
309
+ __m512h a_norm_acc = _mm512_setzero_ph();
310
+ __m512h b_norm_acc = _mm512_setzero_ph();
311
+ while (count_scalars > 0) {
312
+ nk_size_t const n = count_scalars < 32 ? count_scalars : 32;
313
+ __mmask32 const mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
314
+ __m512h a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, a_scalars));
315
+ __m512h b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, b_scalars));
316
+ dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
317
+ a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
318
+ b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
319
+ a_scalars += n, b_scalars += n, count_scalars -= n;
320
+ }
321
+ sum_dot_f32x16 = nk_flush_f16_to_f32_sapphire_(dot_acc, sum_dot_f32x16);
322
+ sum_a_f32x16 = nk_flush_f16_to_f32_sapphire_(a_norm_acc, sum_a_f32x16);
323
+ sum_b_f32x16 = nk_flush_f16_to_f32_sapphire_(b_norm_acc, sum_b_f32x16);
324
+
325
+ nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(sum_dot_f32x16);
326
+ nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(sum_a_f32x16);
327
+ nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(sum_b_f32x16);
328
+ *result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
329
+ }
330
+
331
+ #if defined(__clang__)
332
+ #pragma clang attribute pop
333
+ #elif defined(__GNUC__)
334
+ #pragma GCC pop_options
335
+ #endif
336
+
337
+ #if defined(__cplusplus)
338
+ } // extern "C"
339
+ #endif
340
+
341
+ #endif // NK_TARGET_SAPPHIRE
342
+ #endif // NK_TARGET_X86_
343
+ #endif // NK_SPATIAL_SAPPHIRE_H