numkong 7.0.0 → 7.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +239 -122
  2. package/binding.gyp +25 -491
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -78,15 +78,15 @@
78
78
  * but only execute once per point-pair. The polynomial trig approximations use FMA chains.
79
79
  * Note: ZMM sqrt is faster on Genoa (15c) than Ice Lake (19c) due to better 512-bit support.
80
80
  *
81
- * Intrinsic Instruction Ice Genoa
82
- * _mm256_sqrt_ps VSQRTPS (YMM, YMM) 12c @ p0 15c @ p01
83
- * _mm256_sqrt_pd VSQRTPD (YMM, YMM) 13c @ p0 21c @ p01
84
- * _mm512_sqrt_ps VSQRTPS (ZMM, ZMM) 19c @ p05 15c @ p01
85
- * _mm512_sqrt_pd VSQRTPD (ZMM, ZMM) 23c @ p05 21c @ p01
86
- * _mm256_div_ps VDIVPS (YMM, YMM, YMM) 11c @ p0 11c @ p01
87
- * _mm256_div_pd VDIVPD (YMM, YMM, YMM) 13c @ p0 13c @ p01
88
- * _mm256_fmadd_ps VFMADD231PS (YMM, YMM, YMM) 4c @ p01 4c @ p01
89
- * _mm256_fmadd_pd VFMADD231PD (YMM, YMM, YMM) 4c @ p01 4c @ p01
81
+ * Intrinsic Instruction Icelake Genoa
82
+ * _mm256_sqrt_ps VSQRTPS (YMM, YMM) 12cy @ p0 15cy @ p01
83
+ * _mm256_sqrt_pd VSQRTPD (YMM, YMM) 13cy @ p0 21cy @ p01
84
+ * _mm512_sqrt_ps VSQRTPS (ZMM, ZMM) 19cy @ p0+p0+p05 15cy @ p01
85
+ * _mm512_sqrt_pd VSQRTPD (ZMM, ZMM) 23cy @ p0+p0+p05 21cy @ p01
86
+ * _mm256_div_ps VDIVPS (YMM, YMM, YMM) 11cy @ p0 11cy @ p01
87
+ * _mm256_div_pd VDIVPD (YMM, YMM, YMM) 13cy @ p0 13cy @ p01
88
+ * _mm256_fmadd_ps VFMADD231PS (YMM, YMM, YMM) 4cy @ p01 4cy @ p01
89
+ * _mm256_fmadd_pd VFMADD231PD (YMM, YMM, YMM) 4cy @ p01 4cy @ p01
90
90
  *
91
91
  * @section arm_instructions Relevant ARM NEON/SVE Instructions
92
92
  *
@@ -94,21 +94,21 @@
94
94
  * acceptable since sqrt only appears once per distance calculation. FMA chains for trig
95
95
  * polynomial evaluation pipeline well across all 4 V-units.
96
96
  *
97
- * Intrinsic Instruction M1 Firestorm Graviton 3 Graviton 4
98
- * vfmaq_f32 FMLA.S (vec) 4c @ V0123 4c @ V0123 4c @ V0123
99
- * vfmaq_f64 FMLA.D (vec) 4c @ V0123 4c @ V0123 4c @ V0123
100
- * vsqrtq_f32 FSQRT.S (vec) 10c @ V02 10c @ V02 9c @ V02
101
- * vsqrtq_f64 FSQRT.D (vec) 13c @ V02 16c @ V02 16c @ V02
97
+ * Intrinsic Instruction M1 Firestorm Graviton 3 Graviton 4
98
+ * vfmaq_f32 FMLA.S (vec) 4cy @ V0123 4cy @ V0123 4cy @ V0123
99
+ * vfmaq_f64 FMLA.D (vec) 4cy @ V0123 4cy @ V0123 4cy @ V0123
100
+ * vsqrtq_f32 FSQRT.S (vec) 10cy @ V02 10cy @ V02 9cy @ V02
101
+ * vsqrtq_f64 FSQRT.D (vec) 13cy @ V02 16cy @ V02 16cy @ V02
102
102
  *
103
103
  * @section references References
104
104
  *
105
- * - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
105
+ * - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html
106
106
  * - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
107
107
  * - Earth Ellipsoid: https://en.wikipedia.org/wiki/Earth_ellipsoid
108
108
  * - Oblate Spheroid Geodesic: https://mathworld.wolfram.com/OblateSpheroidGeodesic.html
109
- * - Staging experiments: https://github.com/ashvardanian/HaversineMathKong
110
109
  * - Speeding up atan2f by 50x: https://mazzo.li/posts/vectorized-atan2.html
111
- * - Simplifying the GNU C Sine Function: https://www.awelm.com/posts/simplifying-the-gnu-c-sine-function/
110
+ * - Simplifying the GNU C Sine Function:
111
+ * https://web.archive.org/web/20230605051610/https://www.awelm.com/posts/simplifying-the-gnu-c-sine-function/
112
112
  *
