numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,877 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated MaxSim (ColBERT late-interaction) for Sapphire Rapids AMX.
|
|
3
|
+
* @file include/numkong/maxsim/sapphireamx.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 7, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/maxsim.h
|
|
8
|
+
*
|
|
9
|
+
* bf16: fused AMX approach using TDPBF16PS for direct bf16 dot products,
|
|
10
|
+
* with per-tile column extraction for running argmax and angular distance finalization.
|
|
11
|
+
* Uses 4 accumulator tiles (TMM4-7) for 4-way document tile pipelining.
|
|
12
|
+
*
|
|
13
|
+
* f32/f16: coarse i8 screening via AMX TDPBSSD (signed i8 × signed i8 → i32)
|
|
14
|
+
* with 4-accumulator pipeline, then full-precision refinement with nk_dot_f32/nk_dot_f16.
|
|
15
|
+
*
|
|
16
|
+
* TMM register allocation (all 3 dtypes):
|
|
17
|
+
* - TMM0: query (A-side) — loaded once per depth step
|
|
18
|
+
* - TMM1: document (B-side) — reloaded 4× per depth step (one per doc tile)
|
|
19
|
+
* - TMM4: accumulator 0 (doc tile 0)
|
|
20
|
+
* - TMM5: accumulator 1 (doc tile 1)
|
|
21
|
+
* - TMM6: accumulator 2 (doc tile 2)
|
|
22
|
+
* - TMM7: accumulator 3 (doc tile 3)
|
|
23
|
+
* - TMM2, TMM3: unused
|
|
24
|
+
*
|
|
25
|
+
* BF16 packed layout:
|
|
26
|
+
* [Header 64B] [0-63B padding for 64B alignment]
|
|
27
|
+
* [A-side tiles: col_tiles × depth_tiles × 1KB]
|
|
28
|
+
* [B-side tiles: col_tiles × depth_tiles × 1KB]
|
|
29
|
+
* [inverse norms: n × f32]
|
|
30
|
+
*
|
|
31
|
+
* i8 packed layout (f32/f16):
|
|
32
|
+
* [Header 64B] [0-63B padding for 64B alignment]
|
|
33
|
+
* [i8 A-side tiles: col_tiles × depth_tiles × 1KB]
|
|
34
|
+
* [i8 B-side tiles: col_tiles × depth_tiles × 1KB]
|
|
35
|
+
* [originals 64B-aligned: n × original_stride]
|
|
36
|
+
* [inverse norms: n × f32]
|
|
37
|
+
*
|
|
38
|
+
* Intrinsic Instruction Notes
|
|
39
|
+
* _tile_dpbf16ps TDPBF16PS C += A × B (bf16 → f32), 16×16×32 MACs
|
|
40
|
+
* _tile_dpbssd TDPBSSD C += A × B (i8 × i8 → i32), 16×16×64 MACs
|
|
41
|
+
* _tile_loadd TILELOADD Load tile from memory
|
|
42
|
+
* _tile_stored TILESTORED Store tile to memory
|
|
43
|
+
* _tile_zero TILEZERO Zero a tile register
|
|
44
|
+
*/
|
|
45
|
+
#ifndef NK_MAXSIM_SAPPHIREAMX_H
|
|
46
|
+
#define NK_MAXSIM_SAPPHIREAMX_H
|
|
47
|
+
|
|
48
|
+
#if NK_TARGET_X86_
|
|
49
|
+
#if NK_TARGET_SAPPHIREAMX
|
|
50
|
+
|
|
51
|
+
#include "numkong/types.h"
|
|
52
|
+
#include "numkong/dots/sapphireamx.h" // AMX tile types, configure, load, transpose
|
|
53
|
+
#include "numkong/dot.h" // `nk_dot_f32`, `nk_dot_f16`
|
|
54
|
+
#include "numkong/cast/haswell.h" // `nk_f16_to_f32_haswell`
|
|
55
|
+
#include "numkong/cast/serial.h" // `nk_bf16_to_f32_serial`
|
|
56
|
+
#include "numkong/scalar/haswell.h" // `nk_f32_rsqrt_haswell`
|
|
57
|
+
|
|
58
|
+
#if defined(__cplusplus)
|
|
59
|
+
extern "C" {
|
|
60
|
+
#endif
|
|
61
|
+
|
|
62
|
+
#if defined(__clang__)
|
|
63
|
+
#pragma clang attribute push( \
|
|
64
|
+
__attribute__((target( \
|
|
65
|
+
"avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512vnni,avx512vbmi,avx512bf16,avx512fp16,f16c,fma,bmi,bmi2,amx-tile,amx-bf16,amx-int8"))), \
|
|
66
|
+
apply_to = function)
|
|
67
|
+
#elif defined(__GNUC__)
|
|
68
|
+
#pragma GCC push_options
|
|
69
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vnni", "avx512vbmi", "avx512bf16", \
|
|
70
|
+
"avx512fp16", "f16c", "fma", "bmi", "bmi2", "amx-tile", "amx-bf16", "amx-int8")
|
|
71
|
+
#endif
|
|
72
|
+
|
|
73
|
+
#pragma region i8 Header (for f32/f16 coarse+refine)
|
|
74
|
+
|
|
75
|
+
/**
|
|
76
|
+
* i8 packed buffer header for AMX coarse+refine MaxSim (64 bytes).
|
|
77
|
+
* Stores both A-side (row-major) and B-side (quad-interleaved) i8 tile formats,
|
|
78
|
+
* original f32/f16 vectors for full-precision refinement, and per-vector inverse norms.
|
|
79
|
+
*/
|
|
80
|
+
typedef struct {
|
|
81
|
+
nk_u32_t column_tile_count; ///< ceil(n / 16) — number of vector-tile groups
|
|
82
|
+
nk_u32_t depth_tile_count; ///< ceil(depth / 64) — TDPBSSD processes 64 i8 per tile
|
|
83
|
+
nk_u32_t columns; ///< actual vector count
|
|
84
|
+
nk_u32_t depth; ///< actual depth (dimensions per vector)
|
|
85
|
+
nk_u32_t a_side_offset; ///< byte offset from buffer start to 64B-aligned A-side tiles
|
|
86
|
+
nk_u32_t b_side_offset; ///< byte offset from buffer start to i8 B-side tiles
|
|
87
|
+
nk_u32_t originals_offset; ///< byte offset from buffer start to original f32/f16 vectors
|
|
88
|
+
nk_u32_t original_stride_bytes; ///< 64B-aligned stride for originals
|
|
89
|
+
nk_u32_t norms_offset; ///< byte offset from buffer start to f32 inverse norms
|
|
90
|
+
nk_u32_t reserved[7]; ///< padding to 64 bytes
|
|
91
|
+
} nk_maxsim_sapphireamx_i8_header_t;
|
|
92
|
+
|
|
93
|
+
NK_STATIC_ASSERT(sizeof(nk_maxsim_sapphireamx_i8_header_t) == 64, nk_maxsim_sapphireamx_i8_header_must_be_64_bytes);
|
|
94
|
+
|
|
95
|
+
#pragma endregion
|
|
96
|
+
|
|
97
|
+
#pragma region Single Precision Floats
|
|
98
|
+
|
|
99
|
+
NK_PUBLIC nk_size_t nk_maxsim_packed_size_f32_sapphireamx(nk_size_t vector_count, nk_size_t depth) {
|
|
100
|
+
nk_size_t column_tile_count = nk_size_divide_round_up_(vector_count, 16);
|
|
101
|
+
nk_size_t depth_tile_count = nk_size_divide_round_up_(depth, 64);
|
|
102
|
+
nk_size_t a_side_bytes = column_tile_count * depth_tile_count * 1024;
|
|
103
|
+
nk_size_t b_side_bytes = column_tile_count * depth_tile_count * 1024;
|
|
104
|
+
nk_size_t original_stride = nk_size_round_up_to_multiple_(depth * sizeof(nk_f32_t), 64);
|
|
105
|
+
nk_size_t originals_bytes = vector_count * original_stride;
|
|
106
|
+
nk_size_t norms_bytes = vector_count * sizeof(nk_f32_t);
|
|
107
|
+
return 64 + 63 + a_side_bytes + b_side_bytes + originals_bytes + norms_bytes;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
NK_PUBLIC void nk_maxsim_pack_f32_sapphireamx( //
|
|
111
|
+
nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
|
|
112
|
+
|
|
113
|
+
nk_size_t column_tile_count = nk_size_divide_round_up_(vector_count, 16);
|
|
114
|
+
nk_size_t depth_tile_count = nk_size_divide_round_up_(depth, 64);
|
|
115
|
+
nk_size_t original_stride_bytes = nk_size_round_up_to_multiple_(depth * sizeof(nk_f32_t), 64);
|
|
116
|
+
nk_size_t a_side_total_bytes = column_tile_count * depth_tile_count * 1024;
|
|
117
|
+
nk_size_t b_side_total_bytes = column_tile_count * depth_tile_count * 1024;
|
|
118
|
+
|
|
119
|
+
// Set up header — compute 64B-aligned A-side offset
|
|
120
|
+
nk_maxsim_sapphireamx_i8_header_t *header = (nk_maxsim_sapphireamx_i8_header_t *)packed;
|
|
121
|
+
nk_u32_t a_side_offset = (nk_u32_t)(nk_size_round_up_to_multiple_((nk_size_t)((char *)packed + 64), 64) -
|
|
122
|
+
(nk_size_t)(char *)packed);
|
|
123
|
+
header->column_tile_count = (nk_u32_t)column_tile_count;
|
|
124
|
+
header->depth_tile_count = (nk_u32_t)depth_tile_count;
|
|
125
|
+
header->columns = (nk_u32_t)vector_count;
|
|
126
|
+
header->depth = (nk_u32_t)depth;
|
|
127
|
+
header->a_side_offset = a_side_offset;
|
|
128
|
+
header->b_side_offset = (nk_u32_t)(a_side_offset + a_side_total_bytes);
|
|
129
|
+
header->originals_offset = (nk_u32_t)(a_side_offset + a_side_total_bytes + b_side_total_bytes);
|
|
130
|
+
header->original_stride_bytes = (nk_u32_t)original_stride_bytes;
|
|
131
|
+
header->norms_offset = (nk_u32_t)(a_side_offset + a_side_total_bytes + b_side_total_bytes +
|
|
132
|
+
vector_count * original_stride_bytes);
|
|
133
|
+
for (nk_size_t reserved_index = 0; reserved_index < 7; reserved_index++) header->reserved[reserved_index] = 0;
|
|
134
|
+
|
|
135
|
+
// Pointers to data regions (A-side is guaranteed 64B-aligned)
|
|
136
|
+
nk_i8_t *a_side_base = (nk_i8_t *)((char *)packed + a_side_offset);
|
|
137
|
+
char *b_side_base = (char *)packed + header->b_side_offset;
|
|
138
|
+
char *originals_base = (char *)packed + header->originals_offset;
|
|
139
|
+
nk_f32_t *inverse_norms = (nk_f32_t *)((char *)packed + header->norms_offset);
|
|
140
|
+
|
|
141
|
+
// Zero all A-side tiles (aligned stores — A-side offset is 64B-aligned)
|
|
142
|
+
{
|
|
143
|
+
__m512i zero_i32x16 = _mm512_setzero_si512();
|
|
144
|
+
for (nk_size_t byte_offset = 0; byte_offset < a_side_total_bytes; byte_offset += 64)
|
|
145
|
+
_mm512_store_si512((void *)(a_side_base + byte_offset), zero_i32x16);
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
// Quantize vectors and scatter into A-side tiles, copy originals, compute inverse norms
|
|
149
|
+
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
150
|
+
nk_f32_t const *source_vector = (nk_f32_t const *)((char const *)vectors + vector_index * stride);
|
|
151
|
+
|
|
152
|
+
// Pass 1: find absmax and norm_squared
|
|
153
|
+
nk_f32_t absmax_f32 = 0.0f;
|
|
154
|
+
nk_f32_t norm_squared_f32 = 0.0f;
|
|
155
|
+
for (nk_size_t dimension_index = 0; dimension_index < depth; dimension_index++) {
|
|
156
|
+
nk_f32_t element_f32 = source_vector[dimension_index];
|
|
157
|
+
nk_f32_t abs_element_f32 = nk_f32_abs_(element_f32);
|
|
158
|
+
if (abs_element_f32 > absmax_f32) absmax_f32 = abs_element_f32;
|
|
159
|
+
norm_squared_f32 += element_f32 * element_f32;
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
// Pass 2: quantize to i8 [-127,127] and scatter into A-side tile positions
|
|
163
|
+
nk_f32_t inverse_absmax_f32 = (absmax_f32 > 0.0f) ? (1.0f / absmax_f32) : 0.0f;
|
|
164
|
+
nk_size_t column_tile_index = vector_index / 16;
|
|
165
|
+
nk_size_t row_in_tile = vector_index % 16;
|
|
166
|
+
|
|
167
|
+
for (nk_size_t dimension_index = 0; dimension_index < depth; dimension_index++) {
|
|
168
|
+
nk_f32_t element_f32 = source_vector[dimension_index];
|
|
169
|
+
nk_f32_t scaled_f32 = element_f32 * inverse_absmax_f32 * 127.0f;
|
|
170
|
+
nk_i8_t quantized_i8 = (nk_i8_t)(scaled_f32 + (element_f32 > 0.0f ? 0.5f : -0.5f));
|
|
171
|
+
|
|
172
|
+
nk_size_t depth_tile_index = dimension_index / 64;
|
|
173
|
+
nk_size_t column_in_tile = dimension_index % 64;
|
|
174
|
+
nk_size_t tile_flat_index = column_tile_index * depth_tile_count + depth_tile_index;
|
|
175
|
+
a_side_base[tile_flat_index * 1024 + row_in_tile * 64 + column_in_tile] = quantized_i8;
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
// Store inverse norm
|
|
179
|
+
inverse_norms[vector_index] = (norm_squared_f32 > 0.0f) ? nk_f32_rsqrt_haswell(norm_squared_f32) : 0.0f;
|
|
180
|
+
|
|
181
|
+
// Copy original vector with 64B-aligned stride
|
|
182
|
+
char *destination_original = originals_base + vector_index * original_stride_bytes;
|
|
183
|
+
nk_copy_bytes_(destination_original, (char const *)source_vector, depth * sizeof(nk_f32_t));
|
|
184
|
+
for (nk_size_t byte_index = depth * sizeof(nk_f32_t); byte_index < original_stride_bytes; byte_index++)
|
|
185
|
+
destination_original[byte_index] = 0;
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
// Transpose each A-side tile to B-side (both are 64B-aligned via header padding)
|
|
189
|
+
for (nk_size_t tile_flat_index = 0; tile_flat_index < column_tile_count * depth_tile_count; tile_flat_index++) {
|
|
190
|
+
nk_dots_i8_a16x64_sapphireamx_t const *a_tile =
|
|
191
|
+
(nk_dots_i8_a16x64_sapphireamx_t const *)(a_side_base + tile_flat_index * 1024);
|
|
192
|
+
nk_dots_i8_b64x16_sapphireamx_t *b_tile = (nk_dots_i8_b64x16_sapphireamx_t *)(b_side_base +
|
|
193
|
+
tile_flat_index * 1024);
|
|
194
|
+
nk_dots_pack_i8_transposed_sapphireamx_(a_tile, b_tile);
|
|
195
|
+
}
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
NK_PUBLIC void nk_maxsim_packed_f32_sapphireamx( //
|
|
199
|
+
void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
|
|
200
|
+
nk_size_t depth, nk_f64_t *result) {
|
|
201
|
+
|
|
202
|
+
nk_maxsim_sapphireamx_i8_header_t const *query_header = (nk_maxsim_sapphireamx_i8_header_t const *)query_packed;
|
|
203
|
+
nk_maxsim_sapphireamx_i8_header_t const *document_header =
|
|
204
|
+
(nk_maxsim_sapphireamx_i8_header_t const *)document_packed;
|
|
205
|
+
|
|
206
|
+
nk_size_t const depth_tile_count = query_header->depth_tile_count;
|
|
207
|
+
nk_size_t const query_tile_count = query_header->column_tile_count;
|
|
208
|
+
nk_size_t const document_tile_count = document_header->column_tile_count;
|
|
209
|
+
|
|
210
|
+
// Query loads from A-side (64B-aligned), documents from B-side
|
|
211
|
+
char const *query_a_side_base = (char const *)query_packed + query_header->a_side_offset;
|
|
212
|
+
char const *document_b_side_base = (char const *)document_packed + document_header->b_side_offset;
|
|
213
|
+
|
|
214
|
+
// Original vectors for refinement
|
|
215
|
+
char const *query_originals = (char const *)query_packed + query_header->originals_offset;
|
|
216
|
+
char const *document_originals = (char const *)document_packed + document_header->originals_offset;
|
|
217
|
+
nk_size_t const query_original_stride = query_header->original_stride_bytes;
|
|
218
|
+
nk_size_t const document_original_stride = document_header->original_stride_bytes;
|
|
219
|
+
|
|
220
|
+
nk_f32_t const *query_inverse_norms = (nk_f32_t const *)((char const *)query_packed + query_header->norms_offset);
|
|
221
|
+
nk_f32_t const *document_inverse_norms = (nk_f32_t const *)((char const *)document_packed +
|
|
222
|
+
document_header->norms_offset);
|
|
223
|
+
|
|
224
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
225
|
+
|
|
226
|
+
// Gather indices for column extraction from 16×16 tile:
|
|
227
|
+
// tile_result[row][col] at i32 offset row*16 + col
|
|
228
|
+
__m512i const row_stride_indices_i32x16 = _mm512_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192,
|
|
229
|
+
208, 224, 240);
|
|
230
|
+
|
|
231
|
+
nk_f64_t total_angular_distance_f64 = 0.0;
|
|
232
|
+
|
|
233
|
+
for (nk_size_t query_tile_index = 0; query_tile_index < query_tile_count; query_tile_index++) {
|
|
234
|
+
nk_size_t query_row_start = query_tile_index * 16;
|
|
235
|
+
nk_size_t valid_queries = (query_row_start + 16 <= query_count) ? 16 : (query_count - query_row_start);
|
|
236
|
+
|
|
237
|
+
__m512i running_maximum_i32x16 = _mm512_set1_epi32(NK_I32_MIN);
|
|
238
|
+
__m512i running_argmax_i32x16 = _mm512_setzero_si512();
|
|
239
|
+
|
|
240
|
+
NK_ALIGN64 nk_i32_t tile_results_i32[4][16][16];
|
|
241
|
+
nk_size_t document_tile_index = 0;
|
|
242
|
+
|
|
243
|
+
// Fast path: 4 document tiles at a time
|
|
244
|
+
for (; document_tile_index + 4 <= document_tile_count; document_tile_index += 4) {
|
|
245
|
+
_tile_zero(4);
|
|
246
|
+
_tile_zero(5);
|
|
247
|
+
_tile_zero(6);
|
|
248
|
+
_tile_zero(7);
|
|
249
|
+
|
|
250
|
+
for (nk_size_t depth_step_index = 0; depth_step_index < depth_tile_count; depth_step_index++) {
|
|
251
|
+
nk_size_t query_tile_flat_index = query_tile_index * depth_tile_count + depth_step_index;
|
|
252
|
+
|
|
253
|
+
_tile_loadd(0, (void const *)(query_a_side_base + query_tile_flat_index * 1024), 64);
|
|
254
|
+
|
|
255
|
+
nk_size_t document_tile_flat_0 = (document_tile_index + 0) * depth_tile_count + depth_step_index;
|
|
256
|
+
nk_size_t document_tile_flat_1 = (document_tile_index + 1) * depth_tile_count + depth_step_index;
|
|
257
|
+
nk_size_t document_tile_flat_2 = (document_tile_index + 2) * depth_tile_count + depth_step_index;
|
|
258
|
+
nk_size_t document_tile_flat_3 = (document_tile_index + 3) * depth_tile_count + depth_step_index;
|
|
259
|
+
|
|
260
|
+
_tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_0 * 1024), 64);
|
|
261
|
+
_tile_dpbssd(4, 0, 1);
|
|
262
|
+
_tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_1 * 1024), 64);
|
|
263
|
+
_tile_dpbssd(5, 0, 1);
|
|
264
|
+
_tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_2 * 1024), 64);
|
|
265
|
+
_tile_dpbssd(6, 0, 1);
|
|
266
|
+
_tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_3 * 1024), 64);
|
|
267
|
+
_tile_dpbssd(7, 0, 1);
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
_tile_stored(4, tile_results_i32[0], 64);
|
|
271
|
+
_tile_stored(5, tile_results_i32[1], 64);
|
|
272
|
+
_tile_stored(6, tile_results_i32[2], 64);
|
|
273
|
+
_tile_stored(7, tile_results_i32[3], 64);
|
|
274
|
+
|
|
275
|
+
// Column extraction from 4 tiles
|
|
276
|
+
for (nk_size_t tile_offset = 0; tile_offset < 4; tile_offset++) {
|
|
277
|
+
nk_size_t document_column_start = (document_tile_index + tile_offset) * 16;
|
|
278
|
+
for (nk_size_t column_within_tile = 0; column_within_tile < 16; column_within_tile++) {
|
|
279
|
+
__m512i gather_index_i32x16 = _mm512_add_epi32(row_stride_indices_i32x16,
|
|
280
|
+
_mm512_set1_epi32((int)column_within_tile));
|
|
281
|
+
__m512i column_dots_i32x16 = _mm512_i32gather_epi32(gather_index_i32x16,
|
|
282
|
+
tile_results_i32[tile_offset], 4);
|
|
283
|
+
__mmask16 is_better_bx16 = _mm512_cmpgt_epi32_mask(column_dots_i32x16, running_maximum_i32x16);
|
|
284
|
+
running_maximum_i32x16 = _mm512_mask_mov_epi32(running_maximum_i32x16, is_better_bx16,
|
|
285
|
+
column_dots_i32x16);
|
|
286
|
+
running_argmax_i32x16 = _mm512_mask_mov_epi32(
|
|
287
|
+
running_argmax_i32x16, is_better_bx16,
|
|
288
|
+
_mm512_set1_epi32((int)(document_column_start + column_within_tile)));
|
|
289
|
+
}
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
// Remainder: 1 document tile at a time
|
|
294
|
+
for (; document_tile_index < document_tile_count; document_tile_index++) {
|
|
295
|
+
nk_size_t document_column_start = document_tile_index * 16;
|
|
296
|
+
nk_size_t valid_documents = (document_column_start + 16 <= document_count)
|
|
297
|
+
? 16
|
|
298
|
+
: (document_count - document_column_start);
|
|
299
|
+
|
|
300
|
+
_tile_zero(4);
|
|
301
|
+
|
|
302
|
+
for (nk_size_t depth_step_index = 0; depth_step_index < depth_tile_count; depth_step_index++) {
|
|
303
|
+
nk_size_t query_tile_flat_index = query_tile_index * depth_tile_count + depth_step_index;
|
|
304
|
+
nk_size_t document_tile_flat_index = document_tile_index * depth_tile_count + depth_step_index;
|
|
305
|
+
|
|
306
|
+
_tile_loadd(0, (void const *)(query_a_side_base + query_tile_flat_index * 1024), 64);
|
|
307
|
+
_tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_index * 1024), 64);
|
|
308
|
+
_tile_dpbssd(4, 0, 1);
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
_tile_stored(4, tile_results_i32[0], 64);
|
|
312
|
+
|
|
313
|
+
for (nk_size_t column_within_tile = 0; column_within_tile < valid_documents; column_within_tile++) {
|
|
314
|
+
__m512i gather_index_i32x16 = _mm512_add_epi32(row_stride_indices_i32x16,
|
|
315
|
+
_mm512_set1_epi32((int)column_within_tile));
|
|
316
|
+
__m512i column_dots_i32x16 = _mm512_i32gather_epi32(gather_index_i32x16, tile_results_i32[0], 4);
|
|
317
|
+
__mmask16 is_better_bx16 = _mm512_cmpgt_epi32_mask(column_dots_i32x16, running_maximum_i32x16);
|
|
318
|
+
running_maximum_i32x16 = _mm512_mask_mov_epi32(running_maximum_i32x16, is_better_bx16,
|
|
319
|
+
column_dots_i32x16);
|
|
320
|
+
running_argmax_i32x16 = _mm512_mask_mov_epi32(
|
|
321
|
+
running_argmax_i32x16, is_better_bx16,
|
|
322
|
+
_mm512_set1_epi32((int)(document_column_start + column_within_tile)));
|
|
323
|
+
}
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
// Refinement: for each valid query, compute full-precision dot with best document
|
|
327
|
+
NK_ALIGN64 nk_i32_t best_document_indices_i32[16];
|
|
328
|
+
_mm512_store_si512(best_document_indices_i32, running_argmax_i32x16);
|
|
329
|
+
|
|
330
|
+
for (nk_size_t query_in_tile = 0; query_in_tile < valid_queries; query_in_tile++) {
|
|
331
|
+
nk_size_t query_index = query_row_start + query_in_tile;
|
|
332
|
+
nk_u32_t best_document_index = (nk_u32_t)best_document_indices_i32[query_in_tile];
|
|
333
|
+
|
|
334
|
+
nk_f64_t dot_result_f64;
|
|
335
|
+
nk_dot_f32((nk_f32_t const *)(query_originals + query_index * query_original_stride),
|
|
336
|
+
(nk_f32_t const *)(document_originals + best_document_index * document_original_stride), depth,
|
|
337
|
+
&dot_result_f64);
|
|
338
|
+
|
|
339
|
+
nk_f64_t cosine_f64 = dot_result_f64 * (nk_f64_t)query_inverse_norms[query_index] *
|
|
340
|
+
(nk_f64_t)document_inverse_norms[best_document_index];
|
|
341
|
+
nk_f64_t angular_distance_f64 = 1.0 - cosine_f64;
|
|
342
|
+
if (angular_distance_f64 < 0.0) angular_distance_f64 = 0.0;
|
|
343
|
+
total_angular_distance_f64 += angular_distance_f64;
|
|
344
|
+
}
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
*result = total_angular_distance_f64;
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
#pragma endregion
|
|
351
|
+
|
|
352
|
+
#pragma region Half Precision Floats
|
|
353
|
+
|
|
354
|
+
NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_sapphireamx(nk_size_t vector_count, nk_size_t depth) {
|
|
355
|
+
nk_size_t column_tile_count = nk_size_divide_round_up_(vector_count, 16);
|
|
356
|
+
nk_size_t depth_tile_count = nk_size_divide_round_up_(depth, 64);
|
|
357
|
+
nk_size_t a_side_bytes = column_tile_count * depth_tile_count * 1024;
|
|
358
|
+
nk_size_t b_side_bytes = column_tile_count * depth_tile_count * 1024;
|
|
359
|
+
nk_size_t original_stride = nk_size_round_up_to_multiple_(depth * sizeof(nk_f16_t), 64);
|
|
360
|
+
nk_size_t originals_bytes = vector_count * original_stride;
|
|
361
|
+
nk_size_t norms_bytes = vector_count * sizeof(nk_f32_t);
|
|
362
|
+
return 64 + 63 + a_side_bytes + b_side_bytes + originals_bytes + norms_bytes;
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
NK_PUBLIC void nk_maxsim_pack_f16_sapphireamx( //
|
|
366
|
+
nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
|
|
367
|
+
|
|
368
|
+
nk_size_t column_tile_count = nk_size_divide_round_up_(vector_count, 16);
|
|
369
|
+
nk_size_t depth_tile_count = nk_size_divide_round_up_(depth, 64);
|
|
370
|
+
nk_size_t original_stride_bytes = nk_size_round_up_to_multiple_(depth * sizeof(nk_f16_t), 64);
|
|
371
|
+
nk_size_t a_side_total_bytes = column_tile_count * depth_tile_count * 1024;
|
|
372
|
+
nk_size_t b_side_total_bytes = column_tile_count * depth_tile_count * 1024;
|
|
373
|
+
|
|
374
|
+
// Set up header — compute 64B-aligned A-side offset
|
|
375
|
+
nk_maxsim_sapphireamx_i8_header_t *header = (nk_maxsim_sapphireamx_i8_header_t *)packed;
|
|
376
|
+
nk_u32_t a_side_offset = (nk_u32_t)(nk_size_round_up_to_multiple_((nk_size_t)((char *)packed + 64), 64) -
|
|
377
|
+
(nk_size_t)(char *)packed);
|
|
378
|
+
header->column_tile_count = (nk_u32_t)column_tile_count;
|
|
379
|
+
header->depth_tile_count = (nk_u32_t)depth_tile_count;
|
|
380
|
+
header->columns = (nk_u32_t)vector_count;
|
|
381
|
+
header->depth = (nk_u32_t)depth;
|
|
382
|
+
header->a_side_offset = a_side_offset;
|
|
383
|
+
header->b_side_offset = (nk_u32_t)(a_side_offset + a_side_total_bytes);
|
|
384
|
+
header->originals_offset = (nk_u32_t)(a_side_offset + a_side_total_bytes + b_side_total_bytes);
|
|
385
|
+
header->original_stride_bytes = (nk_u32_t)original_stride_bytes;
|
|
386
|
+
header->norms_offset = (nk_u32_t)(a_side_offset + a_side_total_bytes + b_side_total_bytes +
|
|
387
|
+
vector_count * original_stride_bytes);
|
|
388
|
+
for (nk_size_t reserved_index = 0; reserved_index < 7; reserved_index++) header->reserved[reserved_index] = 0;
|
|
389
|
+
|
|
390
|
+
// Pointers to data regions (A-side is guaranteed 64B-aligned)
|
|
391
|
+
nk_i8_t *a_side_base = (nk_i8_t *)((char *)packed + a_side_offset);
|
|
392
|
+
char *b_side_base = (char *)packed + header->b_side_offset;
|
|
393
|
+
char *originals_base = (char *)packed + header->originals_offset;
|
|
394
|
+
nk_f32_t *inverse_norms = (nk_f32_t *)((char *)packed + header->norms_offset);
|
|
395
|
+
|
|
396
|
+
// Zero all A-side tiles (aligned stores — A-side offset is 64B-aligned)
|
|
397
|
+
{
|
|
398
|
+
__m512i zero_i32x16 = _mm512_setzero_si512();
|
|
399
|
+
for (nk_size_t byte_offset = 0; byte_offset < a_side_total_bytes; byte_offset += 64)
|
|
400
|
+
_mm512_store_si512((void *)(a_side_base + byte_offset), zero_i32x16);
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
// Quantize vectors and scatter into A-side tiles, copy originals, compute inverse norms
|
|
404
|
+
nk_size_t const stride_elements = stride / sizeof(nk_f16_t);
|
|
405
|
+
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
406
|
+
nk_f16_t const *source_vector = vectors + vector_index * stride_elements;
|
|
407
|
+
|
|
408
|
+
// Pass 1: find absmax and norm_squared (convert f16 → f32)
|
|
409
|
+
nk_f32_t absmax_f32 = 0.0f;
|
|
410
|
+
nk_f32_t norm_squared_f32 = 0.0f;
|
|
411
|
+
for (nk_size_t dimension_index = 0; dimension_index < depth; dimension_index++) {
|
|
412
|
+
nk_f32_t element_f32;
|
|
413
|
+
nk_f16_to_f32_haswell(&source_vector[dimension_index], &element_f32);
|
|
414
|
+
nk_f32_t abs_element_f32 = nk_f32_abs_(element_f32);
|
|
415
|
+
if (abs_element_f32 > absmax_f32) absmax_f32 = abs_element_f32;
|
|
416
|
+
norm_squared_f32 += element_f32 * element_f32;
|
|
417
|
+
}
|
|
418
|
+
|
|
419
|
+
// Pass 2: quantize to i8 [-127,127] and scatter into A-side tile positions
|
|
420
|
+
nk_f32_t inverse_absmax_f32 = (absmax_f32 > 0.0f) ? (1.0f / absmax_f32) : 0.0f;
|
|
421
|
+
nk_size_t column_tile_index = vector_index / 16;
|
|
422
|
+
nk_size_t row_in_tile = vector_index % 16;
|
|
423
|
+
|
|
424
|
+
for (nk_size_t dimension_index = 0; dimension_index < depth; dimension_index++) {
|
|
425
|
+
nk_f32_t element_f32;
|
|
426
|
+
nk_f16_to_f32_haswell(&source_vector[dimension_index], &element_f32);
|
|
427
|
+
nk_f32_t scaled_f32 = element_f32 * inverse_absmax_f32 * 127.0f;
|
|
428
|
+
nk_i8_t quantized_i8 = (nk_i8_t)(scaled_f32 + (element_f32 > 0.0f ? 0.5f : -0.5f));
|
|
429
|
+
|
|
430
|
+
nk_size_t depth_tile_index = dimension_index / 64;
|
|
431
|
+
nk_size_t column_in_tile = dimension_index % 64;
|
|
432
|
+
nk_size_t tile_flat_index = column_tile_index * depth_tile_count + depth_tile_index;
|
|
433
|
+
a_side_base[tile_flat_index * 1024 + row_in_tile * 64 + column_in_tile] = quantized_i8;
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
// Store inverse norm
|
|
437
|
+
inverse_norms[vector_index] = (norm_squared_f32 > 0.0f) ? nk_f32_rsqrt_haswell(norm_squared_f32) : 0.0f;
|
|
438
|
+
|
|
439
|
+
// Copy original f16 vector with 64B-aligned stride
|
|
440
|
+
char *destination_original = originals_base + vector_index * original_stride_bytes;
|
|
441
|
+
nk_copy_bytes_(destination_original, (char const *)source_vector, depth * sizeof(nk_f16_t));
|
|
442
|
+
for (nk_size_t byte_index = depth * sizeof(nk_f16_t); byte_index < original_stride_bytes; byte_index++)
|
|
443
|
+
destination_original[byte_index] = 0;
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
// Transpose each A-side tile to B-side (both are 64B-aligned via header padding)
|
|
447
|
+
for (nk_size_t tile_flat_index = 0; tile_flat_index < column_tile_count * depth_tile_count; tile_flat_index++) {
|
|
448
|
+
nk_dots_i8_a16x64_sapphireamx_t const *a_tile =
|
|
449
|
+
(nk_dots_i8_a16x64_sapphireamx_t const *)(a_side_base + tile_flat_index * 1024);
|
|
450
|
+
nk_dots_i8_b64x16_sapphireamx_t *b_tile = (nk_dots_i8_b64x16_sapphireamx_t *)(b_side_base +
|
|
451
|
+
tile_flat_index * 1024);
|
|
452
|
+
nk_dots_pack_i8_transposed_sapphireamx_(a_tile, b_tile);
|
|
453
|
+
}
|
|
454
|
+
}
|
|
455
|
+
|
|
456
|
+
NK_PUBLIC void nk_maxsim_packed_f16_sapphireamx( //
|
|
457
|
+
void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
|
|
458
|
+
nk_size_t depth, nk_f32_t *result) {
|
|
459
|
+
|
|
460
|
+
nk_maxsim_sapphireamx_i8_header_t const *query_header = (nk_maxsim_sapphireamx_i8_header_t const *)query_packed;
|
|
461
|
+
nk_maxsim_sapphireamx_i8_header_t const *document_header =
|
|
462
|
+
(nk_maxsim_sapphireamx_i8_header_t const *)document_packed;
|
|
463
|
+
|
|
464
|
+
nk_size_t const depth_tile_count = query_header->depth_tile_count;
|
|
465
|
+
nk_size_t const query_tile_count = query_header->column_tile_count;
|
|
466
|
+
nk_size_t const document_tile_count = document_header->column_tile_count;
|
|
467
|
+
|
|
468
|
+
// Query loads from A-side (64B-aligned), documents from B-side
|
|
469
|
+
char const *query_a_side_base = (char const *)query_packed + query_header->a_side_offset;
|
|
470
|
+
char const *document_b_side_base = (char const *)document_packed + document_header->b_side_offset;
|
|
471
|
+
|
|
472
|
+
// Original vectors for refinement
|
|
473
|
+
char const *query_originals = (char const *)query_packed + query_header->originals_offset;
|
|
474
|
+
char const *document_originals = (char const *)document_packed + document_header->originals_offset;
|
|
475
|
+
nk_size_t const query_original_stride = query_header->original_stride_bytes;
|
|
476
|
+
nk_size_t const document_original_stride = document_header->original_stride_bytes;
|
|
477
|
+
|
|
478
|
+
nk_f32_t const *query_inverse_norms = (nk_f32_t const *)((char const *)query_packed + query_header->norms_offset);
|
|
479
|
+
nk_f32_t const *document_inverse_norms = (nk_f32_t const *)((char const *)document_packed +
|
|
480
|
+
document_header->norms_offset);
|
|
481
|
+
|
|
482
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
483
|
+
|
|
484
|
+
__m512i const row_stride_indices_i32x16 = _mm512_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192,
|
|
485
|
+
208, 224, 240);
|
|
486
|
+
|
|
487
|
+
nk_f64_t total_angular_distance_f64 = 0.0;
|
|
488
|
+
|
|
489
|
+
for (nk_size_t query_tile_index = 0; query_tile_index < query_tile_count; query_tile_index++) {
|
|
490
|
+
nk_size_t query_row_start = query_tile_index * 16;
|
|
491
|
+
nk_size_t valid_queries = (query_row_start + 16 <= query_count) ? 16 : (query_count - query_row_start);
|
|
492
|
+
|
|
493
|
+
__m512i running_maximum_i32x16 = _mm512_set1_epi32(NK_I32_MIN);
|
|
494
|
+
__m512i running_argmax_i32x16 = _mm512_setzero_si512();
|
|
495
|
+
|
|
496
|
+
NK_ALIGN64 nk_i32_t tile_results_i32[4][16][16];
|
|
497
|
+
nk_size_t document_tile_index = 0;
|
|
498
|
+
|
|
499
|
+
// Fast path: 4 document tiles at a time
|
|
500
|
+
for (; document_tile_index + 4 <= document_tile_count; document_tile_index += 4) {
|
|
501
|
+
_tile_zero(4);
|
|
502
|
+
_tile_zero(5);
|
|
503
|
+
_tile_zero(6);
|
|
504
|
+
_tile_zero(7);
|
|
505
|
+
|
|
506
|
+
for (nk_size_t depth_step_index = 0; depth_step_index < depth_tile_count; depth_step_index++) {
|
|
507
|
+
nk_size_t query_tile_flat_index = query_tile_index * depth_tile_count + depth_step_index;
|
|
508
|
+
|
|
509
|
+
_tile_loadd(0, (void const *)(query_a_side_base + query_tile_flat_index * 1024), 64);
|
|
510
|
+
|
|
511
|
+
nk_size_t document_tile_flat_0 = (document_tile_index + 0) * depth_tile_count + depth_step_index;
|
|
512
|
+
nk_size_t document_tile_flat_1 = (document_tile_index + 1) * depth_tile_count + depth_step_index;
|
|
513
|
+
nk_size_t document_tile_flat_2 = (document_tile_index + 2) * depth_tile_count + depth_step_index;
|
|
514
|
+
nk_size_t document_tile_flat_3 = (document_tile_index + 3) * depth_tile_count + depth_step_index;
|
|
515
|
+
|
|
516
|
+
_tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_0 * 1024), 64);
|
|
517
|
+
_tile_dpbssd(4, 0, 1);
|
|
518
|
+
_tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_1 * 1024), 64);
|
|
519
|
+
_tile_dpbssd(5, 0, 1);
|
|
520
|
+
_tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_2 * 1024), 64);
|
|
521
|
+
_tile_dpbssd(6, 0, 1);
|
|
522
|
+
_tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_3 * 1024), 64);
|
|
523
|
+
_tile_dpbssd(7, 0, 1);
|
|
524
|
+
}
|
|
525
|
+
|
|
526
|
+
_tile_stored(4, tile_results_i32[0], 64);
|
|
527
|
+
_tile_stored(5, tile_results_i32[1], 64);
|
|
528
|
+
_tile_stored(6, tile_results_i32[2], 64);
|
|
529
|
+
_tile_stored(7, tile_results_i32[3], 64);
|
|
530
|
+
|
|
531
|
+
for (nk_size_t tile_offset = 0; tile_offset < 4; tile_offset++) {
|
|
532
|
+
nk_size_t document_column_start = (document_tile_index + tile_offset) * 16;
|
|
533
|
+
for (nk_size_t column_within_tile = 0; column_within_tile < 16; column_within_tile++) {
|
|
534
|
+
__m512i gather_index_i32x16 = _mm512_add_epi32(row_stride_indices_i32x16,
|
|
535
|
+
_mm512_set1_epi32((int)column_within_tile));
|
|
536
|
+
__m512i column_dots_i32x16 = _mm512_i32gather_epi32(gather_index_i32x16,
|
|
537
|
+
tile_results_i32[tile_offset], 4);
|
|
538
|
+
__mmask16 is_better_bx16 = _mm512_cmpgt_epi32_mask(column_dots_i32x16, running_maximum_i32x16);
|
|
539
|
+
running_maximum_i32x16 = _mm512_mask_mov_epi32(running_maximum_i32x16, is_better_bx16,
|
|
540
|
+
column_dots_i32x16);
|
|
541
|
+
running_argmax_i32x16 = _mm512_mask_mov_epi32(
|
|
542
|
+
running_argmax_i32x16, is_better_bx16,
|
|
543
|
+
_mm512_set1_epi32((int)(document_column_start + column_within_tile)));
|
|
544
|
+
}
|
|
545
|
+
}
|
|
546
|
+
}
|
|
547
|
+
|
|
548
|
+
// Remainder: 1 document tile at a time
|
|
549
|
+
for (; document_tile_index < document_tile_count; document_tile_index++) {
|
|
550
|
+
nk_size_t document_column_start = document_tile_index * 16;
|
|
551
|
+
nk_size_t valid_documents = (document_column_start + 16 <= document_count)
|
|
552
|
+
? 16
|
|
553
|
+
: (document_count - document_column_start);
|
|
554
|
+
|
|
555
|
+
_tile_zero(4);
|
|
556
|
+
|
|
557
|
+
for (nk_size_t depth_step_index = 0; depth_step_index < depth_tile_count; depth_step_index++) {
|
|
558
|
+
nk_size_t query_tile_flat_index = query_tile_index * depth_tile_count + depth_step_index;
|
|
559
|
+
nk_size_t document_tile_flat_index = document_tile_index * depth_tile_count + depth_step_index;
|
|
560
|
+
|
|
561
|
+
_tile_loadd(0, (void const *)(query_a_side_base + query_tile_flat_index * 1024), 64);
|
|
562
|
+
_tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_index * 1024), 64);
|
|
563
|
+
_tile_dpbssd(4, 0, 1);
|
|
564
|
+
}
|
|
565
|
+
|
|
566
|
+
_tile_stored(4, tile_results_i32[0], 64);
|
|
567
|
+
|
|
568
|
+
for (nk_size_t column_within_tile = 0; column_within_tile < valid_documents; column_within_tile++) {
|
|
569
|
+
__m512i gather_index_i32x16 = _mm512_add_epi32(row_stride_indices_i32x16,
|
|
570
|
+
_mm512_set1_epi32((int)column_within_tile));
|
|
571
|
+
__m512i column_dots_i32x16 = _mm512_i32gather_epi32(gather_index_i32x16, tile_results_i32[0], 4);
|
|
572
|
+
__mmask16 is_better_bx16 = _mm512_cmpgt_epi32_mask(column_dots_i32x16, running_maximum_i32x16);
|
|
573
|
+
running_maximum_i32x16 = _mm512_mask_mov_epi32(running_maximum_i32x16, is_better_bx16,
|
|
574
|
+
column_dots_i32x16);
|
|
575
|
+
running_argmax_i32x16 = _mm512_mask_mov_epi32(
|
|
576
|
+
running_argmax_i32x16, is_better_bx16,
|
|
577
|
+
_mm512_set1_epi32((int)(document_column_start + column_within_tile)));
|
|
578
|
+
}
|
|
579
|
+
}
|
|
580
|
+
|
|
581
|
+
// Refinement: for each valid query, compute full-precision dot with best document
|
|
582
|
+
NK_ALIGN64 nk_i32_t best_document_indices_i32[16];
|
|
583
|
+
_mm512_store_si512(best_document_indices_i32, running_argmax_i32x16);
|
|
584
|
+
|
|
585
|
+
for (nk_size_t query_in_tile = 0; query_in_tile < valid_queries; query_in_tile++) {
|
|
586
|
+
nk_size_t query_index = query_row_start + query_in_tile;
|
|
587
|
+
nk_u32_t best_document_index = (nk_u32_t)best_document_indices_i32[query_in_tile];
|
|
588
|
+
|
|
589
|
+
nk_f32_t dot_result_f32;
|
|
590
|
+
nk_dot_f16((nk_f16_t const *)(query_originals + query_index * query_original_stride),
|
|
591
|
+
(nk_f16_t const *)(document_originals + best_document_index * document_original_stride), depth,
|
|
592
|
+
&dot_result_f32);
|
|
593
|
+
|
|
594
|
+
nk_f32_t cosine_f32 = dot_result_f32 * query_inverse_norms[query_index] *
|
|
595
|
+
document_inverse_norms[best_document_index];
|
|
596
|
+
nk_f32_t angular_distance_f32 = 1.0f - cosine_f32;
|
|
597
|
+
if (angular_distance_f32 < 0.0f) angular_distance_f32 = 0.0f;
|
|
598
|
+
total_angular_distance_f64 += (nk_f64_t)angular_distance_f32;
|
|
599
|
+
}
|
|
600
|
+
}
|
|
601
|
+
|
|
602
|
+
*result = (nk_f32_t)total_angular_distance_f64;
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
#pragma endregion
|
|
606
|
+
|
|
607
|
+
#pragma region Brain Floats (Fused AMX)
|
|
608
|
+
|
|
609
|
+
/**
|
|
610
|
+
* BF16 packed buffer header for AMX fused MaxSim (64 bytes).
|
|
611
|
+
* Stores both A-side (row-major) and B-side (pair-interleaved) tile formats
|
|
612
|
+
* plus per-vector inverse norms for angular distance finalization.
|
|
613
|
+
*/
|
|
614
|
+
typedef struct {
|
|
615
|
+
nk_u32_t column_tile_count; ///< ceil(n / 16) — number of row-tile groups
|
|
616
|
+
nk_u32_t depth_tile_count; ///< ceil(depth / 32) — BF16 TDPBF16PS depth granularity
|
|
617
|
+
nk_u32_t columns; ///< actual vector count
|
|
618
|
+
nk_u32_t depth; ///< actual depth (dimensions per vector)
|
|
619
|
+
nk_u32_t a_side_offset; ///< byte offset from buffer start to 64B-aligned A-side tiles
|
|
620
|
+
nk_u32_t b_side_offset; ///< byte offset from buffer start to B-side tiles
|
|
621
|
+
nk_u32_t norms_offset; ///< byte offset from buffer start to inverse norms (f32)
|
|
622
|
+
nk_u32_t reserved[9]; ///< padding to 64 bytes
|
|
623
|
+
} nk_maxsim_sapphireamx_bf16_header_t;
|
|
624
|
+
|
|
625
|
+
NK_STATIC_ASSERT(sizeof(nk_maxsim_sapphireamx_bf16_header_t) == 64, nk_maxsim_sapphireamx_bf16_header_must_be_64_bytes);
|
|
626
|
+
|
|
627
|
+
NK_PUBLIC nk_size_t nk_maxsim_packed_size_bf16_sapphireamx(nk_size_t vector_count, nk_size_t depth) {
|
|
628
|
+
nk_size_t const tile_bytes = 1024; // 16 × 32 × 2B = 1KB per tile
|
|
629
|
+
nk_size_t column_tile_count = nk_size_divide_round_up_(vector_count, 16);
|
|
630
|
+
nk_size_t depth_tile_count = nk_size_divide_round_up_(depth, 32);
|
|
631
|
+
nk_size_t a_side_bytes = column_tile_count * depth_tile_count * tile_bytes;
|
|
632
|
+
nk_size_t b_side_bytes = column_tile_count * depth_tile_count * tile_bytes;
|
|
633
|
+
nk_size_t norms_bytes = vector_count * sizeof(nk_f32_t);
|
|
634
|
+
return sizeof(nk_maxsim_sapphireamx_bf16_header_t) + 63 + a_side_bytes + b_side_bytes + norms_bytes;
|
|
635
|
+
}
|
|
636
|
+
|
|
637
|
+
NK_PUBLIC void nk_maxsim_pack_bf16_sapphireamx( //
|
|
638
|
+
nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
|
|
639
|
+
|
|
640
|
+
nk_size_t const tile_bytes = 1024;
|
|
641
|
+
nk_size_t const stride_elements = stride / sizeof(nk_bf16_t);
|
|
642
|
+
nk_size_t column_tile_count = nk_size_divide_round_up_(vector_count, 16);
|
|
643
|
+
nk_size_t depth_tile_count = nk_size_divide_round_up_(depth, 32);
|
|
644
|
+
|
|
645
|
+
// Set up header — compute 64B-aligned A-side offset
|
|
646
|
+
nk_maxsim_sapphireamx_bf16_header_t *header = (nk_maxsim_sapphireamx_bf16_header_t *)packed;
|
|
647
|
+
nk_u32_t a_side_offset = (nk_u32_t)(nk_size_round_up_to_multiple_(
|
|
648
|
+
(nk_size_t)((char *)packed + sizeof(nk_maxsim_sapphireamx_bf16_header_t)),
|
|
649
|
+
64) -
|
|
650
|
+
(nk_size_t)(char *)packed);
|
|
651
|
+
header->column_tile_count = (nk_u32_t)column_tile_count;
|
|
652
|
+
header->depth_tile_count = (nk_u32_t)depth_tile_count;
|
|
653
|
+
header->columns = (nk_u32_t)vector_count;
|
|
654
|
+
header->depth = (nk_u32_t)depth;
|
|
655
|
+
header->a_side_offset = a_side_offset;
|
|
656
|
+
|
|
657
|
+
nk_size_t a_side_total_bytes = column_tile_count * depth_tile_count * tile_bytes;
|
|
658
|
+
nk_size_t b_side_total_bytes = column_tile_count * depth_tile_count * tile_bytes;
|
|
659
|
+
header->b_side_offset = (nk_u32_t)(a_side_offset + a_side_total_bytes);
|
|
660
|
+
header->norms_offset = (nk_u32_t)(a_side_offset + a_side_total_bytes + b_side_total_bytes);
|
|
661
|
+
for (nk_size_t reserved_index = 0; reserved_index < 9; reserved_index++) header->reserved[reserved_index] = 0;
|
|
662
|
+
|
|
663
|
+
// Pointers to data regions (A-side is guaranteed 64B-aligned)
|
|
664
|
+
char *a_side_base = (char *)packed + a_side_offset;
|
|
665
|
+
char *b_side_base = (char *)packed + header->b_side_offset;
|
|
666
|
+
nk_f32_t *inverse_norms = (nk_f32_t *)((char *)packed + header->norms_offset);
|
|
667
|
+
|
|
668
|
+
// Pack tiles: for each column tile × depth tile, store both A-side and B-side
|
|
669
|
+
for (nk_size_t column_tile_index = 0; column_tile_index < column_tile_count; column_tile_index++) {
|
|
670
|
+
nk_size_t row_start = column_tile_index * 16;
|
|
671
|
+
nk_size_t valid_rows = (row_start + 16 <= vector_count) ? 16 : (vector_count - row_start);
|
|
672
|
+
|
|
673
|
+
for (nk_size_t depth_tile_index = 0; depth_tile_index < depth_tile_count; depth_tile_index++) {
|
|
674
|
+
nk_size_t depth_start = depth_tile_index * 32;
|
|
675
|
+
nk_size_t valid_columns = (depth_start + 32 <= depth) ? 32 : (depth - depth_start);
|
|
676
|
+
|
|
677
|
+
nk_size_t tile_flat_index = column_tile_index * depth_tile_count + depth_tile_index;
|
|
678
|
+
|
|
679
|
+
// Load source vectors into A-side tile (row-major, zero-padded)
|
|
680
|
+
nk_dots_bf16_a16x32_sapphireamx_t a_tile;
|
|
681
|
+
nk_dots_bf16_load_a_sapphireamx_(&a_tile, vectors + row_start * stride_elements + depth_start,
|
|
682
|
+
stride_elements, valid_rows, valid_columns);
|
|
683
|
+
|
|
684
|
+
// Store A-side tile to packed buffer
|
|
685
|
+
nk_copy_bytes_(a_side_base + tile_flat_index * tile_bytes, &a_tile, tile_bytes);
|
|
686
|
+
|
|
687
|
+
// Transpose to B-side tile (pair-interleaved) and store
|
|
688
|
+
nk_dots_bf16_b32x16_sapphireamx_t b_tile;
|
|
689
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&a_tile, &b_tile);
|
|
690
|
+
nk_copy_bytes_(b_side_base + tile_flat_index * tile_bytes, &b_tile, tile_bytes);
|
|
691
|
+
}
|
|
692
|
+
}
|
|
693
|
+
|
|
694
|
+
// Compute inverse norms for each vector
|
|
695
|
+
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
696
|
+
nk_bf16_t const *source_vector = vectors + vector_index * stride_elements;
|
|
697
|
+
nk_f32_t norm_squared_f32 = 0.0f;
|
|
698
|
+
for (nk_size_t dimension_index = 0; dimension_index < depth; dimension_index++) {
|
|
699
|
+
nk_f32_t element_f32;
|
|
700
|
+
nk_bf16_to_f32_serial(&source_vector[dimension_index], &element_f32);
|
|
701
|
+
norm_squared_f32 += element_f32 * element_f32;
|
|
702
|
+
}
|
|
703
|
+
inverse_norms[vector_index] = (norm_squared_f32 > 0.0f) ? nk_f32_rsqrt_haswell(norm_squared_f32) : 0.0f;
|
|
704
|
+
}
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
/**
|
|
708
|
+
* BF16 fused AMX compute: TDPBF16PS tile multiply + column extraction + angular finalization.
|
|
709
|
+
*
|
|
710
|
+
* For each group of 16 queries, processes all document tiles via AMX TDPBF16PS.
|
|
711
|
+
* Fast path uses 4 accumulators (TMM4-7) for 4-way document tile pipelining.
|
|
712
|
+
* Column extraction from the 16×16 f32 accumulator tiles uses AVX-512 gather
|
|
713
|
+
* to build per-document dot product vectors, then element-wise max tracks the
|
|
714
|
+
* running best document per query.
|
|
715
|
+
*/
|
|
716
|
+
NK_PUBLIC void nk_maxsim_packed_bf16_sapphireamx( //
|
|
717
|
+
void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
|
|
718
|
+
nk_size_t depth, nk_f32_t *result) {
|
|
719
|
+
|
|
720
|
+
nk_unused_(depth); // tile counts from header encode depth
|
|
721
|
+
|
|
722
|
+
nk_maxsim_sapphireamx_bf16_header_t const *query_header = (nk_maxsim_sapphireamx_bf16_header_t const *)query_packed;
|
|
723
|
+
nk_maxsim_sapphireamx_bf16_header_t const *document_header =
|
|
724
|
+
(nk_maxsim_sapphireamx_bf16_header_t const *)document_packed;
|
|
725
|
+
|
|
726
|
+
nk_size_t const depth_tile_count = query_header->depth_tile_count;
|
|
727
|
+
nk_size_t const query_column_tile_count = query_header->column_tile_count;
|
|
728
|
+
nk_size_t const document_column_tile_count = document_header->column_tile_count;
|
|
729
|
+
|
|
730
|
+
// Query loads from A-side tiles (64B-aligned), documents from B-side tiles
|
|
731
|
+
char const *query_a_side_base = (char const *)query_packed + query_header->a_side_offset;
|
|
732
|
+
char const *document_b_side_base = (char const *)document_packed + document_header->b_side_offset;
|
|
733
|
+
|
|
734
|
+
nk_f32_t const *query_inverse_norms = (nk_f32_t const *)((char const *)query_packed + query_header->norms_offset);
|
|
735
|
+
nk_f32_t const *document_inverse_norms = (nk_f32_t const *)((char const *)document_packed +
|
|
736
|
+
document_header->norms_offset);
|
|
737
|
+
|
|
738
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
739
|
+
|
|
740
|
+
// Gather indices for column extraction from 16×16 f32 tile:
|
|
741
|
+
// tile_result[row][col] is at f32 offset row*16 + col
|
|
742
|
+
__m512i const row_stride_indices_i32x16 = _mm512_setr_epi32(0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192,
|
|
743
|
+
208, 224, 240);
|
|
744
|
+
|
|
745
|
+
nk_f64_t total_angular_distance_f64 = 0.0;
|
|
746
|
+
|
|
747
|
+
for (nk_size_t query_tile_index = 0; query_tile_index < query_column_tile_count; query_tile_index++) {
|
|
748
|
+
nk_size_t query_row_start = query_tile_index * 16;
|
|
749
|
+
nk_size_t valid_queries = (query_row_start + 16 <= query_count) ? 16 : (query_count - query_row_start);
|
|
750
|
+
__mmask16 valid_query_mask_bx16 = (valid_queries >= 16) ? (__mmask16)0xFFFF
|
|
751
|
+
: (__mmask16)((1u << valid_queries) - 1);
|
|
752
|
+
|
|
753
|
+
__m512 running_maximum_f32x16 = _mm512_set1_ps(NK_F32_MIN);
|
|
754
|
+
__m512i running_argmax_i32x16 = _mm512_setzero_si512();
|
|
755
|
+
|
|
756
|
+
NK_ALIGN64 nk_f32_t tile_results_f32[4][16][16];
|
|
757
|
+
nk_size_t document_tile_index = 0;
|
|
758
|
+
|
|
759
|
+
// Fast path: 4 document tiles at a time using TMM4-7
|
|
760
|
+
for (; document_tile_index + 4 <= document_column_tile_count; document_tile_index += 4) {
|
|
761
|
+
_tile_zero(4);
|
|
762
|
+
_tile_zero(5);
|
|
763
|
+
_tile_zero(6);
|
|
764
|
+
_tile_zero(7);
|
|
765
|
+
|
|
766
|
+
for (nk_size_t depth_step_index = 0; depth_step_index < depth_tile_count; depth_step_index++) {
|
|
767
|
+
nk_size_t query_tile_flat_index = query_tile_index * depth_tile_count + depth_step_index;
|
|
768
|
+
|
|
769
|
+
_tile_loadd(0, (void const *)(query_a_side_base + query_tile_flat_index * 1024), 64);
|
|
770
|
+
|
|
771
|
+
nk_size_t document_tile_flat_0 = (document_tile_index + 0) * depth_tile_count + depth_step_index;
|
|
772
|
+
nk_size_t document_tile_flat_1 = (document_tile_index + 1) * depth_tile_count + depth_step_index;
|
|
773
|
+
nk_size_t document_tile_flat_2 = (document_tile_index + 2) * depth_tile_count + depth_step_index;
|
|
774
|
+
nk_size_t document_tile_flat_3 = (document_tile_index + 3) * depth_tile_count + depth_step_index;
|
|
775
|
+
|
|
776
|
+
_tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_0 * 1024), 64);
|
|
777
|
+
_tile_dpbf16ps(4, 0, 1);
|
|
778
|
+
_tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_1 * 1024), 64);
|
|
779
|
+
_tile_dpbf16ps(5, 0, 1);
|
|
780
|
+
_tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_2 * 1024), 64);
|
|
781
|
+
_tile_dpbf16ps(6, 0, 1);
|
|
782
|
+
_tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_3 * 1024), 64);
|
|
783
|
+
_tile_dpbf16ps(7, 0, 1);
|
|
784
|
+
}
|
|
785
|
+
|
|
786
|
+
_tile_stored(4, tile_results_f32[0], 64);
|
|
787
|
+
_tile_stored(5, tile_results_f32[1], 64);
|
|
788
|
+
_tile_stored(6, tile_results_f32[2], 64);
|
|
789
|
+
_tile_stored(7, tile_results_f32[3], 64);
|
|
790
|
+
|
|
791
|
+
// Column extraction from 4 tiles
|
|
792
|
+
for (nk_size_t tile_offset = 0; tile_offset < 4; tile_offset++) {
|
|
793
|
+
nk_size_t document_column_start = (document_tile_index + tile_offset) * 16;
|
|
794
|
+
for (nk_size_t column_within_tile = 0; column_within_tile < 16; column_within_tile++) {
|
|
795
|
+
__m512i gather_index_i32x16 = _mm512_add_epi32(row_stride_indices_i32x16,
|
|
796
|
+
_mm512_set1_epi32((int)column_within_tile));
|
|
797
|
+
__m512 column_dots_f32x16 = _mm512_i32gather_ps(gather_index_i32x16,
|
|
798
|
+
(float const *)tile_results_f32[tile_offset], 4);
|
|
799
|
+
__mmask16 is_better_bx16 = _mm512_cmp_ps_mask(column_dots_f32x16, running_maximum_f32x16,
|
|
800
|
+
_CMP_GT_OQ);
|
|
801
|
+
running_maximum_f32x16 = _mm512_mask_mov_ps(running_maximum_f32x16, is_better_bx16,
|
|
802
|
+
column_dots_f32x16);
|
|
803
|
+
running_argmax_i32x16 = _mm512_mask_mov_epi32(
|
|
804
|
+
running_argmax_i32x16, is_better_bx16,
|
|
805
|
+
_mm512_set1_epi32((int)(document_column_start + column_within_tile)));
|
|
806
|
+
}
|
|
807
|
+
}
|
|
808
|
+
}
|
|
809
|
+
|
|
810
|
+
// Remainder: 1 document tile at a time using TMM4 only
|
|
811
|
+
for (; document_tile_index < document_column_tile_count; document_tile_index++) {
|
|
812
|
+
nk_size_t document_column_start = document_tile_index * 16;
|
|
813
|
+
nk_size_t valid_documents = (document_column_start + 16 <= document_count)
|
|
814
|
+
? 16
|
|
815
|
+
: (document_count - document_column_start);
|
|
816
|
+
|
|
817
|
+
_tile_zero(4);
|
|
818
|
+
|
|
819
|
+
for (nk_size_t depth_step_index = 0; depth_step_index < depth_tile_count; depth_step_index++) {
|
|
820
|
+
nk_size_t query_tile_flat_index = query_tile_index * depth_tile_count + depth_step_index;
|
|
821
|
+
nk_size_t document_tile_flat_index = document_tile_index * depth_tile_count + depth_step_index;
|
|
822
|
+
|
|
823
|
+
_tile_loadd(0, (void const *)(query_a_side_base + query_tile_flat_index * 1024), 64);
|
|
824
|
+
_tile_loadd(1, (void const *)(document_b_side_base + document_tile_flat_index * 1024), 64);
|
|
825
|
+
_tile_dpbf16ps(4, 0, 1);
|
|
826
|
+
}
|
|
827
|
+
|
|
828
|
+
_tile_stored(4, tile_results_f32[0], 64);
|
|
829
|
+
|
|
830
|
+
for (nk_size_t column_within_tile = 0; column_within_tile < valid_documents; column_within_tile++) {
|
|
831
|
+
__m512i gather_index_i32x16 = _mm512_add_epi32(row_stride_indices_i32x16,
|
|
832
|
+
_mm512_set1_epi32((int)column_within_tile));
|
|
833
|
+
__m512 column_dots_f32x16 = _mm512_i32gather_ps(gather_index_i32x16, (float const *)tile_results_f32[0],
|
|
834
|
+
4);
|
|
835
|
+
__mmask16 is_better_bx16 = _mm512_cmp_ps_mask(column_dots_f32x16, running_maximum_f32x16, _CMP_GT_OQ);
|
|
836
|
+
running_maximum_f32x16 = _mm512_mask_mov_ps(running_maximum_f32x16, is_better_bx16, column_dots_f32x16);
|
|
837
|
+
running_argmax_i32x16 = _mm512_mask_mov_epi32(
|
|
838
|
+
running_argmax_i32x16, is_better_bx16,
|
|
839
|
+
_mm512_set1_epi32((int)(document_column_start + column_within_tile)));
|
|
840
|
+
}
|
|
841
|
+
}
|
|
842
|
+
|
|
843
|
+
// Angular distance finalization using AVX-512
|
|
844
|
+
__m512 query_inverse_norms_f32x16 = _mm512_maskz_loadu_ps(valid_query_mask_bx16,
|
|
845
|
+
query_inverse_norms + query_row_start);
|
|
846
|
+
__m512 document_inverse_norms_f32x16 = _mm512_i32gather_ps(running_argmax_i32x16, document_inverse_norms, 4);
|
|
847
|
+
|
|
848
|
+
// cosine = dot × inv_norm_q × inv_norm_d
|
|
849
|
+
__m512 cosine_f32x16 = _mm512_mul_ps(_mm512_mul_ps(running_maximum_f32x16, query_inverse_norms_f32x16),
|
|
850
|
+
document_inverse_norms_f32x16);
|
|
851
|
+
|
|
852
|
+
// angular = max(1 - cosine, 0), masked to valid queries only
|
|
853
|
+
__m512 angular_distance_f32x16 = _mm512_max_ps(_mm512_sub_ps(_mm512_set1_ps(1.0f), cosine_f32x16),
|
|
854
|
+
_mm512_setzero_ps());
|
|
855
|
+
angular_distance_f32x16 = _mm512_maskz_mov_ps(valid_query_mask_bx16, angular_distance_f32x16);
|
|
856
|
+
|
|
857
|
+
total_angular_distance_f64 += (nk_f64_t)_mm512_reduce_add_ps(angular_distance_f32x16);
|
|
858
|
+
}
|
|
859
|
+
|
|
860
|
+
*result = (nk_f32_t)total_angular_distance_f64;
|
|
861
|
+
}
|
|
862
|
+
|
|
863
|
+
#pragma endregion
|
|
864
|
+
|
|
865
|
+
#if defined(__clang__)
|
|
866
|
+
#pragma clang attribute pop
|
|
867
|
+
#elif defined(__GNUC__)
|
|
868
|
+
#pragma GCC pop_options
|
|
869
|
+
#endif
|
|
870
|
+
|
|
871
|
+
#if defined(__cplusplus)
|
|
872
|
+
} // extern "C"
|
|
873
|
+
#endif
|
|
874
|
+
|
|
875
|
+
#endif // NK_TARGET_SAPPHIREAMX
|
|
876
|
+
#endif // NK_TARGET_X86_
|
|
877
|
+
#endif // NK_MAXSIM_SAPPHIREAMX_H
|