numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,511 @@
1
+ /**
2
+ * @brief SIMD-accelerated MaxSim (angular distance late-interaction) for Alder Lake (AVX-VNNI).
3
+ * @file include/numkong/maxsim/alder.h
4
+ * @author Ash Vardanian
5
+ * @date March 5, 2026
6
+ *
7
+ * @sa include/numkong/maxsim.h
8
+ *
9
+ * Uses AVX-VNNI VPDPBUSD (u8×i8→i32 with accumulate) for coarse i8 screening.
10
+ * Unlike Haswell's VPMADDUBSW+VPMADDWD, DPBUSD has no i16 intermediate, so no saturation concern.
11
+ * Quantization range [-127, 127] (vs Haswell's [-79, 79]) for better precision.
12
+ * Bias correction via XOR-0x80 converts signed queries to unsigned, then subtracts 128 × sum_quantized.
13
+ *
14
+ * 4x4 register tiling: 4 queries × 4 documents = 16 YMM accumulators per depth loop.
15
+ * Depth steps at 32 bytes (YMM width in bytes).
16
+ */
17
+ #ifndef NK_MAXSIM_ALDER_H
18
+ #define NK_MAXSIM_ALDER_H
19
+
20
+ #if NK_TARGET_X86_
21
+ #if NK_TARGET_ALDER
22
+
23
+ #include "numkong/types.h"
24
+ #include "numkong/maxsim/serial.h" // `nk_maxsim_packed_header_t`
25
+ #include "numkong/maxsim/haswell.h" // `nk_maxsim_reduce_i32x8x4_haswell_`
26
+ #include "numkong/dot.h" // `nk_dot_bf16`, `nk_dot_f32`, `nk_dot_f16`
27
+ #include "numkong/cast/haswell.h" // `nk_f16_to_f32_haswell`
28
+ #include "numkong/spatial/haswell.h" // `nk_f32_sqrt_haswell`
29
+
30
+ // On GCC/Clang, VEX encoding is handled by target attributes.
31
+ // Alias the MSVC-specific _avx intrinsic names to standard names.
32
+ #if !defined(_MSC_VER) && !defined(_mm256_dpbusd_avx_epi32)
33
+ #define _mm256_dpbusd_avx_epi32 _mm256_dpbusd_epi32
34
+ #endif
35
+
36
+ #if defined(__cplusplus)
37
+ extern "C" {
38
+ #endif
39
+
40
+ #if defined(__clang__)
41
+ #pragma clang attribute push(__attribute__((target("avx2,f16c,fma,bmi,bmi2,avxvnni"))), apply_to = function)
42
+ #elif defined(__GNUC__)
43
+ #pragma GCC push_options
44
+ #pragma GCC target("avx2", "f16c", "fma", "bmi", "bmi2", "avxvnni")
45
+ #endif
46
+
47
+ NK_PUBLIC nk_size_t nk_maxsim_packed_size_bf16_alder(nk_size_t vector_count, nk_size_t depth) {
48
+ return nk_maxsim_packed_size_(vector_count, depth, sizeof(nk_bf16_t), 32);
49
+ }
50
+
51
+ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f32_alder(nk_size_t vector_count, nk_size_t depth) {
52
+ return nk_maxsim_packed_size_(vector_count, depth, sizeof(nk_f32_t), 32);
53
+ }
54
+
55
+ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_alder(nk_size_t vector_count, nk_size_t depth) {
56
+ return nk_maxsim_packed_size_(vector_count, depth, sizeof(nk_f16_t), 32);
57
+ }
58
+
59
+ NK_PUBLIC void nk_maxsim_pack_bf16_alder( //
60
+ nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
61
+
62
+ nk_size_t const element_bytes = sizeof(nk_bf16_t);
63
+ nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 32, element_bytes);
64
+
65
+ nk_maxsim_packed_header_t const *header = (nk_maxsim_packed_header_t const *)packed;
66
+ nk_i8_t *quantized_i8 = (nk_i8_t *)((char *)packed + header->offset_i8_data);
67
+ nk_maxsim_vector_metadata_t *metadata = (nk_maxsim_vector_metadata_t *)((char *)packed + header->offset_metadata);
68
+ char *originals = (char *)packed + header->offset_original_data;
69
+ nk_size_t const original_stride = header->original_stride_bytes;
70
+
71
+ for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
72
+ char const *source_row = (char const *)vectors + vector_index * stride;
73
+ nk_f32_t norm_sq;
74
+ nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f,
75
+ (nk_maxsim_to_f32_t)nk_bf16_to_f32_serial,
76
+ &quantized_i8[vector_index * depth_i8_padded], &metadata[vector_index], &norm_sq);
77
+ metadata[vector_index].inverse_norm_f32 = norm_sq > 0.0f ? (1.0f / nk_f32_sqrt_haswell(norm_sq)) : 0.0f;
78
+ char *destination_original = originals + vector_index * original_stride;
79
+ nk_copy_bytes_(destination_original, source_row, depth * element_bytes);
80
+ for (nk_size_t byte_index = depth * element_bytes; byte_index < original_stride; byte_index++)
81
+ destination_original[byte_index] = 0;
82
+ }
83
+ }
84
+
85
+ NK_PUBLIC void nk_maxsim_pack_f32_alder( //
86
+ nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
87
+
88
+ nk_size_t const element_bytes = sizeof(nk_f32_t);
89
+ nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 32, element_bytes);
90
+
91
+ nk_maxsim_packed_header_t const *header = (nk_maxsim_packed_header_t const *)packed;
92
+ nk_i8_t *quantized_i8 = (nk_i8_t *)((char *)packed + header->offset_i8_data);
93
+ nk_maxsim_vector_metadata_t *metadata = (nk_maxsim_vector_metadata_t *)((char *)packed + header->offset_metadata);
94
+ char *originals = (char *)packed + header->offset_original_data;
95
+ nk_size_t const original_stride = header->original_stride_bytes;
96
+
97
+ for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
98
+ char const *source_row = (char const *)vectors + vector_index * stride;
99
+ nk_f32_t norm_sq;
100
+ nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f, nk_f32_to_f32_,
101
+ &quantized_i8[vector_index * depth_i8_padded], &metadata[vector_index], &norm_sq);
102
+ metadata[vector_index].inverse_norm_f32 = norm_sq > 0.0f ? (1.0f / nk_f32_sqrt_haswell(norm_sq)) : 0.0f;
103
+ char *destination_original = originals + vector_index * original_stride;
104
+ nk_copy_bytes_(destination_original, source_row, depth * element_bytes);
105
+ for (nk_size_t byte_index = depth * element_bytes; byte_index < original_stride; byte_index++)
106
+ destination_original[byte_index] = 0;
107
+ }
108
+ }
109
+
110
+ NK_PUBLIC void nk_maxsim_pack_f16_alder( //
111
+ nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
112
+
113
+ nk_size_t const element_bytes = sizeof(nk_f16_t);
114
+ nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 32, element_bytes);
115
+
116
+ nk_maxsim_packed_header_t const *header = (nk_maxsim_packed_header_t const *)packed;
117
+ nk_i8_t *quantized_i8 = (nk_i8_t *)((char *)packed + header->offset_i8_data);
118
+ nk_maxsim_vector_metadata_t *metadata = (nk_maxsim_vector_metadata_t *)((char *)packed + header->offset_metadata);
119
+ char *originals = (char *)packed + header->offset_original_data;
120
+ nk_size_t const original_stride = header->original_stride_bytes;
121
+
122
+ for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
123
+ char const *source_row = (char const *)vectors + vector_index * stride;
124
+ nk_f32_t norm_sq;
125
+ nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f,
126
+ (nk_maxsim_to_f32_t)nk_f16_to_f32_haswell,
127
+ &quantized_i8[vector_index * depth_i8_padded], &metadata[vector_index], &norm_sq);
128
+ metadata[vector_index].inverse_norm_f32 = norm_sq > 0.0f ? (1.0f / nk_f32_sqrt_haswell(norm_sq)) : 0.0f;
129
+ char *destination_original = originals + vector_index * original_stride;
130
+ nk_copy_bytes_(destination_original, source_row, depth * element_bytes);
131
+ for (nk_size_t byte_index = depth * element_bytes; byte_index < original_stride; byte_index++)
132
+ destination_original[byte_index] = 0;
133
+ }
134
+ }
135
+
136
+ /**
137
+ * @brief Factored coarse i8 argmax kernel for Alder Lake using DPBUSD.
138
+ * Uses single VPDPBUSD instruction per query×doc pair (no i16 intermediate).
139
+ * 4Q×4D register tiling with 16 YMM accumulators.
140
+ */
141
+ NK_INTERNAL void nk_maxsim_coarse_argmax_alder_( //
142
+ nk_i8_t const *query_i8, nk_i8_t const *document_i8, //
143
+ nk_maxsim_vector_metadata_t const *document_metadata, //
144
+ nk_size_t query_count, nk_size_t document_count, //
145
+ nk_size_t depth_i8_padded, nk_u32_t *best_document_indices) {
146
+
147
+ __m256i const xor_mask_u8x32 = _mm256_set1_epi8((char)0x80);
148
+
149
+ // Primary path: 4-query grouping
150
+ nk_size_t query_block_start_index = 0;
151
+ for (; query_block_start_index + 4 <= query_count; query_block_start_index += 4) {
152
+ __m128i running_max_i32x4 = _mm_set1_epi32(NK_I32_MIN);
153
+ __m128i running_argmax_i32x4 = _mm_setzero_si128();
154
+
155
+ // 4Q×4D document blocking
156
+ nk_size_t document_block_start_index = 0;
157
+ for (; document_block_start_index + 4 <= document_count; document_block_start_index += 4) {
158
+ __m256i accumulator_tiles_i32x8[4][4];
159
+ for (nk_size_t query_tile_index = 0; query_tile_index < 4; query_tile_index++)
160
+ for (nk_size_t document_tile_index = 0; document_tile_index < 4; document_tile_index++)
161
+ accumulator_tiles_i32x8[query_tile_index][document_tile_index] = _mm256_setzero_si256();
162
+
163
+ for (nk_size_t depth_index = 0; depth_index < depth_i8_padded; depth_index += 32) {
164
+ __m256i query_biased_u8x32_0 = _mm256_xor_si256(
165
+ _mm256_loadu_si256(
166
+ (__m256i const *)(query_i8 + (query_block_start_index + 0) * depth_i8_padded + depth_index)),
167
+ xor_mask_u8x32);
168
+ __m256i query_biased_u8x32_1 = _mm256_xor_si256(
169
+ _mm256_loadu_si256(
170
+ (__m256i const *)(query_i8 + (query_block_start_index + 1) * depth_i8_padded + depth_index)),
171
+ xor_mask_u8x32);
172
+ __m256i query_biased_u8x32_2 = _mm256_xor_si256(
173
+ _mm256_loadu_si256(
174
+ (__m256i const *)(query_i8 + (query_block_start_index + 2) * depth_i8_padded + depth_index)),
175
+ xor_mask_u8x32);
176
+ __m256i query_biased_u8x32_3 = _mm256_xor_si256(
177
+ _mm256_loadu_si256(
178
+ (__m256i const *)(query_i8 + (query_block_start_index + 3) * depth_i8_padded + depth_index)),
179
+ xor_mask_u8x32);
180
+
181
+ __m256i document_i8x32;
182
+
183
+ // Document 0
184
+ document_i8x32 = _mm256_loadu_si256(
185
+ (__m256i const *)(document_i8 + (document_block_start_index + 0) * depth_i8_padded + depth_index));
186
+ accumulator_tiles_i32x8[0][0] = _mm256_dpbusd_avx_epi32(accumulator_tiles_i32x8[0][0],
187
+ query_biased_u8x32_0, document_i8x32);
188
+ accumulator_tiles_i32x8[1][0] = _mm256_dpbusd_avx_epi32(accumulator_tiles_i32x8[1][0],
189
+ query_biased_u8x32_1, document_i8x32);
190
+ accumulator_tiles_i32x8[2][0] = _mm256_dpbusd_avx_epi32(accumulator_tiles_i32x8[2][0],
191
+ query_biased_u8x32_2, document_i8x32);
192
+ accumulator_tiles_i32x8[3][0] = _mm256_dpbusd_avx_epi32(accumulator_tiles_i32x8[3][0],
193
+ query_biased_u8x32_3, document_i8x32);
194
+
195
+ // Document 1
196
+ document_i8x32 = _mm256_loadu_si256(
197
+ (__m256i const *)(document_i8 + (document_block_start_index + 1) * depth_i8_padded + depth_index));
198
+ accumulator_tiles_i32x8[0][1] = _mm256_dpbusd_avx_epi32(accumulator_tiles_i32x8[0][1],
199
+ query_biased_u8x32_0, document_i8x32);
200
+ accumulator_tiles_i32x8[1][1] = _mm256_dpbusd_avx_epi32(accumulator_tiles_i32x8[1][1],
201
+ query_biased_u8x32_1, document_i8x32);
202
+ accumulator_tiles_i32x8[2][1] = _mm256_dpbusd_avx_epi32(accumulator_tiles_i32x8[2][1],
203
+ query_biased_u8x32_2, document_i8x32);
204
+ accumulator_tiles_i32x8[3][1] = _mm256_dpbusd_avx_epi32(accumulator_tiles_i32x8[3][1],
205
+ query_biased_u8x32_3, document_i8x32);
206
+
207
+ // Document 2
208
+ document_i8x32 = _mm256_loadu_si256(
209
+ (__m256i const *)(document_i8 + (document_block_start_index + 2) * depth_i8_padded + depth_index));
210
+ accumulator_tiles_i32x8[0][2] = _mm256_dpbusd_avx_epi32(accumulator_tiles_i32x8[0][2],
211
+ query_biased_u8x32_0, document_i8x32);
212
+ accumulator_tiles_i32x8[1][2] = _mm256_dpbusd_avx_epi32(accumulator_tiles_i32x8[1][2],
213
+ query_biased_u8x32_1, document_i8x32);
214
+ accumulator_tiles_i32x8[2][2] = _mm256_dpbusd_avx_epi32(accumulator_tiles_i32x8[2][2],
215
+ query_biased_u8x32_2, document_i8x32);
216
+ accumulator_tiles_i32x8[3][2] = _mm256_dpbusd_avx_epi32(accumulator_tiles_i32x8[3][2],
217
+ query_biased_u8x32_3, document_i8x32);
218
+
219
+ // Document 3
220
+ document_i8x32 = _mm256_loadu_si256(
221
+ (__m256i const *)(document_i8 + (document_block_start_index + 3) * depth_i8_padded + depth_index));
222
+ accumulator_tiles_i32x8[0][3] = _mm256_dpbusd_avx_epi32(accumulator_tiles_i32x8[0][3],
223
+ query_biased_u8x32_0, document_i8x32);
224
+ accumulator_tiles_i32x8[1][3] = _mm256_dpbusd_avx_epi32(accumulator_tiles_i32x8[1][3],
225
+ query_biased_u8x32_1, document_i8x32);
226
+ accumulator_tiles_i32x8[2][3] = _mm256_dpbusd_avx_epi32(accumulator_tiles_i32x8[2][3],
227
+ query_biased_u8x32_2, document_i8x32);
228
+ accumulator_tiles_i32x8[3][3] = _mm256_dpbusd_avx_epi32(accumulator_tiles_i32x8[3][3],
229
+ query_biased_u8x32_3, document_i8x32);
230
+ }
231
+
232
+ // Reduce each query's 4 doc accumulators -> __m128i
233
+ __m128i query_0_coarse_dots_i32x4 = nk_maxsim_reduce_i32x8x4_haswell_(
234
+ accumulator_tiles_i32x8[0][0], accumulator_tiles_i32x8[0][1], accumulator_tiles_i32x8[0][2],
235
+ accumulator_tiles_i32x8[0][3]);
236
+ __m128i query_1_coarse_dots_i32x4 = nk_maxsim_reduce_i32x8x4_haswell_(
237
+ accumulator_tiles_i32x8[1][0], accumulator_tiles_i32x8[1][1], accumulator_tiles_i32x8[1][2],
238
+ accumulator_tiles_i32x8[1][3]);
239
+ __m128i query_2_coarse_dots_i32x4 = nk_maxsim_reduce_i32x8x4_haswell_(
240
+ accumulator_tiles_i32x8[2][0], accumulator_tiles_i32x8[2][1], accumulator_tiles_i32x8[2][2],
241
+ accumulator_tiles_i32x8[2][3]);
242
+ __m128i query_3_coarse_dots_i32x4 = nk_maxsim_reduce_i32x8x4_haswell_(
243
+ accumulator_tiles_i32x8[3][0], accumulator_tiles_i32x8[3][1], accumulator_tiles_i32x8[3][2],
244
+ accumulator_tiles_i32x8[3][3]);
245
+
246
+ // Bias correction: subtract 128 × sum_quantized for each document
247
+ __m128i bias_correction_i32x4 = _mm_set_epi32(
248
+ 128 * document_metadata[document_block_start_index + 3].sum_i8_i32,
249
+ 128 * document_metadata[document_block_start_index + 2].sum_i8_i32,
250
+ 128 * document_metadata[document_block_start_index + 1].sum_i8_i32,
251
+ 128 * document_metadata[document_block_start_index + 0].sum_i8_i32);
252
+ query_0_coarse_dots_i32x4 = _mm_sub_epi32(query_0_coarse_dots_i32x4, bias_correction_i32x4);
253
+ query_1_coarse_dots_i32x4 = _mm_sub_epi32(query_1_coarse_dots_i32x4, bias_correction_i32x4);
254
+ query_2_coarse_dots_i32x4 = _mm_sub_epi32(query_2_coarse_dots_i32x4, bias_correction_i32x4);
255
+ query_3_coarse_dots_i32x4 = _mm_sub_epi32(query_3_coarse_dots_i32x4, bias_correction_i32x4);
256
+
257
+ // 4x4 transpose: [query][doc] -> [doc][query] for vectorized argmax
258
+ __m128i transpose_queries_01_low_i32x4 = _mm_unpacklo_epi32(query_0_coarse_dots_i32x4,
259
+ query_1_coarse_dots_i32x4);
260
+ __m128i transpose_queries_23_low_i32x4 = _mm_unpacklo_epi32(query_2_coarse_dots_i32x4,
261
+ query_3_coarse_dots_i32x4);
262
+ __m128i transpose_queries_01_high_i32x4 = _mm_unpackhi_epi32(query_0_coarse_dots_i32x4,
263
+ query_1_coarse_dots_i32x4);
264
+ __m128i transpose_queries_23_high_i32x4 = _mm_unpackhi_epi32(query_2_coarse_dots_i32x4,
265
+ query_3_coarse_dots_i32x4);
266
+ __m128i document_0_dots_i32x4 = _mm_unpacklo_epi64(transpose_queries_01_low_i32x4,
267
+ transpose_queries_23_low_i32x4);
268
+ __m128i document_1_dots_i32x4 = _mm_unpackhi_epi64(transpose_queries_01_low_i32x4,
269
+ transpose_queries_23_low_i32x4);
270
+ __m128i document_2_dots_i32x4 = _mm_unpacklo_epi64(transpose_queries_01_high_i32x4,
271
+ transpose_queries_23_high_i32x4);
272
+ __m128i document_3_dots_i32x4 = _mm_unpackhi_epi64(transpose_queries_01_high_i32x4,
273
+ transpose_queries_23_high_i32x4);
274
+
275
+ // Branchless SIMD argmax
276
+ __m128i comparison_mask_i32x4, document_index_i32x4;
277
+
278
+ comparison_mask_i32x4 = _mm_cmpgt_epi32(document_0_dots_i32x4, running_max_i32x4);
279
+ document_index_i32x4 = _mm_set1_epi32((int)(document_block_start_index + 0));
280
+ running_max_i32x4 = _mm_blendv_epi8(running_max_i32x4, document_0_dots_i32x4, comparison_mask_i32x4);
281
+ running_argmax_i32x4 = _mm_blendv_epi8(running_argmax_i32x4, document_index_i32x4, comparison_mask_i32x4);
282
+
283
+ comparison_mask_i32x4 = _mm_cmpgt_epi32(document_1_dots_i32x4, running_max_i32x4);
284
+ document_index_i32x4 = _mm_set1_epi32((int)(document_block_start_index + 1));
285
+ running_max_i32x4 = _mm_blendv_epi8(running_max_i32x4, document_1_dots_i32x4, comparison_mask_i32x4);
286
+ running_argmax_i32x4 = _mm_blendv_epi8(running_argmax_i32x4, document_index_i32x4, comparison_mask_i32x4);
287
+
288
+ comparison_mask_i32x4 = _mm_cmpgt_epi32(document_2_dots_i32x4, running_max_i32x4);
289
+ document_index_i32x4 = _mm_set1_epi32((int)(document_block_start_index + 2));
290
+ running_max_i32x4 = _mm_blendv_epi8(running_max_i32x4, document_2_dots_i32x4, comparison_mask_i32x4);
291
+ running_argmax_i32x4 = _mm_blendv_epi8(running_argmax_i32x4, document_index_i32x4, comparison_mask_i32x4);
292
+
293
+ comparison_mask_i32x4 = _mm_cmpgt_epi32(document_3_dots_i32x4, running_max_i32x4);
294
+ document_index_i32x4 = _mm_set1_epi32((int)(document_block_start_index + 3));
295
+ running_max_i32x4 = _mm_blendv_epi8(running_max_i32x4, document_3_dots_i32x4, comparison_mask_i32x4);
296
+ running_argmax_i32x4 = _mm_blendv_epi8(running_argmax_i32x4, document_index_i32x4, comparison_mask_i32x4);
297
+ }
298
+
299
+ // Document tail: 4Q×1D
300
+ for (nk_size_t document_index = document_block_start_index; document_index < document_count; document_index++) {
301
+ nk_i8_t const *document_i8_row = document_i8 + document_index * depth_i8_padded;
302
+
303
+ __m256i accumulator_i32x8_0 = _mm256_setzero_si256();
304
+ __m256i accumulator_i32x8_1 = _mm256_setzero_si256();
305
+ __m256i accumulator_i32x8_2 = _mm256_setzero_si256();
306
+ __m256i accumulator_i32x8_3 = _mm256_setzero_si256();
307
+
308
+ for (nk_size_t depth_index = 0; depth_index < depth_i8_padded; depth_index += 32) {
309
+ __m256i document_i8x32 = _mm256_loadu_si256((__m256i const *)(document_i8_row + depth_index));
310
+
311
+ accumulator_i32x8_0 = _mm256_dpbusd_avx_epi32(
312
+ accumulator_i32x8_0,
313
+ _mm256_xor_si256(
314
+ _mm256_loadu_si256((
315
+ __m256i const *)(query_i8 + (query_block_start_index + 0) * depth_i8_padded + depth_index)),
316
+ xor_mask_u8x32),
317
+ document_i8x32);
318
+
319
+ accumulator_i32x8_1 = _mm256_dpbusd_avx_epi32(
320
+ accumulator_i32x8_1,
321
+ _mm256_xor_si256(
322
+ _mm256_loadu_si256((
323
+ __m256i const *)(query_i8 + (query_block_start_index + 1) * depth_i8_padded + depth_index)),
324
+ xor_mask_u8x32),
325
+ document_i8x32);
326
+
327
+ accumulator_i32x8_2 = _mm256_dpbusd_avx_epi32(
328
+ accumulator_i32x8_2,
329
+ _mm256_xor_si256(
330
+ _mm256_loadu_si256((
331
+ __m256i const *)(query_i8 + (query_block_start_index + 2) * depth_i8_padded + depth_index)),
332
+ xor_mask_u8x32),
333
+ document_i8x32);
334
+
335
+ accumulator_i32x8_3 = _mm256_dpbusd_avx_epi32(
336
+ accumulator_i32x8_3,
337
+ _mm256_xor_si256(
338
+ _mm256_loadu_si256((
339
+ __m256i const *)(query_i8 + (query_block_start_index + 3) * depth_i8_padded + depth_index)),
340
+ xor_mask_u8x32),
341
+ document_i8x32);
342
+ }
343
+
344
+ __m128i reduced_i32x4 = nk_maxsim_reduce_i32x8x4_haswell_(accumulator_i32x8_0, accumulator_i32x8_1,
345
+ accumulator_i32x8_2, accumulator_i32x8_3);
346
+ nk_i32_t bias_correction_i32 = 128 * document_metadata[document_index].sum_i8_i32;
347
+ __m128i coarse_dots_i32x4 = _mm_sub_epi32(reduced_i32x4, _mm_set1_epi32(bias_correction_i32));
348
+
349
+ __m128i comparison_mask_i32x4 = _mm_cmpgt_epi32(coarse_dots_i32x4, running_max_i32x4);
350
+ __m128i document_index_i32x4 = _mm_set1_epi32((int)document_index);
351
+ running_max_i32x4 = _mm_blendv_epi8(running_max_i32x4, coarse_dots_i32x4, comparison_mask_i32x4);
352
+ running_argmax_i32x4 = _mm_blendv_epi8(running_argmax_i32x4, document_index_i32x4, comparison_mask_i32x4);
353
+ }
354
+
355
+ best_document_indices[query_block_start_index + 0] = (nk_u32_t)_mm_extract_epi32(running_argmax_i32x4, 0);
356
+ best_document_indices[query_block_start_index + 1] = (nk_u32_t)_mm_extract_epi32(running_argmax_i32x4, 1);
357
+ best_document_indices[query_block_start_index + 2] = (nk_u32_t)_mm_extract_epi32(running_argmax_i32x4, 2);
358
+ best_document_indices[query_block_start_index + 3] = (nk_u32_t)_mm_extract_epi32(running_argmax_i32x4, 3);
359
+ }
360
+
361
+ // Query tail: 1Q×1D
362
+ for (nk_size_t query_index = query_block_start_index; query_index < query_count; query_index++) {
363
+ nk_i8_t const *query_i8_row = query_i8 + query_index * depth_i8_padded;
364
+ nk_i32_t running_max_i32 = NK_I32_MIN;
365
+ nk_u32_t running_argmax_u32 = 0;
366
+
367
+ for (nk_size_t document_index = 0; document_index < document_count; document_index++) {
368
+ nk_i8_t const *document_i8_row = document_i8 + document_index * depth_i8_padded;
369
+ __m256i accumulator_i32x8 = _mm256_setzero_si256();
370
+
371
+ for (nk_size_t depth_index = 0; depth_index < depth_i8_padded; depth_index += 32) {
372
+ __m256i document_i8x32 = _mm256_loadu_si256((__m256i const *)(document_i8_row + depth_index));
373
+ __m256i query_biased_u8x32 = _mm256_xor_si256(
374
+ _mm256_loadu_si256((__m256i const *)(query_i8_row + depth_index)), xor_mask_u8x32);
375
+ accumulator_i32x8 = _mm256_dpbusd_avx_epi32(accumulator_i32x8, query_biased_u8x32, document_i8x32);
376
+ }
377
+
378
+ // Horizontal sum of 8 i32 lanes
379
+ __m128i sum_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(accumulator_i32x8),
380
+ _mm256_extracti128_si256(accumulator_i32x8, 1));
381
+ sum_i32x4 = _mm_add_epi32(sum_i32x4, _mm_shuffle_epi32(sum_i32x4, 0x4E)); // 01001110
382
+ sum_i32x4 = _mm_add_epi32(sum_i32x4, _mm_shuffle_epi32(sum_i32x4, 0xB1)); // 10110001
383
+ nk_i32_t coarse_dot_i32 = _mm_extract_epi32(sum_i32x4, 0) -
384
+ 128 * document_metadata[document_index].sum_i8_i32;
385
+
386
+ if (coarse_dot_i32 > running_max_i32) {
387
+ running_max_i32 = coarse_dot_i32;
388
+ running_argmax_u32 = (nk_u32_t)document_index;
389
+ }
390
+ }
391
+
392
+ best_document_indices[query_index] = running_argmax_u32;
393
+ }
394
+ }
395
+
396
+ NK_PUBLIC void nk_maxsim_packed_bf16_alder( //
397
+ void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
398
+ nk_size_t depth, nk_f32_t *result) {
399
+
400
+ nk_maxsim_packed_regions_t regions = nk_maxsim_extract_packed_regions_(query_packed, document_packed);
401
+ nk_f64_t total_angular_distance = 0.0;
402
+
403
+ for (nk_size_t chunk_start = 0; chunk_start < query_count; chunk_start += 256) {
404
+ nk_size_t chunk_size = query_count - chunk_start < 256 ? query_count - chunk_start : 256;
405
+ nk_u32_t best_document_indices[256];
406
+
407
+ nk_maxsim_coarse_argmax_alder_(regions.query_quantized + chunk_start * regions.depth_i8_padded,
408
+ regions.document_quantized, regions.document_metadata, chunk_size,
409
+ document_count, regions.depth_i8_padded, best_document_indices);
410
+
411
+ for (nk_size_t query_index = 0; query_index < chunk_size; query_index++) {
412
+ nk_u32_t best_document_index = best_document_indices[query_index];
413
+ nk_f32_t dot_result;
414
+ nk_dot_bf16((nk_bf16_t const *)(regions.query_originals +
415
+ (chunk_start + query_index) * regions.query_original_stride),
416
+ (nk_bf16_t const *)(regions.document_originals +
417
+ best_document_index * regions.document_original_stride),
418
+ depth, &dot_result);
419
+ nk_f32_t cosine = dot_result * regions.query_metadata[chunk_start + query_index].inverse_norm_f32 *
420
+ regions.document_metadata[best_document_index].inverse_norm_f32;
421
+ nk_f32_t angular = 1.0f - cosine;
422
+ if (angular < 0.0f) angular = 0.0f;
423
+ total_angular_distance += (nk_f64_t)angular;
424
+ }
425
+ }
426
+
427
+ *result = (nk_f32_t)total_angular_distance;
428
+ }
429
+
430
+ NK_PUBLIC void nk_maxsim_packed_f32_alder( //
431
+ void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
432
+ nk_size_t depth, nk_f64_t *result) {
433
+
434
+ nk_maxsim_packed_regions_t regions = nk_maxsim_extract_packed_regions_(query_packed, document_packed);
435
+ nk_f64_t total_angular_distance = 0.0;
436
+
437
+ for (nk_size_t chunk_start = 0; chunk_start < query_count; chunk_start += 256) {
438
+ nk_size_t chunk_size = query_count - chunk_start < 256 ? query_count - chunk_start : 256;
439
+ nk_u32_t best_document_indices[256];
440
+
441
+ nk_maxsim_coarse_argmax_alder_(regions.query_quantized + chunk_start * regions.depth_i8_padded,
442
+ regions.document_quantized, regions.document_metadata, chunk_size,
443
+ document_count, regions.depth_i8_padded, best_document_indices);
444
+
445
+ for (nk_size_t query_index = 0; query_index < chunk_size; query_index++) {
446
+ nk_u32_t best_document_index = best_document_indices[query_index];
447
+ nk_f64_t dot_result;
448
+ nk_dot_f32(
449
+ (nk_f32_t const *)(regions.query_originals +
450
+ (chunk_start + query_index) * regions.query_original_stride),
451
+ (nk_f32_t const *)(regions.document_originals + best_document_index * regions.document_original_stride),
452
+ depth, &dot_result);
453
+ nk_f64_t cosine = dot_result *
454
+ (nk_f64_t)regions.query_metadata[chunk_start + query_index].inverse_norm_f32 *
455
+ (nk_f64_t)regions.document_metadata[best_document_index].inverse_norm_f32;
456
+ nk_f64_t angular = 1.0 - cosine;
457
+ if (angular < 0.0) angular = 0.0;
458
+ total_angular_distance += angular;
459
+ }
460
+ }
461
+
462
+ *result = total_angular_distance;
463
+ }
464
+
465
+ NK_PUBLIC void nk_maxsim_packed_f16_alder( //
466
+ void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
467
+ nk_size_t depth, nk_f32_t *result) {
468
+
469
+ nk_maxsim_packed_regions_t regions = nk_maxsim_extract_packed_regions_(query_packed, document_packed);
470
+ nk_f64_t total_angular_distance = 0.0;
471
+
472
+ for (nk_size_t chunk_start = 0; chunk_start < query_count; chunk_start += 256) {
473
+ nk_size_t chunk_size = query_count - chunk_start < 256 ? query_count - chunk_start : 256;
474
+ nk_u32_t best_document_indices[256];
475
+
476
+ nk_maxsim_coarse_argmax_alder_(regions.query_quantized + chunk_start * regions.depth_i8_padded,
477
+ regions.document_quantized, regions.document_metadata, chunk_size,
478
+ document_count, regions.depth_i8_padded, best_document_indices);
479
+
480
+ for (nk_size_t query_index = 0; query_index < chunk_size; query_index++) {
481
+ nk_u32_t best_document_index = best_document_indices[query_index];
482
+ nk_f32_t dot_result;
483
+ nk_dot_f16(
484
+ (nk_f16_t const *)(regions.query_originals +
485
+ (chunk_start + query_index) * regions.query_original_stride),
486
+ (nk_f16_t const *)(regions.document_originals + best_document_index * regions.document_original_stride),
487
+ depth, &dot_result);
488
+ nk_f32_t cosine = dot_result * regions.query_metadata[chunk_start + query_index].inverse_norm_f32 *
489
+ regions.document_metadata[best_document_index].inverse_norm_f32;
490
+ nk_f32_t angular = 1.0f - cosine;
491
+ if (angular < 0.0f) angular = 0.0f;
492
+ total_angular_distance += (nk_f64_t)angular;
493
+ }
494
+ }
495
+
496
+ *result = (nk_f32_t)total_angular_distance;
497
+ }
498
+
499
+ #if defined(__clang__)
500
+ #pragma clang attribute pop
501
+ #elif defined(__GNUC__)
502
+ #pragma GCC pop_options
503
+ #endif
504
+
505
+ #if defined(__cplusplus)
506
+ } // extern "C"
507
+ #endif
508
+
509
+ #endif // NK_TARGET_ALDER
510
+ #endif // NK_TARGET_X86_
511
+ #endif // NK_MAXSIM_ALDER_H
@@ -0,0 +1,115 @@
1
+ /**
2
+ * @brief SIMD-accelerated MaxSim (ColBERT late-interaction) for Genoa — bf16 only.
3
+ * @file include/numkong/maxsim/genoa.h
4
+ * @author Ash Vardanian
5
+ * @date February 17, 2026
6
+ *
7
+ * @sa include/numkong/maxsim.h
8
+ *
9
+ * Uses AVX-512 VNNI (VPDPBUSD) for coarse i8 screening via icelake.h, and VDPBF16PS for bf16 refinement.
10
+ * f32/f16 MaxSim variants live in icelake.h — this file only provides bf16 pack and compute.
11
+ *
12
+ * Intrinsic Instruction Genoa (Zen4)
13
+ * _mm512_dpbf16_ps VDPBF16PS 6cy @ p01 (512-bit)
14
+ */
15
+ #ifndef NK_MAXSIM_GENOA_H
16
+ #define NK_MAXSIM_GENOA_H
17
+
18
+ #if NK_TARGET_X86_
19
+ #if NK_TARGET_GENOA
20
+
21
+ #include "numkong/types.h"
22
+ #include "numkong/maxsim/icelake.h" // `nk_maxsim_coarse_argmax_icelake_`
23
+ #include "numkong/dot.h" // `nk_dot_bf16`
24
+
25
+ #if defined(__cplusplus)
26
+ extern "C" {
27
+ #endif
28
+
29
+ #if defined(__clang__)
30
+ #pragma clang attribute push( \
31
+ __attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512vnni,avx512bf16,f16c,fma,bmi,bmi2"))), \
32
+ apply_to = function)
33
+ #elif defined(__GNUC__)
34
+ #pragma GCC push_options
35
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vnni", "avx512bf16", "f16c", "fma", \
36
+ "bmi", "bmi2")
37
+ #endif
38
+
39
+ NK_PUBLIC nk_size_t nk_maxsim_packed_size_bf16_genoa(nk_size_t vector_count, nk_size_t depth) {
40
+ return nk_maxsim_packed_size_(vector_count, depth, sizeof(nk_bf16_t), 64);
41
+ }
42
+
43
+ NK_PUBLIC void nk_maxsim_pack_bf16_genoa( //
44
+ nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
45
+
46
+ nk_size_t const element_bytes = sizeof(nk_bf16_t);
47
+ nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 64, element_bytes);
48
+
49
+ nk_maxsim_packed_header_t const *header = (nk_maxsim_packed_header_t const *)packed;
50
+ nk_i8_t *quantized_i8 = (nk_i8_t *)((char *)packed + header->offset_i8_data);
51
+ nk_maxsim_vector_metadata_t *metadata = (nk_maxsim_vector_metadata_t *)((char *)packed + header->offset_metadata);
52
+ char *originals = (char *)packed + header->offset_original_data;
53
+ nk_size_t const original_stride = header->original_stride_bytes;
54
+
55
+ for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
56
+ char const *source_row = (char const *)vectors + vector_index * stride;
57
+ nk_f32_t norm_sq;
58
+ nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f,
59
+ (nk_maxsim_to_f32_t)nk_bf16_to_f32_serial,
60
+ &quantized_i8[vector_index * depth_i8_padded], &metadata[vector_index], &norm_sq);
61
+ metadata[vector_index].inverse_norm_f32 = norm_sq > 0.0f ? (1.0f / nk_f32_sqrt_haswell(norm_sq)) : 0.0f;
62
+ char *destination_original = originals + vector_index * original_stride;
63
+ nk_copy_bytes_(destination_original, source_row, depth * element_bytes);
64
+ for (nk_size_t byte_index = depth * element_bytes; byte_index < original_stride; byte_index++)
65
+ destination_original[byte_index] = 0;
66
+ }
67
+ }
68
+
69
+ NK_PUBLIC void nk_maxsim_packed_bf16_genoa( //
70
+ void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
71
+ nk_size_t depth, nk_f32_t *result) {
72
+
73
+ nk_maxsim_packed_regions_t regions = nk_maxsim_extract_packed_regions_(query_packed, document_packed);
74
+ nk_f64_t total_angular_distance = 0.0;
75
+
76
+ for (nk_size_t chunk_start = 0; chunk_start < query_count; chunk_start += 256) {
77
+ nk_size_t chunk_size = query_count - chunk_start < 256 ? query_count - chunk_start : 256;
78
+ nk_u32_t best_document_indices[256];
79
+
80
+ nk_maxsim_coarse_argmax_icelake_(regions.query_quantized + chunk_start * regions.depth_i8_padded,
81
+ regions.document_quantized, regions.document_metadata, chunk_size,
82
+ document_count, regions.depth_i8_padded, best_document_indices);
83
+
84
+ for (nk_size_t query_index = 0; query_index < chunk_size; query_index++) {
85
+ nk_u32_t best_document_index = best_document_indices[query_index];
86
+ nk_f32_t dot_result;
87
+ nk_dot_bf16((nk_bf16_t const *)(regions.query_originals +
88
+ (chunk_start + query_index) * regions.query_original_stride),
89
+ (nk_bf16_t const *)(regions.document_originals +
90
+ best_document_index * regions.document_original_stride),
91
+ depth, &dot_result);
92
+ nk_f32_t cosine = dot_result * regions.query_metadata[chunk_start + query_index].inverse_norm_f32 *
93
+ regions.document_metadata[best_document_index].inverse_norm_f32;
94
+ nk_f32_t angular = 1.0f - cosine;
95
+ if (angular < 0.0f) angular = 0.0f;
96
+ total_angular_distance += (nk_f64_t)angular;
97
+ }
98
+ }
99
+
100
+ *result = (nk_f32_t)total_angular_distance;
101
+ }
102
+
103
+ #if defined(__clang__)
104
+ #pragma clang attribute pop
105
+ #elif defined(__GNUC__)
106
+ #pragma GCC pop_options
107
+ #endif
108
+
109
+ #if defined(__cplusplus)
110
+ } // extern "C"
111
+ #endif
112
+
113
+ #endif // NK_TARGET_GENOA
114
+ #endif // NK_TARGET_X86_
115
+ #endif // NK_MAXSIM_GENOA_H