113
113
  */
114
114
  #ifndef NK_GEOSPATIAL_H
@@ -26,7 +26,7 @@
26
26
 
27
27
  namespace ashvardanian::numkong {
28
28
 
29
- #pragma region - Packing Utilities
29
+ #pragma region Packing Utilities
30
30
 
31
31
  /**
32
32
  * @brief Estimates the memory requirements for packed B matrix.
@@ -155,9 +155,9 @@ NK_PUBLIC void maxsim_pack(typename in_type_::raw_t const *vectors, std::size_t
155
155
  }
156
156
  }
157
157
 
158
- #pragma endregion - Packing Utilities
158
+ #pragma endregion Packing Utilities
159
159
 
160
- #pragma region - Packed Containers
160
+ #pragma region Packed Containers
161
161
 
162
162
  /**
163
163
  * @brief Owning, move-only, pre-packed matrix for efficient GEMM.
@@ -329,7 +329,7 @@ class packed_maxsim {
329
329
  std::size_t size_bytes() const noexcept { return size_bytes_; }
330
330
  };
331
331
 
332
- #pragma endregion - Packed Containers
332
+ #pragma endregion Packed Containers
333
333
 
334
334
  } // namespace ashvardanian::numkong
335
335
 
@@ -4,21 +4,21 @@ NumKong implements ColBERT-style late-interaction scoring: the MaxSim score sums
4
4
 
5
5
  MaxSim score:
6
6
 
7
- ```math
7
+ $$
8
8
  \text{MaxSim}(Q, D) = \sum_{i=0}^{m-1} \min_{j=0}^{n-1} \text{angular}(q_i, d_j)
9
- ```
9
+ $$
10
10
 
11
11
  Coarse screening finds the best document via i8 dot products as a proxy for argmin angular:
12
12
 
13
- ```math
13
+ $$
14
14
  j^* = \arg\max_j \text{dot}_{\text{i8}}(q_i, d_j)
15
- ```
15
+ $$
16
16
 
17
17
  Full-precision refinement:
18
18
 
19
- ```math
19
+ $$
20
20
  \text{angular}(q_i, d_{j^*}) = 1 - \frac{\text{dot}(q_i, d_{j^*})}{\|q_i\| \cdot \|d_{j^*}\|}
21
- ```
21
+ $$
22
22
 
23
23
  Reformulating as Python pseudocode:
24
24
 
@@ -46,7 +46,7 @@ def maxsim(queries: np.ndarray, documents: np.ndarray) -> float:
46
46
 
47
47
  ## Optimizations
48
48
 
49
- ### Dual Pre-Packing Advantage
49
+ ### Dual Pre-Packing
50
50
 
51
51
  `nk_maxsim_packed_bf16_sme`, `nk_maxsim_packed_f32_sme` benefit from having _both_ query and document matrices pre-packed into identical contiguous formats, unlike the `nk_dots_packed_*` family where only B is pre-packed and A is accessed with arbitrary stride.
52
52
  In the dots GEMM, one ZA tile must be reserved for A-side staging (loading unpacked A rows into the tile array), leaving 3 ZA tiles for accumulation.
@@ -172,16 +172,16 @@ Measured with Wasmtime v42 (Cranelift backend).
172
172
 
173
173
  #### WASM
174
174
 
175
- Measured with Wasmtime v42 (Cranelift backend).
175
+ Measured with Wasmtime v43 (Cranelift backend).
176
176
 
177
177
  | Kernel | 256³ | 1024³ | 4096³ |
178
178
  | :---------------------------------- | -----------------------: | -----------------------: | -----------------------: |
179
179
  | __f32__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
180
- | `nk_maxsim_packed_f32_serial` | 9.22 gso/s, 46.8K ulp | 10.1 gso/s, 46.8K ulp | 10.5 gso/s, 46.8K ulp |
181
- | `nk_maxsim_packed_f32_v128relaxed` | 28.9 gso/s, 46.0K ulp | 31.2 gso/s, 46.0K ulp | 32.0 gso/s, 46.0K ulp |
180
+ | `nk_maxsim_packed_f32_serial` | 33.7 gso/s, 46.8K ulp | 35.0 gso/s, 46.8K ulp | 35.8 gso/s, 46.8K ulp |
181
+ | `nk_maxsim_packed_f32_v128relaxed` | 88.5 gso/s, 46.0K ulp | 98.1 gso/s, 46.0K ulp | 82.7 gso/s, 46.0K ulp |
182
182
  | __bf16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
183
- | `nk_maxsim_packed_bf16_serial` | 8.95 gso/s, 49.2K ulp | 10.1 gso/s, 49.2K ulp | 10.0 gso/s, 49.2K ulp |
184
- | `nk_maxsim_packed_bf16_v128relaxed` | 29.6 gso/s, 49.4K ulp | 31.9 gso/s, 49.4K ulp | 31.6 gso/s, 49.4K ulp |
183
+ | `nk_maxsim_packed_bf16_serial` | 34.4 gso/s, 49.2K ulp | 35.1 gso/s, 49.2K ulp | 35.7 gso/s, 49.2K ulp |
184
+ | `nk_maxsim_packed_bf16_v128relaxed` | 92.3 gso/s, 49.4K ulp | 100 gso/s, 49.4K ulp | 83.2 gso/s, 49.4K ulp |
185
185
  | __f16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
186
- | `nk_maxsim_packed_f16_serial` | 9.21 gso/s, 49.5K ulp | 10.3 gso/s, 49.5K ulp | 10.6 gso/s, 49.5K ulp |
187
- | `nk_maxsim_packed_f16_v128relaxed` | 27.2 gso/s, 49.3K ulp | 33.7 gso/s, 49.3K ulp | 31.5 gso/s, 49.3K ulp |
186
+ | `nk_maxsim_packed_f16_serial` | 33.8 gso/s, 49.5K ulp | 35.0 gso/s, 49.5K ulp | 35.7 gso/s, 49.5K ulp |
187
+ | `nk_maxsim_packed_f16_v128relaxed` | 87.0 gso/s, 49.3K ulp | 95.8 gso/s, 49.3K ulp | 82.3 gso/s, 49.3K ulp |
@@ -57,7 +57,7 @@ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_alder(nk_size_t vector_count, nk_s
57
57
  }
58
58
 
59
59
  NK_PUBLIC void nk_maxsim_pack_bf16_alder( //
60
- nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
60
+ nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
61
61
 
62
62
  nk_size_t const element_bytes = sizeof(nk_bf16_t);
63
63
  nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 32, element_bytes);
@@ -69,7 +69,7 @@ NK_PUBLIC void nk_maxsim_pack_bf16_alder( //
69
69
  nk_size_t const original_stride = header->original_stride_bytes;
70
70
 
71
71
  for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
72
- char const *source_row = (char const *)vectors + vector_index * stride;
72
+ char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
73
73
  nk_f32_t norm_sq;
74
74
  nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f,
75
75
  (nk_maxsim_to_f32_t)nk_bf16_to_f32_serial,
@@ -83,7 +83,7 @@ NK_PUBLIC void nk_maxsim_pack_bf16_alder( //
83
83
  }
84
84
 
85
85
  NK_PUBLIC void nk_maxsim_pack_f32_alder( //
86
- nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
86
+ nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
87
87
 
88
88
  nk_size_t const element_bytes = sizeof(nk_f32_t);
89
89
  nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 32, element_bytes);
@@ -95,7 +95,7 @@ NK_PUBLIC void nk_maxsim_pack_f32_alder( //
95
95
  nk_size_t const original_stride = header->original_stride_bytes;
96
96
 
97
97
  for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
98
- char const *source_row = (char const *)vectors + vector_index * stride;
98
+ char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
99
99
  nk_f32_t norm_sq;
100
100
  nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f, nk_f32_to_f32_,
101
101
  &quantized_i8[vector_index * depth_i8_padded], &metadata[vector_index], &norm_sq);
@@ -108,7 +108,7 @@ NK_PUBLIC void nk_maxsim_pack_f32_alder( //
108
108
  }
109
109
 
110
110
  NK_PUBLIC void nk_maxsim_pack_f16_alder( //
111
- nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
111
+ nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
112
112
 
113
113
  nk_size_t const element_bytes = sizeof(nk_f16_t);
114
114
  nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 32, element_bytes);
@@ -120,7 +120,7 @@ NK_PUBLIC void nk_maxsim_pack_f16_alder( //
120
120
  nk_size_t const original_stride = header->original_stride_bytes;
121
121
 
122
122
  for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
123
- char const *source_row = (char const *)vectors + vector_index * stride;
123
+ char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
124
124
  nk_f32_t norm_sq;
125
125
  nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f,
126
126
  (nk_maxsim_to_f32_t)nk_f16_to_f32_haswell,
@@ -9,8 +9,8 @@
9
9
  * Uses AVX-512 VNNI (VPDPBUSD) for coarse i8 screening via icelake.h, and VDPBF16PS for bf16 refinement.
10
10
  * f32/f16 MaxSim variants live in icelake.h — this file only provides bf16 pack and compute.
11
11
  *
12
- * Intrinsic Instruction Genoa (Zen4)
13
- * _mm512_dpbf16_ps VDPBF16PS 6cy @ p01 (512-bit)
12
+ * Intrinsic Instruction Genoa
13
+ * _mm512_dpbf16_ps VDPBF16PS 6cy @ p01
14
14
  */
15
15
  #ifndef NK_MAXSIM_GENOA_H
16
16
  #define NK_MAXSIM_GENOA_H
@@ -41,7 +41,7 @@ NK_PUBLIC nk_size_t nk_maxsim_packed_size_bf16_genoa(nk_size_t vector_count, nk_
41
41
  }
42
42
 
43
43
  NK_PUBLIC void nk_maxsim_pack_bf16_genoa( //
44
- nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
44
+ nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
45
45
 
46
46
  nk_size_t const element_bytes = sizeof(nk_bf16_t);
47
47
  nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 64, element_bytes);
@@ -53,7 +53,7 @@ NK_PUBLIC void nk_maxsim_pack_bf16_genoa( //
53
53
  nk_size_t const original_stride = header->original_stride_bytes;
54
54
 
55
55
  for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
56
- char const *source_row = (char const *)vectors + vector_index * stride;
56
+ char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
57
57
  nk_f32_t norm_sq;
58
58
  nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f,
59
59
  (nk_maxsim_to_f32_t)nk_bf16_to_f32_serial,
@@ -49,7 +49,7 @@ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_haswell(nk_size_t vector_count, nk
49
49
  }
50
50
 
51
51
  NK_PUBLIC void nk_maxsim_pack_bf16_haswell( //
52
- nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
52
+ nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
53
53
 
54
54
  nk_size_t const element_bytes = sizeof(nk_bf16_t);
55
55
  nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 32, element_bytes);
@@ -61,7 +61,7 @@ NK_PUBLIC void nk_maxsim_pack_bf16_haswell( //
61
61
  nk_size_t const original_stride = header->original_stride_bytes;
62
62
 
63
63
  for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
64
- char const *source_row = (char const *)vectors + vector_index * stride;
64
+ char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
65
65
  nk_f32_t norm_sq;
66
66
  nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 79.0f,
67
67
  (nk_maxsim_to_f32_t)nk_bf16_to_f32_serial,
@@ -75,7 +75,7 @@ NK_PUBLIC void nk_maxsim_pack_bf16_haswell( //
75
75
  }
76
76
 
77
77
  NK_PUBLIC void nk_maxsim_pack_f32_haswell( //
78
- nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
78
+ nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
79
79
 
80
80
  nk_size_t const element_bytes = sizeof(nk_f32_t);
81
81
  nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 32, element_bytes);
@@ -87,7 +87,7 @@ NK_PUBLIC void nk_maxsim_pack_f32_haswell( //
87
87
  nk_size_t const original_stride = header->original_stride_bytes;
88
88
 
89
89
  for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
90
- char const *source_row = (char const *)vectors + vector_index * stride;
90
+ char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
91
91
  nk_f32_t norm_sq;
92
92
  nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 79.0f, nk_f32_to_f32_,
93
93
  &quantized_i8[vector_index * depth_i8_padded], &metadata[vector_index], &norm_sq);
@@ -100,7 +100,7 @@ NK_PUBLIC void nk_maxsim_pack_f32_haswell( //
100
100
  }
101
101
 
102
102
  NK_PUBLIC void nk_maxsim_pack_f16_haswell( //
103
- nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
103
+ nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
104
104
 
105
105
  nk_size_t const element_bytes = sizeof(nk_f16_t);
106
106
  nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 32, element_bytes);
@@ -112,7 +112,7 @@ NK_PUBLIC void nk_maxsim_pack_f16_haswell( //
112
112
  nk_size_t const original_stride = header->original_stride_bytes;
113
113
 
114
114
  for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
115
- char const *source_row = (char const *)vectors + vector_index * stride;
115
+ char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
116
116
  nk_f32_t norm_sq;
117
117
  nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 79.0f,
118
118
  (nk_maxsim_to_f32_t)nk_f16_to_f32_haswell,
@@ -9,15 +9,15 @@
9
9
  * Uses AVX-512 VNNI (VPDPBUSD) for coarse i8 screening. The coarse argmax kernel and reduce helper
10
10
  * are shared with genoa.h — genoa.h imports them from this file for its bf16 compute path.
11
11
  *
12
- * VPDPBUSD computes 4 groups of (u8 x i8) -> i32 per 128-bit lane, processing 64 i8 pairs
12
+ * VPDPBUSD computes 4 groups of (u8 × i8) i32 per 128-bit lane, processing 64 i8 pairs
13
13
  * per ZMM register operation. Bias correction via XOR with 0x80 converts signed queries
14
14
  * to unsigned, then subtracts 128 * sum(document_i8) after the depth loop.
15
15
  *
16
- * 4x4 register tiling: 4 queries x 4 documents = 16 ZMM accumulators per depth loop.
16
+ * 4x4 register tiling: 4 queries × 4 documents = 16 ZMM accumulators per depth loop.
17
17
  * Each document load is amortized across 4 VPDPBUSDs, and each query load across 4 documents.
18
18
  *
19
- * Intrinsic Instruction Icelake Genoa (Zen4)
20
- * _mm512_dpbusd_epi32 VPDPBUSD 5cy @ p0 4cy @ p01 (512-bit)
19
+ * Intrinsic Instruction Icelake Genoa
20
+ * _mm512_dpbusd_epi32 VPDPBUSD 5cy @ p0 4cy @ p01
21
21
  */
22
22
  #ifndef NK_MAXSIM_ICELAKE_H
23
23
  #define NK_MAXSIM_ICELAKE_H
@@ -44,14 +44,14 @@ extern "C" {
44
44
  #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vnni", "f16c", "fma", "bmi", "bmi2")
45
45
  #endif
46
46
 
47
- #pragma region Single Precision Floats
47
+ #pragma region F32 Floats
48
48
 
49
49
  NK_PUBLIC nk_size_t nk_maxsim_packed_size_f32_icelake(nk_size_t vector_count, nk_size_t depth) {
50
50
  return nk_maxsim_packed_size_(vector_count, depth, sizeof(nk_f32_t), 64);
51
51
  }
52
52
 
53
53
  NK_PUBLIC void nk_maxsim_pack_f32_icelake( //
54
- nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
54
+ nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
55
55
 
56
56
  nk_size_t const element_bytes = sizeof(nk_f32_t);
57
57
  nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 64, element_bytes);
@@ -63,7 +63,7 @@ NK_PUBLIC void nk_maxsim_pack_f32_icelake( //
63
63
  nk_size_t const original_stride = header->original_stride_bytes;
64
64
 
65
65
  for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
66
- char const *source_row = (char const *)vectors + vector_index * stride;
66
+ char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
67
67
  nk_f32_t norm_sq;
68
68
  nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f, nk_f32_to_f32_,
69
69
  &quantized_i8[vector_index * depth_i8_padded], &metadata[vector_index], &norm_sq);
@@ -75,16 +75,16 @@ NK_PUBLIC void nk_maxsim_pack_f32_icelake( //
75
75
  }
76
76
  }
77
77
 
78
- #pragma endregion
78
+ #pragma endregion F32 Floats
79
79
 
80
- #pragma region Half Precision Floats
80
+ #pragma region F16 Floats
81
81
 
82
82
  NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_icelake(nk_size_t vector_count, nk_size_t depth) {
83
83
  return nk_maxsim_packed_size_(vector_count, depth, sizeof(nk_f16_t), 64);
84
84
  }
85
85
 
86
86
  NK_PUBLIC void nk_maxsim_pack_f16_icelake( //
87
- nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
87
+ nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
88
88
 
89
89
  nk_size_t const element_bytes = sizeof(nk_f16_t);
90
90
  nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 64, element_bytes);
@@ -96,7 +96,7 @@ NK_PUBLIC void nk_maxsim_pack_f16_icelake( //
96
96
  nk_size_t const original_stride = header->original_stride_bytes;
97
97
 
98
98
  for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
99
- char const *source_row = (char const *)vectors + vector_index * stride;
99
+ char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
100
100
  nk_f32_t norm_sq;
101
101
  nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f,
102
102
  (nk_maxsim_to_f32_t)nk_f16_to_f32_haswell,
@@ -109,7 +109,7 @@ NK_PUBLIC void nk_maxsim_pack_f16_icelake( //
109
109
  }
110
110
  }
111
111
 
112
- #pragma endregion
112
+ #pragma endregion F16 Floats
113
113
 
114
114
  #pragma region Coarse Argmax
115
115
 
@@ -117,7 +117,7 @@ NK_PUBLIC void nk_maxsim_pack_f16_icelake( //
117
117
  NK_INTERNAL __m128i nk_maxsim_reduce_i32x16x4_icelake_( //
118
118
  __m512i accumulator_a_i32x16, __m512i accumulator_b_i32x16, //
119
119
  __m512i accumulator_c_i32x16, __m512i accumulator_d_i32x16) {
120
- // Step 1: 16 -> 8 (extract high 256-bit half and add to low half)
120
+ // Step 1: 16 8 (extract high 256-bit half and add to low half)
121
121
  __m256i sum_a_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(accumulator_a_i32x16),
122
122
  _mm512_extracti32x8_epi32(accumulator_a_i32x16, 1));
123
123
  __m256i sum_b_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(accumulator_b_i32x16),
@@ -126,12 +126,12 @@ NK_INTERNAL __m128i nk_maxsim_reduce_i32x16x4_icelake_( //
126
126
  _mm512_extracti32x8_epi32(accumulator_c_i32x16, 1));
127
127
  __m256i sum_d_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(accumulator_d_i32x16),
128
128
  _mm512_extracti32x8_epi32(accumulator_d_i32x16, 1));
129
- // Step 2: 8 -> 4 (extract high 128-bit half and add to low half)
129
+ // Step 2: 8 4 (extract high 128-bit half and add to low half)
130
130
  __m128i sum_a_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_a_i32x8), _mm256_extracti128_si256(sum_a_i32x8, 1));
131
131
  __m128i sum_b_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_b_i32x8), _mm256_extracti128_si256(sum_b_i32x8, 1));
