numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,553 @@
1
+ /**
2
+ * @brief SIMD-accelerated MaxSim (angular distance late-interaction) for Haswell (AVX2).
3
+ * @file include/numkong/maxsim/haswell.h
4
+ * @author Ash Vardanian
5
+ * @date February 28, 2026
6
+ *
7
+ * @sa include/numkong/maxsim.h
8
+ *
9
+ * Uses AVX2 VPMADDUBSW (u8×i8→i16) + VPMADDWD (i16→i32) for coarse i8 screening.
10
+ * Quantization range [-79, 79] ensures no i16 saturation: worst pair sum = 2 × 207 × 79 = 32706 < 32767.
11
+ * Bias correction via XOR-0x80 converts signed queries to unsigned, then subtracts 128 × sum_quantized.
12
+ *
13
+ * 4x4 register tiling: 4 queries × 4 documents = 16 YMM accumulators per depth loop.
14
+ * Depth steps at 32 bytes (YMM width in bytes).
15
+ */
16
+ #ifndef NK_MAXSIM_HASWELL_H
17
+ #define NK_MAXSIM_HASWELL_H
18
+
19
+ #if NK_TARGET_X86_
20
+ #if NK_TARGET_HASWELL
21
+
22
+ #include "numkong/types.h"
23
+ #include "numkong/maxsim/serial.h" // `nk_maxsim_packed_header_t`
24
+ #include "numkong/dot.h" // `nk_dot_bf16`, `nk_dot_f32`, `nk_dot_f16`
25
+ #include "numkong/cast/haswell.h" // `nk_f16_to_f32_haswell`
26
+ #include "numkong/spatial/haswell.h" // `nk_f32_sqrt_haswell`
27
+
28
+ #if defined(__cplusplus)
29
+ extern "C" {
30
+ #endif
31
+
32
+ #if defined(__clang__)
33
+ #pragma clang attribute push(__attribute__((target("avx2,f16c,fma,bmi,bmi2"))), apply_to = function)
34
+ #elif defined(__GNUC__)
35
+ #pragma GCC push_options
36
+ #pragma GCC target("avx2", "f16c", "fma", "bmi", "bmi2")
37
+ #endif
38
+
39
+ NK_PUBLIC nk_size_t nk_maxsim_packed_size_bf16_haswell(nk_size_t vector_count, nk_size_t depth) {
40
+ return nk_maxsim_packed_size_(vector_count, depth, sizeof(nk_bf16_t), 32);
41
+ }
42
+
43
+ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f32_haswell(nk_size_t vector_count, nk_size_t depth) {
44
+ return nk_maxsim_packed_size_(vector_count, depth, sizeof(nk_f32_t), 32);
45
+ }
46
+
47
+ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_haswell(nk_size_t vector_count, nk_size_t depth) {
48
+ return nk_maxsim_packed_size_(vector_count, depth, sizeof(nk_f16_t), 32);
49
+ }
50
+
51
+ NK_PUBLIC void nk_maxsim_pack_bf16_haswell( //
52
+ nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
53
+
54
+ nk_size_t const element_bytes = sizeof(nk_bf16_t);
55
+ nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 32, element_bytes);
56
+
57
+ nk_maxsim_packed_header_t const *header = (nk_maxsim_packed_header_t const *)packed;
58
+ nk_i8_t *quantized_i8 = (nk_i8_t *)((char *)packed + header->offset_i8_data);
59
+ nk_maxsim_vector_metadata_t *metadata = (nk_maxsim_vector_metadata_t *)((char *)packed + header->offset_metadata);
60
+ char *originals = (char *)packed + header->offset_original_data;
61
+ nk_size_t const original_stride = header->original_stride_bytes;
62
+
63
+ for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
64
+ char const *source_row = (char const *)vectors + vector_index * stride;
65
+ nk_f32_t norm_sq;
66
+ nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 79.0f,
67
+ (nk_maxsim_to_f32_t)nk_bf16_to_f32_serial,
68
+ &quantized_i8[vector_index * depth_i8_padded], &metadata[vector_index], &norm_sq);
69
+ metadata[vector_index].inverse_norm_f32 = norm_sq > 0.0f ? (1.0f / nk_f32_sqrt_haswell(norm_sq)) : 0.0f;
70
+ char *destination_original = originals + vector_index * original_stride;
71
+ nk_copy_bytes_(destination_original, source_row, depth * element_bytes);
72
+ for (nk_size_t byte_index = depth * element_bytes; byte_index < original_stride; byte_index++)
73
+ destination_original[byte_index] = 0;
74
+ }
75
+ }
76
+
77
+ NK_PUBLIC void nk_maxsim_pack_f32_haswell( //
78
+ nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
79
+
80
+ nk_size_t const element_bytes = sizeof(nk_f32_t);
81
+ nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 32, element_bytes);
82
+
83
+ nk_maxsim_packed_header_t const *header = (nk_maxsim_packed_header_t const *)packed;
84
+ nk_i8_t *quantized_i8 = (nk_i8_t *)((char *)packed + header->offset_i8_data);
85
+ nk_maxsim_vector_metadata_t *metadata = (nk_maxsim_vector_metadata_t *)((char *)packed + header->offset_metadata);
86
+ char *originals = (char *)packed + header->offset_original_data;
87
+ nk_size_t const original_stride = header->original_stride_bytes;
88
+
89
+ for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
90
+ char const *source_row = (char const *)vectors + vector_index * stride;
91
+ nk_f32_t norm_sq;
92
+ nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 79.0f, nk_f32_to_f32_,
93
+ &quantized_i8[vector_index * depth_i8_padded], &metadata[vector_index], &norm_sq);
94
+ metadata[vector_index].inverse_norm_f32 = norm_sq > 0.0f ? (1.0f / nk_f32_sqrt_haswell(norm_sq)) : 0.0f;
95
+ char *destination_original = originals + vector_index * original_stride;
96
+ nk_copy_bytes_(destination_original, source_row, depth * element_bytes);
97
+ for (nk_size_t byte_index = depth * element_bytes; byte_index < original_stride; byte_index++)
98
+ destination_original[byte_index] = 0;
99
+ }
100
+ }
101
+
102
+ NK_PUBLIC void nk_maxsim_pack_f16_haswell( //
103
+ nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
104
+
105
+ nk_size_t const element_bytes = sizeof(nk_f16_t);
106
+ nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 32, element_bytes);
107
+
108
+ nk_maxsim_packed_header_t const *header = (nk_maxsim_packed_header_t const *)packed;
109
+ nk_i8_t *quantized_i8 = (nk_i8_t *)((char *)packed + header->offset_i8_data);
110
+ nk_maxsim_vector_metadata_t *metadata = (nk_maxsim_vector_metadata_t *)((char *)packed + header->offset_metadata);
111
+ char *originals = (char *)packed + header->offset_original_data;
112
+ nk_size_t const original_stride = header->original_stride_bytes;
113
+
114
+ for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
115
+ char const *source_row = (char const *)vectors + vector_index * stride;
116
+ nk_f32_t norm_sq;
117
+ nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 79.0f,
118
+ (nk_maxsim_to_f32_t)nk_f16_to_f32_haswell,
119
+ &quantized_i8[vector_index * depth_i8_padded], &metadata[vector_index], &norm_sq);
120
+ metadata[vector_index].inverse_norm_f32 = norm_sq > 0.0f ? (1.0f / nk_f32_sqrt_haswell(norm_sq)) : 0.0f;
121
+ char *destination_original = originals + vector_index * original_stride;
122
+ nk_copy_bytes_(destination_original, source_row, depth * element_bytes);
123
+ for (nk_size_t byte_index = depth * element_bytes; byte_index < original_stride; byte_index++)
124
+ destination_original[byte_index] = 0;
125
+ }
126
+ }
127
+
128
+ /** @brief Reduces 4 YMM i32x8 accumulators to a single __m128i with 4 horizontal sums. */
129
+ NK_INTERNAL __m128i nk_maxsim_reduce_i32x8x4_haswell_( //
130
+ __m256i accumulator_a_i32x8, __m256i accumulator_b_i32x8, //
131
+ __m256i accumulator_c_i32x8, __m256i accumulator_d_i32x8) {
132
+ // Step 1: 8 -> 4 (extract high 128-bit half and add to low half)
133
+ __m128i sum_a_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(accumulator_a_i32x8),
134
+ _mm256_extracti128_si256(accumulator_a_i32x8, 1));
135
+ __m128i sum_b_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(accumulator_b_i32x8),
136
+ _mm256_extracti128_si256(accumulator_b_i32x8, 1));
137
+ __m128i sum_c_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(accumulator_c_i32x8),
138
+ _mm256_extracti128_si256(accumulator_c_i32x8, 1));
139
+ __m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(accumulator_d_i32x8),
140
+ _mm256_extracti128_si256(accumulator_d_i32x8, 1));
141
+ // Step 2: 4x4 transpose + reduce -> [sum_a, sum_b, sum_c, sum_d]
142
+ __m128i transpose_ab_low_i32x4 = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
143
+ __m128i transpose_cd_low_i32x4 = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
144
+ __m128i transpose_ab_high_i32x4 = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
145
+ __m128i transpose_cd_high_i32x4 = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
146
+ __m128i sum_lane_0_i32x4 = _mm_unpacklo_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
147
+ __m128i sum_lane_1_i32x4 = _mm_unpackhi_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
148
+ __m128i sum_lane_2_i32x4 = _mm_unpacklo_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
149
+ __m128i sum_lane_3_i32x4 = _mm_unpackhi_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
150
+ return _mm_add_epi32(_mm_add_epi32(sum_lane_0_i32x4, sum_lane_1_i32x4),
151
+ _mm_add_epi32(sum_lane_2_i32x4, sum_lane_3_i32x4));
152
+ }
153
+
154
+ /**
155
+ * @brief Factored coarse i8 argmax kernel for Haswell.
156
+ * Uses AVX2 VPMADDUBSW (u8×i8→i16) + VPMADDWD (i16×1→i32) with XOR-0x80 bias.
157
+ * 4Q×4D register tiling with 16 YMM accumulators.
158
+ */
159
+ NK_INTERNAL void nk_maxsim_coarse_argmax_haswell_( //
160
+ nk_i8_t const *query_i8, nk_i8_t const *document_i8, //
161
+ nk_maxsim_vector_metadata_t const *document_metadata, //
162
+ nk_size_t query_count, nk_size_t document_count, //
163
+ nk_size_t depth_i8_padded, nk_u32_t *best_document_indices) {
164
+
165
+ __m256i const xor_mask_u8x32 = _mm256_set1_epi8((char)0x80);
166
+ __m256i const ones_i16x16 = _mm256_set1_epi16(1);
167
+
168
+ // Primary path: 4-query grouping
169
+ nk_size_t query_block_start_index = 0;
170
+ for (; query_block_start_index + 4 <= query_count; query_block_start_index += 4) {
171
+ __m128i running_max_i32x4 = _mm_set1_epi32(NK_I32_MIN);
172
+ __m128i running_argmax_i32x4 = _mm_setzero_si128();
173
+
174
+ // 4Q×4D document blocking
175
+ nk_size_t document_block_start_index = 0;
176
+ for (; document_block_start_index + 4 <= document_count; document_block_start_index += 4) {
177
+ __m256i accumulator_tiles_i32x8[4][4];
178
+ for (nk_size_t query_tile_index = 0; query_tile_index < 4; query_tile_index++)
179
+ for (nk_size_t document_tile_index = 0; document_tile_index < 4; document_tile_index++)
180
+ accumulator_tiles_i32x8[query_tile_index][document_tile_index] = _mm256_setzero_si256();
181
+
182
+ for (nk_size_t depth_index = 0; depth_index < depth_i8_padded; depth_index += 32) {
183
+ __m256i query_biased_u8x32_0 = _mm256_xor_si256(
184
+ _mm256_loadu_si256(
185
+ (__m256i const *)(query_i8 + (query_block_start_index + 0) * depth_i8_padded + depth_index)),
186
+ xor_mask_u8x32);
187
+ __m256i query_biased_u8x32_1 = _mm256_xor_si256(
188
+ _mm256_loadu_si256(
189
+ (__m256i const *)(query_i8 + (query_block_start_index + 1) * depth_i8_padded + depth_index)),
190
+ xor_mask_u8x32);
191
+ __m256i query_biased_u8x32_2 = _mm256_xor_si256(
192
+ _mm256_loadu_si256(
193
+ (__m256i const *)(query_i8 + (query_block_start_index + 2) * depth_i8_padded + depth_index)),
194
+ xor_mask_u8x32);
195
+ __m256i query_biased_u8x32_3 = _mm256_xor_si256(
196
+ _mm256_loadu_si256(
197
+ (__m256i const *)(query_i8 + (query_block_start_index + 3) * depth_i8_padded + depth_index)),
198
+ xor_mask_u8x32);
199
+
200
+ __m256i document_i8x32, products_i16x16, products_i32x8;
201
+
202
+ // Document 0
203
+ document_i8x32 = _mm256_loadu_si256(
204
+ (__m256i const *)(document_i8 + (document_block_start_index + 0) * depth_i8_padded + depth_index));
205
+ products_i16x16 = _mm256_maddubs_epi16(query_biased_u8x32_0, document_i8x32);
206
+ products_i32x8 = _mm256_madd_epi16(products_i16x16, ones_i16x16);
207
+ accumulator_tiles_i32x8[0][0] = _mm256_add_epi32(accumulator_tiles_i32x8[0][0], products_i32x8);
208
+ products_i16x16 = _mm256_maddubs_epi16(query_biased_u8x32_1, document_i8x32);
209
+ products_i32x8 = _mm256_madd_epi16(products_i16x16, ones_i16x16);
210
+ accumulator_tiles_i32x8[1][0] = _mm256_add_epi32(accumulator_tiles_i32x8[1][0], products_i32x8);
211
+ products_i16x16 = _mm256_maddubs_epi16(query_biased_u8x32_2, document_i8x32);
212
+ products_i32x8 = _mm256_madd_epi16(products_i16x16, ones_i16x16);
213
+ accumulator_tiles_i32x8[2][0] = _mm256_add_epi32(accumulator_tiles_i32x8[2][0], products_i32x8);
214
+ products_i16x16 = _mm256_maddubs_epi16(query_biased_u8x32_3, document_i8x32);
215
+ products_i32x8 = _mm256_madd_epi16(products_i16x16, ones_i16x16);
216
+ accumulator_tiles_i32x8[3][0] = _mm256_add_epi32(accumulator_tiles_i32x8[3][0], products_i32x8);
217
+
218
+ // Document 1
219
+ document_i8x32 = _mm256_loadu_si256(
220
+ (__m256i const *)(document_i8 + (document_block_start_index + 1) * depth_i8_padded + depth_index));
221
+ products_i16x16 = _mm256_maddubs_epi16(query_biased_u8x32_0, document_i8x32);
222
+ products_i32x8 = _mm256_madd_epi16(products_i16x16, ones_i16x16);
223
+ accumulator_tiles_i32x8[0][1] = _mm256_add_epi32(accumulator_tiles_i32x8[0][1], products_i32x8);
224
+ products_i16x16 = _mm256_maddubs_epi16(query_biased_u8x32_1, document_i8x32);
225
+ products_i32x8 = _mm256_madd_epi16(products_i16x16, ones_i16x16);
226
+ accumulator_tiles_i32x8[1][1] = _mm256_add_epi32(accumulator_tiles_i32x8[1][1], products_i32x8);
227
+ products_i16x16 = _mm256_maddubs_epi16(query_biased_u8x32_2, document_i8x32);
228
+ products_i32x8 = _mm256_madd_epi16(products_i16x16, ones_i16x16);
229
+ accumulator_tiles_i32x8[2][1] = _mm256_add_epi32(accumulator_tiles_i32x8[2][1], products_i32x8);
230
+ products_i16x16 = _mm256_maddubs_epi16(query_biased_u8x32_3, document_i8x32);
231
+ products_i32x8 = _mm256_madd_epi16(products_i16x16, ones_i16x16);
232
+ accumulator_tiles_i32x8[3][1] = _mm256_add_epi32(accumulator_tiles_i32x8[3][1], products_i32x8);
233
+
234
+ // Document 2
235
+ document_i8x32 = _mm256_loadu_si256(
236
+ (__m256i const *)(document_i8 + (document_block_start_index + 2) * depth_i8_padded + depth_index));
237
+ products_i16x16 = _mm256_maddubs_epi16(query_biased_u8x32_0, document_i8x32);
238
+ products_i32x8 = _mm256_madd_epi16(products_i16x16, ones_i16x16);
239
+ accumulator_tiles_i32x8[0][2] = _mm256_add_epi32(accumulator_tiles_i32x8[0][2], products_i32x8);
240
+ products_i16x16 = _mm256_maddubs_epi16(query_biased_u8x32_1, document_i8x32);
241
+ products_i32x8 = _mm256_madd_epi16(products_i16x16, ones_i16x16);
242
+ accumulator_tiles_i32x8[1][2] = _mm256_add_epi32(accumulator_tiles_i32x8[1][2], products_i32x8);
243
+ products_i16x16 = _mm256_maddubs_epi16(query_biased_u8x32_2, document_i8x32);
244
+ products_i32x8 = _mm256_madd_epi16(products_i16x16, ones_i16x16);
245
+ accumulator_tiles_i32x8[2][2] = _mm256_add_epi32(accumulator_tiles_i32x8[2][2], products_i32x8);
246
+ products_i16x16 = _mm256_maddubs_epi16(query_biased_u8x32_3, document_i8x32);
247
+ products_i32x8 = _mm256_madd_epi16(products_i16x16, ones_i16x16);
248
+ accumulator_tiles_i32x8[3][2] = _mm256_add_epi32(accumulator_tiles_i32x8[3][2], products_i32x8);
249
+
250
+ // Document 3
251
+ document_i8x32 = _mm256_loadu_si256(
252
+ (__m256i const *)(document_i8 + (document_block_start_index + 3) * depth_i8_padded + depth_index));
253
+ products_i16x16 = _mm256_maddubs_epi16(query_biased_u8x32_0, document_i8x32);
254
+ products_i32x8 = _mm256_madd_epi16(products_i16x16, ones_i16x16);
255
+ accumulator_tiles_i32x8[0][3] = _mm256_add_epi32(accumulator_tiles_i32x8[0][3], products_i32x8);
256
+ products_i16x16 = _mm256_maddubs_epi16(query_biased_u8x32_1, document_i8x32);
257
+ products_i32x8 = _mm256_madd_epi16(products_i16x16, ones_i16x16);
258
+ accumulator_tiles_i32x8[1][3] = _mm256_add_epi32(accumulator_tiles_i32x8[1][3], products_i32x8);
259
+ products_i16x16 = _mm256_maddubs_epi16(query_biased_u8x32_2, document_i8x32);
260
+ products_i32x8 = _mm256_madd_epi16(products_i16x16, ones_i16x16);
261
+ accumulator_tiles_i32x8[2][3] = _mm256_add_epi32(accumulator_tiles_i32x8[2][3], products_i32x8);
262
+ products_i16x16 = _mm256_maddubs_epi16(query_biased_u8x32_3, document_i8x32);
263
+ products_i32x8 = _mm256_madd_epi16(products_i16x16, ones_i16x16);
264
+ accumulator_tiles_i32x8[3][3] = _mm256_add_epi32(accumulator_tiles_i32x8[3][3], products_i32x8);
265
+ }
266
+
267
+ // Reduce each query's 4 doc accumulators -> __m128i
268
+ __m128i query_0_coarse_dots_i32x4 = nk_maxsim_reduce_i32x8x4_haswell_(
269
+ accumulator_tiles_i32x8[0][0], accumulator_tiles_i32x8[0][1], accumulator_tiles_i32x8[0][2],
270
+ accumulator_tiles_i32x8[0][3]);
271
+ __m128i query_1_coarse_dots_i32x4 = nk_maxsim_reduce_i32x8x4_haswell_(
272
+ accumulator_tiles_i32x8[1][0], accumulator_tiles_i32x8[1][1], accumulator_tiles_i32x8[1][2],
273
+ accumulator_tiles_i32x8[1][3]);
274
+ __m128i query_2_coarse_dots_i32x4 = nk_maxsim_reduce_i32x8x4_haswell_(
275
+ accumulator_tiles_i32x8[2][0], accumulator_tiles_i32x8[2][1], accumulator_tiles_i32x8[2][2],
276
+ accumulator_tiles_i32x8[2][3]);
277
+ __m128i query_3_coarse_dots_i32x4 = nk_maxsim_reduce_i32x8x4_haswell_(
278
+ accumulator_tiles_i32x8[3][0], accumulator_tiles_i32x8[3][1], accumulator_tiles_i32x8[3][2],
279
+ accumulator_tiles_i32x8[3][3]);
280
+
281
+ // Bias correction: subtract 128 × sum_quantized for each document
282
+ __m128i bias_correction_i32x4 = _mm_set_epi32(
283
+ 128 * document_metadata[document_block_start_index + 3].sum_i8_i32,
284
+ 128 * document_metadata[document_block_start_index + 2].sum_i8_i32,
285
+ 128 * document_metadata[document_block_start_index + 1].sum_i8_i32,
286
+ 128 * document_metadata[document_block_start_index + 0].sum_i8_i32);
287
+ query_0_coarse_dots_i32x4 = _mm_sub_epi32(query_0_coarse_dots_i32x4, bias_correction_i32x4);
288
+ query_1_coarse_dots_i32x4 = _mm_sub_epi32(query_1_coarse_dots_i32x4, bias_correction_i32x4);
289
+ query_2_coarse_dots_i32x4 = _mm_sub_epi32(query_2_coarse_dots_i32x4, bias_correction_i32x4);
290
+ query_3_coarse_dots_i32x4 = _mm_sub_epi32(query_3_coarse_dots_i32x4, bias_correction_i32x4);
291
+
292
+ // 4x4 transpose: [query][doc] -> [doc][query] for vectorized argmax
293
+ __m128i transpose_queries_01_low_i32x4 = _mm_unpacklo_epi32(query_0_coarse_dots_i32x4,
294
+ query_1_coarse_dots_i32x4);
295
+ __m128i transpose_queries_23_low_i32x4 = _mm_unpacklo_epi32(query_2_coarse_dots_i32x4,
296
+ query_3_coarse_dots_i32x4);
297
+ __m128i transpose_queries_01_high_i32x4 = _mm_unpackhi_epi32(query_0_coarse_dots_i32x4,
298
+ query_1_coarse_dots_i32x4);
299
+ __m128i transpose_queries_23_high_i32x4 = _mm_unpackhi_epi32(query_2_coarse_dots_i32x4,
300
+ query_3_coarse_dots_i32x4);
301
+ __m128i document_0_dots_i32x4 = _mm_unpacklo_epi64(transpose_queries_01_low_i32x4,
302
+ transpose_queries_23_low_i32x4);
303
+ __m128i document_1_dots_i32x4 = _mm_unpackhi_epi64(transpose_queries_01_low_i32x4,
304
+ transpose_queries_23_low_i32x4);
305
+ __m128i document_2_dots_i32x4 = _mm_unpacklo_epi64(transpose_queries_01_high_i32x4,
306
+ transpose_queries_23_high_i32x4);
307
+ __m128i document_3_dots_i32x4 = _mm_unpackhi_epi64(transpose_queries_01_high_i32x4,
308
+ transpose_queries_23_high_i32x4);
309
+
310
+ // Branchless SIMD argmax
311
+ __m128i comparison_mask_i32x4, document_index_i32x4;
312
+
313
+ comparison_mask_i32x4 = _mm_cmpgt_epi32(document_0_dots_i32x4, running_max_i32x4);
314
+ document_index_i32x4 = _mm_set1_epi32((int)(document_block_start_index + 0));
315
+ running_max_i32x4 = _mm_blendv_epi8(running_max_i32x4, document_0_dots_i32x4, comparison_mask_i32x4);
316
+ running_argmax_i32x4 = _mm_blendv_epi8(running_argmax_i32x4, document_index_i32x4, comparison_mask_i32x4);
317
+
318
+ comparison_mask_i32x4 = _mm_cmpgt_epi32(document_1_dots_i32x4, running_max_i32x4);
319
+ document_index_i32x4 = _mm_set1_epi32((int)(document_block_start_index + 1));
320
+ running_max_i32x4 = _mm_blendv_epi8(running_max_i32x4, document_1_dots_i32x4, comparison_mask_i32x4);
321
+ running_argmax_i32x4 = _mm_blendv_epi8(running_argmax_i32x4, document_index_i32x4, comparison_mask_i32x4);
322
+
323
+ comparison_mask_i32x4 = _mm_cmpgt_epi32(document_2_dots_i32x4, running_max_i32x4);
324
+ document_index_i32x4 = _mm_set1_epi32((int)(document_block_start_index + 2));
325
+ running_max_i32x4 = _mm_blendv_epi8(running_max_i32x4, document_2_dots_i32x4, comparison_mask_i32x4);
326
+ running_argmax_i32x4 = _mm_blendv_epi8(running_argmax_i32x4, document_index_i32x4, comparison_mask_i32x4);
327
+
328
+ comparison_mask_i32x4 = _mm_cmpgt_epi32(document_3_dots_i32x4, running_max_i32x4);
329
+ document_index_i32x4 = _mm_set1_epi32((int)(document_block_start_index + 3));
330
+ running_max_i32x4 = _mm_blendv_epi8(running_max_i32x4, document_3_dots_i32x4, comparison_mask_i32x4);
331
+ running_argmax_i32x4 = _mm_blendv_epi8(running_argmax_i32x4, document_index_i32x4, comparison_mask_i32x4);
332
+ }
333
+
334
+ // Document tail: 4Q×1D
335
+ for (nk_size_t document_index = document_block_start_index; document_index < document_count; document_index++) {
336
+ nk_i8_t const *document_i8_row = document_i8 + document_index * depth_i8_padded;
337
+
338
+ __m256i accumulator_i32x8_0 = _mm256_setzero_si256();
339
+ __m256i accumulator_i32x8_1 = _mm256_setzero_si256();
340
+ __m256i accumulator_i32x8_2 = _mm256_setzero_si256();
341
+ __m256i accumulator_i32x8_3 = _mm256_setzero_si256();
342
+
343
+ for (nk_size_t depth_index = 0; depth_index < depth_i8_padded; depth_index += 32) {
344
+ __m256i document_i8x32 = _mm256_loadu_si256((__m256i const *)(document_i8_row + depth_index));
345
+ __m256i products_i16x16, products_i32x8;
346
+
347
+ products_i16x16 = _mm256_maddubs_epi16(
348
+ _mm256_xor_si256(
349
+ _mm256_loadu_si256((
350
+ __m256i const *)(query_i8 + (query_block_start_index + 0) * depth_i8_padded + depth_index)),
351
+ xor_mask_u8x32),
352
+ document_i8x32);
353
+ products_i32x8 = _mm256_madd_epi16(products_i16x16, ones_i16x16);
354
+ accumulator_i32x8_0 = _mm256_add_epi32(accumulator_i32x8_0, products_i32x8);
355
+
356
+ products_i16x16 = _mm256_maddubs_epi16(
357
+ _mm256_xor_si256(
358
+ _mm256_loadu_si256((
359
+ __m256i const *)(query_i8 + (query_block_start_index + 1) * depth_i8_padded + depth_index)),
360
+ xor_mask_u8x32),
361
+ document_i8x32);
362
+ products_i32x8 = _mm256_madd_epi16(products_i16x16, ones_i16x16);
363
+ accumulator_i32x8_1 = _mm256_add_epi32(accumulator_i32x8_1, products_i32x8);
364
+
365
+ products_i16x16 = _mm256_maddubs_epi16(
366
+ _mm256_xor_si256(
367
+ _mm256_loadu_si256((
368
+ __m256i const *)(query_i8 + (query_block_start_index + 2) * depth_i8_padded + depth_index)),
369
+ xor_mask_u8x32),
370
+ document_i8x32);
371
+ products_i32x8 = _mm256_madd_epi16(products_i16x16, ones_i16x16);
372
+ accumulator_i32x8_2 = _mm256_add_epi32(accumulator_i32x8_2, products_i32x8);
373
+
374
+ products_i16x16 = _mm256_maddubs_epi16(
375
+ _mm256_xor_si256(
376
+ _mm256_loadu_si256((
377
+ __m256i const *)(query_i8 + (query_block_start_index + 3) * depth_i8_padded + depth_index)),
378
+ xor_mask_u8x32),
379
+ document_i8x32);
380
+ products_i32x8 = _mm256_madd_epi16(products_i16x16, ones_i16x16);
381
+ accumulator_i32x8_3 = _mm256_add_epi32(accumulator_i32x8_3, products_i32x8);
382
+ }
383
+
384
+ __m128i reduced_i32x4 = nk_maxsim_reduce_i32x8x4_haswell_(accumulator_i32x8_0, accumulator_i32x8_1,
385
+ accumulator_i32x8_2, accumulator_i32x8_3);
386
+ nk_i32_t bias_correction_i32 = 128 * document_metadata[document_index].sum_i8_i32;
387
+ __m128i coarse_dots_i32x4 = _mm_sub_epi32(reduced_i32x4, _mm_set1_epi32(bias_correction_i32));
388
+
389
+ __m128i comparison_mask_i32x4 = _mm_cmpgt_epi32(coarse_dots_i32x4, running_max_i32x4);
390
+ __m128i document_index_i32x4 = _mm_set1_epi32((int)document_index);
391
+ running_max_i32x4 = _mm_blendv_epi8(running_max_i32x4, coarse_dots_i32x4, comparison_mask_i32x4);
392
+ running_argmax_i32x4 = _mm_blendv_epi8(running_argmax_i32x4, document_index_i32x4, comparison_mask_i32x4);
393
+ }
394
+
395
+ best_document_indices[query_block_start_index + 0] = (nk_u32_t)_mm_extract_epi32(running_argmax_i32x4, 0);
396
+ best_document_indices[query_block_start_index + 1] = (nk_u32_t)_mm_extract_epi32(running_argmax_i32x4, 1);
397
+ best_document_indices[query_block_start_index + 2] = (nk_u32_t)_mm_extract_epi32(running_argmax_i32x4, 2);
398
+ best_document_indices[query_block_start_index + 3] = (nk_u32_t)_mm_extract_epi32(running_argmax_i32x4, 3);
399
+ }
400
+
401
+ // Query tail: 1Q×1D
402
+ for (nk_size_t query_index = query_block_start_index; query_index < query_count; query_index++) {
403
+ nk_i8_t const *query_i8_row = query_i8 + query_index * depth_i8_padded;
404
+ nk_i32_t running_max_i32 = NK_I32_MIN;
405
+ nk_u32_t running_argmax_u32 = 0;
406
+
407
+ for (nk_size_t document_index = 0; document_index < document_count; document_index++) {
408
+ nk_i8_t const *document_i8_row = document_i8 + document_index * depth_i8_padded;
409
+ __m256i accumulator_i32x8 = _mm256_setzero_si256();
410
+
411
+ for (nk_size_t depth_index = 0; depth_index < depth_i8_padded; depth_index += 32) {
412
+ __m256i document_i8x32 = _mm256_loadu_si256((__m256i const *)(document_i8_row + depth_index));
413
+ __m256i query_biased_u8x32 = _mm256_xor_si256(
414
+ _mm256_loadu_si256((__m256i const *)(query_i8_row + depth_index)), xor_mask_u8x32);
415
+ __m256i products_i16x16 = _mm256_maddubs_epi16(query_biased_u8x32, document_i8x32);
416
+ __m256i products_i32x8 = _mm256_madd_epi16(products_i16x16, ones_i16x16);
417
+ accumulator_i32x8 = _mm256_add_epi32(accumulator_i32x8, products_i32x8);
418
+ }
419
+
420
+ // Horizontal sum of 8 i32 lanes
421
+ __m128i sum_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(accumulator_i32x8),
422
+ _mm256_extracti128_si256(accumulator_i32x8, 1));
423
+ sum_i32x4 = _mm_add_epi32(sum_i32x4, _mm_shuffle_epi32(sum_i32x4, 0x4E)); // 01001110
424
+ sum_i32x4 = _mm_add_epi32(sum_i32x4, _mm_shuffle_epi32(sum_i32x4, 0xB1)); // 10110001
425
+ nk_i32_t coarse_dot_i32 = _mm_extract_epi32(sum_i32x4, 0) -
426
+ 128 * document_metadata[document_index].sum_i8_i32;
427
+
428
+ if (coarse_dot_i32 > running_max_i32) {
429
+ running_max_i32 = coarse_dot_i32;
430
+ running_argmax_u32 = (nk_u32_t)document_index;
431
+ }
432
+ }
433
+
434
+ best_document_indices[query_index] = running_argmax_u32;
435
+ }
436
+ }
437
+
438
+ NK_PUBLIC void nk_maxsim_packed_bf16_haswell( //
439
+ void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
440
+ nk_size_t depth, nk_f32_t *result) {
441
+
442
+ nk_maxsim_packed_regions_t regions = nk_maxsim_extract_packed_regions_(query_packed, document_packed);
443
+ nk_f64_t total_angular_distance = 0.0;
444
+
445
+ for (nk_size_t chunk_start = 0; chunk_start < query_count; chunk_start += 256) {
446
+ nk_size_t chunk_size = query_count - chunk_start < 256 ? query_count - chunk_start : 256;
447
+ nk_u32_t best_document_indices[256];
448
+
449
+ nk_maxsim_coarse_argmax_haswell_(regions.query_quantized + chunk_start * regions.depth_i8_padded,
450
+ regions.document_quantized, regions.document_metadata, chunk_size,
451
+ document_count, regions.depth_i8_padded, best_document_indices);
452
+
453
+ for (nk_size_t query_index = 0; query_index < chunk_size; query_index++) {
454
+ nk_u32_t best_document_index = best_document_indices[query_index];
455
+ nk_f32_t dot_result;
456
+ nk_dot_bf16((nk_bf16_t const *)(regions.query_originals +
457
+ (chunk_start + query_index) * regions.query_original_stride),
458
+ (nk_bf16_t const *)(regions.document_originals +
459
+ best_document_index * regions.document_original_stride),
460
+ depth, &dot_result);
461
+ nk_f32_t cosine = dot_result * regions.query_metadata[chunk_start + query_index].inverse_norm_f32 *
462
+ regions.document_metadata[best_document_index].inverse_norm_f32;
463
+ nk_f32_t angular = 1.0f - cosine;
464
+ if (angular < 0.0f) angular = 0.0f;
465
+ total_angular_distance += (nk_f64_t)angular;
466
+ }
467
+ }
468
+
469
+ *result = (nk_f32_t)total_angular_distance;
470
+ }
471
+
472
+ NK_PUBLIC void nk_maxsim_packed_f32_haswell( //
473
+ void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
474
+ nk_size_t depth, nk_f64_t *result) {
475
+
476
+ nk_maxsim_packed_regions_t regions = nk_maxsim_extract_packed_regions_(query_packed, document_packed);
477
+ nk_f64_t total_angular_distance = 0.0;
478
+
479
+ for (nk_size_t chunk_start = 0; chunk_start < query_count; chunk_start += 256) {
480
+ nk_size_t chunk_size = query_count - chunk_start < 256 ? query_count - chunk_start : 256;
481
+ nk_u32_t best_document_indices[256];
482
+
483
+ nk_maxsim_coarse_argmax_haswell_(regions.query_quantized + chunk_start * regions.depth_i8_padded,
484
+ regions.document_quantized, regions.document_metadata, chunk_size,
485
+ document_count, regions.depth_i8_padded, best_document_indices);
486
+
487
+ for (nk_size_t query_index = 0; query_index < chunk_size; query_index++) {
488
+ nk_u32_t best_document_index = best_document_indices[query_index];
489
+ nk_f64_t dot_result;
490
+ nk_dot_f32(
491
+ (nk_f32_t const *)(regions.query_originals +
492
+ (chunk_start + query_index) * regions.query_original_stride),
493
+ (nk_f32_t const *)(regions.document_originals + best_document_index * regions.document_original_stride),
494
+ depth, &dot_result);
495
+ nk_f64_t cosine = dot_result *
496
+ (nk_f64_t)regions.query_metadata[chunk_start + query_index].inverse_norm_f32 *
497
+ (nk_f64_t)regions.document_metadata[best_document_index].inverse_norm_f32;
498
+ nk_f64_t angular = 1.0 - cosine;
499
+ if (angular < 0.0) angular = 0.0;
500
+ total_angular_distance += angular;
501
+ }
502
+ }
503
+
504
+ *result = total_angular_distance;
505
+ }
506
+
507
+ NK_PUBLIC void nk_maxsim_packed_f16_haswell( //
508
+ void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
509
+ nk_size_t depth, nk_f32_t *result) {
510
+
511
+ nk_maxsim_packed_regions_t regions = nk_maxsim_extract_packed_regions_(query_packed, document_packed);
512
+ nk_f64_t total_angular_distance = 0.0;
513
+
514
+ for (nk_size_t chunk_start = 0; chunk_start < query_count; chunk_start += 256) {
515
+ nk_size_t chunk_size = query_count - chunk_start < 256 ? query_count - chunk_start : 256;
516
+ nk_u32_t best_document_indices[256];
517
+
518
+ nk_maxsim_coarse_argmax_haswell_(regions.query_quantized + chunk_start * regions.depth_i8_padded,
519
+ regions.document_quantized, regions.document_metadata, chunk_size,
520
+ document_count, regions.depth_i8_padded, best_document_indices);
521
+
522
+ for (nk_size_t query_index = 0; query_index < chunk_size; query_index++) {
523
+ nk_u32_t best_document_index = best_document_indices[query_index];
524
+ nk_f32_t dot_result;
525
+ nk_dot_f16(
526
+ (nk_f16_t const *)(regions.query_originals +
527
+ (chunk_start + query_index) * regions.query_original_stride),
528
+ (nk_f16_t const *)(regions.document_originals + best_document_index * regions.document_original_stride),
529
+ depth, &dot_result);
530
+ nk_f32_t cosine = dot_result * regions.query_metadata[chunk_start + query_index].inverse_norm_f32 *
531
+ regions.document_metadata[best_document_index].inverse_norm_f32;
532
+ nk_f32_t angular = 1.0f - cosine;
533
+ if (angular < 0.0f) angular = 0.0f;
534
+ total_angular_distance += (nk_f64_t)angular;
535
+ }
536
+ }
537
+
538
+ *result = (nk_f32_t)total_angular_distance;
539
+ }
540
+
541
+ #if defined(__clang__)
542
+ #pragma clang attribute pop
543
+ #elif defined(__GNUC__)
544
+ #pragma GCC pop_options
545
+ #endif
546
+
547
+ #if defined(__cplusplus)
548
+ } // extern "C"
549
+ #endif
550
+
551
+ #endif // NK_TARGET_HASWELL
552
+ #endif // NK_TARGET_X86_
553
+ #endif // NK_MAXSIM_HASWELL_H