numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,346 @@
1
+ /**
2
+ * @brief SWAR-accelerated Spatial Similarity Measures for SIMD-free CPUs.
3
+ * @file include/numkong/spatial/serial.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/spatial.h
8
+ */
9
+ #ifndef NK_SPATIAL_SERIAL_H
10
+ #define NK_SPATIAL_SERIAL_H
11
+
12
+ #include "numkong/types.h"
13
+ #include "numkong/scalar/serial.h" // `nk_f32_rsqrt_serial`
14
+ #include "numkong/cast/serial.h"
15
+ #include "numkong/dot/serial.h" // `nk_dot_f64x2_state_serial_t`
16
+
17
+ #if defined(__cplusplus)
18
+ extern "C" {
19
+ #endif
20
+
21
+ /**
22
+ * @brief Macro for L2 squared distance with Neumaier compensated summation.
23
+ *
24
+ * Implements Neumaier's Kahan-Babuška variant to minimize floating-point rounding errors.
25
+ * Unlike Kahan, Neumaier handles the case where the term being added is larger than the
26
+ * running sum. Achieves O(1) error growth regardless of vector dimension.
27
+ *
28
+ * Performance vs Accuracy Tradeoff:
29
+ * - Adds ~30% overhead (3 extra FP operations per iteration) compared to naive summation
30
+ * - Reduces relative error from ~10⁻⁵ to ~10⁻⁷ at n=100K for f32
31
+ * - Benefits all floating-point types: f64, f32, f16, bf16
32
+ * - Integer types (i8) maintain perfect accuracy regardless
33
+ *
34
+ * Algorithm: For each term, compute t = sum + term, then:
35
+ * - If |sum| ≥ |term|: c += (sum − t) + term (lost low-order bits of term)
36
+ * - Else: c += (term − t) + sum (lost low-order bits of sum)
37
+ *
38
+ * @see Neumaier, A. (1974). "Rundungsfehleranalyse einiger Verfahren zur Summation endlicher Summen"
39
+ */
40
+ #define nk_define_sqeuclidean_(input_type, accumulator_type, output_type, load_and_convert) \
41
+ NK_PUBLIC void nk_sqeuclidean_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
42
+ nk_size_t n, nk_##output_type##_t *result) { \
43
+ nk_##accumulator_type##_t sum = 0, compensation = 0, a_element, b_element; \
44
+ for (nk_size_t i = 0; i != n; ++i) { \
45
+ load_and_convert(a + i, &a_element); \
46
+ load_and_convert(b + i, &b_element); \
47
+ nk_##accumulator_type##_t diff = a_element - b_element; \
48
+ nk_##accumulator_type##_t term = diff * diff, t = sum + term; \
49
+ compensation += (nk_##accumulator_type##_abs_(sum) >= nk_##accumulator_type##_abs_(term)) \
50
+ ? ((sum - t) + term) \
51
+ : ((term - t) + sum); \
52
+ sum = t; \
53
+ } \
54
+ *result = (nk_##output_type##_t)(sum + compensation); \
55
+ }
56
+
57
+ #define nk_define_euclidean_(input_type, accumulator_type, l2sq_output_type, output_type, load_and_convert, \
58
+ compute_sqrt) \
59
+ NK_PUBLIC void nk_euclidean_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
60
+ nk_size_t n, nk_##output_type##_t *result) { \
61
+ nk_##l2sq_output_type##_t distance_sq; \
62
+ nk_sqeuclidean_##input_type##_serial(a, b, n, &distance_sq); \
63
+ *result = compute_sqrt((nk_##output_type##_t)distance_sq); \
64
+ }
65
+
66
+ /**
67
+ * @brief Macro for cosine/angular distance with Neumaier compensated summation.
68
+ *
69
+ * Uses Neumaier summation for all three accumulators (dot_product, a_norm_sq, b_norm_sq).
70
+ * Achieves O(1) error growth regardless of vector dimension.
71
+ *
72
+ * @see nk_define_sqeuclidean_ for detailed documentation on Neumaier summation.
73
+ */
74
+ #define nk_define_angular_(input_type, accumulator_type, output_type, load_and_convert, compute_rsqrt) \
75
+ NK_PUBLIC void nk_angular_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
76
+ nk_size_t n, nk_##output_type##_t *result) { \
77
+ nk_##accumulator_type##_t dot_sum = 0, a_sum = 0, b_sum = 0, a_element, b_element; \
78
+ nk_##accumulator_type##_t compensation_dot = 0, compensation_a = 0, compensation_b = 0; \
79
+ for (nk_size_t i = 0; i != n; ++i) { \
80
+ load_and_convert(a + i, &a_element); \
81
+ load_and_convert(b + i, &b_element); \
82
+ nk_##accumulator_type##_t term_dot = a_element * b_element, t_dot = dot_sum + term_dot; \
83
+ nk_##accumulator_type##_t term_a = a_element * a_element, t_a = a_sum + term_a; \
84
+ nk_##accumulator_type##_t term_b = b_element * b_element, t_b = b_sum + term_b; \
85
+ compensation_dot += (nk_##accumulator_type##_abs_(dot_sum) >= nk_##accumulator_type##_abs_(term_dot)) \
86
+ ? ((dot_sum - t_dot) + term_dot) \
87
+ : ((term_dot - t_dot) + dot_sum); \
88
+ compensation_a += (nk_##accumulator_type##_abs_(a_sum) >= nk_##accumulator_type##_abs_(term_a)) \
89
+ ? ((a_sum - t_a) + term_a) \
90
+ : ((term_a - t_a) + a_sum); \
91
+ compensation_b += (nk_##accumulator_type##_abs_(b_sum) >= nk_##accumulator_type##_abs_(term_b)) \
92
+ ? ((b_sum - t_b) + term_b) \
93
+ : ((term_b - t_b) + b_sum); \
94
+ dot_sum = t_dot; \
95
+ a_sum = t_a; \
96
+ b_sum = t_b; \
97
+ } \
98
+ nk_##accumulator_type##_t dot_product = dot_sum + compensation_dot; \
99
+ nk_##accumulator_type##_t a_norm_sq = a_sum + compensation_a; \
100
+ nk_##accumulator_type##_t b_norm_sq = b_sum + compensation_b; \
101
+ if (a_norm_sq == 0 && b_norm_sq == 0) { *result = 0; } \
102
+ else if (dot_product == 0) { *result = 1; } \
103
+ else { \
104
+ nk_##output_type##_t unclipped_distance = 1 - dot_product * compute_rsqrt(a_norm_sq) * \
105
+ compute_rsqrt(b_norm_sq); \
106
+ *result = unclipped_distance > 0 ? unclipped_distance : 0; \
107
+ } \
108
+ }
109
+
110
+ nk_define_angular_(f64, f64, f64, nk_assign_from_to_, nk_f64_rsqrt_serial) // nk_angular_f64_serial
111
+ nk_define_sqeuclidean_(f64, f64, f64, nk_assign_from_to_) // nk_sqeuclidean_f64_serial
112
+ nk_define_euclidean_(f64, f64, f64, f64, nk_assign_from_to_, nk_f64_sqrt_serial) // nk_euclidean_f64_serial
113
+
114
+ nk_define_angular_(f32, f64, f64, nk_assign_from_to_, nk_f64_rsqrt_serial) // nk_angular_f32_serial
115
+ nk_define_sqeuclidean_(f32, f64, f64, nk_assign_from_to_) // nk_sqeuclidean_f32_serial
116
+ nk_define_euclidean_(f32, f64, f64, f64, nk_assign_from_to_, nk_f64_sqrt_serial) // nk_euclidean_f32_serial
117
+
118
+ nk_define_angular_(f16, f32, f32, nk_f16_to_f32_serial, nk_f32_rsqrt_serial) // nk_angular_f16_serial
119
+ nk_define_sqeuclidean_(f16, f32, f32, nk_f16_to_f32_serial) // nk_sqeuclidean_f16_serial
120
+ nk_define_euclidean_(f16, f32, f32, f32, nk_f16_to_f32_serial, nk_f32_sqrt_serial) // nk_euclidean_f16_serial
121
+
122
+ nk_define_angular_(bf16, f32, f32, nk_bf16_to_f32_serial, nk_f32_rsqrt_serial) // nk_angular_bf16_serial
123
+ nk_define_sqeuclidean_(bf16, f32, f32, nk_bf16_to_f32_serial) // nk_sqeuclidean_bf16_serial
124
+ nk_define_euclidean_(bf16, f32, f32, f32, nk_bf16_to_f32_serial, nk_f32_sqrt_serial) // nk_euclidean_bf16_serial
125
+
126
+ nk_define_angular_(e4m3, f32, f32, nk_e4m3_to_f32_serial, nk_f32_rsqrt_serial) // nk_angular_e4m3_serial
127
+ nk_define_sqeuclidean_(e4m3, f32, f32, nk_e4m3_to_f32_serial) // nk_sqeuclidean_e4m3_serial
128
+ nk_define_euclidean_(e4m3, f32, f32, f32, nk_e4m3_to_f32_serial, nk_f32_sqrt_serial) // nk_euclidean_e4m3_serial
129
+
130
+ nk_define_angular_(e5m2, f32, f32, nk_e5m2_to_f32_serial, nk_f32_rsqrt_serial) // nk_angular_e5m2_serial
131
+ nk_define_sqeuclidean_(e5m2, f32, f32, nk_e5m2_to_f32_serial) // nk_sqeuclidean_e5m2_serial
132
+ nk_define_euclidean_(e5m2, f32, f32, f32, nk_e5m2_to_f32_serial, nk_f32_sqrt_serial) // nk_euclidean_e5m2_serial
133
+
134
+ nk_define_angular_(e2m3, f32, f32, nk_e2m3_to_f32_serial, nk_f32_rsqrt_serial) // nk_angular_e2m3_serial
135
+ nk_define_sqeuclidean_(e2m3, f32, f32, nk_e2m3_to_f32_serial) // nk_sqeuclidean_e2m3_serial
136
+ nk_define_euclidean_(e2m3, f32, f32, f32, nk_e2m3_to_f32_serial, nk_f32_sqrt_serial) // nk_euclidean_e2m3_serial
137
+
138
+ nk_define_angular_(e3m2, f32, f32, nk_e3m2_to_f32_serial, nk_f32_rsqrt_serial) // nk_angular_e3m2_serial
139
+ nk_define_sqeuclidean_(e3m2, f32, f32, nk_e3m2_to_f32_serial) // nk_sqeuclidean_e3m2_serial
140
+ nk_define_euclidean_(e3m2, f32, f32, f32, nk_e3m2_to_f32_serial, nk_f32_sqrt_serial) // nk_euclidean_e3m2_serial
141
+
142
+ nk_define_angular_(i8, i32, f32, nk_assign_from_to_, nk_f32_rsqrt_serial) // nk_angular_i8_serial
143
+ nk_define_sqeuclidean_(i8, i32, u32, nk_assign_from_to_) // nk_sqeuclidean_i8_serial
144
+ nk_define_euclidean_(i8, i32, u32, f32, nk_assign_from_to_, nk_f32_sqrt_serial) // nk_euclidean_i8_serial
145
+
146
+ nk_define_angular_(u8, u32, f32, nk_assign_from_to_, nk_f32_rsqrt_serial) // nk_angular_u8_serial
147
+ nk_define_sqeuclidean_(u8, u32, u32, nk_assign_from_to_) // nk_sqeuclidean_u8_serial
148
+ nk_define_euclidean_(u8, u32, u32, f32, nk_assign_from_to_, nk_f32_sqrt_serial) // nk_euclidean_u8_serial
149
+
150
+ #undef nk_define_sqeuclidean_
151
+ #undef nk_define_euclidean_
152
+ #undef nk_define_angular_
153
+
154
+ NK_PUBLIC void nk_sqeuclidean_i4_serial(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_u32_t *result) {
155
+ // i4 values are packed as nibbles: two 4-bit signed values per byte.
156
+ // Parameter `n` is the number of 4-bit values (dimensions), not bytes.
157
+ // Sign extension: (nibble ^ 8) - 8 maps [0,15] to [-8,7]
158
+ n = nk_size_round_up_to_multiple_(n, 2);
159
+ nk_size_t n_bytes = n / 2;
160
+ nk_i32_t sum = 0;
161
+ for (nk_size_t i = 0; i < n_bytes; ++i) {
162
+ nk_i32_t a_low = (nk_i32_t)nk_i4x2_low_(a[i]);
163
+ nk_i32_t b_low = (nk_i32_t)nk_i4x2_low_(b[i]);
164
+ nk_i32_t a_high = (nk_i32_t)nk_i4x2_high_(a[i]);
165
+ nk_i32_t b_high = (nk_i32_t)nk_i4x2_high_(b[i]);
166
+ nk_i32_t diff_low = a_low - b_low, diff_high = a_high - b_high;
167
+ sum += diff_low * diff_low + diff_high * diff_high;
168
+ }
169
+ *result = (nk_u32_t)sum;
170
+ }
171
+
172
+ NK_PUBLIC void nk_euclidean_i4_serial(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_f32_t *result) {
173
+ nk_u32_t distance_sq;
174
+ nk_sqeuclidean_i4_serial(a, b, n, &distance_sq);
175
+ *result = nk_f32_sqrt_serial((nk_f32_t)distance_sq);
176
+ }
177
+
178
+ NK_PUBLIC void nk_angular_i4_serial(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_f32_t *result) {
179
+ n = nk_size_round_up_to_multiple_(n, 2);
180
+ nk_size_t n_bytes = n / 2;
181
+ nk_i32_t dot_sum = 0, a_norm_sq = 0, b_norm_sq = 0;
182
+ for (nk_size_t i = 0; i < n_bytes; ++i) {
183
+ nk_i32_t a_low = (nk_i32_t)nk_i4x2_low_(a[i]);
184
+ nk_i32_t b_low = (nk_i32_t)nk_i4x2_low_(b[i]);
185
+ nk_i32_t a_high = (nk_i32_t)nk_i4x2_high_(a[i]);
186
+ nk_i32_t b_high = (nk_i32_t)nk_i4x2_high_(b[i]);
187
+ dot_sum += a_low * b_low + a_high * b_high;
188
+ a_norm_sq += a_low * a_low + a_high * a_high;
189
+ b_norm_sq += b_low * b_low + b_high * b_high;
190
+ }
191
+ if (a_norm_sq == 0 && b_norm_sq == 0) { *result = 0; }
192
+ else if (dot_sum == 0) { *result = 1; }
193
+ else {
194
+ nk_f32_t unclipped = 1.0f - (nk_f32_t)dot_sum * nk_f32_rsqrt_serial((nk_f32_t)a_norm_sq) *
195
+ nk_f32_rsqrt_serial((nk_f32_t)b_norm_sq);
196
+ *result = unclipped > 0 ? unclipped : 0;
197
+ }
198
+ }
199
+
200
+ NK_PUBLIC void nk_sqeuclidean_u4_serial(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result) {
201
+ // u4 values are packed as nibbles: two 4-bit unsigned values per byte.
202
+ // Parameter `n` is the number of 4-bit values (dimensions), not bytes.
203
+ // No sign extension needed - values are in [0,15].
204
+ n = nk_size_round_up_to_multiple_(n, 2);
205
+ nk_size_t n_bytes = n / 2;
206
+ nk_u32_t sum = 0;
207
+ for (nk_size_t i = 0; i < n_bytes; ++i) {
208
+ nk_i32_t a_low = (nk_i32_t)nk_u4x2_low_(a[i]);
209
+ nk_i32_t b_low = (nk_i32_t)nk_u4x2_low_(b[i]);
210
+ nk_i32_t a_high = (nk_i32_t)nk_u4x2_high_(a[i]);
211
+ nk_i32_t b_high = (nk_i32_t)nk_u4x2_high_(b[i]);
212
+ nk_i32_t diff_low = a_low - b_low, diff_high = a_high - b_high;
213
+ sum += (nk_u32_t)(diff_low * diff_low + diff_high * diff_high);
214
+ }
215
+ *result = sum;
216
+ }
217
+
218
+ NK_PUBLIC void nk_euclidean_u4_serial(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_f32_t *result) {
219
+ nk_u32_t distance_sq;
220
+ nk_sqeuclidean_u4_serial(a, b, n, &distance_sq);
221
+ *result = nk_f32_sqrt_serial((nk_f32_t)distance_sq);
222
+ }
223
+
224
+ NK_PUBLIC void nk_angular_u4_serial(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_f32_t *result) {
225
+ n = nk_size_round_up_to_multiple_(n, 2);
226
+ nk_size_t n_bytes = n / 2;
227
+ nk_u32_t dot_sum = 0, a_norm_sq = 0, b_norm_sq = 0;
228
+ for (nk_size_t i = 0; i < n_bytes; ++i) {
229
+ nk_u32_t a_low = (nk_u32_t)nk_u4x2_low_(a[i]);
230
+ nk_u32_t b_low = (nk_u32_t)nk_u4x2_low_(b[i]);
231
+ nk_u32_t a_high = (nk_u32_t)nk_u4x2_high_(a[i]);
232
+ nk_u32_t b_high = (nk_u32_t)nk_u4x2_high_(b[i]);
233
+ dot_sum += a_low * b_low + a_high * b_high;
234
+ a_norm_sq += a_low * a_low + a_high * a_high;
235
+ b_norm_sq += b_low * b_low + b_high * b_high;
236
+ }
237
+ if (a_norm_sq == 0 && b_norm_sq == 0) { *result = 0; }
238
+ else if (dot_sum == 0) { *result = 1; }
239
+ else {
240
+ nk_f32_t unclipped = 1.0f - (nk_f32_t)dot_sum * nk_f32_rsqrt_serial((nk_f32_t)a_norm_sq) *
241
+ nk_f32_rsqrt_serial((nk_f32_t)b_norm_sq);
242
+ *result = unclipped > 0 ? unclipped : 0;
243
+ }
244
+ }
245
+
246
+ /** @brief Angular from_dot: computes 1 − dot × rsqrt(query_sumsq × target_sumsq) for 4 pairs (serial). */
247
+ NK_INTERNAL void nk_angular_through_f32_from_dot_serial_(nk_b128_vec_t dots, nk_f32_t query_sumsq,
248
+ nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
249
+ for (int i = 0; i < 4; ++i) {
250
+ nk_f32_t product = query_sumsq * target_sumsqs.f32s[i];
251
+ if (product > 0) {
252
+ nk_f32_t rsqrt_val = nk_f32_rsqrt_serial(product);
253
+ nk_f32_t normalized = dots.f32s[i] * rsqrt_val;
254
+ nk_f32_t result = 1.0f - normalized;
255
+ results->f32s[i] = result > 0 ? result : 0;
256
+ }
257
+ else { results->f32s[i] = (dots.f32s[i] == 0) ? 0.0f : 1.0f; }
258
+ }
259
+ }
260
+
261
+ /** @brief Euclidean from_dot: computes √(query_sumsq + target_sumsq − 2 × dot) for 4 pairs (serial). */
262
+ NK_INTERNAL void nk_euclidean_through_f32_from_dot_serial_(nk_b128_vec_t dots, nk_f32_t query_sumsq,
263
+ nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
264
+ for (int i = 0; i < 4; ++i) {
265
+ nk_f32_t dist_sq = query_sumsq + target_sumsqs.f32s[i] - 2.0f * dots.f32s[i];
266
+ results->f32s[i] = dist_sq > 0 ? nk_f32_sqrt_serial(dist_sq) : 0.0f;
267
+ }
268
+ }
269
+
270
+ /** @brief Angular from_dot for f64 precision. */
271
+ NK_INTERNAL void nk_angular_through_f64_from_dot_serial_(nk_b256_vec_t dots, nk_f64_t query_sumsq,
272
+ nk_b256_vec_t target_sumsqs, nk_b256_vec_t *results) {
273
+ for (int i = 0; i < 4; ++i) {
274
+ nk_f64_t product = query_sumsq * target_sumsqs.f64s[i];
275
+ if (product > 0) {
276
+ nk_f64_t rsqrt_val = nk_f64_rsqrt_serial(product);
277
+ nk_f64_t normalized = dots.f64s[i] * rsqrt_val;
278
+ nk_f64_t result = 1.0 - normalized;
279
+ results->f64s[i] = result > 0 ? result : 0;
280
+ }
281
+ else { results->f64s[i] = (dots.f64s[i] == 0) ? 0.0 : 1.0; }
282
+ }
283
+ }
284
+
285
+ /** @brief Euclidean from_dot for f64 precision. */
286
+ NK_INTERNAL void nk_euclidean_through_f64_from_dot_serial_(nk_b256_vec_t dots, nk_f64_t query_sumsq,
287
+ nk_b256_vec_t target_sumsqs, nk_b256_vec_t *results) {
288
+ for (int i = 0; i < 4; ++i) {
289
+ nk_f64_t dist_sq = query_sumsq + target_sumsqs.f64s[i] - 2.0 * dots.f64s[i];
290
+ results->f64s[i] = dist_sq > 0 ? nk_f64_sqrt_serial(dist_sq) : 0.0;
291
+ }
292
+ }
293
+
294
+ /** @brief Angular from_dot for i32 accumulators: cast to f32, then same math as f32 variant. */
295
+ NK_INTERNAL void nk_angular_through_i32_from_dot_serial_(nk_b128_vec_t dots, nk_i32_t query_sumsq,
296
+ nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
297
+ for (int i = 0; i < 4; ++i) {
298
+ nk_f32_t product = (nk_f32_t)query_sumsq * (nk_f32_t)target_sumsqs.i32s[i];
299
+ if (product > 0) {
300
+ nk_f32_t rsqrt_val = nk_f32_rsqrt_serial(product);
301
+ nk_f32_t normalized = (nk_f32_t)dots.i32s[i] * rsqrt_val;
302
+ nk_f32_t result = 1.0f - normalized;
303
+ results->f32s[i] = result > 0 ? result : 0;
304
+ }
305
+ else { results->f32s[i] = (dots.i32s[i] == 0) ? 0.0f : 1.0f; }
306
+ }
307
+ }
308
+
309
+ /** @brief Euclidean from_dot for i32 accumulators: cast to f32, then same math as f32 variant. */
310
+ NK_INTERNAL void nk_euclidean_through_i32_from_dot_serial_(nk_b128_vec_t dots, nk_i32_t query_sumsq,
311
+ nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
312
+ for (int i = 0; i < 4; ++i) {
313
+ nk_f32_t dist_sq = (nk_f32_t)query_sumsq + (nk_f32_t)target_sumsqs.i32s[i] - 2.0f * (nk_f32_t)dots.i32s[i];
314
+ results->f32s[i] = dist_sq > 0 ? nk_f32_sqrt_serial(dist_sq) : 0.0f;
315
+ }
316
+ }
317
+
318
+ /** @brief Angular from_dot for u32 accumulators: cast to f32, then same math as f32 variant. */
319
+ NK_INTERNAL void nk_angular_through_u32_from_dot_serial_(nk_b128_vec_t dots, nk_u32_t query_sumsq,
320
+ nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
321
+ for (int i = 0; i < 4; ++i) {
322
+ nk_f32_t product = (nk_f32_t)query_sumsq * (nk_f32_t)target_sumsqs.u32s[i];
323
+ if (product > 0) {
324
+ nk_f32_t rsqrt_val = nk_f32_rsqrt_serial(product);
325
+ nk_f32_t normalized = (nk_f32_t)dots.u32s[i] * rsqrt_val;
326
+ nk_f32_t result = 1.0f - normalized;
327
+ results->f32s[i] = result > 0 ? result : 0;
328
+ }
329
+ else { results->f32s[i] = (dots.u32s[i] == 0) ? 0.0f : 1.0f; }
330
+ }
331
+ }
332
+
333
+ /** @brief Euclidean from_dot for u32 accumulators: cast to f32, then same math as f32 variant. */
334
+ NK_INTERNAL void nk_euclidean_through_u32_from_dot_serial_(nk_b128_vec_t dots, nk_u32_t query_sumsq,
335
+ nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
336
+ for (int i = 0; i < 4; ++i) {
337
+ nk_f32_t dist_sq = (nk_f32_t)query_sumsq + (nk_f32_t)target_sumsqs.u32s[i] - 2.0f * (nk_f32_t)dots.u32s[i];
338
+ results->f32s[i] = dist_sq > 0 ? nk_f32_sqrt_serial(dist_sq) : 0.0f;
339
+ }
340
+ }
341
+
342
+ #if defined(__cplusplus)
343
+ } // extern "C"
344
+ #endif
345
+
346
+ #endif // NK_SPATIAL_SERIAL_H
@@ -0,0 +1,323 @@
1
+ /**
2
+ * @brief SIMD-accelerated Spatial Similarity Measures for Sierra Forest.
3
+ * @file include/numkong/spatial/sierra.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/spatial.h
8
+ *
9
+ * @section spatial_sierra_instructions AVXVNNIINT8 Instructions Performance
10
+ *
11
+ * Intrinsic Instruction Sierra Forest
12
+ * _mm256_dpbssds_epi32 VPDPBSSDS (YMM, YMM, YMM) 4cy @ p05
13
+ * _mm256_dpbssd_epi32 VPDPBSSD (YMM, YMM, YMM) 4cy @ p05
14
+ * _mm256_dpbuud_epi32 VPDPBUUD (YMM, YMM, YMM) 4cy @ p05
15
+ * _mm_rsqrt_ps VRSQRTPS (XMM, XMM) 5cy @ p0
16
+ * _mm_sqrt_ss VSQRTSS (XMM, XMM, XMM) 12cy @ p0
17
+ *
18
+ * Sierra Forest (AVXVNNIINT8) provides native signed x signed and unsigned x unsigned
19
+ * dot products, eliminating the need for algebraic corrections required on Alder Lake.
20
+ * This gives ~2.6x throughput over Haswell and ~1.3x over Alder for spatial kernels.
21
+ */
22
+ #ifndef NK_SPATIAL_SIERRA_H
23
+ #define NK_SPATIAL_SIERRA_H
24
+
25
+ #if NK_TARGET_X86_
26
+ #if NK_TARGET_SIERRA
27
+
28
+ #include "numkong/types.h"
29
+ #include "numkong/scalar/haswell.h" // `nk_f32_sqrt_haswell`
30
+ #include "numkong/reduce/haswell.h" // `nk_reduce_add_i32x8_haswell_`
31
+ #include "numkong/cast/serial.h" // `nk_partial_load_b8x32_serial_`
32
+
33
+ #if defined(__cplusplus)
34
+ extern "C" {
35
+ #endif
36
+
37
+ #if defined(__clang__)
38
+ #pragma clang attribute push(__attribute__((target("avx2,f16c,fma,bmi,bmi2,avxvnni,avxvnniint8"))), apply_to = function)
39
+ #elif defined(__GNUC__)
40
+ #pragma GCC push_options
41
+ #pragma GCC target("avx2", "f16c", "fma", "bmi", "bmi2", "avxvnni", "avxvnniint8")
42
+ #endif
43
+
44
+ NK_PUBLIC void nk_angular_i8_sierra(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
45
+
46
+ __m256i dot_product_i32x8 = _mm256_setzero_si256();
47
+ __m256i a_norm_sq_i32x8 = _mm256_setzero_si256();
48
+ __m256i b_norm_sq_i32x8 = _mm256_setzero_si256();
49
+
50
+ nk_size_t i = 0;
51
+ for (; i + 32 <= n; i += 32) {
52
+ __m256i a_i8x32 = _mm256_loadu_si256((__m256i const *)(a + i));
53
+ __m256i b_i8x32 = _mm256_loadu_si256((__m256i const *)(b + i));
54
+ dot_product_i32x8 = _mm256_dpbssds_epi32(dot_product_i32x8, a_i8x32, b_i8x32);
55
+ a_norm_sq_i32x8 = _mm256_dpbssds_epi32(a_norm_sq_i32x8, a_i8x32, a_i8x32);
56
+ b_norm_sq_i32x8 = _mm256_dpbssds_epi32(b_norm_sq_i32x8, b_i8x32, b_i8x32);
57
+ }
58
+
59
+ nk_i32_t dot_product_i32 = nk_reduce_add_i32x8_haswell_(dot_product_i32x8);
60
+ nk_i32_t a_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(a_norm_sq_i32x8);
61
+ nk_i32_t b_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(b_norm_sq_i32x8);
62
+
63
+ for (; i < n; ++i) {
64
+ nk_i32_t a_element_i32 = a[i], b_element_i32 = b[i];
65
+ dot_product_i32 += a_element_i32 * b_element_i32;
66
+ a_norm_sq_i32 += a_element_i32 * a_element_i32;
67
+ b_norm_sq_i32 += b_element_i32 * b_element_i32;
68
+ }
69
+
70
+ *result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32, b_norm_sq_i32);
71
+ }
72
+
73
+ NK_PUBLIC void nk_sqeuclidean_i8_sierra(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result) {
74
+ // ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b) using dpbssds (signed x signed)
75
+
76
+ __m256i dot_product_i32x8 = _mm256_setzero_si256();
77
+ __m256i a_norm_sq_i32x8 = _mm256_setzero_si256();
78
+ __m256i b_norm_sq_i32x8 = _mm256_setzero_si256();
79
+
80
+ nk_size_t i = 0;
81
+ for (; i + 32 <= n; i += 32) {
82
+ __m256i a_i8x32 = _mm256_loadu_si256((__m256i const *)(a + i));
83
+ __m256i b_i8x32 = _mm256_loadu_si256((__m256i const *)(b + i));
84
+ dot_product_i32x8 = _mm256_dpbssds_epi32(dot_product_i32x8, a_i8x32, b_i8x32);
85
+ a_norm_sq_i32x8 = _mm256_dpbssds_epi32(a_norm_sq_i32x8, a_i8x32, a_i8x32);
86
+ b_norm_sq_i32x8 = _mm256_dpbssds_epi32(b_norm_sq_i32x8, b_i8x32, b_i8x32);
87
+ }
88
+
89
+ nk_i32_t dot_product_i32 = nk_reduce_add_i32x8_haswell_(dot_product_i32x8);
90
+ nk_i32_t a_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(a_norm_sq_i32x8);
91
+ nk_i32_t b_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(b_norm_sq_i32x8);
92
+
93
+ for (; i < n; ++i) {
94
+ nk_i32_t a_element_i32 = a[i], b_element_i32 = b[i];
95
+ dot_product_i32 += a_element_i32 * b_element_i32;
96
+ a_norm_sq_i32 += a_element_i32 * a_element_i32;
97
+ b_norm_sq_i32 += b_element_i32 * b_element_i32;
98
+ }
99
+
100
+ *result = (nk_u32_t)(a_norm_sq_i32 + b_norm_sq_i32 - 2 * dot_product_i32);
101
+ }
102
+
103
+ NK_PUBLIC void nk_euclidean_i8_sierra(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
104
+ nk_u32_t distance_sq_u32;
105
+ nk_sqeuclidean_i8_sierra(a, b, n, &distance_sq_u32);
106
+ *result = nk_f32_sqrt_haswell((nk_f32_t)distance_sq_u32);
107
+ }
108
+
109
+ NK_PUBLIC void nk_angular_u8_sierra(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
110
+
111
+ __m256i dot_product_u32x8 = _mm256_setzero_si256();
112
+ __m256i a_norm_sq_u32x8 = _mm256_setzero_si256();
113
+ __m256i b_norm_sq_u32x8 = _mm256_setzero_si256();
114
+
115
+ nk_size_t i = 0;
116
+ for (; i + 32 <= n; i += 32) {
117
+ __m256i a_u8x32 = _mm256_loadu_si256((__m256i const *)(a + i));
118
+ __m256i b_u8x32 = _mm256_loadu_si256((__m256i const *)(b + i));
119
+ dot_product_u32x8 = _mm256_dpbuud_epi32(dot_product_u32x8, a_u8x32, b_u8x32);
120
+ a_norm_sq_u32x8 = _mm256_dpbuud_epi32(a_norm_sq_u32x8, a_u8x32, a_u8x32);
121
+ b_norm_sq_u32x8 = _mm256_dpbuud_epi32(b_norm_sq_u32x8, b_u8x32, b_u8x32);
122
+ }
123
+
124
+ nk_i32_t dot_product_i32 = nk_reduce_add_i32x8_haswell_(dot_product_u32x8);
125
+ nk_i32_t a_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(a_norm_sq_u32x8);
126
+ nk_i32_t b_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(b_norm_sq_u32x8);
127
+
128
+ for (; i < n; ++i) {
129
+ nk_i32_t a_element_i32 = a[i], b_element_i32 = b[i];
130
+ dot_product_i32 += a_element_i32 * b_element_i32;
131
+ a_norm_sq_i32 += a_element_i32 * a_element_i32;
132
+ b_norm_sq_i32 += b_element_i32 * b_element_i32;
133
+ }
134
+
135
+ *result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32, b_norm_sq_i32);
136
+ }
137
+
138
+ NK_PUBLIC void nk_sqeuclidean_u8_sierra(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
139
+ // ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b) using dpbuud (unsigned x unsigned)
140
+
141
+ __m256i dot_product_u32x8 = _mm256_setzero_si256();
142
+ __m256i a_norm_sq_u32x8 = _mm256_setzero_si256();
143
+ __m256i b_norm_sq_u32x8 = _mm256_setzero_si256();
144
+
145
+ nk_size_t i = 0;
146
+ for (; i + 32 <= n; i += 32) {
147
+ __m256i a_u8x32 = _mm256_loadu_si256((__m256i const *)(a + i));
148
+ __m256i b_u8x32 = _mm256_loadu_si256((__m256i const *)(b + i));
149
+ dot_product_u32x8 = _mm256_dpbuud_epi32(dot_product_u32x8, a_u8x32, b_u8x32);
150
+ a_norm_sq_u32x8 = _mm256_dpbuud_epi32(a_norm_sq_u32x8, a_u8x32, a_u8x32);
151
+ b_norm_sq_u32x8 = _mm256_dpbuud_epi32(b_norm_sq_u32x8, b_u8x32, b_u8x32);
152
+ }
153
+
154
+ nk_i32_t dot_product_i32 = nk_reduce_add_i32x8_haswell_(dot_product_u32x8);
155
+ nk_i32_t a_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(a_norm_sq_u32x8);
156
+ nk_i32_t b_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(b_norm_sq_u32x8);
157
+
158
+ for (; i < n; ++i) {
159
+ nk_i32_t a_element_i32 = a[i], b_element_i32 = b[i];
160
+ dot_product_i32 += a_element_i32 * b_element_i32;
161
+ a_norm_sq_i32 += a_element_i32 * a_element_i32;
162
+ b_norm_sq_i32 += b_element_i32 * b_element_i32;
163
+ }
164
+
165
+ *result = (nk_u32_t)(a_norm_sq_i32 + b_norm_sq_i32 - 2 * dot_product_i32);
166
+ }
167
+
168
+ NK_PUBLIC void nk_euclidean_u8_sierra(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
169
+ nk_u32_t distance_sq_u32;
170
+ nk_sqeuclidean_u8_sierra(a, b, n, &distance_sq_u32);
171
+ *result = nk_f32_sqrt_haswell((nk_f32_t)distance_sq_u32);
172
+ }
173
+
174
+ NK_PUBLIC void nk_angular_e2m3_sierra(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
175
+ nk_f32_t *result) {
176
+ // Angular distance for e2m3 using dual-VPSHUFB LUT + VPDPBSSD norm decomposition.
177
+ // Every e2m3 value × 16 is an exact integer in [-120, +120].
178
+ // DPBSSD(signed, signed) eliminates the need for unsigned conversion tricks.
179
+ //
180
+ __m256i const lut_lower_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28,
181
+ 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
182
+ __m256i const lut_upper_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
183
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
184
+ __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
185
+ __m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
186
+ __m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
187
+ __m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
188
+ __m256i dot_i32x8 = _mm256_setzero_si256();
189
+ __m256i a_norm_i32x8 = _mm256_setzero_si256();
190
+ __m256i b_norm_i32x8 = _mm256_setzero_si256();
191
+ __m256i a_e2m3_u8x32, b_e2m3_u8x32;
192
+
193
+ nk_angular_e2m3_sierra_cycle:
194
+ if (count_scalars < 32) {
195
+ nk_b256_vec_t a_vec, b_vec;
196
+ nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
197
+ nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
198
+ a_e2m3_u8x32 = a_vec.ymm;
199
+ b_e2m3_u8x32 = b_vec.ymm;
200
+ count_scalars = 0;
201
+ }
202
+ else {
203
+ a_e2m3_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
204
+ b_e2m3_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
205
+ a_scalars += 32, b_scalars += 32, count_scalars -= 32;
206
+ }
207
+
208
+ // Decode a: extract magnitude, dual-VPSHUFB LUT, apply sign
209
+ __m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
210
+ __m256i a_shuffle_idx = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
211
+ __m256i a_upper_sel = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
212
+ __m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, a_shuffle_idx),
213
+ _mm256_shuffle_epi8(lut_upper_u8x32, a_shuffle_idx), a_upper_sel);
214
+ __m256i a_negate = _mm256_cmpeq_epi8(_mm256_and_si256(a_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
215
+ __m256i a_signed_i8x32 = _mm256_blendv_epi8(a_unsigned_u8x32,
216
+ _mm256_sub_epi8(_mm256_setzero_si256(), a_unsigned_u8x32), a_negate);
217
+
218
+ // Decode b: same LUT decode + sign
219
+ __m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
220
+ __m256i b_shuffle_idx = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
221
+ __m256i b_upper_sel = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
222
+ __m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, b_shuffle_idx),
223
+ _mm256_shuffle_epi8(lut_upper_u8x32, b_shuffle_idx), b_upper_sel);
224
+ __m256i b_negate = _mm256_cmpeq_epi8(_mm256_and_si256(b_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
225
+ __m256i b_signed_i8x32 = _mm256_blendv_epi8(b_unsigned_u8x32,
226
+ _mm256_sub_epi8(_mm256_setzero_si256(), b_unsigned_u8x32), b_negate);
227
+
228
+ // VPDPBSSD: signed × signed → i32
229
+ dot_i32x8 = _mm256_dpbssd_epi32(dot_i32x8, a_signed_i8x32, b_signed_i8x32);
230
+ a_norm_i32x8 = _mm256_dpbssd_epi32(a_norm_i32x8, a_signed_i8x32, a_signed_i8x32);
231
+ b_norm_i32x8 = _mm256_dpbssd_epi32(b_norm_i32x8, b_signed_i8x32, b_signed_i8x32);
232
+
233
+ if (count_scalars) goto nk_angular_e2m3_sierra_cycle;
234
+
235
+ nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(dot_i32x8);
236
+ nk_i32_t a_norm_i32 = nk_reduce_add_i32x8_haswell_(a_norm_i32x8);
237
+ nk_i32_t b_norm_i32 = nk_reduce_add_i32x8_haswell_(b_norm_i32x8);
238
+ *result = nk_angular_normalize_f32_haswell_(dot_i32, a_norm_i32, b_norm_i32);
239
+ }
240
+
241
+ NK_PUBLIC void nk_sqeuclidean_e2m3_sierra(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars,
242
+ nk_size_t count_scalars, nk_f32_t *result) {
243
+ // Squared Euclidean distance for e2m3 using norm decomposition + VPDPBSSD.
244
+ // ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b)
245
+ //
246
+ __m256i const lut_lower_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28,
247
+ 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
248
+ __m256i const lut_upper_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
249
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
250
+ __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
251
+ __m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
252
+ __m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
253
+ __m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
254
+ __m256i dot_i32x8 = _mm256_setzero_si256();
255
+ __m256i a_norm_i32x8 = _mm256_setzero_si256();
256
+ __m256i b_norm_i32x8 = _mm256_setzero_si256();
257
+ __m256i a_e2m3_u8x32, b_e2m3_u8x32;
258
+
259
+ nk_sqeuclidean_e2m3_sierra_cycle:
260
+ if (count_scalars < 32) {
261
+ nk_b256_vec_t a_vec, b_vec;
262
+ nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
263
+ nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
264
+ a_e2m3_u8x32 = a_vec.ymm;
265
+ b_e2m3_u8x32 = b_vec.ymm;
266
+ count_scalars = 0;
267
+ }
268
+ else {
269
+ a_e2m3_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
270
+ b_e2m3_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
271
+ a_scalars += 32, b_scalars += 32, count_scalars -= 32;
272
+ }
273
+
274
+ // Decode a
275
+ __m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
276
+ __m256i a_shuffle_idx = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
277
+ __m256i a_upper_sel = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
278
+ __m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, a_shuffle_idx),
279
+ _mm256_shuffle_epi8(lut_upper_u8x32, a_shuffle_idx), a_upper_sel);
280
+ __m256i a_negate = _mm256_cmpeq_epi8(_mm256_and_si256(a_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
281
+ __m256i a_signed_i8x32 = _mm256_blendv_epi8(a_unsigned_u8x32,
282
+ _mm256_sub_epi8(_mm256_setzero_si256(), a_unsigned_u8x32), a_negate);
283
+
284
+ // Decode b
285
+ __m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
286
+ __m256i b_shuffle_idx = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
287
+ __m256i b_upper_sel = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
288
+ __m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, b_shuffle_idx),
289
+ _mm256_shuffle_epi8(lut_upper_u8x32, b_shuffle_idx), b_upper_sel);
290
+ __m256i b_negate = _mm256_cmpeq_epi8(_mm256_and_si256(b_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
291
+ __m256i b_signed_i8x32 = _mm256_blendv_epi8(b_unsigned_u8x32,
292
+ _mm256_sub_epi8(_mm256_setzero_si256(), b_unsigned_u8x32), b_negate);
293
+
294
+ dot_i32x8 = _mm256_dpbssd_epi32(dot_i32x8, a_signed_i8x32, b_signed_i8x32);
295
+ a_norm_i32x8 = _mm256_dpbssd_epi32(a_norm_i32x8, a_signed_i8x32, a_signed_i8x32);
296
+ b_norm_i32x8 = _mm256_dpbssd_epi32(b_norm_i32x8, b_signed_i8x32, b_signed_i8x32);
297
+
298
+ if (count_scalars) goto nk_sqeuclidean_e2m3_sierra_cycle;
299
+
300
+ nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(dot_i32x8);
301
+ nk_i32_t a_norm_i32 = nk_reduce_add_i32x8_haswell_(a_norm_i32x8);
302
+ nk_i32_t b_norm_i32 = nk_reduce_add_i32x8_haswell_(b_norm_i32x8);
303
+ *result = (nk_f32_t)(a_norm_i32 + b_norm_i32 - 2 * dot_i32) / 256.0f;
304
+ }
305
+
306
+ NK_PUBLIC void nk_euclidean_e2m3_sierra(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
307
+ nk_sqeuclidean_e2m3_sierra(a, b, n, result);
308
+ *result = nk_f32_sqrt_haswell(*result);
309
+ }
310
+
311
+ #if defined(__clang__)
312
+ #pragma clang attribute pop
313
+ #elif defined(__GNUC__)
314
+ #pragma GCC pop_options
315
+ #endif
316
+
317
+ #if defined(__cplusplus)
318
+ } // extern "C"
319
+ #endif
320
+
321
+ #endif // NK_TARGET_SIERRA
322
+ #endif // NK_TARGET_X86_
323
+ #endif // NK_SPATIAL_SIERRA_H