132
132
  __m128i sum_c_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_c_i32x8), _mm256_extracti128_si256(sum_c_i32x8, 1));
133
133
  __m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_d_i32x8), _mm256_extracti128_si256(sum_d_i32x8, 1));
134
- // Step 3: 4x4 transpose + reduce -> [sum_a, sum_b, sum_c, sum_d]
134
+ // Step 3: 4x4 transpose + reduce [sum_a, sum_b, sum_c, sum_d]
135
135
  __m128i transpose_ab_low_i32x4 = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
136
136
  __m128i transpose_cd_low_i32x4 = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
137
137
  __m128i transpose_ab_high_i32x4 = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
@@ -258,7 +258,7 @@ NK_INTERNAL void nk_maxsim_coarse_argmax_icelake_( //
258
258
  query_2_coarse_dots_i32x4 = _mm_sub_epi32(query_2_coarse_dots_i32x4, bias_correction_i32x4);
259
259
  query_3_coarse_dots_i32x4 = _mm_sub_epi32(query_3_coarse_dots_i32x4, bias_correction_i32x4);
260
260
 
261
- // 4x4 transpose: [query][doc] -> [doc][query] for vectorized argmax
261
+ // 4x4 transpose: [query][doc] [doc][query] for vectorized argmax
262
262
  __m128i transpose_queries_01_low_i32x4 = _mm_unpacklo_epi32(query_0_coarse_dots_i32x4,
263
263
  query_1_coarse_dots_i32x4);
264
264
  __m128i transpose_queries_23_low_i32x4 = _mm_unpacklo_epi32(query_2_coarse_dots_i32x4,
@@ -390,7 +390,7 @@ NK_INTERNAL void nk_maxsim_coarse_argmax_icelake_( //
390
390
  }
391
391
  }
392
392
 
393
- #pragma endregion
393
+ #pragma endregion Coarse Argmax
394
394
 
395
395
  #pragma region Compute Functions
396
396
 
@@ -463,7 +463,7 @@ NK_PUBLIC void nk_maxsim_packed_f16_icelake( //
463
463
  *result = (nk_f32_t)total_angular_distance;
464
464
  }
465
465
 
466
- #pragma endregion
466
+ #pragma endregion Compute Functions
467
467
 
468
468
  #if defined(__clang__)
469
469
  #pragma clang attribute pop
@@ -46,7 +46,7 @@ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_neonsdot(nk_size_t vector_count, n
46
46
  }
47
47
 
48
48
  NK_PUBLIC void nk_maxsim_pack_bf16_neonsdot( //
49
- nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
49
+ nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
50
50
 
51
51
  nk_size_t const element_bytes = sizeof(nk_bf16_t);
52
52
  nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 16, element_bytes);
@@ -58,7 +58,7 @@ NK_PUBLIC void nk_maxsim_pack_bf16_neonsdot( //
58
58
  nk_size_t const original_stride = header->original_stride_bytes;
59
59
 
60
60
  for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
61
- char const *source_row = (char const *)vectors + vector_index * stride;
61
+ char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
62
62
  nk_f32_t norm_sq;
63
63
  nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f,
64
64
  (nk_maxsim_to_f32_t)nk_bf16_to_f32_serial,
@@ -72,7 +72,7 @@ NK_PUBLIC void nk_maxsim_pack_bf16_neonsdot( //
72
72
  }
73
73
 
74
74
  NK_PUBLIC void nk_maxsim_pack_f32_neonsdot( //
75
- nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
75
+ nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
76
76
 
77
77
  nk_size_t const element_bytes = sizeof(nk_f32_t);
78
78
  nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 16, element_bytes);
@@ -84,7 +84,7 @@ NK_PUBLIC void nk_maxsim_pack_f32_neonsdot( //
84
84
  nk_size_t const original_stride = header->original_stride_bytes;
85
85
 
86
86
  for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
87
- char const *source_row = (char const *)vectors + vector_index * stride;
87
+ char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
88
88
  nk_f32_t norm_sq;
89
89
  nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f, nk_f32_to_f32_,
90
90
  &quantized_i8[vector_index * depth_i8_padded], &metadata[vector_index], &norm_sq);
@@ -97,7 +97,7 @@ NK_PUBLIC void nk_maxsim_pack_f32_neonsdot( //
97
97
  }
98
98
 
99
99
  NK_PUBLIC void nk_maxsim_pack_f16_neonsdot( //
100
- nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
100
+ nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
101
101
 
102
102
  nk_size_t const element_bytes = sizeof(nk_f16_t);
103
103
  nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 16, element_bytes);
@@ -109,7 +109,7 @@ NK_PUBLIC void nk_maxsim_pack_f16_neonsdot( //
109
109
  nk_size_t const original_stride = header->original_stride_bytes;
110
110
 
111
111
  for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
112
- char const *source_row = (char const *)vectors + vector_index * stride;
112
+ char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
113
113
  nk_f32_t norm_sq;
114
114
  nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f,
115
115
  (nk_maxsim_to_f32_t)nk_f16_to_f32_neon,
@@ -149,39 +149,39 @@ NK_INTERNAL void nk_maxsim_coarse_argmax_neonsdot_(
149
149
  // Depth loop: 16 bytes per step
150
150
  for (nk_size_t depth_index = 0; depth_index < depth_i8_padded; depth_index += 16) {
151
151
  int8x16_t query_i8x16_0 = vld1q_s8(
152
- (int8_t const *)(query_i8 + (query_block_start_index + 0) * depth_i8_padded + depth_index));
152
+ (nk_i8_t const *)(query_i8 + (query_block_start_index + 0) * depth_i8_padded + depth_index));
153
153
  int8x16_t query_i8x16_1 = vld1q_s8(
154
- (int8_t const *)(query_i8 + (query_block_start_index + 1) * depth_i8_padded + depth_index));
154
+ (nk_i8_t const *)(query_i8 + (query_block_start_index + 1) * depth_i8_padded + depth_index));
155
155
  int8x16_t query_i8x16_2 = vld1q_s8(
156
- (int8_t const *)(query_i8 + (query_block_start_index + 2) * depth_i8_padded + depth_index));
156
+ (nk_i8_t const *)(query_i8 + (query_block_start_index + 2) * depth_i8_padded + depth_index));
157
157
  int8x16_t query_i8x16_3 = vld1q_s8(
158
- (int8_t const *)(query_i8 + (query_block_start_index + 3) * depth_i8_padded + depth_index));
158
+ (nk_i8_t const *)(query_i8 + (query_block_start_index + 3) * depth_i8_padded + depth_index));
159
159
 
160
160
  int8x16_t document_i8x16;
161
161
 
162
162
  document_i8x16 = vld1q_s8(
163
- (int8_t const *)(document_i8 + (document_block_start_index + 0) * depth_i8_padded + depth_index));
163
+ (nk_i8_t const *)(document_i8 + (document_block_start_index + 0) * depth_i8_padded + depth_index));
164
164
  accumulator_tiles_i32x4[0][0] = vdotq_s32(accumulator_tiles_i32x4[0][0], query_i8x16_0, document_i8x16);
