numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,336 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief NumKong packed_matrix type for efficient GEMM.
|
|
3
|
+
* @file include/numkong/matrix.hpp
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 2026
|
|
6
|
+
*
|
|
7
|
+
* Provides a pre-packed matrix type that wraps `dots_pack` / `dots_packed` for
|
|
8
|
+
* cache-efficient matrix multiplication.
|
|
9
|
+
*
|
|
10
|
+
* @code
|
|
11
|
+
* auto b = nk::tensor<nk::f32_t>::try_zeros({256, 512});
|
|
12
|
+
* auto packed = nk::packed_matrix<nk::f32_t>::try_pack(b.view());
|
|
13
|
+
* // multiply many times with different A matrices
|
|
14
|
+
* @endcode
|
|
15
|
+
*/
|
|
16
|
+
|
|
17
|
+
#ifndef NK_MATRIX_HPP
|
|
18
|
+
#define NK_MATRIX_HPP
|
|
19
|
+
|
|
20
|
+
#include <cstring>
|
|
21
|
+
#include <type_traits>
|
|
22
|
+
|
|
23
|
+
#include "numkong/dots.h"
|
|
24
|
+
#include "numkong/maxsim.h"
|
|
25
|
+
#include "numkong/tensor.hpp"
|
|
26
|
+
|
|
27
|
+
namespace ashvardanian::numkong {
|
|
28
|
+
|
|
29
|
+
#pragma region - Packing Utilities
|
|
30
|
+
|
|
31
|
+
/**
|
|
32
|
+
* @brief Estimates the memory requirements for packed B matrix.
|
|
33
|
+
* @param[in] row_count Number of rows in B (n)
|
|
34
|
+
* @param[in] depth Number of dimensions per row (k)
|
|
35
|
+
* @return Size in bytes for row-major B data plus stride metadata
|
|
36
|
+
*
|
|
37
|
+
* @tparam in_type_ Input element type
|
|
38
|
+
* @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
|
|
39
|
+
*/
|
|
40
|
+
template <numeric_dtype in_type_, allow_simd_t allow_simd_ = prefer_simd_k>
|
|
41
|
+
NK_PUBLIC size_t dots_packed_size(size_t row_count, size_t depth) {
|
|
42
|
+
constexpr bool simd = allow_simd_ == prefer_simd_k;
|
|
43
|
+
|
|
44
|
+
if constexpr (std::is_same_v<in_type_, f64_t> && simd) return nk_dots_packed_size_f64(row_count, depth);
|
|
45
|
+
else if constexpr (std::is_same_v<in_type_, f32_t> && simd) return nk_dots_packed_size_f32(row_count, depth);
|
|
46
|
+
else if constexpr (std::is_same_v<in_type_, f16_t> && simd) return nk_dots_packed_size_f16(row_count, depth);
|
|
47
|
+
else if constexpr (std::is_same_v<in_type_, bf16_t> && simd) return nk_dots_packed_size_bf16(row_count, depth);
|
|
48
|
+
else if constexpr (std::is_same_v<in_type_, i8_t> && simd) return nk_dots_packed_size_i8(row_count, depth);
|
|
49
|
+
else if constexpr (std::is_same_v<in_type_, u8_t> && simd) return nk_dots_packed_size_u8(row_count, depth);
|
|
50
|
+
else if constexpr (std::is_same_v<in_type_, e4m3_t> && simd) return nk_dots_packed_size_e4m3(row_count, depth);
|
|
51
|
+
else if constexpr (std::is_same_v<in_type_, e5m2_t> && simd) return nk_dots_packed_size_e5m2(row_count, depth);
|
|
52
|
+
else if constexpr (std::is_same_v<in_type_, e2m3_t> && simd) return nk_dots_packed_size_e2m3(row_count, depth);
|
|
53
|
+
else if constexpr (std::is_same_v<in_type_, e3m2_t> && simd) return nk_dots_packed_size_e3m2(row_count, depth);
|
|
54
|
+
else if constexpr (std::is_same_v<in_type_, u4x2_t> && simd) return nk_dots_packed_size_u4(row_count, depth);
|
|
55
|
+
else if constexpr (std::is_same_v<in_type_, i4x2_t> && simd) return nk_dots_packed_size_i4(row_count, depth);
|
|
56
|
+
else {
|
|
57
|
+
// We need enough space for the pointer to the original B matrix and its stride
|
|
58
|
+
return sizeof(void *) + sizeof(size_t);
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
/**
|
|
63
|
+
* @brief Packs matrix B into row-major form for efficient dots_packed access.
|
|
64
|
+
* @param[in] b Input matrix B in row-major form [row_count x depth]
|
|
65
|
+
* @param[in] row_count Number of rows in B (n)
|
|
66
|
+
* @param[in] depth Number of dimensions per row (k)
|
|
67
|
+
* @param[in] b_stride_in_bytes Stride between rows of B in bytes
|
|
68
|
+
* @param[out] b_packed Output buffer for packed row-major B with metadata
|
|
69
|
+
*
|
|
70
|
+
* @tparam in_type_ Input element type
|
|
71
|
+
* @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
|
|
72
|
+
*/
|
|
73
|
+
template <numeric_dtype in_type_, allow_simd_t allow_simd_ = prefer_simd_k>
|
|
74
|
+
NK_PUBLIC void dots_pack(in_type_ const *b, size_t row_count, size_t depth, size_t b_stride_in_bytes, void *b_packed) {
|
|
75
|
+
using raw_t = typename in_type_::raw_t;
|
|
76
|
+
constexpr bool simd = allow_simd_ == prefer_simd_k;
|
|
77
|
+
|
|
78
|
+
if constexpr (std::is_same_v<in_type_, f64_t> && simd)
|
|
79
|
+
nk_dots_pack_f64(reinterpret_cast<raw_t const *>(b), row_count, depth, b_stride_in_bytes, b_packed);
|
|
80
|
+
else if constexpr (std::is_same_v<in_type_, f32_t> && simd)
|
|
81
|
+
nk_dots_pack_f32(reinterpret_cast<raw_t const *>(b), row_count, depth, b_stride_in_bytes, b_packed);
|
|
82
|
+
else if constexpr (std::is_same_v<in_type_, f16_t> && simd)
|
|
83
|
+
nk_dots_pack_f16(reinterpret_cast<raw_t const *>(b), row_count, depth, b_stride_in_bytes, b_packed);
|
|
84
|
+
else if constexpr (std::is_same_v<in_type_, bf16_t> && simd)
|
|
85
|
+
nk_dots_pack_bf16(reinterpret_cast<raw_t const *>(b), row_count, depth, b_stride_in_bytes, b_packed);
|
|
86
|
+
else if constexpr (std::is_same_v<in_type_, i8_t> && simd)
|
|
87
|
+
nk_dots_pack_i8(reinterpret_cast<raw_t const *>(b), row_count, depth, b_stride_in_bytes, b_packed);
|
|
88
|
+
else if constexpr (std::is_same_v<in_type_, u8_t> && simd)
|
|
89
|
+
nk_dots_pack_u8(reinterpret_cast<raw_t const *>(b), row_count, depth, b_stride_in_bytes, b_packed);
|
|
90
|
+
else if constexpr (std::is_same_v<in_type_, e4m3_t> && simd)
|
|
91
|
+
nk_dots_pack_e4m3(reinterpret_cast<raw_t const *>(b), row_count, depth, b_stride_in_bytes, b_packed);
|
|
92
|
+
else if constexpr (std::is_same_v<in_type_, e5m2_t> && simd)
|
|
93
|
+
nk_dots_pack_e5m2(reinterpret_cast<raw_t const *>(b), row_count, depth, b_stride_in_bytes, b_packed);
|
|
94
|
+
else if constexpr (std::is_same_v<in_type_, e2m3_t> && simd)
|
|
95
|
+
nk_dots_pack_e2m3(reinterpret_cast<raw_t const *>(b), row_count, depth, b_stride_in_bytes, b_packed);
|
|
96
|
+
else if constexpr (std::is_same_v<in_type_, e3m2_t> && simd)
|
|
97
|
+
nk_dots_pack_e3m2(reinterpret_cast<raw_t const *>(b), row_count, depth, b_stride_in_bytes, b_packed);
|
|
98
|
+
else if constexpr (std::is_same_v<in_type_, u4x2_t> && simd)
|
|
99
|
+
nk_dots_pack_u4(reinterpret_cast<raw_t const *>(b), row_count, depth, b_stride_in_bytes, b_packed);
|
|
100
|
+
else if constexpr (std::is_same_v<in_type_, i4x2_t> && simd)
|
|
101
|
+
nk_dots_pack_i4(reinterpret_cast<raw_t const *>(b), row_count, depth, b_stride_in_bytes, b_packed);
|
|
102
|
+
else {
|
|
103
|
+
// Persist the pointer to the original B matrix and its stride
|
|
104
|
+
char *b_packed_bytes = reinterpret_cast<char *>(b_packed);
|
|
105
|
+
std::memcpy(b_packed_bytes, &b, sizeof(void *));
|
|
106
|
+
std::memcpy(b_packed_bytes + sizeof(void *), &b_stride_in_bytes, sizeof(size_t));
|
|
107
|
+
}
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
/**
|
|
111
|
+
* @brief Estimates the memory requirements for a maxsim packed vector set.
|
|
112
|
+
* @param[in] vector_count Number of vectors to pack.
|
|
113
|
+
* @param[in] depth Number of dimensions per vector.
|
|
114
|
+
* @return Size in bytes for the packed buffer.
|
|
115
|
+
*
|
|
116
|
+
* @tparam in_type_ Input element type (bf16_t, f32_t, f16_t).
|
|
117
|
+
* @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`.
|
|
118
|
+
*/
|
|
119
|
+
template <numeric_dtype in_type_, allow_simd_t allow_simd_ = prefer_simd_k>
|
|
120
|
+
NK_PUBLIC std::size_t maxsim_packed_size(std::size_t vector_count, std::size_t depth) {
|
|
121
|
+
constexpr bool simd = allow_simd_ == prefer_simd_k;
|
|
122
|
+
|
|
123
|
+
if constexpr (std::is_same_v<in_type_, bf16_t> && simd) return nk_maxsim_packed_size_bf16(vector_count, depth);
|
|
124
|
+
else if constexpr (std::is_same_v<in_type_, f32_t> && simd) return nk_maxsim_packed_size_f32(vector_count, depth);
|
|
125
|
+
else if constexpr (std::is_same_v<in_type_, f16_t> && simd) return nk_maxsim_packed_size_f16(vector_count, depth);
|
|
126
|
+
else return sizeof(void *) + sizeof(std::size_t);
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
/**
|
|
130
|
+
* @brief Packs vectors into a backend-specific layout for maxsim computation.
|
|
131
|
+
* @param[in] vectors Input vectors in row-major order.
|
|
132
|
+
* @param[in] vector_count Number of vectors.
|
|
133
|
+
* @param[in] depth Number of dimensions per vector.
|
|
134
|
+
* @param[in] stride Row stride in bytes for the input vectors.
|
|
135
|
+
* @param[out] packed Output packed buffer from maxsim_packed_size.
|
|
136
|
+
*
|
|
137
|
+
* @tparam in_type_ Input element type (bf16_t, f32_t, f16_t).
|
|
138
|
+
* @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`.
|
|
139
|
+
*/
|
|
140
|
+
template <numeric_dtype in_type_, allow_simd_t allow_simd_ = prefer_simd_k>
|
|
141
|
+
NK_PUBLIC void maxsim_pack(typename in_type_::raw_t const *vectors, std::size_t vector_count, std::size_t depth,
|
|
142
|
+
std::size_t stride, void *packed) {
|
|
143
|
+
constexpr bool simd = allow_simd_ == prefer_simd_k;
|
|
144
|
+
|
|
145
|
+
if constexpr (std::is_same_v<in_type_, bf16_t> && simd)
|
|
146
|
+
nk_maxsim_pack_bf16(vectors, vector_count, depth, stride, packed);
|
|
147
|
+
else if constexpr (std::is_same_v<in_type_, f32_t> && simd)
|
|
148
|
+
nk_maxsim_pack_f32(vectors, vector_count, depth, stride, packed);
|
|
149
|
+
else if constexpr (std::is_same_v<in_type_, f16_t> && simd)
|
|
150
|
+
nk_maxsim_pack_f16(vectors, vector_count, depth, stride, packed);
|
|
151
|
+
else {
|
|
152
|
+
char *packed_bytes = reinterpret_cast<char *>(packed);
|
|
153
|
+
std::memcpy(packed_bytes, &vectors, sizeof(void *));
|
|
154
|
+
std::memcpy(packed_bytes + sizeof(void *), &stride, sizeof(std::size_t));
|
|
155
|
+
}
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
#pragma endregion - Packing Utilities
|
|
159
|
+
|
|
160
|
+
#pragma region - Packed Containers
|
|
161
|
+
|
|
162
|
+
/**
|
|
163
|
+
* @brief Owning, move-only, pre-packed matrix for efficient GEMM.
|
|
164
|
+
* @tparam value_type_ Element type (e.g., f32_t, bf16_t).
|
|
165
|
+
* @tparam allocator_type_ Allocator for the packed buffer (default: aligned_allocator<char>).
|
|
166
|
+
*
|
|
167
|
+
* Wraps `dots_pack` to pre-arrange a matrix B into a cache-friendly layout.
|
|
168
|
+
* Use `try_pack()` to create from a matrix_view, then pass to `dots_packed()` for computation.
|
|
169
|
+
*/
|
|
170
|
+
template <numeric_dtype value_type_, typename allocator_type_ = aligned_allocator<char>>
|
|
171
|
+
struct packed_matrix {
|
|
172
|
+
using value_type = value_type_;
|
|
173
|
+
using result_type = typename value_type_::dot_result_t;
|
|
174
|
+
using allocator_type = allocator_type_;
|
|
175
|
+
using alloc_traits = std::allocator_traits<allocator_type_>;
|
|
176
|
+
using size_type = std::size_t;
|
|
177
|
+
|
|
178
|
+
private:
|
|
179
|
+
char *data_ = nullptr;
|
|
180
|
+
size_type size_bytes_ = 0;
|
|
181
|
+
size_type rows_ = 0; // n (number of rows in B)
|
|
182
|
+
size_type depth_ = 0; // k (number of columns in B)
|
|
183
|
+
[[no_unique_address]] allocator_type_ alloc_;
|
|
184
|
+
|
|
185
|
+
public:
|
|
186
|
+
packed_matrix() noexcept = default;
|
|
187
|
+
|
|
188
|
+
explicit packed_matrix(allocator_type_ const &alloc) noexcept : alloc_(alloc) {}
|
|
189
|
+
|
|
190
|
+
~packed_matrix() noexcept {
|
|
191
|
+
if (data_) alloc_traits::deallocate(alloc_, data_, size_bytes_);
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
packed_matrix(packed_matrix &&other) noexcept
|
|
195
|
+
: data_(std::exchange(other.data_, nullptr)), size_bytes_(std::exchange(other.size_bytes_, 0)),
|
|
196
|
+
rows_(std::exchange(other.rows_, 0)), depth_(std::exchange(other.depth_, 0)),
|
|
197
|
+
alloc_(std::move(other.alloc_)) {}
|
|
198
|
+
|
|
199
|
+
packed_matrix &operator=(packed_matrix &&other) noexcept {
|
|
200
|
+
if (this != &other) {
|
|
201
|
+
if (data_) alloc_traits::deallocate(alloc_, data_, size_bytes_);
|
|
202
|
+
if constexpr (alloc_traits::propagate_on_container_move_assignment::value) alloc_ = std::move(other.alloc_);
|
|
203
|
+
data_ = std::exchange(other.data_, nullptr);
|
|
204
|
+
size_bytes_ = std::exchange(other.size_bytes_, 0);
|
|
205
|
+
rows_ = std::exchange(other.rows_, 0);
|
|
206
|
+
depth_ = std::exchange(other.depth_, 0);
|
|
207
|
+
}
|
|
208
|
+
return *this;
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
packed_matrix(packed_matrix const &) = delete;
|
|
212
|
+
packed_matrix &operator=(packed_matrix const &) = delete;
|
|
213
|
+
|
|
214
|
+
/**
|
|
215
|
+
* @brief Pack a 2D matrix_view into cache-efficient layout.
|
|
216
|
+
* @param b 2D matrix view. Uses extents[0] as rows, extents[1] as depth.
|
|
217
|
+
* @param alloc Allocator instance.
|
|
218
|
+
* @return Non-empty packed_matrix on success, empty on failure.
|
|
219
|
+
*/
|
|
220
|
+
[[nodiscard]] static packed_matrix try_pack(matrix_view<value_type_> b, allocator_type_ alloc = {}) noexcept {
|
|
221
|
+
packed_matrix pm(alloc);
|
|
222
|
+
if (b.rank() < 2) return pm;
|
|
223
|
+
|
|
224
|
+
pm.rows_ = b.extent(0);
|
|
225
|
+
pm.depth_ = b.extent(1);
|
|
226
|
+
pm.size_bytes_ = dots_packed_size<value_type_>(pm.rows_, pm.depth_);
|
|
227
|
+
if (pm.size_bytes_ == 0) return pm;
|
|
228
|
+
|
|
229
|
+
pm.data_ = alloc_traits::allocate(pm.alloc_, pm.size_bytes_);
|
|
230
|
+
if (!pm.data_) {
|
|
231
|
+
pm.size_bytes_ = 0;
|
|
232
|
+
return pm;
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
dots_pack<value_type_>(b.data(), pm.rows_, pm.depth_, static_cast<size_type>(b.stride_bytes(0)), pm.data_);
|
|
236
|
+
return pm;
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
/** @brief Number of rows in the packed matrix (n). */
|
|
240
|
+
constexpr size_type rows() const noexcept { return rows_; }
|
|
241
|
+
|
|
242
|
+
/** @brief Number of columns / depth (k). */
|
|
243
|
+
constexpr size_type depth() const noexcept { return depth_; }
|
|
244
|
+
|
|
245
|
+
/** @brief Size of the packed buffer in bytes. */
|
|
246
|
+
constexpr size_type size_bytes() const noexcept { return size_bytes_; }
|
|
247
|
+
|
|
248
|
+
/** @brief True if no matrix is packed. */
|
|
249
|
+
constexpr bool empty() const noexcept { return data_ == nullptr; }
|
|
250
|
+
|
|
251
|
+
/** @brief Raw pointer to the packed data. */
|
|
252
|
+
constexpr void const *data() const noexcept { return data_; }
|
|
253
|
+
};
|
|
254
|
+
|
|
255
|
+
/**
|
|
256
|
+
* @brief Pre-packed vector set for MaxSim (ColBERT late-interaction).
|
|
257
|
+
*
|
|
258
|
+
* MaxSim computes Σᵢ minⱼ angular(qᵢ, dⱼ) using quantized i8 screening
|
|
259
|
+
* followed by full-precision refinement. Both queries and documents must
|
|
260
|
+
* be independently packed before calling `maxsim()`.
|
|
261
|
+
*
|
|
262
|
+
* Supported types: bf16_t, f32_t, f16_t.
|
|
263
|
+
*/
|
|
264
|
+
template <numeric_dtype value_type_, typename allocator_type_ = aligned_allocator<char>>
|
|
265
|
+
class packed_maxsim {
|
|
266
|
+
using alloc_traits = std::allocator_traits<allocator_type_>;
|
|
267
|
+
|
|
268
|
+
char *data_ = nullptr;
|
|
269
|
+
std::size_t size_bytes_ = 0;
|
|
270
|
+
std::size_t vector_count_ = 0;
|
|
271
|
+
std::size_t depth_ = 0;
|
|
272
|
+
[[no_unique_address]] allocator_type_ alloc_;
|
|
273
|
+
|
|
274
|
+
public:
|
|
275
|
+
packed_maxsim() noexcept = default;
|
|
276
|
+
explicit packed_maxsim(allocator_type_ const &alloc) noexcept : alloc_(alloc) {}
|
|
277
|
+
|
|
278
|
+
~packed_maxsim() noexcept {
|
|
279
|
+
if (data_) alloc_traits::deallocate(alloc_, data_, size_bytes_);
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
packed_maxsim(packed_maxsim &&o) noexcept
|
|
283
|
+
: data_(std::exchange(o.data_, nullptr)), size_bytes_(std::exchange(o.size_bytes_, 0)),
|
|
284
|
+
vector_count_(std::exchange(o.vector_count_, 0)), depth_(std::exchange(o.depth_, 0)),
|
|
285
|
+
alloc_(std::move(o.alloc_)) {}
|
|
286
|
+
|
|
287
|
+
packed_maxsim &operator=(packed_maxsim &&o) noexcept {
|
|
288
|
+
if (this != &o) {
|
|
289
|
+
if (data_) alloc_traits::deallocate(alloc_, data_, size_bytes_);
|
|
290
|
+
if constexpr (alloc_traits::propagate_on_container_move_assignment::value) alloc_ = std::move(o.alloc_);
|
|
291
|
+
data_ = std::exchange(o.data_, nullptr);
|
|
292
|
+
size_bytes_ = std::exchange(o.size_bytes_, 0);
|
|
293
|
+
vector_count_ = std::exchange(o.vector_count_, 0);
|
|
294
|
+
depth_ = std::exchange(o.depth_, 0);
|
|
295
|
+
}
|
|
296
|
+
return *this;
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
packed_maxsim(packed_maxsim const &) = delete;
|
|
300
|
+
packed_maxsim &operator=(packed_maxsim const &) = delete;
|
|
301
|
+
|
|
302
|
+
/** @brief Pack a 2D matrix of vectors. Returns empty on failure. */
|
|
303
|
+
[[nodiscard]] static packed_maxsim try_pack(matrix_view<value_type_> vectors, allocator_type_ alloc = {}) noexcept {
|
|
304
|
+
packed_maxsim pm(alloc);
|
|
305
|
+
if (vectors.rank() < 2) return pm;
|
|
306
|
+
|
|
307
|
+
pm.vector_count_ = vectors.extent(0);
|
|
308
|
+
pm.depth_ = vectors.extent(1);
|
|
309
|
+
pm.size_bytes_ = maxsim_packed_size<value_type_>(pm.vector_count_, pm.depth_);
|
|
310
|
+
if (pm.size_bytes_ == 0) return pm;
|
|
311
|
+
|
|
312
|
+
pm.data_ = alloc_traits::allocate(pm.alloc_, pm.size_bytes_);
|
|
313
|
+
if (!pm.data_) {
|
|
314
|
+
pm.size_bytes_ = 0;
|
|
315
|
+
return pm;
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
maxsim_pack<value_type_>(reinterpret_cast<typename value_type_::raw_t const *>(vectors.data()),
|
|
319
|
+
pm.vector_count_, pm.depth_, static_cast<std::size_t>(vectors.stride_bytes(0)),
|
|
320
|
+
pm.data_);
|
|
321
|
+
return pm;
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
std::size_t vector_count() const noexcept { return vector_count_; }
|
|
325
|
+
std::size_t rows() const noexcept { return vector_count_; }
|
|
326
|
+
std::size_t depth() const noexcept { return depth_; }
|
|
327
|
+
bool empty() const noexcept { return data_ == nullptr; }
|
|
328
|
+
void const *data() const noexcept { return data_; }
|
|
329
|
+
std::size_t size_bytes() const noexcept { return size_bytes_; }
|
|
330
|
+
};
|
|
331
|
+
|
|
332
|
+
#pragma endregion - Packed Containers
|
|
333
|
+
|
|
334
|
+
} // namespace ashvardanian::numkong
|
|
335
|
+
|
|
336
|
+
#endif // NK_MATRIX_HPP
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
# MaxSim Late-Interaction Scoring in NumKong
|
|
2
|
+
|
|
3
|
+
NumKong implements ColBERT-style late-interaction scoring: the MaxSim score sums, over each query token, the minimum angular distance to any document token. A two-stage coarse-to-fine strategy uses i8-quantized screening to find the best document per query, then full-precision refinement computes the final angular distance.
|
|
4
|
+
|
|
5
|
+
MaxSim score:
|
|
6
|
+
|
|
7
|
+
```math
|
|
8
|
+
\text{MaxSim}(Q, D) = \sum_{i=0}^{m-1} \min_{j=0}^{n-1} \text{angular}(q_i, d_j)
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
Coarse screening finds the best document via i8 dot products as a proxy for argmin angular:
|
|
12
|
+
|
|
13
|
+
```math
|
|
14
|
+
j^* = \arg\max_j \text{dot}_{\text{i8}}(q_i, d_j)
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
Full-precision refinement:
|
|
18
|
+
|
|
19
|
+
```math
|
|
20
|
+
\text{angular}(q_i, d_{j^*}) = 1 - \frac{\text{dot}(q_i, d_{j^*})}{\|q_i\| \cdot \|d_{j^*}\|}
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
Reformulating as Python pseudocode:
|
|
24
|
+
|
|
25
|
+
```python
|
|
26
|
+
import numpy as np
|
|
27
|
+
|
|
28
|
+
def maxsim(queries: np.ndarray, documents: np.ndarray) -> float:
|
|
29
|
+
score = 0.0
|
|
30
|
+
for q in queries:
|
|
31
|
+
dots = documents @ q
|
|
32
|
+
best = np.argmax(dots)
|
|
33
|
+
d = documents[best]
|
|
34
|
+
angular = 1 - np.dot(q, d) / (np.linalg.norm(q) * np.linalg.norm(d))
|
|
35
|
+
score += angular
|
|
36
|
+
return score
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
## Input & Output Types
|
|
40
|
+
|
|
41
|
+
| Input Type | Output Type | Description |
|
|
42
|
+
| ---------- | ----------- | ---------------------------------- |
|
|
43
|
+
| `bf16` | `f32` | 16-bit brain float, widened output |
|
|
44
|
+
| `f32` | `f32` | 32-bit IEEE 754 single precision |
|
|
45
|
+
| `f16` | `f32` | 16-bit IEEE 754 half precision |
|
|
46
|
+
|
|
47
|
+
## Optimizations
|
|
48
|
+
|
|
49
|
+
### Dual Pre-Packing Advantage
|
|
50
|
+
|
|
51
|
+
`nk_maxsim_packed_bf16_sme`, `nk_maxsim_packed_f32_sme` benefit from having _both_ query and document matrices pre-packed into identical contiguous formats, unlike the `nk_dots_packed_*` family where only B is pre-packed and A is accessed with arbitrary stride.
|
|
52
|
+
In the dots GEMM, one ZA tile must be reserved for A-side staging (loading unpacked A rows into the tile array), leaving 3 ZA tiles for accumulation.
|
|
53
|
+
With both sides pre-packed, all 4 ZA tiles (ZA0–ZA3) serve as accumulators — a +33% increase in MOPA throughput.
|
|
54
|
+
No output matrix materialization: dots_packed writes a full M×N f32 result matrix, while maxsim reduces each query row to a single argmax index in-flight, eliminating the M×N memory round-trip.
|
|
55
|
+
Benchmark data (Apple M4, SVL=512):
|
|
56
|
+
|
|
57
|
+
| Dimensions | dots_packed GEMM | maxsim fused | GEMM speedup | End-to-end |
|
|
58
|
+
| -------------------- | ---------------: | -----------: | -----------: | ---------: |
|
|
59
|
+
| 32×128×128 (ColBERT) | 840 GFLOPS | 1516 GFLOPS | 1.81× | 5.10× |
|
|
60
|
+
| 32×256×128 | 1037 GFLOPS | 1591 GFLOPS | 1.53× | 5.17× |
|
|
61
|
+
| 64×512×128 | 1016 GFLOPS | 1651 GFLOPS | 1.62× | 5.42× |
|
|
62
|
+
| 32×128×256 | 859 GFLOPS | 1725 GFLOPS | 2.01× | 4.06× |
|
|
63
|
+
| 32×1024×768 (BERT) | 1124 GFLOPS | 1932 GFLOPS | 1.72× | 2.61× |
|
|
64
|
+
|
|
65
|
+
End-to-end speedup (5×) exceeds GEMM-only speedup (1.5–2×) because maxsim eliminates output materialization and fuses argmax+angular refinement into the tile extraction loop.
|
|
66
|
+
|
|
67
|
+
### Two-Stage Coarse-to-Fine Scoring
|
|
68
|
+
|
|
69
|
+
All backends use i8-quantized coarse screening at O(m·n·k) with 1 byte/element instead of 2–4, followed by full-precision refinement at O(m·k) for only the winning pairs.
|
|
70
|
+
Break-even at ~4 documents per query — beyond that, coarse screening dominates and the i8 bandwidth advantage compounds.
|
|
71
|
+
|
|
72
|
+
### ISA-Specific Quantization Ranges
|
|
73
|
+
|
|
74
|
+
Haswell uses [-79, 79] — `VPMADDUBSW` produces i16 intermediates, must avoid saturation (2×depth×79 < 32767).
|
|
75
|
+
Alder Lake and Ice Lake use [-127, 127] — `VPDPBUSD` accumulates directly to i32, no i16 bottleneck.
|
|
76
|
+
WASM v128relaxed uses [-63, 63] — `i32x4_relaxed_dot_i8x16_i7x16_add` requires 7-bit operands.
|
|
77
|
+
Serial uses [-127, 127].
|
|
78
|
+
|
|
79
|
+
### XOR-0x80 Bias Correction
|
|
80
|
+
|
|
81
|
+
`nk_maxsim_packed_bf16_haswell`, `nk_maxsim_packed_f32_alder`, `nk_maxsim_packed_f32_icelake` work around the unsigned×signed operand requirement of `VPMADDUBSW` and `VPDPBUSD`.
|
|
82
|
+
Both query and document are signed after quantization, so queries are XOR'd with `0x80` to shift to unsigned range.
|
|
83
|
+
Post-multiply correction subtracts $128 \cdot \text{sum\_i8}(d_j)$ per document, where sums are precomputed in packed metadata.
|
|
84
|
+
|
|
85
|
+
### Vertical Column Extraction on SME
|
|
86
|
+
|
|
87
|
+
`nk_maxsim_packed_bf16_sme`, `nk_maxsim_packed_f32_sme` accumulate Q×D dot products into ZA tiles (4 tiles ZA0–ZA3, each SVL×SVL).
|
|
88
|
+
The argmax operation needs to find the best document for each query.
|
|
89
|
+
The naive approach reads rows horizontally (`svread_hor_za32`) and reduces each row with `svmaxv` — but `svmaxv` is a horizontal reduction costing ~8 cycles on typical SVE implementations.
|
|
90
|
+
Vertical column extraction flips the access pattern: `svread_ver_za32_f32_m` reads one _column_ of ZA, returning one dot-product score per query for a single document.
|
|
91
|
+
Element-wise `svcmpgt_f32` + `svsel_f32` (~1 cycle each) update the running maximum across all queries simultaneously.
|
|
92
|
+
For 32 queries × 256 documents: horizontal approach = 32 × 256 × `svmaxv` = 8,192 horizontal reductions; vertical approach = 256 column reads × 1 element-wise `svmax` = 256 vertical reads + 256 comparisons (~270 cycles vs ~2,048 cycles for the argmax phase alone).
|
|
93
|
+
The argmax index is tracked in-flight using `svsel` to conditionally update an index vector alongside the maximum values — no separate argmax pass needed.
|
|
94
|
+
After finding the best document index per query, full-precision angular refinement uses the originals stored in the packed buffer's third region.
|
|
95
|
+
|
|
96
|
+
### Three-Region Packed Buffer
|
|
97
|
+
|
|
98
|
+
All backends use a three-region packed buffer layout: [Header 64B] [i8 vectors, 64B-aligned] [metadata, 64B-aligned] [originals, 64B-aligned].
|
|
99
|
+
Per-vector metadata (12 bytes) stores quantization scale, i8 sum (for bias correction), and inverse norm (for angular finalization).
|
|
100
|
+
The originals region stores full-precision vectors for refinement via existing `nk_dot_*` primitives.
|
|
101
|
+
|
|
102
|
+
## Performance
|
|
103
|
+
|
|
104
|
+
The following performance tables are produced by manually re-running `nk_test` and `nk_bench` included internal tools to measure both accuracy and throughput at different input shapes.
|
|
105
|
+
The input size is controlled by `NK_MATRIX_HEIGHT`, `NK_MATRIX_WIDTH`, and `NK_MATRIX_DEPTH` environment variables, all set to the same value for late-interaction scoring over square matrices.
|
|
106
|
+
Columns show throughput for 256³, 1024³, and 4096³ configurations.
|
|
107
|
+
The throughput is measured in GSO/s as Giga Scalar Operations per Second, with $\text{ops} = 2 \cdot M \cdot N \cdot K$ complexity for scoring $M$ query tokens against $N$ document tokens of dimension $K$.
|
|
108
|
+
Accuracy is reported as mean ULP (units in last place) unless noted otherwise — the average number of representable floating-point values between the result and the exact answer.
|
|
109
|
+
Each kernel runs for at least 20 seconds per configuration.
|
|
110
|
+
Benchmark threads are pinned to specific cores; on machines with heterogeneous core types (e.g., Apple P/E cores), only the fastest cores are used.
|
|
111
|
+
Workloads that significantly degrade CPU frequencies (Intel AMX, Apple SME) run in separate passes to avoid affecting throughput measurements of other kernels.
|
|
112
|
+
|
|
113
|
+
### Intel Sapphire Rapids
|
|
114
|
+
|
|
115
|
+
#### Native
|
|
116
|
+
|
|
117
|
+
| Kernel | 256³ | 1024³ | 4096³ |
|
|
118
|
+
| :---------------------------------- | -----------------------: | -----------------------: | -----------------------: |
|
|
119
|
+
| __f32__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
120
|
+
| `nk_maxsim_packed_f32_serial` | 15.7 gso/s, 48.9K ulp | 15.2 gso/s, 48.9K ulp | 16.3 gso/s, 48.9K ulp |
|
|
121
|
+
| `nk_maxsim_packed_f32_haswell` | 77.2 gso/s, 49.3K ulp | 70.7 gso/s, 49.3K ulp | 74.5 gso/s, 49.3K ulp |
|
|
122
|
+
| `nk_maxsim_packed_f32_alder` | 99.7 gso/s, 48.9K ulp | 97.7 gso/s, 48.9K ulp | 94.5 gso/s, 48.9K ulp |
|
|
123
|
+
| `nk_maxsim_packed_f32_icelake` | 131 gso/s, 48.9K ulp | 124 gso/s, 48.9K ulp | 136 gso/s, 48.9K ulp |
|
|
124
|
+
| `nk_maxsim_packed_f32_sapphireamx` | 273 gso/s, 48.9K ulp | 293 gso/s, 48.9K ulp | 285 gso/s, 48.9K ulp |
|
|
125
|
+
| __bf16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
126
|
+
| `nk_maxsim_packed_bf16_serial` | 15.9 gso/s, 49.0K ulp | 17.0 gso/s, 49.0K ulp | 15.3 gso/s, 49.0K ulp |
|
|
127
|
+
| `nk_maxsim_packed_bf16_haswell` | 79.2 gso/s, 49.3K ulp | 85.0 gso/s, 49.3K ulp | 81.0 gso/s, 49.3K ulp |
|
|
128
|
+
| `nk_maxsim_packed_bf16_alder` | 114 gso/s, 49.0K ulp | 110 gso/s, 49.0K ulp | 115 gso/s, 49.0K ulp |
|
|
129
|
+
| `nk_maxsim_packed_bf16_genoa` | 163 gso/s, 49.0K ulp | 165 gso/s, 49.0K ulp | 174 gso/s, 49.0K ulp |
|
|
130
|
+
| `nk_maxsim_packed_bf16_sapphireamx` | 418 gso/s, 994 ulp | 418 gso/s, 994 ulp | 445 gso/s, 994 ulp |
|
|
131
|
+
| __f16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
132
|
+
| `nk_maxsim_packed_f16_serial` | 15.5 gso/s, 49.4K ulp | 15.6 gso/s, 49.4K ulp | 16.9 gso/s, 49.4K ulp |
|
|
133
|
+
| `nk_maxsim_packed_f16_haswell` | 79.1 gso/s, 49.8K ulp | 78.1 gso/s, 49.8K ulp | 79.1 gso/s, 49.8K ulp |
|
|
134
|
+
| `nk_maxsim_packed_f16_alder` | 113 gso/s, 49.4K ulp | 112 gso/s, 49.4K ulp | 107 gso/s, 49.4K ulp |
|
|
135
|
+
| `nk_maxsim_packed_f16_icelake` | 154 gso/s, 49.4K ulp | 164 gso/s, 49.4K ulp | 163 gso/s, 49.4K ulp |
|
|
136
|
+
| `nk_maxsim_packed_f16_sapphireamx` | 339 gso/s, 49.5K ulp | 395 gso/s, 49.5K ulp | 381 gso/s, 49.5K ulp |
|
|
137
|
+
|
|
138
|
+
#### WASM
|
|
139
|
+
|
|
140
|
+
Measured with Wasmtime v42 (Cranelift backend).
|
|
141
|
+
|
|
142
|
+
| Kernel | 256³ | 1024³ | 4096³ |
|
|
143
|
+
| :---------------------------------- | -----------------------: | -----------------------: | -----------------------: |
|
|
144
|
+
| __f32__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
145
|
+
| `nk_maxsim_packed_f32_serial` | ? gso/s, 46.8K ulp | ? gso/s, 46.8K ulp | ? gso/s, 46.8K ulp |
|
|
146
|
+
| `nk_maxsim_packed_f32_v128relaxed` | ? gso/s, 1.58M ulp | ? gso/s, 1.58M ulp | ? gso/s, 1.58M ulp |
|
|
147
|
+
| __bf16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
148
|
+
| `nk_maxsim_packed_bf16_serial` | ? gso/s, 47.0K ulp | ? gso/s, 47.0K ulp | ? gso/s, 47.0K ulp |
|
|
149
|
+
| `nk_maxsim_packed_bf16_v128relaxed` | ? gso/s, 1.58M ulp | ? gso/s, 1.58M ulp | ? gso/s, 1.58M ulp |
|
|
150
|
+
| __f16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
151
|
+
| `nk_maxsim_packed_f16_serial` | ? gso/s, 46.4K ulp | ? gso/s, 46.4K ulp | ? gso/s, 46.4K ulp |
|
|
152
|
+
| `nk_maxsim_packed_f16_v128relaxed` | ? gso/s, 1.58M ulp | ? gso/s, 1.58M ulp | ? gso/s, 1.58M ulp |
|
|
153
|
+
|
|
154
|
+
### Apple M4
|
|
155
|
+
|
|
156
|
+
#### Native
|
|
157
|
+
|
|
158
|
+
| Kernel | 256³ | 1024³ | 4096³ |
|
|
159
|
+
| :------------------------------- | -----------------------: | -----------------------: | -----------------------: |
|
|
160
|
+
| __f32__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
161
|
+
| `nk_maxsim_packed_f32_serial` | 124 gso/s, 166K ulp | 136 gso/s, 104K ulp | 130 gso/s, 55.1K ulp |
|
|
162
|
+
| `nk_maxsim_packed_f32_neonsdot` | 170 gso/s, 167K ulp | 240 gso/s, 104K ulp | 167 gso/s, 55.1K ulp |
|
|
163
|
+
| `nk_maxsim_packed_f32_sme` | 291 gso/s, 200K ulp | 1,800 gso/s, 64.6K ulp | ? gso/s, ? ulp |
|
|
164
|
+
| __bf16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
165
|
+
| `nk_maxsim_packed_bf16_serial` | 135 gso/s, 167K ulp | 139 gso/s, 105K ulp | 132 gso/s, 54.8K ulp |
|
|
166
|
+
| `nk_maxsim_packed_bf16_neonsdot` | 192 gso/s, 167K ulp | 257 gso/s, 105K ulp | 161 gso/s, 54.8K ulp |
|
|
167
|
+
| `nk_maxsim_packed_bf16_sme` | 580 gso/s, 16.1K ulp | 1,620 gso/s, 735 ulp | ? gso/s, ? ulp |
|
|
168
|
+
| __f16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
169
|
+
| `nk_maxsim_packed_f16_serial` | 136 gso/s, 169K ulp | 140 gso/s, 104K ulp | 134 gso/s, 55.1K ulp |
|
|
170
|
+
| `nk_maxsim_packed_f16_neonsdot` | 193 gso/s, 166K ulp | 255 gso/s, 104K ulp | 172 gso/s, 55.1K ulp |
|
|
171
|
+
| `nk_maxsim_packed_f16_sme` | 573 gso/s, 16.0K ulp | 1,620 gso/s, 725 ulp | ? gso/s, ? ulp |
|
|
172
|
+
|
|
173
|
+
#### WASM
|
|
174
|
+
|
|
175
|
+
Measured with Wasmtime v42 (Cranelift backend).
|
|
176
|
+
|
|
177
|
+
| Kernel | 256³ | 1024³ | 4096³ |
|
|
178
|
+
| :---------------------------------- | -----------------------: | -----------------------: | -----------------------: |
|
|
179
|
+
| __f32__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
180
|
+
| `nk_maxsim_packed_f32_serial` | 9.22 gso/s, 46.8K ulp | 10.1 gso/s, 46.8K ulp | 10.5 gso/s, 46.8K ulp |
|
|
181
|
+
| `nk_maxsim_packed_f32_v128relaxed` | 28.9 gso/s, 46.0K ulp | 31.2 gso/s, 46.0K ulp | 32.0 gso/s, 46.0K ulp |
|
|
182
|
+
| __bf16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
183
|
+
| `nk_maxsim_packed_bf16_serial` | 8.95 gso/s, 49.2K ulp | 10.1 gso/s, 49.2K ulp | 10.0 gso/s, 49.2K ulp |
|
|
184
|
+
| `nk_maxsim_packed_bf16_v128relaxed` | 29.6 gso/s, 49.4K ulp | 31.9 gso/s, 49.4K ulp | 31.6 gso/s, 49.4K ulp |
|
|
185
|
+
| __f16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
186
|
+
| `nk_maxsim_packed_f16_serial` | 9.21 gso/s, 49.5K ulp | 10.3 gso/s, 49.5K ulp | 10.6 gso/s, 49.5K ulp |
|
|
187
|
+
| `nk_maxsim_packed_f16_v128relaxed` | 27.2 gso/s, 49.3K ulp | 33.7 gso/s, 49.3K ulp | 31.5 gso/s, 49.3K ulp |
|