numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,607 @@
1
+ /**
2
+ * @brief SIMD-accelerated Spatial Similarity Measures for Alder Lake.
3
+ * @file include/numkong/spatial/alder.h
4
+ * @author Ash Vardanian
5
+ * @date March 4, 2026
6
+ *
7
+ * @sa include/numkong/spatial.h
8
+ *
9
+ * @section spatial_alder_instructions AVX-VNNI Instructions Performance
10
+ *
11
+ * Intrinsic Instruction Alder Lake Raptor Lake
12
+ * _mm256_dpbusd_epi32 VPDPBUSD (YMM, YMM, YMM) 4cy @ p05 4cy @ p05
13
+ * _mm256_sad_epu8 VPSADBW (YMM, YMM, YMM) 3cy @ p5 3cy @ p5
14
+ * _mm256_xor_si256 VPXOR (YMM, YMM, YMM) 1cy @ p015 1cy @ p015
15
+ * _mm256_add_epi64 VPADDQ (YMM, YMM, YMM) 1cy @ p015 1cy @ p015
16
+ * _mm_rsqrt_ps VRSQRTPS (XMM, XMM) 5cy @ p0 5cy @ p0
17
+ * _mm_sqrt_ss VSQRTSS (XMM, XMM, XMM) 12cy @ p0 12cy @ p0
18
+ *
19
+ * All spatial kernels use the dpbusd norm-decomposition approach:
20
+ * ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b)
21
+ * This avoids the p5 bottleneck from unpack operations, achieving ~2x throughput
22
+ * over Haswell's subs+unpack+madd approach (16 elem/cy vs 8 elem/cy).
23
+ */
24
+ #ifndef NK_SPATIAL_ALDER_H
25
+ #define NK_SPATIAL_ALDER_H
26
+
27
+ #if NK_TARGET_X86_
28
+ #if NK_TARGET_ALDER
29
+
30
+ #include "numkong/types.h"
31
+ #include "numkong/dot/alder.h" // VEX compat macros + dpbusd helpers
32
+ #include "numkong/scalar/haswell.h" // `nk_f32_sqrt_haswell`
33
+ #include "numkong/reduce/haswell.h"
34
+ #include "numkong/cast/serial.h" // `nk_partial_load_b8x32_serial_`
35
+
36
+ #if defined(__cplusplus)
37
+ extern "C" {
38
+ #endif
39
+
40
+ #if defined(__clang__)
41
+ #pragma clang attribute push(__attribute__((target("avx2,f16c,fma,bmi,bmi2,avxvnni"))), apply_to = function)
42
+ #elif defined(__GNUC__)
43
+ #pragma GCC push_options
44
+ #pragma GCC target("avx2", "f16c", "fma", "bmi", "bmi2", "avxvnni")
45
+ #endif
46
+
47
+ NK_PUBLIC void nk_angular_i8_alder(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
48
+ // Angular distance using DPBUSD with algebraic transformation for signed x signed.
49
+ //
50
+ // For angular distance we need: dot(a,b), ||a||^2, ||b||^2
51
+ // Using dpbusd(u8, i8) for asymmetric unsigned x signed:
52
+ // a' = a XOR 0x80 (signed -> unsigned), then dpbusd(a', b) = (a+128)*b
53
+ // a*b = dpbusd(a',b) - 128*sum(b)
54
+ //
55
+ // For norms: dpbusd(a', a) = (a+128)*a, so a^2 = dpbusd(a',a) - 128*sum(a)
56
+ // Similarly for b: dpbusd(b', b) = (b+128)*b
57
+ //
58
+ __m256i const xor_mask_u8x32 = _mm256_set1_epi8((char)0x80);
59
+ __m256i const zeros_u8x32 = _mm256_setzero_si256();
60
+ __m256i dot_product_i32x8 = _mm256_setzero_si256();
61
+ __m256i a_norm_sq_i32x8 = _mm256_setzero_si256();
62
+ __m256i b_norm_sq_i32x8 = _mm256_setzero_si256();
63
+ __m256i sum_a_biased_i64x4 = _mm256_setzero_si256();
64
+ __m256i sum_b_biased_i64x4 = _mm256_setzero_si256();
65
+
66
+ nk_size_t i = 0;
67
+ for (; i + 32 <= n; i += 32) {
68
+ __m256i a_i8x32 = _mm256_loadu_si256((__m256i const *)(a + i));
69
+ __m256i b_i8x32 = _mm256_loadu_si256((__m256i const *)(b + i));
70
+
71
+ // Convert to unsigned for dpbusd
72
+ __m256i a_unsigned_u8x32 = _mm256_xor_si256(a_i8x32, xor_mask_u8x32);
73
+ __m256i b_unsigned_u8x32 = _mm256_xor_si256(b_i8x32, xor_mask_u8x32);
74
+
75
+ // dpbusd: (a+128)*b, (a+128)*a, (b+128)*b
76
+ dot_product_i32x8 = _mm256_dpbusd_avx_epi32(dot_product_i32x8, a_unsigned_u8x32, b_i8x32);
77
+ a_norm_sq_i32x8 = _mm256_dpbusd_avx_epi32(a_norm_sq_i32x8, a_unsigned_u8x32, a_i8x32);
78
+ b_norm_sq_i32x8 = _mm256_dpbusd_avx_epi32(b_norm_sq_i32x8, b_unsigned_u8x32, b_i8x32);
79
+
80
+ // Accumulate biased sums for correction: sum(a+128), sum(b+128) via SAD
81
+ sum_a_biased_i64x4 = _mm256_add_epi64(sum_a_biased_i64x4, _mm256_sad_epu8(a_unsigned_u8x32, zeros_u8x32));
82
+ sum_b_biased_i64x4 = _mm256_add_epi64(sum_b_biased_i64x4, _mm256_sad_epu8(b_unsigned_u8x32, zeros_u8x32));
83
+ }
84
+
85
+ // Reduce and apply corrections inline:
86
+ // correction_x = 128 * sum_x_biased - 16384 * elements_processed
87
+ // value = reduce(accumulator) - correction
88
+ nk_i64_t sum_a_biased = nk_reduce_add_i64x4_haswell_(sum_a_biased_i64x4);
89
+ nk_i64_t sum_b_biased = nk_reduce_add_i64x4_haswell_(sum_b_biased_i64x4);
90
+ nk_i64_t correction_a = 128LL * sum_a_biased - 16384LL * (nk_i64_t)i;
91
+ nk_i64_t correction_b = 128LL * sum_b_biased - 16384LL * (nk_i64_t)i;
92
+
93
+ nk_i32_t dot_product_i32 = nk_reduce_add_i32x8_haswell_(dot_product_i32x8) - (nk_i32_t)correction_b;
94
+ nk_i32_t a_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(a_norm_sq_i32x8) - (nk_i32_t)correction_a;
95
+ nk_i32_t b_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(b_norm_sq_i32x8) - (nk_i32_t)correction_b;
96
+
97
+ // Scalar tail
98
+ for (; i < n; ++i) {
99
+ nk_i32_t a_element_i32 = a[i], b_element_i32 = b[i];
100
+ dot_product_i32 += a_element_i32 * b_element_i32;
101
+ a_norm_sq_i32 += a_element_i32 * a_element_i32;
102
+ b_norm_sq_i32 += b_element_i32 * b_element_i32;
103
+ }
104
+
105
+ *result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32, b_norm_sq_i32);
106
+ }
107
+
108
+ NK_PUBLIC void nk_sqeuclidean_i8_alder(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result) {
109
+ // Squared Euclidean distance for i8 using DPBUSD with norm decomposition.
110
+ // ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b)
111
+ //
112
+ __m256i const xor_mask_u8x32 = _mm256_set1_epi8((char)0x80);
113
+ __m256i const zeros_u8x32 = _mm256_setzero_si256();
114
+ __m256i dot_product_i32x8 = _mm256_setzero_si256();
115
+ __m256i a_norm_sq_i32x8 = _mm256_setzero_si256();
116
+ __m256i b_norm_sq_i32x8 = _mm256_setzero_si256();
117
+ __m256i sum_a_biased_i64x4 = _mm256_setzero_si256();
118
+ __m256i sum_b_biased_i64x4 = _mm256_setzero_si256();
119
+
120
+ nk_size_t i = 0;
121
+ for (; i + 32 <= n; i += 32) {
122
+ __m256i a_i8x32 = _mm256_loadu_si256((__m256i const *)(a + i));
123
+ __m256i b_i8x32 = _mm256_loadu_si256((__m256i const *)(b + i));
124
+ __m256i a_unsigned_u8x32 = _mm256_xor_si256(a_i8x32, xor_mask_u8x32);
125
+ __m256i b_unsigned_u8x32 = _mm256_xor_si256(b_i8x32, xor_mask_u8x32);
126
+
127
+ dot_product_i32x8 = _mm256_dpbusd_avx_epi32(dot_product_i32x8, a_unsigned_u8x32, b_i8x32);
128
+ a_norm_sq_i32x8 = _mm256_dpbusd_avx_epi32(a_norm_sq_i32x8, a_unsigned_u8x32, a_i8x32);
129
+ b_norm_sq_i32x8 = _mm256_dpbusd_avx_epi32(b_norm_sq_i32x8, b_unsigned_u8x32, b_i8x32);
130
+
131
+ sum_a_biased_i64x4 = _mm256_add_epi64(sum_a_biased_i64x4, _mm256_sad_epu8(a_unsigned_u8x32, zeros_u8x32));
132
+ sum_b_biased_i64x4 = _mm256_add_epi64(sum_b_biased_i64x4, _mm256_sad_epu8(b_unsigned_u8x32, zeros_u8x32));
133
+ }
134
+
135
+ nk_i64_t sum_a_biased = nk_reduce_add_i64x4_haswell_(sum_a_biased_i64x4);
136
+ nk_i64_t sum_b_biased = nk_reduce_add_i64x4_haswell_(sum_b_biased_i64x4);
137
+ nk_i64_t correction_a = 128LL * sum_a_biased - 16384LL * (nk_i64_t)i;
138
+ nk_i64_t correction_b = 128LL * sum_b_biased - 16384LL * (nk_i64_t)i;
139
+
140
+ nk_i32_t dot_product_i32 = nk_reduce_add_i32x8_haswell_(dot_product_i32x8) - (nk_i32_t)correction_b;
141
+ nk_i32_t a_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(a_norm_sq_i32x8) - (nk_i32_t)correction_a;
142
+ nk_i32_t b_norm_sq_i32 = nk_reduce_add_i32x8_haswell_(b_norm_sq_i32x8) - (nk_i32_t)correction_b;
143
+
144
+ for (; i < n; ++i) {
145
+ nk_i32_t a_element_i32 = a[i], b_element_i32 = b[i];
146
+ dot_product_i32 += a_element_i32 * b_element_i32;
147
+ a_norm_sq_i32 += a_element_i32 * a_element_i32;
148
+ b_norm_sq_i32 += b_element_i32 * b_element_i32;
149
+ }
150
+
151
+ // ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b)
152
+ *result = (nk_u32_t)(a_norm_sq_i32 + b_norm_sq_i32 - 2 * dot_product_i32);
153
+ }
154
+
155
+ NK_PUBLIC void nk_euclidean_i8_alder(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
156
+ nk_u32_t distance_sq_u32;
157
+ nk_sqeuclidean_i8_alder(a, b, n, &distance_sq_u32);
158
+ *result = nk_f32_sqrt_haswell((nk_f32_t)distance_sq_u32);
159
+ }
160
+
161
+ NK_PUBLIC void nk_sqeuclidean_u8_alder(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
162
+ // Squared Euclidean distance for u8 using DPBUSD with norm decomposition.
163
+ // ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b)
164
+ //
165
+ // For u8 x u8: dpbusd(a, b'^0x80) = a*(b-128), so dot(a,b) = dpbusd(a,b') + 128*sum(a)
166
+ // For norms: dpbusd(a, a'^0x80) = a*(a-128), so ||a||^2 = dpbusd(a,a') + 128*sum(a)
167
+ //
168
+ __m256i const xor_mask_u8x32 = _mm256_set1_epi8((char)0x80);
169
+ __m256i const zeros_u8x32 = _mm256_setzero_si256();
170
+ __m256i dot_product_i32x8 = _mm256_setzero_si256();
171
+ __m256i a_norm_sq_i32x8 = _mm256_setzero_si256();
172
+ __m256i b_norm_sq_i32x8 = _mm256_setzero_si256();
173
+ __m256i sum_a_u64x4 = _mm256_setzero_si256();
174
+ __m256i sum_b_u64x4 = _mm256_setzero_si256();
175
+
176
+ nk_size_t i = 0;
177
+ for (; i + 32 <= n; i += 32) {
178
+ __m256i a_u8x32 = _mm256_loadu_si256((__m256i const *)(a + i));
179
+ __m256i b_u8x32 = _mm256_loadu_si256((__m256i const *)(b + i));
180
+ __m256i a_signed_i8x32 = _mm256_xor_si256(a_u8x32, xor_mask_u8x32);
181
+ __m256i b_signed_i8x32 = _mm256_xor_si256(b_u8x32, xor_mask_u8x32);
182
+
183
+ // dpbusd(a, b-128) = a*(b-128), dpbusd(a, a-128) = a*(a-128), etc.
184
+ dot_product_i32x8 = _mm256_dpbusd_avx_epi32(dot_product_i32x8, a_u8x32, b_signed_i8x32);
185
+ a_norm_sq_i32x8 = _mm256_dpbusd_avx_epi32(a_norm_sq_i32x8, a_u8x32, a_signed_i8x32);
186
+ b_norm_sq_i32x8 = _mm256_dpbusd_avx_epi32(b_norm_sq_i32x8, b_u8x32, b_signed_i8x32);
187
+
188
+ sum_a_u64x4 = _mm256_add_epi64(sum_a_u64x4, _mm256_sad_epu8(a_u8x32, zeros_u8x32));
189
+ sum_b_u64x4 = _mm256_add_epi64(sum_b_u64x4, _mm256_sad_epu8(b_u8x32, zeros_u8x32));
190
+ }
191
+
192
+ // Corrections: x*(y-128) + 128*sum(x) = x*y
193
+ nk_i64_t sum_a_i64 = nk_reduce_add_i64x4_haswell_(sum_a_u64x4);
194
+ nk_i64_t sum_b_i64 = nk_reduce_add_i64x4_haswell_(sum_b_u64x4);
195
+ nk_i32_t dot_product_i32 = (nk_i32_t)((nk_i64_t)nk_reduce_add_i32x8_haswell_(dot_product_i32x8) +
196
+ 128LL * sum_a_i64);
197
+ nk_i32_t a_norm_sq_i32 = (nk_i32_t)((nk_i64_t)nk_reduce_add_i32x8_haswell_(a_norm_sq_i32x8) + 128LL * sum_a_i64);
198
+ nk_i32_t b_norm_sq_i32 = (nk_i32_t)((nk_i64_t)nk_reduce_add_i32x8_haswell_(b_norm_sq_i32x8) + 128LL * sum_b_i64);
199
+
200
+ for (; i < n; ++i) {
201
+ nk_i32_t a_element_i32 = a[i], b_element_i32 = b[i];
202
+ dot_product_i32 += a_element_i32 * b_element_i32;
203
+ a_norm_sq_i32 += a_element_i32 * a_element_i32;
204
+ b_norm_sq_i32 += b_element_i32 * b_element_i32;
205
+ }
206
+
207
+ *result = (nk_u32_t)(a_norm_sq_i32 + b_norm_sq_i32 - 2 * dot_product_i32);
208
+ }
209
+
210
+ NK_PUBLIC void nk_euclidean_u8_alder(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
211
+ nk_u32_t distance_sq_u32;
212
+ nk_sqeuclidean_u8_alder(a, b, n, &distance_sq_u32);
213
+ *result = nk_f32_sqrt_haswell((nk_f32_t)distance_sq_u32);
214
+ }
215
+
216
+ NK_PUBLIC void nk_angular_u8_alder(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
217
+ // Angular distance for u8 using DPBUSD with algebraic transformation.
218
+ // dpbusd(a, b'^0x80) = a*(b-128), so dot(a,b) = dpbusd(a,b') + 128*sum(a)
219
+ //
220
+ __m256i const xor_mask_u8x32 = _mm256_set1_epi8((char)0x80);
221
+ __m256i const zeros_u8x32 = _mm256_setzero_si256();
222
+ __m256i dot_product_i32x8 = _mm256_setzero_si256();
223
+ __m256i a_norm_sq_i32x8 = _mm256_setzero_si256();
224
+ __m256i b_norm_sq_i32x8 = _mm256_setzero_si256();
225
+ __m256i sum_a_u64x4 = _mm256_setzero_si256();
226
+ __m256i sum_b_u64x4 = _mm256_setzero_si256();
227
+
228
+ nk_size_t i = 0;
229
+ for (; i + 32 <= n; i += 32) {
230
+ __m256i a_u8x32 = _mm256_loadu_si256((__m256i const *)(a + i));
231
+ __m256i b_u8x32 = _mm256_loadu_si256((__m256i const *)(b + i));
232
+ __m256i a_signed_i8x32 = _mm256_xor_si256(a_u8x32, xor_mask_u8x32);
233
+ __m256i b_signed_i8x32 = _mm256_xor_si256(b_u8x32, xor_mask_u8x32);
234
+
235
+ dot_product_i32x8 = _mm256_dpbusd_avx_epi32(dot_product_i32x8, a_u8x32, b_signed_i8x32);
236
+ a_norm_sq_i32x8 = _mm256_dpbusd_avx_epi32(a_norm_sq_i32x8, a_u8x32, a_signed_i8x32);
237
+ b_norm_sq_i32x8 = _mm256_dpbusd_avx_epi32(b_norm_sq_i32x8, b_u8x32, b_signed_i8x32);
238
+
239
+ sum_a_u64x4 = _mm256_add_epi64(sum_a_u64x4, _mm256_sad_epu8(a_u8x32, zeros_u8x32));
240
+ sum_b_u64x4 = _mm256_add_epi64(sum_b_u64x4, _mm256_sad_epu8(b_u8x32, zeros_u8x32));
241
+ }
242
+
243
+ nk_i64_t sum_a_i64 = nk_reduce_add_i64x4_haswell_(sum_a_u64x4);
244
+ nk_i64_t sum_b_i64 = nk_reduce_add_i64x4_haswell_(sum_b_u64x4);
245
+ nk_i32_t dot_product_i32 = (nk_i32_t)((nk_i64_t)nk_reduce_add_i32x8_haswell_(dot_product_i32x8) +
246
+ 128LL * sum_a_i64);
247
+ nk_i32_t a_norm_sq_i32 = (nk_i32_t)((nk_i64_t)nk_reduce_add_i32x8_haswell_(a_norm_sq_i32x8) + 128LL * sum_a_i64);
248
+ nk_i32_t b_norm_sq_i32 = (nk_i32_t)((nk_i64_t)nk_reduce_add_i32x8_haswell_(b_norm_sq_i32x8) + 128LL * sum_b_i64);
249
+
250
+ for (; i < n; ++i) {
251
+ nk_i32_t a_element_i32 = a[i], b_element_i32 = b[i];
252
+ dot_product_i32 += a_element_i32 * b_element_i32;
253
+ a_norm_sq_i32 += a_element_i32 * a_element_i32;
254
+ b_norm_sq_i32 += b_element_i32 * b_element_i32;
255
+ }
256
+
257
+ *result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32, b_norm_sq_i32);
258
+ }
259
+
260
+ NK_PUBLIC void nk_angular_e2m3_alder(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
261
+ nk_f32_t *result) {
262
+ // Angular distance for e2m3 using dual-VPSHUFB LUT + VPDPBUSD norm decomposition.
263
+ // Every e2m3 value × 16 is an exact integer in [-120, +120].
264
+ // We compute dot(a,b), ||a||^2, ||b||^2 in scaled integer domain,
265
+ // then normalize: angular = 1 - dot / sqrt(||a||^2 * ||b||^2).
266
+ // Final division by 256.0f for dot and norms cancels in the ratio.
267
+ //
268
+ __m256i const lut_lower_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28,
269
+ 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
270
+ __m256i const lut_upper_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
271
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
272
+ __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
273
+ __m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
274
+ __m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
275
+ __m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
276
+ __m256i dot_i32x8 = _mm256_setzero_si256();
277
+ __m256i a_norm_i32x8 = _mm256_setzero_si256();
278
+ __m256i b_norm_i32x8 = _mm256_setzero_si256();
279
+ __m256i a_e2m3_u8x32, b_e2m3_u8x32;
280
+
281
+ nk_angular_e2m3_alder_cycle:
282
+ if (count_scalars < 32) {
283
+ nk_b256_vec_t a_vec, b_vec;
284
+ nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
285
+ nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
286
+ a_e2m3_u8x32 = a_vec.ymm;
287
+ b_e2m3_u8x32 = b_vec.ymm;
288
+ count_scalars = 0;
289
+ }
290
+ else {
291
+ a_e2m3_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
292
+ b_e2m3_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
293
+ a_scalars += 32, b_scalars += 32, count_scalars -= 32;
294
+ }
295
+
296
+ // Decode a: extract magnitude, dual-VPSHUFB LUT
297
+ __m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
298
+ __m256i a_shuffle_idx = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
299
+ __m256i a_upper_sel = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
300
+ __m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, a_shuffle_idx),
301
+ _mm256_shuffle_epi8(lut_upper_u8x32, a_shuffle_idx), a_upper_sel);
302
+
303
+ // Decode b: same LUT decode
304
+ __m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
305
+ __m256i b_shuffle_idx = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
306
+ __m256i b_upper_sel = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
307
+ __m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, b_shuffle_idx),
308
+ _mm256_shuffle_epi8(lut_upper_u8x32, b_shuffle_idx), b_upper_sel);
309
+
310
+ // Dot product with sign: combined sign from (a XOR b) & 0x20
311
+ __m256i sign_combined = _mm256_and_si256(_mm256_xor_si256(a_e2m3_u8x32, b_e2m3_u8x32), sign_mask_u8x32);
312
+ __m256i negate_mask = _mm256_cmpeq_epi8(sign_combined, sign_mask_u8x32);
313
+ __m256i b_negated = _mm256_sub_epi8(_mm256_setzero_si256(), b_unsigned_u8x32);
314
+ __m256i b_dot_i8x32 = _mm256_blendv_epi8(b_unsigned_u8x32, b_negated, negate_mask);
315
+
316
+ // DPBUSD: a_unsigned[u8] × b_signed[i8] → i32 for dot product
317
+ dot_i32x8 = _mm256_dpbusd_avx_epi32(dot_i32x8, a_unsigned_u8x32, b_dot_i8x32);
318
+ // Norms: magnitude² is always positive, DPBUSD(unsigned, unsigned-as-signed) works since max=120 < 127
319
+ a_norm_i32x8 = _mm256_dpbusd_avx_epi32(a_norm_i32x8, a_unsigned_u8x32, a_unsigned_u8x32);
320
+ b_norm_i32x8 = _mm256_dpbusd_avx_epi32(b_norm_i32x8, b_unsigned_u8x32, b_unsigned_u8x32);
321
+
322
+ if (count_scalars) goto nk_angular_e2m3_alder_cycle;
323
+
324
+ nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(dot_i32x8);
325
+ nk_i32_t a_norm_i32 = nk_reduce_add_i32x8_haswell_(a_norm_i32x8);
326
+ nk_i32_t b_norm_i32 = nk_reduce_add_i32x8_haswell_(b_norm_i32x8);
327
+ // The 256.0f factor cancels in the angular normalization ratio
328
+ *result = nk_angular_normalize_f32_haswell_(dot_i32, a_norm_i32, b_norm_i32);
329
+ }
330
+
331
+ NK_PUBLIC void nk_sqeuclidean_e2m3_alder(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars,
332
+ nk_size_t count_scalars, nk_f32_t *result) {
333
+ // Squared Euclidean distance for e2m3 using norm decomposition:
334
+ // ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b)
335
+ // Each value × 16 is exact integer, so result = integer_result / 256.0f
336
+ //
337
+ __m256i const lut_lower_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28,
338
+ 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
339
+ __m256i const lut_upper_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
340
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
341
+ __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
342
+ __m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
343
+ __m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
344
+ __m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
345
+ __m256i dot_i32x8 = _mm256_setzero_si256();
346
+ __m256i a_norm_i32x8 = _mm256_setzero_si256();
347
+ __m256i b_norm_i32x8 = _mm256_setzero_si256();
348
+ __m256i a_e2m3_u8x32, b_e2m3_u8x32;
349
+
350
+ nk_sqeuclidean_e2m3_alder_cycle:
351
+ if (count_scalars < 32) {
352
+ nk_b256_vec_t a_vec, b_vec;
353
+ nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
354
+ nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
355
+ a_e2m3_u8x32 = a_vec.ymm;
356
+ b_e2m3_u8x32 = b_vec.ymm;
357
+ count_scalars = 0;
358
+ }
359
+ else {
360
+ a_e2m3_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
361
+ b_e2m3_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
362
+ a_scalars += 32, b_scalars += 32, count_scalars -= 32;
363
+ }
364
+
365
+ // Decode a and b magnitudes via LUT
366
+ __m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
367
+ __m256i a_shuffle_idx = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
368
+ __m256i a_upper_sel = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
369
+ __m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, a_shuffle_idx),
370
+ _mm256_shuffle_epi8(lut_upper_u8x32, a_shuffle_idx), a_upper_sel);
371
+
372
+ __m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
373
+ __m256i b_shuffle_idx = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
374
+ __m256i b_upper_sel = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
375
+ __m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, b_shuffle_idx),
376
+ _mm256_shuffle_epi8(lut_upper_u8x32, b_shuffle_idx), b_upper_sel);
377
+
378
+ // Signed dot product: combined sign from (a XOR b) & 0x20
379
+ __m256i sign_combined = _mm256_and_si256(_mm256_xor_si256(a_e2m3_u8x32, b_e2m3_u8x32), sign_mask_u8x32);
380
+ __m256i negate_mask = _mm256_cmpeq_epi8(sign_combined, sign_mask_u8x32);
381
+ __m256i b_negated = _mm256_sub_epi8(_mm256_setzero_si256(), b_unsigned_u8x32);
382
+ __m256i b_dot_i8x32 = _mm256_blendv_epi8(b_unsigned_u8x32, b_negated, negate_mask);
383
+
384
+ dot_i32x8 = _mm256_dpbusd_avx_epi32(dot_i32x8, a_unsigned_u8x32, b_dot_i8x32);
385
+ a_norm_i32x8 = _mm256_dpbusd_avx_epi32(a_norm_i32x8, a_unsigned_u8x32, a_unsigned_u8x32);
386
+ b_norm_i32x8 = _mm256_dpbusd_avx_epi32(b_norm_i32x8, b_unsigned_u8x32, b_unsigned_u8x32);
387
+
388
+ if (count_scalars) goto nk_sqeuclidean_e2m3_alder_cycle;
389
+
390
+ nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(dot_i32x8);
391
+ nk_i32_t a_norm_i32 = nk_reduce_add_i32x8_haswell_(a_norm_i32x8);
392
+ nk_i32_t b_norm_i32 = nk_reduce_add_i32x8_haswell_(b_norm_i32x8);
393
+ // ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b), scaled by 256
394
+ *result = (nk_f32_t)(a_norm_i32 + b_norm_i32 - 2 * dot_i32) / 256.0f;
395
+ }
396
+
397
+ NK_PUBLIC void nk_euclidean_e2m3_alder(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
398
+ nk_sqeuclidean_e2m3_alder(a, b, n, result);
399
+ *result = nk_f32_sqrt_haswell(*result);
400
+ }
401
+
402
+ NK_PUBLIC void nk_angular_e3m2_alder(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars, nk_size_t count_scalars,
403
+ nk_f32_t *result) {
404
+ // Angular distance for e3m2 using dual-VPSHUFB LUT decode to i16 + VPDPWSSD norm decomposition.
405
+ // Every e3m2 value × 16 is an exact integer (max magnitude 448), requiring i16.
406
+ // VPDPWSSD replaces Haswell's VPMADDWD + VPADDD, saving one instruction per accumulation.
407
+ //
408
+ __m256i const lut_lo_lower_u8x32 = _mm256_set_epi8( //
409
+ 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
410
+ 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
411
+ __m256i const lut_lo_upper_u8x32 = _mm256_set_epi8( //
412
+ (char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
413
+ (char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32);
414
+ __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
415
+ __m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
416
+ __m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
417
+ __m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
418
+ __m256i const high_threshold_u8x32 = _mm256_set1_epi8(27);
419
+ __m256i const ones_u8x32 = _mm256_set1_epi8(1);
420
+ __m256i const ones_i16x16 = _mm256_set1_epi16(1);
421
+ __m256i dot_i32x8 = _mm256_setzero_si256();
422
+ __m256i a_norm_i32x8 = _mm256_setzero_si256();
423
+ __m256i b_norm_i32x8 = _mm256_setzero_si256();
424
+ __m256i a_e3m2_u8x32, b_e3m2_u8x32;
425
+
426
+ nk_angular_e3m2_alder_cycle:
427
+ if (count_scalars < 32) {
428
+ nk_b256_vec_t a_vec, b_vec;
429
+ nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
430
+ nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
431
+ a_e3m2_u8x32 = a_vec.ymm;
432
+ b_e3m2_u8x32 = b_vec.ymm;
433
+ count_scalars = 0;
434
+ }
435
+ else {
436
+ a_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
437
+ b_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
438
+ a_scalars += 32, b_scalars += 32, count_scalars -= 32;
439
+ }
440
+
441
+ // Extract 5-bit magnitude, split into low 4 bits and bit 4
442
+ __m256i a_magnitude_u8x32 = _mm256_and_si256(a_e3m2_u8x32, magnitude_mask_u8x32);
443
+ __m256i b_magnitude_u8x32 = _mm256_and_si256(b_e3m2_u8x32, magnitude_mask_u8x32);
444
+ __m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
445
+ __m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
446
+ __m256i a_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
447
+ half_select_u8x32);
448
+ __m256i b_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
449
+ half_select_u8x32);
450
+
451
+ // Dual VPSHUFB: lookup low bytes in both halves, blend based on bit 4
452
+ __m256i a_lo_bytes_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lo_lower_u8x32, a_shuffle_index_u8x32),
453
+ _mm256_shuffle_epi8(lut_lo_upper_u8x32, a_shuffle_index_u8x32),
454
+ a_upper_select_u8x32);
455
+ __m256i b_lo_bytes_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lo_lower_u8x32, b_shuffle_index_u8x32),
456
+ _mm256_shuffle_epi8(lut_lo_upper_u8x32, b_shuffle_index_u8x32),
457
+ b_upper_select_u8x32);
458
+
459
+ // High byte: 1 iff magnitude >= 28 (signed compare safe: 27 < 128)
460
+ __m256i a_hi_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(a_magnitude_u8x32, high_threshold_u8x32), ones_u8x32);
461
+ __m256i b_hi_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(b_magnitude_u8x32, high_threshold_u8x32), ones_u8x32);
462
+
463
+ // Interleave low and high bytes into i16 (little-endian: low byte first)
464
+ __m256i a_lo_i16x16 = _mm256_unpacklo_epi8(a_lo_bytes_u8x32, a_hi_bytes_u8x32);
465
+ __m256i a_hi_i16x16 = _mm256_unpackhi_epi8(a_lo_bytes_u8x32, a_hi_bytes_u8x32);
466
+ __m256i b_lo_i16x16 = _mm256_unpacklo_epi8(b_lo_bytes_u8x32, b_hi_bytes_u8x32);
467
+ __m256i b_hi_i16x16 = _mm256_unpackhi_epi8(b_lo_bytes_u8x32, b_hi_bytes_u8x32);
468
+
469
+ // Combined sign: (a ^ b) & 0x20, widen to i16 via unpack, create +1/-1 sign vector
470
+ __m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e3m2_u8x32, b_e3m2_u8x32), sign_mask_u8x32);
471
+ __m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(sign_combined_u8x32, sign_mask_u8x32);
472
+ __m256i negate_lo_i16x16 = _mm256_unpacklo_epi8(negate_mask_u8x32, negate_mask_u8x32);
473
+ __m256i negate_hi_i16x16 = _mm256_unpackhi_epi8(negate_mask_u8x32, negate_mask_u8x32);
474
+ __m256i sign_lo_i16x16 = _mm256_or_si256(negate_lo_i16x16, ones_i16x16);
475
+ __m256i sign_hi_i16x16 = _mm256_or_si256(negate_hi_i16x16, ones_i16x16);
476
+ __m256i b_signed_lo_i16x16 = _mm256_sign_epi16(b_lo_i16x16, sign_lo_i16x16);
477
+ __m256i b_signed_hi_i16x16 = _mm256_sign_epi16(b_hi_i16x16, sign_hi_i16x16);
478
+
479
+ // VPDPWSSD: i16×i16→i32 fused dot-product-accumulate (replaces VPMADDWD + VPADDD)
480
+ dot_i32x8 = _mm256_dpwssd_avx_epi32(dot_i32x8, a_lo_i16x16, b_signed_lo_i16x16);
481
+ dot_i32x8 = _mm256_dpwssd_avx_epi32(dot_i32x8, a_hi_i16x16, b_signed_hi_i16x16);
482
+ a_norm_i32x8 = _mm256_dpwssd_avx_epi32(a_norm_i32x8, a_lo_i16x16, a_lo_i16x16);
483
+ a_norm_i32x8 = _mm256_dpwssd_avx_epi32(a_norm_i32x8, a_hi_i16x16, a_hi_i16x16);
484
+ b_norm_i32x8 = _mm256_dpwssd_avx_epi32(b_norm_i32x8, b_lo_i16x16, b_lo_i16x16);
485
+ b_norm_i32x8 = _mm256_dpwssd_avx_epi32(b_norm_i32x8, b_hi_i16x16, b_hi_i16x16);
486
+
487
+ if (count_scalars) goto nk_angular_e3m2_alder_cycle;
488
+
489
+ nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(dot_i32x8);
490
+ nk_i32_t a_norm_i32 = nk_reduce_add_i32x8_haswell_(a_norm_i32x8);
491
+ nk_i32_t b_norm_i32 = nk_reduce_add_i32x8_haswell_(b_norm_i32x8);
492
+ // The 256.0f factor cancels in the angular normalization ratio
493
+ *result = nk_angular_normalize_f32_haswell_(dot_i32, a_norm_i32, b_norm_i32);
494
+ }
495
+
496
+ NK_PUBLIC void nk_sqeuclidean_e3m2_alder(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars,
497
+ nk_size_t count_scalars, nk_f32_t *result) {
498
+ // Squared Euclidean distance for e3m2 using norm decomposition + VPDPWSSD:
499
+ // ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b)
500
+ // Each value × 16 is exact integer, so result = integer_result / 256.0f
501
+ //
502
+ __m256i const lut_lo_lower_u8x32 = _mm256_set_epi8( //
503
+ 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
504
+ 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
505
+ __m256i const lut_lo_upper_u8x32 = _mm256_set_epi8( //
506
+ (char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
507
+ (char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32);
508
+ __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
509
+ __m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
510
+ __m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
511
+ __m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
512
+ __m256i const high_threshold_u8x32 = _mm256_set1_epi8(27);
513
+ __m256i const ones_u8x32 = _mm256_set1_epi8(1);
514
+ __m256i const ones_i16x16 = _mm256_set1_epi16(1);
515
+ __m256i dot_i32x8 = _mm256_setzero_si256();
516
+ __m256i a_norm_i32x8 = _mm256_setzero_si256();
517
+ __m256i b_norm_i32x8 = _mm256_setzero_si256();
518
+ __m256i a_e3m2_u8x32, b_e3m2_u8x32;
519
+
520
+ nk_sqeuclidean_e3m2_alder_cycle:
521
+ if (count_scalars < 32) {
522
+ nk_b256_vec_t a_vec, b_vec;
523
+ nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
524
+ nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
525
+ a_e3m2_u8x32 = a_vec.ymm;
526
+ b_e3m2_u8x32 = b_vec.ymm;
527
+ count_scalars = 0;
528
+ }
529
+ else {
530
+ a_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
531
+ b_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
532
+ a_scalars += 32, b_scalars += 32, count_scalars -= 32;
533
+ }
534
+
535
+ // Extract 5-bit magnitude, split into low 4 bits and bit 4
536
+ __m256i a_magnitude_u8x32 = _mm256_and_si256(a_e3m2_u8x32, magnitude_mask_u8x32);
537
+ __m256i b_magnitude_u8x32 = _mm256_and_si256(b_e3m2_u8x32, magnitude_mask_u8x32);
538
+ __m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
539
+ __m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
540
+ __m256i a_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
541
+ half_select_u8x32);
542
+ __m256i b_upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
543
+ half_select_u8x32);
544
+
545
+ // Dual VPSHUFB: lookup low bytes in both halves, blend based on bit 4
546
+ __m256i a_lo_bytes_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lo_lower_u8x32, a_shuffle_index_u8x32),
547
+ _mm256_shuffle_epi8(lut_lo_upper_u8x32, a_shuffle_index_u8x32),
548
+ a_upper_select_u8x32);
549
+ __m256i b_lo_bytes_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lo_lower_u8x32, b_shuffle_index_u8x32),
550
+ _mm256_shuffle_epi8(lut_lo_upper_u8x32, b_shuffle_index_u8x32),
551
+ b_upper_select_u8x32);
552
+
553
+ // High byte: 1 iff magnitude >= 28
554
+ __m256i a_hi_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(a_magnitude_u8x32, high_threshold_u8x32), ones_u8x32);
555
+ __m256i b_hi_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(b_magnitude_u8x32, high_threshold_u8x32), ones_u8x32);
556
+
557
+ // Interleave low and high bytes into i16
558
+ __m256i a_lo_i16x16 = _mm256_unpacklo_epi8(a_lo_bytes_u8x32, a_hi_bytes_u8x32);
559
+ __m256i a_hi_i16x16 = _mm256_unpackhi_epi8(a_lo_bytes_u8x32, a_hi_bytes_u8x32);
560
+ __m256i b_lo_i16x16 = _mm256_unpacklo_epi8(b_lo_bytes_u8x32, b_hi_bytes_u8x32);
561
+ __m256i b_hi_i16x16 = _mm256_unpackhi_epi8(b_lo_bytes_u8x32, b_hi_bytes_u8x32);
562
+
563
+ // Combined sign for dot product
564
+ __m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e3m2_u8x32, b_e3m2_u8x32), sign_mask_u8x32);
565
+ __m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(sign_combined_u8x32, sign_mask_u8x32);
566
+ __m256i negate_lo_i16x16 = _mm256_unpacklo_epi8(negate_mask_u8x32, negate_mask_u8x32);
567
+ __m256i negate_hi_i16x16 = _mm256_unpackhi_epi8(negate_mask_u8x32, negate_mask_u8x32);
568
+ __m256i sign_lo_i16x16 = _mm256_or_si256(negate_lo_i16x16, ones_i16x16);
569
+ __m256i sign_hi_i16x16 = _mm256_or_si256(negate_hi_i16x16, ones_i16x16);
570
+ __m256i b_signed_lo_i16x16 = _mm256_sign_epi16(b_lo_i16x16, sign_lo_i16x16);
571
+ __m256i b_signed_hi_i16x16 = _mm256_sign_epi16(b_hi_i16x16, sign_hi_i16x16);
572
+
573
+ // VPDPWSSD: i16×i16→i32 fused dot-product-accumulate
574
+ dot_i32x8 = _mm256_dpwssd_avx_epi32(dot_i32x8, a_lo_i16x16, b_signed_lo_i16x16);
575
+ dot_i32x8 = _mm256_dpwssd_avx_epi32(dot_i32x8, a_hi_i16x16, b_signed_hi_i16x16);
576
+ a_norm_i32x8 = _mm256_dpwssd_avx_epi32(a_norm_i32x8, a_lo_i16x16, a_lo_i16x16);
577
+ a_norm_i32x8 = _mm256_dpwssd_avx_epi32(a_norm_i32x8, a_hi_i16x16, a_hi_i16x16);
578
+ b_norm_i32x8 = _mm256_dpwssd_avx_epi32(b_norm_i32x8, b_lo_i16x16, b_lo_i16x16);
579
+ b_norm_i32x8 = _mm256_dpwssd_avx_epi32(b_norm_i32x8, b_hi_i16x16, b_hi_i16x16);
580
+
581
+ if (count_scalars) goto nk_sqeuclidean_e3m2_alder_cycle;
582
+
583
+ nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(dot_i32x8);
584
+ nk_i32_t a_norm_i32 = nk_reduce_add_i32x8_haswell_(a_norm_i32x8);
585
+ nk_i32_t b_norm_i32 = nk_reduce_add_i32x8_haswell_(b_norm_i32x8);
586
+ // ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b), scaled by 256
587
+ *result = (nk_f32_t)(a_norm_i32 + b_norm_i32 - 2 * dot_i32) / 256.0f;
588
+ }
589
+
590
+ NK_PUBLIC void nk_euclidean_e3m2_alder(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
591
+ nk_sqeuclidean_e3m2_alder(a, b, n, result);
592
+ *result = nk_f32_sqrt_haswell(*result);
593
+ }
594
+
595
+ #if defined(__clang__)
596
+ #pragma clang attribute pop
597
+ #elif defined(__GNUC__)
598
+ #pragma GCC pop_options
599
+ #endif
600
+
601
+ #if defined(__cplusplus)
602
+ } // extern "C"
603
+ #endif
604
+
605
+ #endif // NK_TARGET_ALDER
606
+ #endif // NK_TARGET_X86_
607
+ #endif // NK_SPATIAL_ALDER_H