165
165
  accumulator_tiles_i32x4[1][0] = vdotq_s32(accumulator_tiles_i32x4[1][0], query_i8x16_1, document_i8x16);
166
166
  accumulator_tiles_i32x4[2][0] = vdotq_s32(accumulator_tiles_i32x4[2][0], query_i8x16_2, document_i8x16);
167
167
  accumulator_tiles_i32x4[3][0] = vdotq_s32(accumulator_tiles_i32x4[3][0], query_i8x16_3, document_i8x16);
168
168
 
169
169
  document_i8x16 = vld1q_s8(
170
- (int8_t const *)(document_i8 + (document_block_start_index + 1) * depth_i8_padded + depth_index));
170
+ (nk_i8_t const *)(document_i8 + (document_block_start_index + 1) * depth_i8_padded + depth_index));
171
171
  accumulator_tiles_i32x4[0][1] = vdotq_s32(accumulator_tiles_i32x4[0][1], query_i8x16_0, document_i8x16);
172
172
  accumulator_tiles_i32x4[1][1] = vdotq_s32(accumulator_tiles_i32x4[1][1], query_i8x16_1, document_i8x16);
173
173
  accumulator_tiles_i32x4[2][1] = vdotq_s32(accumulator_tiles_i32x4[2][1], query_i8x16_2, document_i8x16);
