numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,877 @@
1
+ /**
2
+ * @brief SIMD-accelerated MaxSim (ColBERT late-interaction) for Sapphire Rapids AMX.
3
+ * @file include/numkong/maxsim/sapphireamx.h
4
+ * @author Ash Vardanian
5
+ * @date March 7, 2026
6
+ *
7
+ * @sa include/numkong/maxsim.h
8
+ *
9
+ * bf16: fused AMX approach using TDPBF16PS for direct bf16 dot products,
10
+ * with per-tile column extraction for running argmax and angular distance finalization.
11
+ * Uses 4 accumulator tiles (TMM4-7) for 4-way document tile pipelining.
12
+ *
13
+ * f32/f16: coarse i8 screening via AMX TDPBSSD (signed i8 × signed i8 → i32)
14
+ * with 4-accumulator pipeline, then full-precision refinement with nk_dot_f32/nk_dot_f16.
15
+ *
16
+ * TMM register allocation (all 3 dtypes):
17
+ * - TMM0: query (A-side) — loaded once per depth step
18
+ * - TMM1: document (B-side) — reloaded 4× per depth step (one per doc tile)
19
+ * - TMM4: accumulator 0 (doc tile 0)
20
+ * - TMM5: accumulator 1 (doc tile 1)
21
+ * - TMM6: accumulator 2 (doc tile 2)
22
+ * - TMM7: accumulator 3 (doc tile 3)
23
+ * - TMM2, TMM3: unused
24
+ *
25
+ * BF16 packed layout:
26
+ * [Header 64B] [0-63B padding for 64B alignment]
27
+ * [A-side tiles: col_tiles × depth_tiles × 1KB]
28
+ * [B-side tiles: col_tiles × depth_tiles × 1KB]
29
+ * [inverse norms: n × f32]
30
+ *
31
+ * i8 packed layout (f32/f16):
32
+ * [Header 64B] [0-63B padding for 64B alignment]
33
+ * [i8 A-side tiles: col_tiles × depth_tiles × 1KB]
34
+ * [i8 B-side tiles: col_tiles × depth_tiles × 1KB]
35
+ * [originals 64B-aligned: n × original_stride]
36
+ * [inverse norms: n × f32]
37
+ *
38
+ * Intrinsic Instruction Notes
39
+ * _tile_dpbf16ps TDPBF16PS C += A × B (bf16 → f32), 16×16×32 MACs
40
+ * _tile_dpbssd TDPBSSD C += A × B (i8 × i8 → i32), 16×16×64 MACs
41
+ * _tile_loadd TILELOADD Load tile from memory
42
+ * _tile_stored TILESTORED Store tile to memory
43
+ * _tile_zero TILEZERO Zero a tile register
44
+ */
45
+ #ifndef NK_MAXSIM_SAPPHIREAMX_H
46
+ #define NK_MAXSIM_SAPPHIREAMX_H
47
+
48
+ #if NK_TARGET_X86_
49
+ #if NK_TARGET_SAPPHIREAMX
50
+
51
+ #include "numkong/types.h"
52
+ #include "numkong/dots/sapphireamx.h" // AMX tile types, configure, load, transpose
53
+ #include "numkong/dot.h" // `nk_dot_f32`, `nk_dot_f16`
54
+ #include "numkong/cast/haswell.h" // `nk_f16_to_f32_haswell`
55
+ #include "numkong/cast/serial.h" // `nk_bf16_to_f32_serial`
56
+ #include "numkong/scalar/haswell.h" // `nk_f32_rsqrt_haswell`
57
+
58
+ #if defined(__cplusplus)
59
+ extern "C" {
60
+ #endif
61
+
62
+ #if defined(__clang__)
63
+ #pragma clang attribute push( \
64
+ __attribute__((target( \
65
+ "avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512vnni,avx512vbmi,avx512bf16,avx512fp16,f16c,fma,bmi,bmi2,amx-tile,amx-bf16,amx-int8"))), \
66
+ apply_to = function)
67
+ #elif defined(__GNUC__)
68
+ #pragma GCC push_options
69
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vnni", "avx512vbmi", "avx512bf16", \
70
+ "avx512fp16", "f16c", "fma", "bmi", "bmi2", "amx-tile", "amx-bf16", "amx-int8")
71
+ #endif
72
+
73
+ #pragma region i8 Header (for f32/f16 coarse+refine)
74
+
75
+ /**
76
+ * i8 packed buffer header for AMX coarse+refine MaxSim (64 bytes).
77
+ * Stores both A-side (row-major) and B-side (quad-interleaved) i8 tile formats,
78
+ * original f32/f16 vectors for full-precision refinement, and per-vector inverse norms.
79
+ */
80
+ typedef struct {
81
+ nk_u32_t column_tile_count; ///< ceil(n / 16) — number of vector-tile groups
82
+ nk_u32_t depth_tile_count; ///< ceil(depth / 64) — TDPBSSD processes 64 i8 per tile
83
+ nk_u32_t columns; ///< actual vector count
84
+ nk_u32_t depth; ///< actual depth (dimensions per vector)
85
+ nk_u32_t a_side_offset; ///< byte offset from buffer start to 64B-aligned A-side tiles
86
+ nk_u32_t b_side_offset; ///< byte offset from buffer start to i8 B-side tiles
87
+ nk_u32_t originals_offset; ///< byte offset from buffer start to original f32/f16 vectors
88
+ nk_u32_t original_stride_bytes; ///< 64B-aligned stride for originals
89
+ nk_u32_t norms_offset; ///< byte offset from buffer start to f32 inverse norms
90
+ nk_u32_t reserved[7]; ///< padding to 64 bytes
91
+ } nk_maxsim_sapphireamx_i8_header_t;
92
+
93
+ NK_STATIC_ASSERT(sizeof(nk_maxsim_sapphireamx_i8_header_t) == 64, nk_maxsim_sapphireamx_i8_header_must_be_64_bytes);
94
+
95
+ #pragma endregion
96
+
97
+ #pragma region Single Precision Floats
98
+
99
+ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f32_sapphireamx(nk_size_t vector_count, nk_size_t depth) {
100
+ nk_size_t column_tile_count = nk_size_divide_round_up_(vector_count, 16);
101
+ nk_size_t depth_tile_count = nk_size_divide_round_up_(depth, 64);
102
+ nk_size_t a_side_bytes = column_tile_count * depth_tile_count * 1024;
103
+ nk_size_t b_side_bytes = column_tile_count * depth_tile_count * 1024;
104
+ nk_size_t original_stride = nk_size_round_up_to_multiple_(depth * sizeof(nk_f32_t), 64);
105
+ nk_size_t originals_bytes = vector_count * original_stride;
106
+ nk_size_t norms_bytes = vector_count * sizeof(nk_f32_t);
107
+ return 64 + 63 + a_side_bytes + b_side_bytes + originals_bytes + norms_bytes;
108
+ }
109
+
110
+ NK_PUBLIC void nk_maxsim_pack_f32_sapphireamx( //
111
+ nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
112
+
113
+ nk_size_t column_tile_count = nk_size_divide_round_up_(vector_count, 16);
114
+ nk_size_t depth_tile_count = nk_size_divide_round_up_(depth, 64);
115
+ nk_size_t original_stride_bytes = nk_size_round_up_to_multiple_(depth * sizeof(nk_f32_t), 64);
116
+ nk_size_t a_side_total_bytes = column_tile_count * depth_tile_count * 1024;
117
+ nk_size_t b_side_total_bytes = column_tile_count * depth_tile_count * 1024;
118
+
119
+ // Set up header — compute 64B-aligned A-side offset
120
+ nk_maxsim_sapphireamx_i8_header_t *header = (nk_maxsim_sapphireamx_i8_header_t *)packed;
121
+ nk_u32_t a_side_offset = (nk_u32_t)(nk_size_round_up_to_multiple_((nk_size_t)((char *)packed + 64), 64) -
122
+ (nk_size_t)(char *)packed);
123
+ header->column_tile_count = (nk_u32_t)column_tile_count;
124
+ header->depth_tile_count = (nk_u32_t)depth_tile_count;
125
+ header->columns = (nk_u32_t)vector_count;
126
+ header->depth = (nk_u32_t)depth;
127
+ header->a_side_offset = a_side_offset;
128
+ header->b_side_offset = (nk_u32_t)(a_side_offset + a_side_total_bytes);
129
+ header->originals_offset = (nk_u32_t)(a_side_offset + a_side_total_bytes + b_side_total_bytes);
130
+ header->original_stride_bytes = (nk_u32_t)original_stride_bytes;
131
+ header->norms_offset = (nk_u32_t)(a_side_offset + a_side_total_bytes + b_side_total_bytes +
132
+ vector_count * original_stride_bytes);
133
+ for (nk_size_t reserved_index = 0; reserved_index < 7; reserved_index++) header->reserved[reserved_index] = 0;
134
+
135
+ // Pointers to data regions (A-side is guaranteed 64B-aligned)
136
+ nk_i8_t *a_side_base = (nk_i8_t *)((char *)packed + a_side_offset);
137
+ char *b_side_base = (char *)packed + header->b_side_offset;
138
+ char *originals_base = (char *)packed + header->originals_offset;
139
+ nk_f32_t *inverse_norms = (nk_f32_t *)((char *)packed + header->norms_offset);
140
+
141
+ // Zero all A-side tiles (aligned stores — A-side offset is 64B-aligned)
142
+ {
143
+ __m512i zero_i32x16 = _mm512_setzero_si512();
144
+ for (nk_size_t byte_offset = 0; byte_offset < a_side_total_bytes; byte_offset += 64)
145
+ _mm512_store_si512((void *)(a_side_base + byte_offset), zero_i32x16);
146
+ }
147
+
148
+ // Quantize vectors and scatter into A-side tiles, copy originals, compute inverse norms
149
+ for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
150
+ nk_f32_t const *source_vector = (nk_f32_t const *)((char const *)vectors + vector_index * stride);
151
+
152
+ // Pass 1: find absmax and norm_squared
153
+ nk_f32_t absmax_f32 = 0.0f;
154
+ nk_f32_t norm_squared_f32 = 0.0f;
155
+ for (nk_size_t dimension_index = 0; dimension_index < depth; dimension_index++) {
156
+ nk_f32_t element_f32 = source_vector[dimension_index];
157
+ nk_f32_t abs_element_f32 = nk_f32_abs_(element_f32);
158
+ if (abs_element_f32 > absmax_f32) absmax_f32 = abs_element_f32;
159
+ norm_squared_f32 += element_f32 * element_f32;
160
+ }
161
+
162
+ // Pass 2: quantize to i8 [-127,127] and scatter into A-side tile positions
163
+ nk_f32_t inverse_absmax_f32 = (absmax_f32 > 0.0f) ? (1.0f / absmax_f32) : 0.0f;
164
+ nk_size_t column_tile_index = vector_index / 16;
165
+ nk_size_t row_in_tile = vector_index % 16;
166
+
167
+ for (nk_size_t dimension_index = 0; dimension_index < depth; dimension_index++) {
168
+ nk_f32_t element_f32 = source_vector[dimension_index];
169
+ nk_f32_t scaled_f32 = element_f32 * inverse_absmax_f32 * 127.0f;
170
+ nk_i8_t quantized_i8 = (nk_i8_t)(scaled_f32 + (element_f32 > 0.0f ? 0.5f : -0.5f));
171
+
172
+ nk_size_t depth_tile_index = dimension_index / 64;
173
+ nk_size_t column_in_tile = dimension_index % 64;
174
+ nk_size_t tile_flat_index = column_tile_index * depth_tile_count + depth_tile_index;
175
+ a_side_base[tile_flat_index * 1024 + row_in_tile * 64 + column_in_tile] = quantized_i8;
176
+ }
177
+
178
+ // Store inverse norm
179
+ inverse_norms[vector_index] = (norm_squared_f32 > 0.0f) ? nk_f32_rsqrt_haswell(norm_squared_f32) : 0.0f;
180
+
181
+ // Copy original vector with 64B-aligned stride
182
+ char *destination_original = originals_base + vector_index * original_stride_bytes;
183
+ nk_copy_bytes_(destination_original, (char const *)source_vector, depth * sizeof(nk_f32_t));
184
+ for (nk_size_t byte_index = depth * sizeof(nk_f32_t); byte_index < original_stride_bytes; byte_index++)
185
+ destination_original[byte_index] = 0;
186
+ }
187
+
188
+ // Transpose each A-side tile to B-side (both are 64B-aligned via header padding)
189
+ for (nk_size_t tile_flat_index = 0; tile_flat_index < column_tile_count * depth_tile_count; tile_flat_index++) {
190
+ nk_dots_i8_a16x64_sapphireamx_t const *a_tile =
191
+ (nk_dots_i8_a16x64_sapphireamx_t const *)(a_side_base + tile_flat_index * 1024);
192
+ nk_dots_i8_b64x16_sapphireamx_t *b_tile = (nk_dots_i8_b64x16_sapphireamx_t *)(b_side_base +
193
+ tile_flat_index * 1024);
194
+ nk_dots_pack_i8_transposed_sapphireamx_(a_tile, b_tile);
195
+ }
196
+ }
197
+
198
+ NK_PUBLIC void nk_maxsim_packed_f32_sapphireamx( //
199
+ void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
200
+ nk_size_t depth, nk_f64_t *result) {
201
+
202
+ nk_maxsim_sapphireamx_i8_header_t const *query_header = (nk_maxsim_sapphireamx_i8_header_t const *)query_packed;
203
+ nk_maxsim_sapphireamx_i8_header_t const *document_header =
204
+ (nk_maxsim_sapphireamx_i8_header_t const *)document_packed;
205
+
206
+ nk_size_t const depth_tile_count = query_header->depth_tile_count;
207
+ nk_size_t const query_tile_count = query_header->column_tile_count;
208
+ nk_size_t const document_tile_count = document_header->column_tile_count;
209
+
210
+ // Query loads from A-side (64B-aligned), documents from B-side
211
+ char const *query_a_side_base = (char const *)query_packed + query_header->a_side_offset;
212
+ char const *document_b_side_base = (char const *)document_packed + document_header->b_side_offset;
213
+
214
+ // Original vectors for refinement
215
+ char const *query_originals = (char const *)query_packed + query_header->originals_offset;
216
+ char const *document_originals = (char const *)document_packed + document_header->originals_offset;
217
+ nk_size_t const query_original_stride = query_header->original_stride_bytes;
218
+ nk_size_t const document_original_stride = document_header->original_stride_bytes;
219
+
220
+ nk_f32_t const *query_inverse_norms = (nk_f32_t const *)((char const *)query_packed + query_header->norms_offset);
221
+ nk_f32_t const *document_inverse_norms = (nk_f32_t const *)((char const *)document_packed +
222
+ document_header->norms_offset);
223
+
224
+ nk_amx_tile_configure_sapphireamx_();
225
+
226
+ // Gather indices for column extraction from 16×16 tile:
227
+ // tile_result[row][col] at i32 offset row*16 + col
228
+ __m512i const row_stride_indices_i32x16 = _mm512_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192,
229
+ 208, 224, 240);
230
+
231
+ nk_f64_t total_angular_distance_f64 = 0.0;
232
+
233
+ for (nk_size_t query_tile_index = 0; query_tile_index < query_tile_count; query_tile_index++) {
234
+ nk_size_t query_row_start = query_tile_index * 16;
235
+ nk_size_t valid_queries = (query_row_start + 16 <= query_count) ? 16 : (query_count - query_row_start);
236
+
237
+ __m512i running_maximum_i32x16 = _mm512_set1_epi32(NK_I32_MIN);
238
+ __m512i running_argmax_i32x16 = _mm512_setzero_si512();
239
+
240
+ NK_ALIGN64 nk_i32_t tile_results_i32[4][16][16];
241
+ nk_size_t document_tile_index = 0;
242
+
243
+ // Fast path: 4 document tiles at a time
244
+ for (; document_tile_index + 4 <= document_tile_count; document_tile_index += 4) {
245
+ _tile_zero(4);
246
+ _tile_zero(5);
247
+ _tile_zero(6);
248
+ _tile_zero(7);
249
+
250
+ for (nk_size_t depth_step_index = 0; depth_step_index < depth_tile_count; depth_step_index++) {
251
+ nk_size_t query_tile_flat_index = query_tile_index * depth_tile_count + depth_step_index;
252
+
253
+ _tile_loadd(0, (void const *)(query_a_side_base + query_tile_flat_index * 1024), 64);
254
+
255
+ nk_size_t document_tile_flat_0 = (document_tile_index + 0) * depth_tile_count + depth_step_index;
256
+ nk_size_t document_tile_flat_1 = (document_tile_index + 1) * depth_tile_count + depth_step_index;
257
+ nk_size_t document_tile_flat_2 = (document_tile_index + 2) * depth_tile_count + depth_step_index;
258
+ nk_size_t document_tile_flat_3 = (document_tile_index + 3) * depth_tile_count + depth_step_index;
259
+
260
+ _tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_0 * 1024), 64);
261
+ _tile_dpbssd(4, 0, 1);
262
+ _tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_1 * 1024), 64);
263
+ _tile_dpbssd(5, 0, 1);
264
+ _tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_2 * 1024), 64);
265
+ _tile_dpbssd(6, 0, 1);
266
+ _tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_3 * 1024), 64);
267
+ _tile_dpbssd(7, 0, 1);
268
+ }
269
+
270
+ _tile_stored(4, tile_results_i32[0], 64);
271
+ _tile_stored(5, tile_results_i32[1], 64);
272
+ _tile_stored(6, tile_results_i32[2], 64);
273
+ _tile_stored(7, tile_results_i32[3], 64);
274
+
275
+ // Column extraction from 4 tiles
276
+ for (nk_size_t tile_offset = 0; tile_offset < 4; tile_offset++) {
277
+ nk_size_t document_column_start = (document_tile_index + tile_offset) * 16;
278
+ for (nk_size_t column_within_tile = 0; column_within_tile < 16; column_within_tile++) {
279
+ __m512i gather_index_i32x16 = _mm512_add_epi32(row_stride_indices_i32x16,
280
+ _mm512_set1_epi32((int)column_within_tile));
281
+ __m512i column_dots_i32x16 = _mm512_i32gather_epi32(gather_index_i32x16,
282
+ tile_results_i32[tile_offset], 4);
283
+ __mmask16 is_better_bx16 = _mm512_cmpgt_epi32_mask(column_dots_i32x16, running_maximum_i32x16);
284
+ running_maximum_i32x16 = _mm512_mask_mov_epi32(running_maximum_i32x16, is_better_bx16,
285
+ column_dots_i32x16);
286
+ running_argmax_i32x16 = _mm512_mask_mov_epi32(
287
+ running_argmax_i32x16, is_better_bx16,
288
+ _mm512_set1_epi32((int)(document_column_start + column_within_tile)));
289
+ }
290
+ }
291
+ }
292
+
293
+ // Remainder: 1 document tile at a time
294
+ for (; document_tile_index < document_tile_count; document_tile_index++) {
295
+ nk_size_t document_column_start = document_tile_index * 16;
296
+ nk_size_t valid_documents = (document_column_start + 16 <= document_count)
297
+ ? 16
298
+ : (document_count - document_column_start);
299
+
300
+ _tile_zero(4);
301
+
302
+ for (nk_size_t depth_step_index = 0; depth_step_index < depth_tile_count; depth_step_index++) {
303
+ nk_size_t query_tile_flat_index = query_tile_index * depth_tile_count + depth_step_index;
304
+ nk_size_t document_tile_flat_index = document_tile_index * depth_tile_count + depth_step_index;
305
+
306
+ _tile_loadd(0, (void const *)(query_a_side_base + query_tile_flat_index * 1024), 64);
307
+ _tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_index * 1024), 64);
308
+ _tile_dpbssd(4, 0, 1);
309
+ }
310
+
311
+ _tile_stored(4, tile_results_i32[0], 64);
312
+
313
+ for (nk_size_t column_within_tile = 0; column_within_tile < valid_documents; column_within_tile++) {
314
+ __m512i gather_index_i32x16 = _mm512_add_epi32(row_stride_indices_i32x16,
315
+ _mm512_set1_epi32((int)column_within_tile));
316
+ __m512i column_dots_i32x16 = _mm512_i32gather_epi32(gather_index_i32x16, tile_results_i32[0], 4);
317
+ __mmask16 is_better_bx16 = _mm512_cmpgt_epi32_mask(column_dots_i32x16, running_maximum_i32x16);
318
+ running_maximum_i32x16 = _mm512_mask_mov_epi32(running_maximum_i32x16, is_better_bx16,
319
+ column_dots_i32x16);
320
+ running_argmax_i32x16 = _mm512_mask_mov_epi32(
321
+ running_argmax_i32x16, is_better_bx16,
322
+ _mm512_set1_epi32((int)(document_column_start + column_within_tile)));
323
+ }
324
+ }
325
+
326
+ // Refinement: for each valid query, compute full-precision dot with best document
327
+ NK_ALIGN64 nk_i32_t best_document_indices_i32[16];
328
+ _mm512_store_si512(best_document_indices_i32, running_argmax_i32x16);
329
+
330
+ for (nk_size_t query_in_tile = 0; query_in_tile < valid_queries; query_in_tile++) {
331
+ nk_size_t query_index = query_row_start + query_in_tile;
332
+ nk_u32_t best_document_index = (nk_u32_t)best_document_indices_i32[query_in_tile];
333
+
334
+ nk_f64_t dot_result_f64;
335
+ nk_dot_f32((nk_f32_t const *)(query_originals + query_index * query_original_stride),
336
+ (nk_f32_t const *)(document_originals + best_document_index * document_original_stride), depth,
337
+ &dot_result_f64);
338
+
339
+ nk_f64_t cosine_f64 = dot_result_f64 * (nk_f64_t)query_inverse_norms[query_index] *
340
+ (nk_f64_t)document_inverse_norms[best_document_index];
341
+ nk_f64_t angular_distance_f64 = 1.0 - cosine_f64;
342
+ if (angular_distance_f64 < 0.0) angular_distance_f64 = 0.0;
343
+ total_angular_distance_f64 += angular_distance_f64;
344
+ }
345
+ }
346
+
347
+ *result = total_angular_distance_f64;
348
+ }
349
+
350
+ #pragma endregion
351
+
352
+ #pragma region Half Precision Floats
353
+
354
+ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_sapphireamx(nk_size_t vector_count, nk_size_t depth) {
355
+ nk_size_t column_tile_count = nk_size_divide_round_up_(vector_count, 16);
356
+ nk_size_t depth_tile_count = nk_size_divide_round_up_(depth, 64);
357
+ nk_size_t a_side_bytes = column_tile_count * depth_tile_count * 1024;
358
+ nk_size_t b_side_bytes = column_tile_count * depth_tile_count * 1024;
359
+ nk_size_t original_stride = nk_size_round_up_to_multiple_(depth * sizeof(nk_f16_t), 64);
360
+ nk_size_t originals_bytes = vector_count * original_stride;
361
+ nk_size_t norms_bytes = vector_count * sizeof(nk_f32_t);
362
+ return 64 + 63 + a_side_bytes + b_side_bytes + originals_bytes + norms_bytes;
363
+ }
364
+
365
+ NK_PUBLIC void nk_maxsim_pack_f16_sapphireamx( //
366
+ nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
367
+
368
+ nk_size_t column_tile_count = nk_size_divide_round_up_(vector_count, 16);
369
+ nk_size_t depth_tile_count = nk_size_divide_round_up_(depth, 64);
370
+ nk_size_t original_stride_bytes = nk_size_round_up_to_multiple_(depth * sizeof(nk_f16_t), 64);
371
+ nk_size_t a_side_total_bytes = column_tile_count * depth_tile_count * 1024;
372
+ nk_size_t b_side_total_bytes = column_tile_count * depth_tile_count * 1024;
373
+
374
+ // Set up header — compute 64B-aligned A-side offset
375
+ nk_maxsim_sapphireamx_i8_header_t *header = (nk_maxsim_sapphireamx_i8_header_t *)packed;
376
+ nk_u32_t a_side_offset = (nk_u32_t)(nk_size_round_up_to_multiple_((nk_size_t)((char *)packed + 64), 64) -
377
+ (nk_size_t)(char *)packed);
378
+ header->column_tile_count = (nk_u32_t)column_tile_count;
379
+ header->depth_tile_count = (nk_u32_t)depth_tile_count;
380
+ header->columns = (nk_u32_t)vector_count;
381
+ header->depth = (nk_u32_t)depth;
382
+ header->a_side_offset = a_side_offset;
383
+ header->b_side_offset = (nk_u32_t)(a_side_offset + a_side_total_bytes);
384
+ header->originals_offset = (nk_u32_t)(a_side_offset + a_side_total_bytes + b_side_total_bytes);
385
+ header->original_stride_bytes = (nk_u32_t)original_stride_bytes;
386
+ header->norms_offset = (nk_u32_t)(a_side_offset + a_side_total_bytes + b_side_total_bytes +
387
+ vector_count * original_stride_bytes);
388
+ for (nk_size_t reserved_index = 0; reserved_index < 7; reserved_index++) header->reserved[reserved_index] = 0;
389
+
390
+ // Pointers to data regions (A-side is guaranteed 64B-aligned)
391
+ nk_i8_t *a_side_base = (nk_i8_t *)((char *)packed + a_side_offset);
392
+ char *b_side_base = (char *)packed + header->b_side_offset;
393
+ char *originals_base = (char *)packed + header->originals_offset;
394
+ nk_f32_t *inverse_norms = (nk_f32_t *)((char *)packed + header->norms_offset);
395
+
396
+ // Zero all A-side tiles (aligned stores — A-side offset is 64B-aligned)
397
+ {
398
+ __m512i zero_i32x16 = _mm512_setzero_si512();
399
+ for (nk_size_t byte_offset = 0; byte_offset < a_side_total_bytes; byte_offset += 64)
400
+ _mm512_store_si512((void *)(a_side_base + byte_offset), zero_i32x16);
401
+ }
402
+
403
+ // Quantize vectors and scatter into A-side tiles, copy originals, compute inverse norms
404
+ nk_size_t const stride_elements = stride / sizeof(nk_f16_t);
405
+ for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
406
+ nk_f16_t const *source_vector = vectors + vector_index * stride_elements;
407
+
408
+ // Pass 1: find absmax and norm_squared (convert f16 → f32)
409
+ nk_f32_t absmax_f32 = 0.0f;
410
+ nk_f32_t norm_squared_f32 = 0.0f;
411
+ for (nk_size_t dimension_index = 0; dimension_index < depth; dimension_index++) {
412
+ nk_f32_t element_f32;
413
+ nk_f16_to_f32_haswell(&source_vector[dimension_index], &element_f32);
414
+ nk_f32_t abs_element_f32 = nk_f32_abs_(element_f32);
415
+ if (abs_element_f32 > absmax_f32) absmax_f32 = abs_element_f32;
416
+ norm_squared_f32 += element_f32 * element_f32;
417
+ }
418
+
419
+ // Pass 2: quantize to i8 [-127,127] and scatter into A-side tile positions
420
+ nk_f32_t inverse_absmax_f32 = (absmax_f32 > 0.0f) ? (1.0f / absmax_f32) : 0.0f;
421
+ nk_size_t column_tile_index = vector_index / 16;
422
+ nk_size_t row_in_tile = vector_index % 16;
423
+
424
+ for (nk_size_t dimension_index = 0; dimension_index < depth; dimension_index++) {
425
+ nk_f32_t element_f32;
426
+ nk_f16_to_f32_haswell(&source_vector[dimension_index], &element_f32);
427
+ nk_f32_t scaled_f32 = element_f32 * inverse_absmax_f32 * 127.0f;
428
+ nk_i8_t quantized_i8 = (nk_i8_t)(scaled_f32 + (element_f32 > 0.0f ? 0.5f : -0.5f));
429
+
430
+ nk_size_t depth_tile_index = dimension_index / 64;
431
+ nk_size_t column_in_tile = dimension_index % 64;
432
+ nk_size_t tile_flat_index = column_tile_index * depth_tile_count + depth_tile_index;
433
+ a_side_base[tile_flat_index * 1024 + row_in_tile * 64 + column_in_tile] = quantized_i8;
434
+ }
435
+
436
+ // Store inverse norm
437
+ inverse_norms[vector_index] = (norm_squared_f32 > 0.0f) ? nk_f32_rsqrt_haswell(norm_squared_f32) : 0.0f;
438
+
439
+ // Copy original f16 vector with 64B-aligned stride
440
+ char *destination_original = originals_base + vector_index * original_stride_bytes;
441
+ nk_copy_bytes_(destination_original, (char const *)source_vector, depth * sizeof(nk_f16_t));
442
+ for (nk_size_t byte_index = depth * sizeof(nk_f16_t); byte_index < original_stride_bytes; byte_index++)
443
+ destination_original[byte_index] = 0;
444
+ }
445
+
446
+ // Transpose each A-side tile to B-side (both are 64B-aligned via header padding)
447
+ for (nk_size_t tile_flat_index = 0; tile_flat_index < column_tile_count * depth_tile_count; tile_flat_index++) {
448
+ nk_dots_i8_a16x64_sapphireamx_t const *a_tile =
449
+ (nk_dots_i8_a16x64_sapphireamx_t const *)(a_side_base + tile_flat_index * 1024);
450
+ nk_dots_i8_b64x16_sapphireamx_t *b_tile = (nk_dots_i8_b64x16_sapphireamx_t *)(b_side_base +
451
+ tile_flat_index * 1024);
452
+ nk_dots_pack_i8_transposed_sapphireamx_(a_tile, b_tile);
453
+ }
454
+ }
455
+
456
+ NK_PUBLIC void nk_maxsim_packed_f16_sapphireamx( //
457
+ void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
458
+ nk_size_t depth, nk_f32_t *result) {
459
+
460
+ nk_maxsim_sapphireamx_i8_header_t const *query_header = (nk_maxsim_sapphireamx_i8_header_t const *)query_packed;
461
+ nk_maxsim_sapphireamx_i8_header_t const *document_header =
462
+ (nk_maxsim_sapphireamx_i8_header_t const *)document_packed;
463
+
464
+ nk_size_t const depth_tile_count = query_header->depth_tile_count;
465
+ nk_size_t const query_tile_count = query_header->column_tile_count;
466
+ nk_size_t const document_tile_count = document_header->column_tile_count;
467
+
468
+ // Query loads from A-side (64B-aligned), documents from B-side
469
+ char const *query_a_side_base = (char const *)query_packed + query_header->a_side_offset;
470
+ char const *document_b_side_base = (char const *)document_packed + document_header->b_side_offset;
471
+
472
+ // Original vectors for refinement
473
+ char const *query_originals = (char const *)query_packed + query_header->originals_offset;
474
+ char const *document_originals = (char const *)document_packed + document_header->originals_offset;
475
+ nk_size_t const query_original_stride = query_header->original_stride_bytes;
476
+ nk_size_t const document_original_stride = document_header->original_stride_bytes;
477
+
478
+ nk_f32_t const *query_inverse_norms = (nk_f32_t const *)((char const *)query_packed + query_header->norms_offset);
479
+ nk_f32_t const *document_inverse_norms = (nk_f32_t const *)((char const *)document_packed +
480
+ document_header->norms_offset);
481
+
482
+ nk_amx_tile_configure_sapphireamx_();
483
+
484
+ __m512i const row_stride_indices_i32x16 = _mm512_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192,
485
+ 208, 224, 240);
486
+
487
+ nk_f64_t total_angular_distance_f64 = 0.0;
488
+
489
+ for (nk_size_t query_tile_index = 0; query_tile_index < query_tile_count; query_tile_index++) {
490
+ nk_size_t query_row_start = query_tile_index * 16;
491
+ nk_size_t valid_queries = (query_row_start + 16 <= query_count) ? 16 : (query_count - query_row_start);
492
+
493
+ __m512i running_maximum_i32x16 = _mm512_set1_epi32(NK_I32_MIN);
494
+ __m512i running_argmax_i32x16 = _mm512_setzero_si512();
495
+
496
+ NK_ALIGN64 nk_i32_t tile_results_i32[4][16][16];
497
+ nk_size_t document_tile_index = 0;
498
+
499
+ // Fast path: 4 document tiles at a time
500
+ for (; document_tile_index + 4 <= document_tile_count; document_tile_index += 4) {
501
+ _tile_zero(4);
502
+ _tile_zero(5);
503
+ _tile_zero(6);
504
+ _tile_zero(7);
505
+
506
+ for (nk_size_t depth_step_index = 0; depth_step_index < depth_tile_count; depth_step_index++) {
507
+ nk_size_t query_tile_flat_index = query_tile_index * depth_tile_count + depth_step_index;
508
+
509
+ _tile_loadd(0, (void const *)(query_a_side_base + query_tile_flat_index * 1024), 64);
510
+
511
+ nk_size_t document_tile_flat_0 = (document_tile_index + 0) * depth_tile_count + depth_step_index;
512
+ nk_size_t document_tile_flat_1 = (document_tile_index + 1) * depth_tile_count + depth_step_index;
513
+ nk_size_t document_tile_flat_2 = (document_tile_index + 2) * depth_tile_count + depth_step_index;
514
+ nk_size_t document_tile_flat_3 = (document_tile_index + 3) * depth_tile_count + depth_step_index;
515
+
516
+ _tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_0 * 1024), 64);
517
+ _tile_dpbssd(4, 0, 1);
518
+ _tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_1 * 1024), 64);
519
+ _tile_dpbssd(5, 0, 1);
520
+ _tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_2 * 1024), 64);
521
+ _tile_dpbssd(6, 0, 1);
522
+ _tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_3 * 1024), 64);
523
+ _tile_dpbssd(7, 0, 1);
524
+ }
525
+
526
+ _tile_stored(4, tile_results_i32[0], 64);
527
+ _tile_stored(5, tile_results_i32[1], 64);
528
+ _tile_stored(6, tile_results_i32[2], 64);
529
+ _tile_stored(7, tile_results_i32[3], 64);
530
+
531
+ for (nk_size_t tile_offset = 0; tile_offset < 4; tile_offset++) {
532
+ nk_size_t document_column_start = (document_tile_index + tile_offset) * 16;
533
+ for (nk_size_t column_within_tile = 0; column_within_tile < 16; column_within_tile++) {
534
+ __m512i gather_index_i32x16 = _mm512_add_epi32(row_stride_indices_i32x16,
535
+ _mm512_set1_epi32((int)column_within_tile));
536
+ __m512i column_dots_i32x16 = _mm512_i32gather_epi32(gather_index_i32x16,
537
+ tile_results_i32[tile_offset], 4);
538
+ __mmask16 is_better_bx16 = _mm512_cmpgt_epi32_mask(column_dots_i32x16, running_maximum_i32x16);
539
+ running_maximum_i32x16 = _mm512_mask_mov_epi32(running_maximum_i32x16, is_better_bx16,
540
+ column_dots_i32x16);
541
+ running_argmax_i32x16 = _mm512_mask_mov_epi32(
542
+ running_argmax_i32x16, is_better_bx16,
543
+ _mm512_set1_epi32((int)(document_column_start + column_within_tile)));
544
+ }
545
+ }
546
+ }
547
+
548
+ // Remainder: 1 document tile at a time
549
+ for (; document_tile_index < document_tile_count; document_tile_index++) {
550
+ nk_size_t document_column_start = document_tile_index * 16;
551
+ nk_size_t valid_documents = (document_column_start + 16 <= document_count)
552
+ ? 16
553
+ : (document_count - document_column_start);
554
+
555
+ _tile_zero(4);
556
+
557
+ for (nk_size_t depth_step_index = 0; depth_step_index < depth_tile_count; depth_step_index++) {
558
+ nk_size_t query_tile_flat_index = query_tile_index * depth_tile_count + depth_step_index;
559
+ nk_size_t document_tile_flat_index = document_tile_index * depth_tile_count + depth_step_index;
560
+
561
+ _tile_loadd(0, (void const *)(query_a_side_base + query_tile_flat_index * 1024), 64);
562
+ _tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_index * 1024), 64);
563
+ _tile_dpbssd(4, 0, 1);
564
+ }
565
+
566
+ _tile_stored(4, tile_results_i32[0], 64);
567
+
568
+ for (nk_size_t column_within_tile = 0; column_within_tile < valid_documents; column_within_tile++) {
569
+ __m512i gather_index_i32x16 = _mm512_add_epi32(row_stride_indices_i32x16,
570
+ _mm512_set1_epi32((int)column_within_tile));
571
+ __m512i column_dots_i32x16 = _mm512_i32gather_epi32(gather_index_i32x16, tile_results_i32[0], 4);
572
+ __mmask16 is_better_bx16 = _mm512_cmpgt_epi32_mask(column_dots_i32x16, running_maximum_i32x16);
573
+ running_maximum_i32x16 = _mm512_mask_mov_epi32(running_maximum_i32x16, is_better_bx16,
574
+ column_dots_i32x16);
575
+ running_argmax_i32x16 = _mm512_mask_mov_epi32(
576
+ running_argmax_i32x16, is_better_bx16,
577
+ _mm512_set1_epi32((int)(document_column_start + column_within_tile)));
578
+ }
579
+ }
580
+
581
+ // Refinement: for each valid query, compute full-precision dot with best document
582
+ NK_ALIGN64 nk_i32_t best_document_indices_i32[16];
583
+ _mm512_store_si512(best_document_indices_i32, running_argmax_i32x16);
584
+
585
+ for (nk_size_t query_in_tile = 0; query_in_tile < valid_queries; query_in_tile++) {
586
+ nk_size_t query_index = query_row_start + query_in_tile;
587
+ nk_u32_t best_document_index = (nk_u32_t)best_document_indices_i32[query_in_tile];
588
+
589
+ nk_f32_t dot_result_f32;
590
+ nk_dot_f16((nk_f16_t const *)(query_originals + query_index * query_original_stride),
591
+ (nk_f16_t const *)(document_originals + best_document_index * document_original_stride), depth,
592
+ &dot_result_f32);
593
+
594
+ nk_f32_t cosine_f32 = dot_result_f32 * query_inverse_norms[query_index] *
595
+ document_inverse_norms[best_document_index];
596
+ nk_f32_t angular_distance_f32 = 1.0f - cosine_f32;
597
+ if (angular_distance_f32 < 0.0f) angular_distance_f32 = 0.0f;
598
+ total_angular_distance_f64 += (nk_f64_t)angular_distance_f32;
599
+ }
600
+ }
601
+
602
+ *result = (nk_f32_t)total_angular_distance_f64;
603
+ }
604
+
605
+ #pragma endregion
606
+
607
+ #pragma region Brain Floats (Fused AMX)
608
+
609
+ /**
610
+ * BF16 packed buffer header for AMX fused MaxSim (64 bytes).
611
+ * Stores both A-side (row-major) and B-side (pair-interleaved) tile formats
612
+ * plus per-vector inverse norms for angular distance finalization.
613
+ */
614
+ typedef struct {
615
+ nk_u32_t column_tile_count; ///< ceil(n / 16) — number of row-tile groups
616
+ nk_u32_t depth_tile_count; ///< ceil(depth / 32) — BF16 TDPBF16PS depth granularity
617
+ nk_u32_t columns; ///< actual vector count
618
+ nk_u32_t depth; ///< actual depth (dimensions per vector)
619
+ nk_u32_t a_side_offset; ///< byte offset from buffer start to 64B-aligned A-side tiles
620
+ nk_u32_t b_side_offset; ///< byte offset from buffer start to B-side tiles
621
+ nk_u32_t norms_offset; ///< byte offset from buffer start to inverse norms (f32)
622
+ nk_u32_t reserved[9]; ///< padding to 64 bytes
623
+ } nk_maxsim_sapphireamx_bf16_header_t;
624
+
625
+ NK_STATIC_ASSERT(sizeof(nk_maxsim_sapphireamx_bf16_header_t) == 64, nk_maxsim_sapphireamx_bf16_header_must_be_64_bytes);
626
+
627
+ NK_PUBLIC nk_size_t nk_maxsim_packed_size_bf16_sapphireamx(nk_size_t vector_count, nk_size_t depth) {
628
+ nk_size_t const tile_bytes = 1024; // 16 × 32 × 2B = 1KB per tile
629
+ nk_size_t column_tile_count = nk_size_divide_round_up_(vector_count, 16);
630
+ nk_size_t depth_tile_count = nk_size_divide_round_up_(depth, 32);
631
+ nk_size_t a_side_bytes = column_tile_count * depth_tile_count * tile_bytes;
632
+ nk_size_t b_side_bytes = column_tile_count * depth_tile_count * tile_bytes;
633
+ nk_size_t norms_bytes = vector_count * sizeof(nk_f32_t);
634
+ return sizeof(nk_maxsim_sapphireamx_bf16_header_t) + 63 + a_side_bytes + b_side_bytes + norms_bytes;
635
+ }
636
+
637
+ NK_PUBLIC void nk_maxsim_pack_bf16_sapphireamx( //
638
+ nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
639
+
640
+ nk_size_t const tile_bytes = 1024;
641
+ nk_size_t const stride_elements = stride / sizeof(nk_bf16_t);
642
+ nk_size_t column_tile_count = nk_size_divide_round_up_(vector_count, 16);
643
+ nk_size_t depth_tile_count = nk_size_divide_round_up_(depth, 32);
644
+
645
+ // Set up header — compute 64B-aligned A-side offset
646
+ nk_maxsim_sapphireamx_bf16_header_t *header = (nk_maxsim_sapphireamx_bf16_header_t *)packed;
647
+ nk_u32_t a_side_offset = (nk_u32_t)(nk_size_round_up_to_multiple_(
648
+ (nk_size_t)((char *)packed + sizeof(nk_maxsim_sapphireamx_bf16_header_t)),
649
+ 64) -
650
+ (nk_size_t)(char *)packed);
651
+ header->column_tile_count = (nk_u32_t)column_tile_count;
652
+ header->depth_tile_count = (nk_u32_t)depth_tile_count;
653
+ header->columns = (nk_u32_t)vector_count;
654
+ header->depth = (nk_u32_t)depth;
655
+ header->a_side_offset = a_side_offset;
656
+
657
+ nk_size_t a_side_total_bytes = column_tile_count * depth_tile_count * tile_bytes;
658
+ nk_size_t b_side_total_bytes = column_tile_count * depth_tile_count * tile_bytes;
659
+ header->b_side_offset = (nk_u32_t)(a_side_offset + a_side_total_bytes);
660
+ header->norms_offset = (nk_u32_t)(a_side_offset + a_side_total_bytes + b_side_total_bytes);
661
+ for (nk_size_t reserved_index = 0; reserved_index < 9; reserved_index++) header->reserved[reserved_index] = 0;
662
+
663
+ // Pointers to data regions (A-side is guaranteed 64B-aligned)
664
+ char *a_side_base = (char *)packed + a_side_offset;
665
+ char *b_side_base = (char *)packed + header->b_side_offset;
666
+ nk_f32_t *inverse_norms = (nk_f32_t *)((char *)packed + header->norms_offset);
667
+
668
+ // Pack tiles: for each column tile × depth tile, store both A-side and B-side
669
+ for (nk_size_t column_tile_index = 0; column_tile_index < column_tile_count; column_tile_index++) {
670
+ nk_size_t row_start = column_tile_index * 16;
671
+ nk_size_t valid_rows = (row_start + 16 <= vector_count) ? 16 : (vector_count - row_start);
672
+
673
+ for (nk_size_t depth_tile_index = 0; depth_tile_index < depth_tile_count; depth_tile_index++) {
674
+ nk_size_t depth_start = depth_tile_index * 32;
675
+ nk_size_t valid_columns = (depth_start + 32 <= depth) ? 32 : (depth - depth_start);
676
+
677
+ nk_size_t tile_flat_index = column_tile_index * depth_tile_count + depth_tile_index;
678
+
679
+ // Load source vectors into A-side tile (row-major, zero-padded)
680
+ nk_dots_bf16_a16x32_sapphireamx_t a_tile;
681
+ nk_dots_bf16_load_a_sapphireamx_(&a_tile, vectors + row_start * stride_elements + depth_start,
682
+ stride_elements, valid_rows, valid_columns);
683
+
684
+ // Store A-side tile to packed buffer
685
+ nk_copy_bytes_(a_side_base + tile_flat_index * tile_bytes, &a_tile, tile_bytes);
686
+
687
+ // Transpose to B-side tile (pair-interleaved) and store
688
+ nk_dots_bf16_b32x16_sapphireamx_t b_tile;
689
+ nk_dots_pack_bf16_transposed_sapphireamx_(&a_tile, &b_tile);
690
+ nk_copy_bytes_(b_side_base + tile_flat_index * tile_bytes, &b_tile, tile_bytes);
691
+ }
692
+ }
693
+
694
+ // Compute inverse norms for each vector
695
+ for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
696
+ nk_bf16_t const *source_vector = vectors + vector_index * stride_elements;
697
+ nk_f32_t norm_squared_f32 = 0.0f;
698
+ for (nk_size_t dimension_index = 0; dimension_index < depth; dimension_index++) {
699
+ nk_f32_t element_f32;
700
+ nk_bf16_to_f32_serial(&source_vector[dimension_index], &element_f32);
701
+ norm_squared_f32 += element_f32 * element_f32;
702
+ }
703
+ inverse_norms[vector_index] = (norm_squared_f32 > 0.0f) ? nk_f32_rsqrt_haswell(norm_squared_f32) : 0.0f;
704
+ }
705
+ }
706
+
707
+ /**
708
+ * BF16 fused AMX compute: TDPBF16PS tile multiply + column extraction + angular finalization.
709
+ *
710
+ * For each group of 16 queries, processes all document tiles via AMX TDPBF16PS.
711
+ * Fast path uses 4 accumulators (TMM4-7) for 4-way document tile pipelining.
712
+ * Column extraction from the 16×16 f32 accumulator tiles uses AVX-512 gather
713
+ * to build per-document dot product vectors, then element-wise max tracks the
714
+ * running best document per query.
715
+ */
716
+ NK_PUBLIC void nk_maxsim_packed_bf16_sapphireamx( //
717
+ void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
718
+ nk_size_t depth, nk_f32_t *result) {
719
+
720
+ nk_unused_(depth); // tile counts from header encode depth
721
+
722
+ nk_maxsim_sapphireamx_bf16_header_t const *query_header = (nk_maxsim_sapphireamx_bf16_header_t const *)query_packed;
723
+ nk_maxsim_sapphireamx_bf16_header_t const *document_header =
724
+ (nk_maxsim_sapphireamx_bf16_header_t const *)document_packed;
725
+
726
+ nk_size_t const depth_tile_count = query_header->depth_tile_count;
727
+ nk_size_t const query_column_tile_count = query_header->column_tile_count;
728
+ nk_size_t const document_column_tile_count = document_header->column_tile_count;
729
+
730
+ // Query loads from A-side tiles (64B-aligned), documents from B-side tiles
731
+ char const *query_a_side_base = (char const *)query_packed + query_header->a_side_offset;
732
+ char const *document_b_side_base = (char const *)document_packed + document_header->b_side_offset;
733
+
734
+ nk_f32_t const *query_inverse_norms = (nk_f32_t const *)((char const *)query_packed + query_header->norms_offset);
735
+ nk_f32_t const *document_inverse_norms = (nk_f32_t const *)((char const *)document_packed +
736
+ document_header->norms_offset);
737
+
738
+ nk_amx_tile_configure_sapphireamx_();
739
+
740
+ // Gather indices for column extraction from 16×16 f32 tile:
741
+ // tile_result[row][col] is at f32 offset row*16 + col
742
+ __m512i const row_stride_indices_i32x16 = _mm512_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192,
743
+ 208, 224, 240);
744
+
745
+ nk_f64_t total_angular_distance_f64 = 0.0;
746
+
747
+ for (nk_size_t query_tile_index = 0; query_tile_index < query_column_tile_count; query_tile_index++) {
748
+ nk_size_t query_row_start = query_tile_index * 16;
749
+ nk_size_t valid_queries = (query_row_start + 16 <= query_count) ? 16 : (query_count - query_row_start);
750
+ __mmask16 valid_query_mask_bx16 = (valid_queries >= 16) ? (__mmask16)0xFFFF
751
+ : (__mmask16)((1u << valid_queries) - 1);
752
+
753
+ __m512 running_maximum_f32x16 = _mm512_set1_ps(NK_F32_MIN);
754
+ __m512i running_argmax_i32x16 = _mm512_setzero_si512();
755
+
756
+ NK_ALIGN64 nk_f32_t tile_results_f32[4][16][16];
757
+ nk_size_t document_tile_index = 0;
758
+
759
+ // Fast path: 4 document tiles at a time using TMM4-7
760
+ for (; document_tile_index + 4 <= document_column_tile_count; document_tile_index += 4) {
761
+ _tile_zero(4);
762
+ _tile_zero(5);
763
+ _tile_zero(6);
764
+ _tile_zero(7);
765
+
766
+ for (nk_size_t depth_step_index = 0; depth_step_index < depth_tile_count; depth_step_index++) {
767
+ nk_size_t query_tile_flat_index = query_tile_index * depth_tile_count + depth_step_index;
768
+
769
+ _tile_loadd(0, (void const *)(query_a_side_base + query_tile_flat_index * 1024), 64);
770
+
771
+ nk_size_t document_tile_flat_0 = (document_tile_index + 0) * depth_tile_count + depth_step_index;
772
+ nk_size_t document_tile_flat_1 = (document_tile_index + 1) * depth_tile_count + depth_step_index;
773
+ nk_size_t document_tile_flat_2 = (document_tile_index + 2) * depth_tile_count + depth_step_index;
774
+ nk_size_t document_tile_flat_3 = (document_tile_index + 3) * depth_tile_count + depth_step_index;
775
+
776
+ _tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_0 * 1024), 64);
777
+ _tile_dpbf16ps(4, 0, 1);
778
+ _tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_1 * 1024), 64);
779
+ _tile_dpbf16ps(5, 0, 1);
780
+ _tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_2 * 1024), 64);
781
+ _tile_dpbf16ps(6, 0, 1);
782
+ _tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_3 * 1024), 64);
783
+ _tile_dpbf16ps(7, 0, 1);
784
+ }
785
+
786
+ _tile_stored(4, tile_results_f32[0], 64);
787
+ _tile_stored(5, tile_results_f32[1], 64);
788
+ _tile_stored(6, tile_results_f32[2], 64);
789
+ _tile_stored(7, tile_results_f32[3], 64);
790
+
791
+ // Column extraction from 4 tiles
792
+ for (nk_size_t tile_offset = 0; tile_offset < 4; tile_offset++) {
793
+ nk_size_t document_column_start = (document_tile_index + tile_offset) * 16;
794
+ for (nk_size_t column_within_tile = 0; column_within_tile < 16; column_within_tile++) {
795
+ __m512i gather_index_i32x16 = _mm512_add_epi32(row_stride_indices_i32x16,
796
+ _mm512_set1_epi32((int)column_within_tile));
797
+ __m512 column_dots_f32x16 = _mm512_i32gather_ps(gather_index_i32x16,
798
+ (float const *)tile_results_f32[tile_offset], 4);
799
+ __mmask16 is_better_bx16 = _mm512_cmp_ps_mask(column_dots_f32x16, running_maximum_f32x16,
800
+ _CMP_GT_OQ);
801
+ running_maximum_f32x16 = _mm512_mask_mov_ps(running_maximum_f32x16, is_better_bx16,
802
+ column_dots_f32x16);
803
+ running_argmax_i32x16 = _mm512_mask_mov_epi32(
804
+ running_argmax_i32x16, is_better_bx16,
805
+ _mm512_set1_epi32((int)(document_column_start + column_within_tile)));
806
+ }
807
+ }
808
+ }
809
+
810
+ // Remainder: 1 document tile at a time using TMM4 only
811
+ for (; document_tile_index < document_column_tile_count; document_tile_index++) {
812
+ nk_size_t document_column_start = document_tile_index * 16;
813
+ nk_size_t valid_documents = (document_column_start + 16 <= document_count)
814
+ ? 16
815
+ : (document_count - document_column_start);
816
+
817
+ _tile_zero(4);
818
+
819
+ for (nk_size_t depth_step_index = 0; depth_step_index < depth_tile_count; depth_step_index++) {
820
+ nk_size_t query_tile_flat_index = query_tile_index * depth_tile_count + depth_step_index;
821
+ nk_size_t document_tile_flat_index = document_tile_index * depth_tile_count + depth_step_index;
822
+
823
+ _tile_loadd(0, (void const *)(query_a_side_base + query_tile_flat_index * 1024), 64);
824
+ _tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_index * 1024), 64);
825
+ _tile_dpbf16ps(4, 0, 1);
826
+ }
827
+
828
+ _tile_stored(4, tile_results_f32[0], 64);
829
+
830
+ for (nk_size_t column_within_tile = 0; column_within_tile < valid_documents; column_within_tile++) {
831
+ __m512i gather_index_i32x16 = _mm512_add_epi32(row_stride_indices_i32x16,
832
+ _mm512_set1_epi32((int)column_within_tile));
833
+ __m512 column_dots_f32x16 = _mm512_i32gather_ps(gather_index_i32x16, (float const *)tile_results_f32[0],
834
+ 4);
835
+ __mmask16 is_better_bx16 = _mm512_cmp_ps_mask(column_dots_f32x16, running_maximum_f32x16, _CMP_GT_OQ);
836
+ running_maximum_f32x16 = _mm512_mask_mov_ps(running_maximum_f32x16, is_better_bx16, column_dots_f32x16);
837
+ running_argmax_i32x16 = _mm512_mask_mov_epi32(
838
+ running_argmax_i32x16, is_better_bx16,
839
+ _mm512_set1_epi32((int)(document_column_start + column_within_tile)));
840
+ }
841
+ }
842
+
843
+ // Angular distance finalization using AVX-512
844
+ __m512 query_inverse_norms_f32x16 = _mm512_maskz_loadu_ps(valid_query_mask_bx16,
845
+ query_inverse_norms + query_row_start);
846
+ __m512 document_inverse_norms_f32x16 = _mm512_i32gather_ps(running_argmax_i32x16, document_inverse_norms, 4);
847
+
848
+ // cosine = dot × inv_norm_q × inv_norm_d
849
+ __m512 cosine_f32x16 = _mm512_mul_ps(_mm512_mul_ps(running_maximum_f32x16, query_inverse_norms_f32x16),
850
+ document_inverse_norms_f32x16);
851
+
852
+ // angular = max(1 - cosine, 0), masked to valid queries only
853
+ __m512 angular_distance_f32x16 = _mm512_max_ps(_mm512_sub_ps(_mm512_set1_ps(1.0f), cosine_f32x16),
854
+ _mm512_setzero_ps());
855
+ angular_distance_f32x16 = _mm512_maskz_mov_ps(valid_query_mask_bx16, angular_distance_f32x16);
856
+
857
+ total_angular_distance_f64 += (nk_f64_t)_mm512_reduce_add_ps(angular_distance_f32x16);
858
+ }
859
+
860
+ *result = (nk_f32_t)total_angular_distance_f64;
861
+ }
862
+
863
+ #pragma endregion
864
+
865
+ #if defined(__clang__)
866
+ #pragma clang attribute pop
867
+ #elif defined(__GNUC__)
868
+ #pragma GCC pop_options
869
+ #endif
870
+
871
+ #if defined(__cplusplus)
872
+ } // extern "C"
873
+ #endif
874
+
875
+ #endif // NK_TARGET_SAPPHIREAMX
876
+ #endif // NK_TARGET_X86_
877
+ #endif // NK_MAXSIM_SAPPHIREAMX_H