numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,480 @@
1
+ /**
2
+ * @brief SIMD-accelerated MaxSim (ColBERT late-interaction) for Ice Lake.
3
+ * @file include/numkong/maxsim/icelake.h
4
+ * @author Ash Vardanian
5
+ * @date February 28, 2026
6
+ *
7
+ * @sa include/numkong/maxsim.h
8
+ *
9
+ * Uses AVX-512 VNNI (VPDPBUSD) for coarse i8 screening. The coarse argmax kernel and reduce helper
10
+ * are shared with genoa.h — genoa.h imports them from this file for its bf16 compute path.
11
+ *
12
+ * VPDPBUSD computes 4 groups of (u8 x i8) -> i32 per 128-bit lane, processing 64 i8 pairs
13
+ * per ZMM register operation. Bias correction via XOR with 0x80 converts signed queries
14
+ * to unsigned, then subtracts 128 * sum(document_i8) after the depth loop.
15
+ *
16
+ * 4x4 register tiling: 4 queries x 4 documents = 16 ZMM accumulators per depth loop.
17
+ * Each document load is amortized across 4 VPDPBUSDs, and each query load across 4 documents.
18
+ *
19
+ * Intrinsic Instruction Icelake Genoa (Zen4)
20
+ * _mm512_dpbusd_epi32 VPDPBUSD 5cy @ p0 4cy @ p01 (512-bit)
21
+ */
22
+ #ifndef NK_MAXSIM_ICELAKE_H
23
+ #define NK_MAXSIM_ICELAKE_H
24
+
25
+ #if NK_TARGET_X86_
26
+ #if NK_TARGET_ICELAKE
27
+
28
+ #include "numkong/types.h"
29
+ #include "numkong/maxsim/serial.h" // `nk_maxsim_packed_header_t`
30
+ #include "numkong/dot.h" // `nk_dot_f32`, `nk_dot_f16`
31
+ #include "numkong/cast/haswell.h" // `nk_f16_to_f32_haswell`
32
+ #include "numkong/spatial/haswell.h" // `nk_f32_sqrt_haswell`
33
+
34
+ #if defined(__cplusplus)
35
+ extern "C" {
36
+ #endif
37
+
38
+ #if defined(__clang__)
39
+ #pragma clang attribute push( \
40
+ __attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512vnni,f16c,fma,bmi,bmi2"))), \
41
+ apply_to = function)
42
+ #elif defined(__GNUC__)
43
+ #pragma GCC push_options
44
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vnni", "f16c", "fma", "bmi", "bmi2")
45
+ #endif
46
+
47
+ #pragma region Single Precision Floats
48
+
49
+ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f32_icelake(nk_size_t vector_count, nk_size_t depth) {
50
+ return nk_maxsim_packed_size_(vector_count, depth, sizeof(nk_f32_t), 64);
51
+ }
52
+
53
+ NK_PUBLIC void nk_maxsim_pack_f32_icelake( //
54
+ nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
55
+
56
+ nk_size_t const element_bytes = sizeof(nk_f32_t);
57
+ nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 64, element_bytes);
58
+
59
+ nk_maxsim_packed_header_t const *header = (nk_maxsim_packed_header_t const *)packed;
60
+ nk_i8_t *quantized_i8 = (nk_i8_t *)((char *)packed + header->offset_i8_data);
61
+ nk_maxsim_vector_metadata_t *metadata = (nk_maxsim_vector_metadata_t *)((char *)packed + header->offset_metadata);
62
+ char *originals = (char *)packed + header->offset_original_data;
63
+ nk_size_t const original_stride = header->original_stride_bytes;
64
+
65
+ for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
66
+ char const *source_row = (char const *)vectors + vector_index * stride;
67
+ nk_f32_t norm_sq;
68
+ nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f, nk_f32_to_f32_,
69
+ &quantized_i8[vector_index * depth_i8_padded], &metadata[vector_index], &norm_sq);
70
+ metadata[vector_index].inverse_norm_f32 = norm_sq > 0.0f ? (1.0f / nk_f32_sqrt_haswell(norm_sq)) : 0.0f;
71
+ char *destination_original = originals + vector_index * original_stride;
72
+ nk_copy_bytes_(destination_original, source_row, depth * element_bytes);
73
+ for (nk_size_t byte_index = depth * element_bytes; byte_index < original_stride; byte_index++)
74
+ destination_original[byte_index] = 0;
75
+ }
76
+ }
77
+
78
+ #pragma endregion
79
+
80
+ #pragma region Half Precision Floats
81
+
82
+ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_icelake(nk_size_t vector_count, nk_size_t depth) {
83
+ return nk_maxsim_packed_size_(vector_count, depth, sizeof(nk_f16_t), 64);
84
+ }
85
+
86
+ NK_PUBLIC void nk_maxsim_pack_f16_icelake( //
87
+ nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
88
+
89
+ nk_size_t const element_bytes = sizeof(nk_f16_t);
90
+ nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 64, element_bytes);
91
+
92
+ nk_maxsim_packed_header_t const *header = (nk_maxsim_packed_header_t const *)packed;
93
+ nk_i8_t *quantized_i8 = (nk_i8_t *)((char *)packed + header->offset_i8_data);
94
+ nk_maxsim_vector_metadata_t *metadata = (nk_maxsim_vector_metadata_t *)((char *)packed + header->offset_metadata);
95
+ char *originals = (char *)packed + header->offset_original_data;
96
+ nk_size_t const original_stride = header->original_stride_bytes;
97
+
98
+ for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
99
+ char const *source_row = (char const *)vectors + vector_index * stride;
100
+ nk_f32_t norm_sq;
101
+ nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f,
102
+ (nk_maxsim_to_f32_t)nk_f16_to_f32_haswell,
103
+ &quantized_i8[vector_index * depth_i8_padded], &metadata[vector_index], &norm_sq);
104
+ metadata[vector_index].inverse_norm_f32 = norm_sq > 0.0f ? (1.0f / nk_f32_sqrt_haswell(norm_sq)) : 0.0f;
105
+ char *destination_original = originals + vector_index * original_stride;
106
+ nk_copy_bytes_(destination_original, source_row, depth * element_bytes);
107
+ for (nk_size_t byte_index = depth * element_bytes; byte_index < original_stride; byte_index++)
108
+ destination_original[byte_index] = 0;
109
+ }
110
+ }
111
+
112
+ #pragma endregion
113
+
114
+ #pragma region Coarse Argmax
115
+
116
+ /** @brief Reduces 4 ZMM i32x16 accumulators to a single __m128i with 4 horizontal sums. */
117
+ NK_INTERNAL __m128i nk_maxsim_reduce_i32x16x4_icelake_( //
118
+ __m512i accumulator_a_i32x16, __m512i accumulator_b_i32x16, //
119
+ __m512i accumulator_c_i32x16, __m512i accumulator_d_i32x16) {
120
+ // Step 1: 16 -> 8 (extract high 256-bit half and add to low half)
121
+ __m256i sum_a_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(accumulator_a_i32x16),
122
+ _mm512_extracti32x8_epi32(accumulator_a_i32x16, 1));
123
+ __m256i sum_b_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(accumulator_b_i32x16),
124
+ _mm512_extracti32x8_epi32(accumulator_b_i32x16, 1));
125
+ __m256i sum_c_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(accumulator_c_i32x16),
126
+ _mm512_extracti32x8_epi32(accumulator_c_i32x16, 1));
127
+ __m256i sum_d_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(accumulator_d_i32x16),
128
+ _mm512_extracti32x8_epi32(accumulator_d_i32x16, 1));
129
+ // Step 2: 8 -> 4 (extract high 128-bit half and add to low half)
130
+ __m128i sum_a_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_a_i32x8), _mm256_extracti128_si256(sum_a_i32x8, 1));
131
+ __m128i sum_b_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_b_i32x8), _mm256_extracti128_si256(sum_b_i32x8, 1));
132
+ __m128i sum_c_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_c_i32x8), _mm256_extracti128_si256(sum_c_i32x8, 1));
133
+ __m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_d_i32x8), _mm256_extracti128_si256(sum_d_i32x8, 1));
134
+ // Step 3: 4x4 transpose + reduce -> [sum_a, sum_b, sum_c, sum_d]
135
+ __m128i transpose_ab_low_i32x4 = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
136
+ __m128i transpose_cd_low_i32x4 = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
137
+ __m128i transpose_ab_high_i32x4 = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
138
+ __m128i transpose_cd_high_i32x4 = _mm_unpackhi_epi32(sum_c_i32x4, sum_d_i32x4);
139
+ __m128i sum_lane_0_i32x4 = _mm_unpacklo_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
140
+ __m128i sum_lane_1_i32x4 = _mm_unpackhi_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
141
+ __m128i sum_lane_2_i32x4 = _mm_unpacklo_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
142
+ __m128i sum_lane_3_i32x4 = _mm_unpackhi_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
143
+ return _mm_add_epi32(_mm_add_epi32(sum_lane_0_i32x4, sum_lane_1_i32x4),
144
+ _mm_add_epi32(sum_lane_2_i32x4, sum_lane_3_i32x4));
145
+ }
146
+
147
+ /**
148
+ * @brief Factored coarse i8 argmax kernel for Ice Lake / Genoa.
149
+ * Uses AVX-512 VNNI VPDPBUSD with XOR-0x80 bias and 128*sum_quantized correction.
150
+ */
151
+ NK_INTERNAL void nk_maxsim_coarse_argmax_icelake_( //
152
+ nk_i8_t const *query_i8, nk_i8_t const *document_i8, //
153
+ nk_maxsim_vector_metadata_t const *document_metadata, //
154
+ nk_size_t query_count, nk_size_t document_count, //
155
+ nk_size_t depth_i8_padded, nk_u32_t *best_document_indices) {
156
+
157
+ __m512i const xor_mask_u8x64 = _mm512_set1_epi8((char)0x80);
158
+
159
+ // Primary path: 4-query grouping
160
+ nk_size_t query_block_start_index = 0;
161
+ for (; query_block_start_index + 4 <= query_count; query_block_start_index += 4) {
162
+ __m128i running_max_i32x4 = _mm_set1_epi32(NK_I32_MIN);
163
+ __m128i running_argmax_i32x4 = _mm_setzero_si128();
164
+
165
+ // 4x4 document blocking
166
+ nk_size_t document_block_start_index = 0;
167
+ for (; document_block_start_index + 4 <= document_count; document_block_start_index += 4) {
168
+ __m512i accumulator_tiles_i32x16[4][4];
169
+ for (nk_size_t query_tile_index = 0; query_tile_index < 4; query_tile_index++)
170
+ for (nk_size_t document_tile_index = 0; document_tile_index < 4; document_tile_index++)
171
+ accumulator_tiles_i32x16[query_tile_index][document_tile_index] = _mm512_setzero_si512();
172
+
173
+ for (nk_size_t depth_index = 0; depth_index < depth_i8_padded; depth_index += 64) {
174
+ __m512i query_biased_u8x64_0 = _mm512_xor_si512(
175
+ _mm512_loadu_si512(
176
+ (__m512i const *)(query_i8 + (query_block_start_index + 0) * depth_i8_padded + depth_index)),
177
+ xor_mask_u8x64);
178
+ __m512i query_biased_u8x64_1 = _mm512_xor_si512(
179
+ _mm512_loadu_si512(
180
+ (__m512i const *)(query_i8 + (query_block_start_index + 1) * depth_i8_padded + depth_index)),
181
+ xor_mask_u8x64);
182
+ __m512i query_biased_u8x64_2 = _mm512_xor_si512(
183
+ _mm512_loadu_si512(
184
+ (__m512i const *)(query_i8 + (query_block_start_index + 2) * depth_i8_padded + depth_index)),
185
+ xor_mask_u8x64);
186
+ __m512i query_biased_u8x64_3 = _mm512_xor_si512(
187
+ _mm512_loadu_si512(
188
+ (__m512i const *)(query_i8 + (query_block_start_index + 3) * depth_i8_padded + depth_index)),
189
+ xor_mask_u8x64);
190
+
191
+ __m512i document_i8x64;
192
+
193
+ document_i8x64 = _mm512_loadu_si512(
194
+ (__m512i const *)(document_i8 + (document_block_start_index + 0) * depth_i8_padded + depth_index));
195
+ accumulator_tiles_i32x16[0][0] = _mm512_dpbusd_epi32(accumulator_tiles_i32x16[0][0],
196
+ query_biased_u8x64_0, document_i8x64);
197
+ accumulator_tiles_i32x16[1][0] = _mm512_dpbusd_epi32(accumulator_tiles_i32x16[1][0],
198
+ query_biased_u8x64_1, document_i8x64);
199
+ accumulator_tiles_i32x16[2][0] = _mm512_dpbusd_epi32(accumulator_tiles_i32x16[2][0],
200
+ query_biased_u8x64_2, document_i8x64);
201
+ accumulator_tiles_i32x16[3][0] = _mm512_dpbusd_epi32(accumulator_tiles_i32x16[3][0],
202
+ query_biased_u8x64_3, document_i8x64);
203
+
204
+ document_i8x64 = _mm512_loadu_si512(
205
+ (__m512i const *)(document_i8 + (document_block_start_index + 1) * depth_i8_padded + depth_index));
206
+ accumulator_tiles_i32x16[0][1] = _mm512_dpbusd_epi32(accumulator_tiles_i32x16[0][1],
207
+ query_biased_u8x64_0, document_i8x64);
208
+ accumulator_tiles_i32x16[1][1] = _mm512_dpbusd_epi32(accumulator_tiles_i32x16[1][1],
209
+ query_biased_u8x64_1, document_i8x64);
210
+ accumulator_tiles_i32x16[2][1] = _mm512_dpbusd_epi32(accumulator_tiles_i32x16[2][1],
211
+ query_biased_u8x64_2, document_i8x64);
212
+ accumulator_tiles_i32x16[3][1] = _mm512_dpbusd_epi32(accumulator_tiles_i32x16[3][1],
213
+ query_biased_u8x64_3, document_i8x64);
214
+
215
+ document_i8x64 = _mm512_loadu_si512(
216
+ (__m512i const *)(document_i8 + (document_block_start_index + 2) * depth_i8_padded + depth_index));
217
+ accumulator_tiles_i32x16[0][2] = _mm512_dpbusd_epi32(accumulator_tiles_i32x16[0][2],
218
+ query_biased_u8x64_0, document_i8x64);
219
+ accumulator_tiles_i32x16[1][2] = _mm512_dpbusd_epi32(accumulator_tiles_i32x16[1][2],
220
+ query_biased_u8x64_1, document_i8x64);
221
+ accumulator_tiles_i32x16[2][2] = _mm512_dpbusd_epi32(accumulator_tiles_i32x16[2][2],
222
+ query_biased_u8x64_2, document_i8x64);
223
+ accumulator_tiles_i32x16[3][2] = _mm512_dpbusd_epi32(accumulator_tiles_i32x16[3][2],
224
+ query_biased_u8x64_3, document_i8x64);
225
+
226
+ document_i8x64 = _mm512_loadu_si512(
227
+ (__m512i const *)(document_i8 + (document_block_start_index + 3) * depth_i8_padded + depth_index));
228
+ accumulator_tiles_i32x16[0][3] = _mm512_dpbusd_epi32(accumulator_tiles_i32x16[0][3],
229
+ query_biased_u8x64_0, document_i8x64);
230
+ accumulator_tiles_i32x16[1][3] = _mm512_dpbusd_epi32(accumulator_tiles_i32x16[1][3],
231
+ query_biased_u8x64_1, document_i8x64);
232
+ accumulator_tiles_i32x16[2][3] = _mm512_dpbusd_epi32(accumulator_tiles_i32x16[2][3],
233
+ query_biased_u8x64_2, document_i8x64);
234
+ accumulator_tiles_i32x16[3][3] = _mm512_dpbusd_epi32(accumulator_tiles_i32x16[3][3],
235
+ query_biased_u8x64_3, document_i8x64);
236
+ }
237
+
238
+ __m128i query_0_coarse_dots_i32x4 = nk_maxsim_reduce_i32x16x4_icelake_(
239
+ accumulator_tiles_i32x16[0][0], accumulator_tiles_i32x16[0][1], accumulator_tiles_i32x16[0][2],
240
+ accumulator_tiles_i32x16[0][3]);
241
+ __m128i query_1_coarse_dots_i32x4 = nk_maxsim_reduce_i32x16x4_icelake_(
242
+ accumulator_tiles_i32x16[1][0], accumulator_tiles_i32x16[1][1], accumulator_tiles_i32x16[1][2],
243
+ accumulator_tiles_i32x16[1][3]);
244
+ __m128i query_2_coarse_dots_i32x4 = nk_maxsim_reduce_i32x16x4_icelake_(
245
+ accumulator_tiles_i32x16[2][0], accumulator_tiles_i32x16[2][1], accumulator_tiles_i32x16[2][2],
246
+ accumulator_tiles_i32x16[2][3]);
247
+ __m128i query_3_coarse_dots_i32x4 = nk_maxsim_reduce_i32x16x4_icelake_(
248
+ accumulator_tiles_i32x16[3][0], accumulator_tiles_i32x16[3][1], accumulator_tiles_i32x16[3][2],
249
+ accumulator_tiles_i32x16[3][3]);
250
+
251
+ __m128i bias_correction_i32x4 = _mm_set_epi32(
252
+ 128 * document_metadata[document_block_start_index + 3].sum_i8_i32,
253
+ 128 * document_metadata[document_block_start_index + 2].sum_i8_i32,
254
+ 128 * document_metadata[document_block_start_index + 1].sum_i8_i32,
255
+ 128 * document_metadata[document_block_start_index + 0].sum_i8_i32);
256
+ query_0_coarse_dots_i32x4 = _mm_sub_epi32(query_0_coarse_dots_i32x4, bias_correction_i32x4);
257
+ query_1_coarse_dots_i32x4 = _mm_sub_epi32(query_1_coarse_dots_i32x4, bias_correction_i32x4);
258
+ query_2_coarse_dots_i32x4 = _mm_sub_epi32(query_2_coarse_dots_i32x4, bias_correction_i32x4);
259
+ query_3_coarse_dots_i32x4 = _mm_sub_epi32(query_3_coarse_dots_i32x4, bias_correction_i32x4);
260
+
261
+ // 4x4 transpose: [query][doc] -> [doc][query] for vectorized argmax
262
+ __m128i transpose_queries_01_low_i32x4 = _mm_unpacklo_epi32(query_0_coarse_dots_i32x4,
263
+ query_1_coarse_dots_i32x4);
264
+ __m128i transpose_queries_23_low_i32x4 = _mm_unpacklo_epi32(query_2_coarse_dots_i32x4,
265
+ query_3_coarse_dots_i32x4);
266
+ __m128i transpose_queries_01_high_i32x4 = _mm_unpackhi_epi32(query_0_coarse_dots_i32x4,
267
+ query_1_coarse_dots_i32x4);
268
+ __m128i transpose_queries_23_high_i32x4 = _mm_unpackhi_epi32(query_2_coarse_dots_i32x4,
269
+ query_3_coarse_dots_i32x4);
270
+ __m128i document_0_dots_i32x4 = _mm_unpacklo_epi64(transpose_queries_01_low_i32x4,
271
+ transpose_queries_23_low_i32x4);
272
+ __m128i document_1_dots_i32x4 = _mm_unpackhi_epi64(transpose_queries_01_low_i32x4,
273
+ transpose_queries_23_low_i32x4);
274
+ __m128i document_2_dots_i32x4 = _mm_unpacklo_epi64(transpose_queries_01_high_i32x4,
275
+ transpose_queries_23_high_i32x4);
276
+ __m128i document_3_dots_i32x4 = _mm_unpackhi_epi64(transpose_queries_01_high_i32x4,
277
+ transpose_queries_23_high_i32x4);
278
+
279
+ __m128i comparison_mask_i32x4, document_index_i32x4;
280
+
281
+ comparison_mask_i32x4 = _mm_cmpgt_epi32(document_0_dots_i32x4, running_max_i32x4);
282
+ document_index_i32x4 = _mm_set1_epi32((int)(document_block_start_index + 0));
283
+ running_max_i32x4 = _mm_blendv_epi8(running_max_i32x4, document_0_dots_i32x4, comparison_mask_i32x4);
284
+ running_argmax_i32x4 = _mm_blendv_epi8(running_argmax_i32x4, document_index_i32x4, comparison_mask_i32x4);
285
+
286
+ comparison_mask_i32x4 = _mm_cmpgt_epi32(document_1_dots_i32x4, running_max_i32x4);
287
+ document_index_i32x4 = _mm_set1_epi32((int)(document_block_start_index + 1));
288
+ running_max_i32x4 = _mm_blendv_epi8(running_max_i32x4, document_1_dots_i32x4, comparison_mask_i32x4);
289
+ running_argmax_i32x4 = _mm_blendv_epi8(running_argmax_i32x4, document_index_i32x4, comparison_mask_i32x4);
290
+
291
+ comparison_mask_i32x4 = _mm_cmpgt_epi32(document_2_dots_i32x4, running_max_i32x4);
292
+ document_index_i32x4 = _mm_set1_epi32((int)(document_block_start_index + 2));
293
+ running_max_i32x4 = _mm_blendv_epi8(running_max_i32x4, document_2_dots_i32x4, comparison_mask_i32x4);
294
+ running_argmax_i32x4 = _mm_blendv_epi8(running_argmax_i32x4, document_index_i32x4, comparison_mask_i32x4);
295
+
296
+ comparison_mask_i32x4 = _mm_cmpgt_epi32(document_3_dots_i32x4, running_max_i32x4);
297
+ document_index_i32x4 = _mm_set1_epi32((int)(document_block_start_index + 3));
298
+ running_max_i32x4 = _mm_blendv_epi8(running_max_i32x4, document_3_dots_i32x4, comparison_mask_i32x4);
299
+ running_argmax_i32x4 = _mm_blendv_epi8(running_argmax_i32x4, document_index_i32x4, comparison_mask_i32x4);
300
+ }
301
+
302
+ // Document tail: 4Q×1D
303
+ for (nk_size_t document_index = document_block_start_index; document_index < document_count; document_index++) {
304
+ nk_i8_t const *document_i8_row = document_i8 + document_index * depth_i8_padded;
305
+
306
+ __m512i accumulator_i32x16_0 = _mm512_setzero_si512();
307
+ __m512i accumulator_i32x16_1 = _mm512_setzero_si512();
308
+ __m512i accumulator_i32x16_2 = _mm512_setzero_si512();
309
+ __m512i accumulator_i32x16_3 = _mm512_setzero_si512();
310
+
311
+ for (nk_size_t depth_index = 0; depth_index < depth_i8_padded; depth_index += 64) {
312
+ __m512i document_i8x64 = _mm512_loadu_si512((__m512i const *)(document_i8_row + depth_index));
313
+
314
+ accumulator_i32x16_0 = _mm512_dpbusd_epi32(
315
+ accumulator_i32x16_0,
316
+ _mm512_xor_si512(
317
+ _mm512_loadu_si512((
318
+ __m512i const *)(query_i8 + (query_block_start_index + 0) * depth_i8_padded + depth_index)),
319
+ xor_mask_u8x64),
320
+ document_i8x64);
321
+ accumulator_i32x16_1 = _mm512_dpbusd_epi32(
322
+ accumulator_i32x16_1,
323
+ _mm512_xor_si512(
324
+ _mm512_loadu_si512((
325
+ __m512i const *)(query_i8 + (query_block_start_index + 1) * depth_i8_padded + depth_index)),
326
+ xor_mask_u8x64),
327
+ document_i8x64);
328
+ accumulator_i32x16_2 = _mm512_dpbusd_epi32(
329
+ accumulator_i32x16_2,
330
+ _mm512_xor_si512(
331
+ _mm512_loadu_si512((
332
+ __m512i const *)(query_i8 + (query_block_start_index + 2) * depth_i8_padded + depth_index)),
333
+ xor_mask_u8x64),
334
+ document_i8x64);
335
+ accumulator_i32x16_3 = _mm512_dpbusd_epi32(
336
+ accumulator_i32x16_3,
337
+ _mm512_xor_si512(
338
+ _mm512_loadu_si512((
339
+ __m512i const *)(query_i8 + (query_block_start_index + 3) * depth_i8_padded + depth_index)),
340
+ xor_mask_u8x64),
341
+ document_i8x64);
342
+ }
343
+
344
+ nk_i32_t bias_correction_i32 = 128 * document_metadata[document_index].sum_i8_i32;
345
+ __m128i coarse_dots_i32x4 = _mm_set_epi32(
346
+ _mm512_reduce_add_epi32(accumulator_i32x16_3) - bias_correction_i32,
347
+ _mm512_reduce_add_epi32(accumulator_i32x16_2) - bias_correction_i32,
348
+ _mm512_reduce_add_epi32(accumulator_i32x16_1) - bias_correction_i32,
349
+ _mm512_reduce_add_epi32(accumulator_i32x16_0) - bias_correction_i32);
350
+
351
+ __m128i comparison_mask_i32x4 = _mm_cmpgt_epi32(coarse_dots_i32x4, running_max_i32x4);
352
+ __m128i document_index_i32x4 = _mm_set1_epi32((int)document_index);
353
+ running_max_i32x4 = _mm_blendv_epi8(running_max_i32x4, coarse_dots_i32x4, comparison_mask_i32x4);
354
+ running_argmax_i32x4 = _mm_blendv_epi8(running_argmax_i32x4, document_index_i32x4, comparison_mask_i32x4);
355
+ }
356
+
357
+ best_document_indices[query_block_start_index + 0] = (nk_u32_t)_mm_extract_epi32(running_argmax_i32x4, 0);
358
+ best_document_indices[query_block_start_index + 1] = (nk_u32_t)_mm_extract_epi32(running_argmax_i32x4, 1);
359
+ best_document_indices[query_block_start_index + 2] = (nk_u32_t)_mm_extract_epi32(running_argmax_i32x4, 2);
360
+ best_document_indices[query_block_start_index + 3] = (nk_u32_t)_mm_extract_epi32(running_argmax_i32x4, 3);
361
+ }
362
+
363
+ // Query tail: 1Q×1D
364
+ for (nk_size_t query_index = query_block_start_index; query_index < query_count; query_index++) {
365
+ nk_i8_t const *query_i8_row = query_i8 + query_index * depth_i8_padded;
366
+ nk_i32_t running_max_i32 = NK_I32_MIN;
367
+ nk_u32_t running_argmax_u32 = 0;
368
+
369
+ for (nk_size_t document_index = 0; document_index < document_count; document_index++) {
370
+ nk_i8_t const *document_i8_row = document_i8 + document_index * depth_i8_padded;
371
+ __m512i accumulator_i32x16 = _mm512_setzero_si512();
372
+
373
+ for (nk_size_t depth_index = 0; depth_index < depth_i8_padded; depth_index += 64) {
374
+ __m512i document_i8x64 = _mm512_loadu_si512((__m512i const *)(document_i8_row + depth_index));
375
+ __m512i query_biased_u8x64 = _mm512_xor_si512(
376
+ _mm512_loadu_si512((__m512i const *)(query_i8_row + depth_index)), xor_mask_u8x64);
377
+ accumulator_i32x16 = _mm512_dpbusd_epi32(accumulator_i32x16, query_biased_u8x64, document_i8x64);
378
+ }
379
+
380
+ nk_i32_t coarse_dot_i32 = _mm512_reduce_add_epi32(accumulator_i32x16) -
381
+ 128 * document_metadata[document_index].sum_i8_i32;
382
+
383
+ if (coarse_dot_i32 > running_max_i32) {
384
+ running_max_i32 = coarse_dot_i32;
385
+ running_argmax_u32 = (nk_u32_t)document_index;
386
+ }
387
+ }
388
+
389
+ best_document_indices[query_index] = running_argmax_u32;
390
+ }
391
+ }
392
+
393
+ #pragma endregion
394
+
395
+ #pragma region Compute Functions
396
+
397
+ NK_PUBLIC void nk_maxsim_packed_f32_icelake( //
398
+ void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
399
+ nk_size_t depth, nk_f64_t *result) {
400
+
401
+ nk_maxsim_packed_regions_t regions = nk_maxsim_extract_packed_regions_(query_packed, document_packed);
402
+ nk_f64_t total_angular_distance = 0.0;
403
+
404
+ for (nk_size_t chunk_start = 0; chunk_start < query_count; chunk_start += 256) {
405
+ nk_size_t chunk_size = query_count - chunk_start < 256 ? query_count - chunk_start : 256;
406
+ nk_u32_t best_document_indices[256];
407
+
408
+ nk_maxsim_coarse_argmax_icelake_(regions.query_quantized + chunk_start * regions.depth_i8_padded,
409
+ regions.document_quantized, regions.document_metadata, chunk_size,
410
+ document_count, regions.depth_i8_padded, best_document_indices);
411
+
412
+ for (nk_size_t query_index = 0; query_index < chunk_size; query_index++) {
413
+ nk_u32_t best_document_index = best_document_indices[query_index];
414
+ nk_f64_t dot_result;
415
+ nk_dot_f32(
416
+ (nk_f32_t const *)(regions.query_originals +
417
+ (chunk_start + query_index) * regions.query_original_stride),
418
+ (nk_f32_t const *)(regions.document_originals + best_document_index * regions.document_original_stride),
419
+ depth, &dot_result);
420
+ nk_f64_t cosine = dot_result *
421
+ (nk_f64_t)regions.query_metadata[chunk_start + query_index].inverse_norm_f32 *
422
+ (nk_f64_t)regions.document_metadata[best_document_index].inverse_norm_f32;
423
+ nk_f64_t angular = 1.0 - cosine;
424
+ if (angular < 0.0) angular = 0.0;
425
+ total_angular_distance += angular;
426
+ }
427
+ }
428
+
429
+ *result = total_angular_distance;
430
+ }
431
+
432
+ NK_PUBLIC void nk_maxsim_packed_f16_icelake( //
433
+ void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
434
+ nk_size_t depth, nk_f32_t *result) {
435
+
436
+ nk_maxsim_packed_regions_t regions = nk_maxsim_extract_packed_regions_(query_packed, document_packed);
437
+ nk_f64_t total_angular_distance = 0.0;
438
+
439
+ for (nk_size_t chunk_start = 0; chunk_start < query_count; chunk_start += 256) {
440
+ nk_size_t chunk_size = query_count - chunk_start < 256 ? query_count - chunk_start : 256;
441
+ nk_u32_t best_document_indices[256];
442
+
443
+ nk_maxsim_coarse_argmax_icelake_(regions.query_quantized + chunk_start * regions.depth_i8_padded,
444
+ regions.document_quantized, regions.document_metadata, chunk_size,
445
+ document_count, regions.depth_i8_padded, best_document_indices);
446
+
447
+ for (nk_size_t query_index = 0; query_index < chunk_size; query_index++) {
448
+ nk_u32_t best_document_index = best_document_indices[query_index];
449
+ nk_f32_t dot_result;
450
+ nk_dot_f16(
451
+ (nk_f16_t const *)(regions.query_originals +
452
+ (chunk_start + query_index) * regions.query_original_stride),
453
+ (nk_f16_t const *)(regions.document_originals + best_document_index * regions.document_original_stride),
454
+ depth, &dot_result);
455
+ nk_f32_t cosine = dot_result * regions.query_metadata[chunk_start + query_index].inverse_norm_f32 *
456
+ regions.document_metadata[best_document_index].inverse_norm_f32;
457
+ nk_f32_t angular = 1.0f - cosine;
458
+ if (angular < 0.0f) angular = 0.0f;
459
+ total_angular_distance += (nk_f64_t)angular;
460
+ }
461
+ }
462
+
463
+ *result = (nk_f32_t)total_angular_distance;
464
+ }
465
+
466
+ #pragma endregion
467
+
468
+ #if defined(__clang__)
469
+ #pragma clang attribute pop
470
+ #elif defined(__GNUC__)
471
+ #pragma GCC pop_options
472
+ #endif
473
+
474
+ #if defined(__cplusplus)
475
+ } // extern "C"
476
+ #endif
477
+
478
+ #endif // NK_TARGET_ICELAKE
479
+ #endif // NK_TARGET_X86_
480
+ #endif // NK_MAXSIM_ICELAKE_H