174
174
  accumulator_tiles_i32x4[3][1] = vdotq_s32(accumulator_tiles_i32x4[3][1], query_i8x16_3, document_i8x16);
175
175
 
176
176
  document_i8x16 = vld1q_s8(
177
- (int8_t const *)(document_i8 + (document_block_start_index + 2) * depth_i8_padded + depth_index));
177
+ (nk_i8_t const *)(document_i8 + (document_block_start_index + 2) * depth_i8_padded + depth_index));
178
178
  accumulator_tiles_i32x4[0][2] = vdotq_s32(accumulator_tiles_i32x4[0][2], query_i8x16_0, document_i8x16);
179
179
  accumulator_tiles_i32x4[1][2] = vdotq_s32(accumulator_tiles_i32x4[1][2], query_i8x16_1, document_i8x16);
180
180
  accumulator_tiles_i32x4[2][2] = vdotq_s32(accumulator_tiles_i32x4[2][2], query_i8x16_2, document_i8x16);
181
181
  accumulator_tiles_i32x4[3][2] = vdotq_s32(accumulator_tiles_i32x4[3][2], query_i8x16_3, document_i8x16);
182
182
 
183
183
  document_i8x16 = vld1q_s8(
184
- (int8_t const *)(document_i8 + (document_block_start_index + 3) * depth_i8_padded + depth_index));
184
+ (nk_i8_t const *)(document_i8 + (document_block_start_index + 3) * depth_i8_padded + depth_index));
185
185
  accumulator_tiles_i32x4[0][3] = vdotq_s32(accumulator_tiles_i32x4[0][3], query_i8x16_0, document_i8x16);
186
186
  accumulator_tiles_i32x4[1][3] = vdotq_s32(accumulator_tiles_i32x4[1][3], query_i8x16_1, document_i8x16);
187
187
  accumulator_tiles_i32x4[2][3] = vdotq_s32(accumulator_tiles_i32x4[2][3], query_i8x16_2, document_i8x16);
@@ -211,27 +211,27 @@ NK_INTERNAL void nk_maxsim_coarse_argmax_neonsdot_(
211
211
  int32x4_t accumulator_i32x4_3 = vdupq_n_s32(0);
212
212
 
213
213
  for (nk_size_t depth_index = 0; depth_index < depth_i8_padded; depth_index += 16) {
214
- int8x16_t document_i8x16 = vld1q_s8((int8_t const *)(document_i8_row + depth_index));
214
+ int8x16_t document_i8x16 = vld1q_s8((nk_i8_t const *)(document_i8_row + depth_index));
215
215
 
216
216
  accumulator_i32x4_0 = vdotq_s32(
217
217
  accumulator_i32x4_0,
218
218
  vld1q_s8(
219
- (int8_t const *)(query_i8 + (query_block_start_index + 0) * depth_i8_padded + depth_index)),
219
+ (nk_i8_t const *)(query_i8 + (query_block_start_index + 0) * depth_i8_padded + depth_index)),
220
220
  document_i8x16);
221
221
  accumulator_i32x4_1 = vdotq_s32(
222
222
  accumulator_i32x4_1,
223
223
  vld1q_s8(
224
- (int8_t const *)(query_i8 + (query_block_start_index + 1) * depth_i8_padded + depth_index)),
224
+ (nk_i8_t const *)(query_i8 + (query_block_start_index + 1) * depth_i8_padded + depth_index)),
225
225
  document_i8x16);
226
226
  accumulator_i32x4_2 = vdotq_s32(
227
227
  accumulator_i32x4_2,
228
228
  vld1q_s8(
229
- (int8_t const *)(query_i8 + (query_block_start_index + 2) * depth_i8_padded + depth_index)),
229
+ (nk_i8_t const *)(query_i8 + (query_block_start_index + 2) * depth_i8_padded + depth_index)),
230
230
  document_i8x16);
231
231
  accumulator_i32x4_3 = vdotq_s32(
232
232
  accumulator_i32x4_3,
233
233
  vld1q_s8(
234
- (int8_t const *)(query_i8 + (query_block_start_index + 3) * depth_i8_padded + depth_index)),
234
+ (nk_i8_t const *)(query_i8 + (query_block_start_index + 3) * depth_i8_padded + depth_index)),
235
235
  document_i8x16);
236
236
  }
237
237
 
@@ -260,8 +260,8 @@ NK_INTERNAL void nk_maxsim_coarse_argmax_neonsdot_(
260
260
  int32x4_t accumulator_i32x4 = vdupq_n_s32(0);
261
261
 
262
262
  for (nk_size_t depth_index = 0; depth_index < depth_i8_padded; depth_index += 16) {
263
- int8x16_t query_i8x16 = vld1q_s8((int8_t const *)(query_i8_row + depth_index));
264
- int8x16_t document_i8x16 = vld1q_s8((int8_t const *)(document_i8_row + depth_index));
263
+ int8x16_t query_i8x16 = vld1q_s8((nk_i8_t const *)(query_i8_row + depth_index));
264
+ int8x16_t document_i8x16 = vld1q_s8((nk_i8_t const *)(document_i8_row + depth_index));
265
265
  accumulator_i32x4 = vdotq_s32(accumulator_i32x4, query_i8x16, document_i8x16);
266
266
  }
267
267