numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,639 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief C++ bindings for multi-target dot-product kernels.
|
|
3
|
+
* @file include/numkong/dots.hpp
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 5, 2026
|
|
6
|
+
*/
|
|
7
|
+
#ifndef NK_DOTS_HPP
|
|
8
|
+
#define NK_DOTS_HPP
|
|
9
|
+
|
|
10
|
+
#include <bit>
|
|
11
|
+
#include <cstdint>
|
|
12
|
+
#include <cstring>
|
|
13
|
+
#include <limits>
|
|
14
|
+
#include <type_traits>
|
|
15
|
+
|
|
16
|
+
#include "numkong/dot.h"
|
|
17
|
+
#include "numkong/dots.h"
|
|
18
|
+
#include "numkong/sets.h"
|
|
19
|
+
|
|
20
|
+
#include "numkong/types.hpp"
|
|
21
|
+
|
|
22
|
+
namespace ashvardanian::numkong {
|
|
23
|
+
|
|
24
|
+
/**
|
|
25
|
+
* @brief Reference unpacked GEMM: C = A × Bᵀ (row-major A and B, B transposed).
|
|
26
|
+
*
|
|
27
|
+
* This matches BLAS sgemm/dgemm with CblasNoTrans for A and CblasTrans for B.
|
|
28
|
+
* Useful as a reference implementation for validating BLAS/MKL/Accelerate.
|
|
29
|
+
*
|
|
30
|
+
* @param a Matrix A [m x k] row-major
|
|
31
|
+
* @param b Matrix B [n x k] row-major (accessed as Bᵀ)
|
|
32
|
+
* @param c Output matrix C [m x n] row-major
|
|
33
|
+
* @param row_count Rows of A and C (m)
|
|
34
|
+
* @param column_count Rows of B and columns of C (n)
|
|
35
|
+
* @param depth Columns of A and B (k)
|
|
36
|
+
* @param a_stride_in_bytes Stride between rows of A in bytes
|
|
37
|
+
* @param b_stride_in_bytes Stride between rows of B in bytes
|
|
38
|
+
* @param c_stride_in_bytes Stride between rows of C in bytes
|
|
39
|
+
* @tparam in_type_ Input element type (e.g., f32_t, bf16_t)
|
|
40
|
+
* @tparam result_type_ Accumulator/output type (e.g., f32_t, f118_t for high precision)
|
|
41
|
+
*/
|
|
42
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::dot_result_t>
|
|
43
|
+
void dots_unpacked(in_type_ const *a, in_type_ const *b, result_type_ *c, size_t row_count, size_t column_count,
|
|
44
|
+
size_t depth, size_t a_stride_in_bytes, size_t b_stride_in_bytes,
|
|
45
|
+
size_t c_stride_in_bytes) noexcept {
|
|
46
|
+
char const *a_bytes = reinterpret_cast<char const *>(a);
|
|
47
|
+
char const *b_bytes = reinterpret_cast<char const *>(b);
|
|
48
|
+
char *c_bytes = reinterpret_cast<char *>(c);
|
|
49
|
+
std::size_t const depth_values = divide_round_up(depth, dimensions_per_value<in_type_>());
|
|
50
|
+
|
|
51
|
+
for (size_t i = 0; i < row_count; i++) {
|
|
52
|
+
in_type_ const *a_row = reinterpret_cast<in_type_ const *>(a_bytes + i * a_stride_in_bytes);
|
|
53
|
+
result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
|
|
54
|
+
for (size_t j = 0; j < column_count; j++) {
|
|
55
|
+
in_type_ const *b_row = reinterpret_cast<in_type_ const *>(b_bytes + j * b_stride_in_bytes);
|
|
56
|
+
result_type_ sum {};
|
|
57
|
+
for (size_t l = 0; l < depth_values; l++) sum = fma(a_row[l], b_row[l], sum);
|
|
58
|
+
c_row[j] = sum;
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
/**
|
|
64
|
+
* @brief Conjugated unpacked dot products: C = A × Bᴴ (Hermitian inner product, row-major)
|
|
65
|
+
*
|
|
66
|
+
* Same as `dots_unpacked`, but conjugates elements of B before multiplication.
|
|
67
|
+
* For real types this is identical to `dots_unpacked`. For complex types this
|
|
68
|
+
* computes the standard Hermitian inner product matching `cblas_{c,z}gemm` with
|
|
69
|
+
* `CblasConjTrans`.
|
|
70
|
+
*/
|
|
71
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::dot_result_t>
|
|
72
|
+
void dots_unpacked_conjugated(in_type_ const *a, in_type_ const *b, result_type_ *c, size_t row_count,
|
|
73
|
+
size_t column_count, size_t depth, size_t a_stride_in_bytes, size_t b_stride_in_bytes,
|
|
74
|
+
size_t c_stride_in_bytes) noexcept {
|
|
75
|
+
char const *a_bytes = reinterpret_cast<char const *>(a);
|
|
76
|
+
char const *b_bytes = reinterpret_cast<char const *>(b);
|
|
77
|
+
char *c_bytes = reinterpret_cast<char *>(c);
|
|
78
|
+
std::size_t const depth_values = divide_round_up(depth, dimensions_per_value<in_type_>());
|
|
79
|
+
|
|
80
|
+
for (size_t i = 0; i < row_count; i++) {
|
|
81
|
+
in_type_ const *a_row = reinterpret_cast<in_type_ const *>(a_bytes + i * a_stride_in_bytes);
|
|
82
|
+
result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
|
|
83
|
+
for (size_t j = 0; j < column_count; j++) {
|
|
84
|
+
in_type_ const *b_row = reinterpret_cast<in_type_ const *>(b_bytes + j * b_stride_in_bytes);
|
|
85
|
+
result_type_ sum {};
|
|
86
|
+
for (size_t l = 0; l < depth_values; l++) sum = fcma(b_row[l], a_row[l], sum);
|
|
87
|
+
c_row[j] = sum;
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
/**
|
|
93
|
+
* @brief Packed dot products (batch matrix multiply): C = A × B (row-major)
|
|
94
|
+
* @param[in] a Matrix A [m x k]
|
|
95
|
+
* @param[in] b_packed Packed matrix B [k x n] with stride metadata appended
|
|
96
|
+
* @param[out] c Output matrix C [m x n]
|
|
97
|
+
* @param[in] row_count Rows of A and C (m)
|
|
98
|
+
* @param[in] column_count Columns of B and C (n)
|
|
99
|
+
* @param[in] depth Columns of A, Rows of B (k)
|
|
100
|
+
* @param[in] a_stride_in_bytes Stride between rows of A in bytes
|
|
101
|
+
* @param[in] c_stride_in_bytes Stride between rows of C in bytes
|
|
102
|
+
*
|
|
103
|
+
* @tparam in_type_ Input element type
|
|
104
|
+
* @tparam result_type_ Accumulator/output type, defaults to `in_type_::dot_result_t`
|
|
105
|
+
* @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
|
|
106
|
+
*/
|
|
107
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::dot_result_t,
|
|
108
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
109
|
+
void dots_packed(in_type_ const *a, void const *b_packed, result_type_ *c, size_t row_count, size_t column_count,
|
|
110
|
+
size_t depth, size_t a_stride_in_bytes, size_t c_stride_in_bytes) noexcept {
|
|
111
|
+
constexpr bool dispatch = allow_simd_ == prefer_simd_k &&
|
|
112
|
+
std::is_same_v<result_type_, typename in_type_::dot_result_t>;
|
|
113
|
+
if constexpr (std::is_same_v<in_type_, f64_t> && dispatch)
|
|
114
|
+
nk_dots_packed_f64(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
|
|
115
|
+
c_stride_in_bytes);
|
|
116
|
+
else if constexpr (std::is_same_v<in_type_, f32_t> && dispatch)
|
|
117
|
+
nk_dots_packed_f32(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
|
|
118
|
+
c_stride_in_bytes);
|
|
119
|
+
else if constexpr (std::is_same_v<in_type_, f16_t> && dispatch)
|
|
120
|
+
nk_dots_packed_f16(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
|
|
121
|
+
c_stride_in_bytes);
|
|
122
|
+
else if constexpr (std::is_same_v<in_type_, bf16_t> && dispatch)
|
|
123
|
+
nk_dots_packed_bf16(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
|
|
124
|
+
c_stride_in_bytes);
|
|
125
|
+
else if constexpr (std::is_same_v<in_type_, i8_t> && dispatch)
|
|
126
|
+
nk_dots_packed_i8(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
|
|
127
|
+
c_stride_in_bytes);
|
|
128
|
+
else if constexpr (std::is_same_v<in_type_, u8_t> && dispatch)
|
|
129
|
+
nk_dots_packed_u8(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
|
|
130
|
+
c_stride_in_bytes);
|
|
131
|
+
else if constexpr (std::is_same_v<in_type_, e4m3_t> && dispatch)
|
|
132
|
+
nk_dots_packed_e4m3(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
|
|
133
|
+
c_stride_in_bytes);
|
|
134
|
+
else if constexpr (std::is_same_v<in_type_, e5m2_t> && dispatch)
|
|
135
|
+
nk_dots_packed_e5m2(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
|
|
136
|
+
c_stride_in_bytes);
|
|
137
|
+
else if constexpr (std::is_same_v<in_type_, e2m3_t> && dispatch)
|
|
138
|
+
nk_dots_packed_e2m3(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
|
|
139
|
+
c_stride_in_bytes);
|
|
140
|
+
else if constexpr (std::is_same_v<in_type_, e3m2_t> && dispatch)
|
|
141
|
+
nk_dots_packed_e3m2(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
|
|
142
|
+
c_stride_in_bytes);
|
|
143
|
+
else if constexpr (std::is_same_v<in_type_, u4x2_t> && dispatch)
|
|
144
|
+
nk_dots_packed_u4(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
|
|
145
|
+
c_stride_in_bytes);
|
|
146
|
+
else if constexpr (std::is_same_v<in_type_, i4x2_t> && dispatch)
|
|
147
|
+
nk_dots_packed_i4(&a->raw_, b_packed, &c->raw_, row_count, column_count, depth, a_stride_in_bytes,
|
|
148
|
+
c_stride_in_bytes);
|
|
149
|
+
else {
|
|
150
|
+
in_type_ const *b;
|
|
151
|
+
size_t b_stride_in_bytes;
|
|
152
|
+
char const *b_packed_bytes = reinterpret_cast<char const *>(b_packed);
|
|
153
|
+
std::memcpy(&b, b_packed_bytes, sizeof(void *));
|
|
154
|
+
std::memcpy(&b_stride_in_bytes, b_packed_bytes + sizeof(void *), sizeof(size_t));
|
|
155
|
+
dots_unpacked<in_type_, result_type_>(a, b, c, row_count, column_count, depth, a_stride_in_bytes,
|
|
156
|
+
b_stride_in_bytes, c_stride_in_bytes);
|
|
157
|
+
}
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
/**
|
|
161
|
+
* @brief Symmetric dot products: C = A × Aᵀ where C[i,j] = ⟨A[i], A[j]⟩
|
|
162
|
+
* @param[in] a Matrix A [n x k] (n vectors of dimension k)
|
|
163
|
+
* @param[in] n_vectors Number of vectors (n)
|
|
164
|
+
* @param[in] depth Dimension of each vector (k)
|
|
165
|
+
* @param[in] a_stride_in_bytes Stride between vectors in A
|
|
166
|
+
* @param[out] c Output matrix C [n x n]
|
|
167
|
+
* @param[in] c_stride_in_bytes Stride between rows of C in bytes
|
|
168
|
+
*
|
|
169
|
+
* @tparam in_type_ Input element type
|
|
170
|
+
* @tparam result_type_ Accumulator/output type, defaults to `in_type_::dot_result_t`
|
|
171
|
+
* @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
|
|
172
|
+
*/
|
|
173
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::dot_result_t,
|
|
174
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
175
|
+
void dots_symmetric(in_type_ const *a, std::size_t n_vectors, std::size_t depth, std::size_t a_stride_in_bytes,
|
|
176
|
+
result_type_ *c, std::size_t c_stride_in_bytes, std::size_t row_start = 0,
|
|
177
|
+
std::size_t row_count = std::numeric_limits<std::size_t>::max()) noexcept {
|
|
178
|
+
if (row_count == std::numeric_limits<std::size_t>::max()) row_count = n_vectors;
|
|
179
|
+
constexpr bool dispatch = allow_simd_ == prefer_simd_k &&
|
|
180
|
+
std::is_same_v<result_type_, typename in_type_::dot_result_t>;
|
|
181
|
+
|
|
182
|
+
if constexpr (std::is_same_v<in_type_, f64_t> && dispatch)
|
|
183
|
+
nk_dots_symmetric_f64(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
|
|
184
|
+
row_count);
|
|
185
|
+
else if constexpr (std::is_same_v<in_type_, f32_t> && dispatch)
|
|
186
|
+
nk_dots_symmetric_f32(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
|
|
187
|
+
row_count);
|
|
188
|
+
else if constexpr (std::is_same_v<in_type_, f16_t> && dispatch)
|
|
189
|
+
nk_dots_symmetric_f16(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
|
|
190
|
+
row_count);
|
|
191
|
+
else if constexpr (std::is_same_v<in_type_, bf16_t> && dispatch)
|
|
192
|
+
nk_dots_symmetric_bf16(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
|
|
193
|
+
row_count);
|
|
194
|
+
else if constexpr (std::is_same_v<in_type_, i8_t> && dispatch)
|
|
195
|
+
nk_dots_symmetric_i8(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
|
|
196
|
+
row_count);
|
|
197
|
+
else if constexpr (std::is_same_v<in_type_, u8_t> && dispatch)
|
|
198
|
+
nk_dots_symmetric_u8(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
|
|
199
|
+
row_count);
|
|
200
|
+
else if constexpr (std::is_same_v<in_type_, e4m3_t> && dispatch)
|
|
201
|
+
nk_dots_symmetric_e4m3(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
|
|
202
|
+
row_count);
|
|
203
|
+
else if constexpr (std::is_same_v<in_type_, e5m2_t> && dispatch)
|
|
204
|
+
nk_dots_symmetric_e5m2(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
|
|
205
|
+
row_count);
|
|
206
|
+
else if constexpr (std::is_same_v<in_type_, e2m3_t> && dispatch)
|
|
207
|
+
nk_dots_symmetric_e2m3(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
|
|
208
|
+
row_count);
|
|
209
|
+
else if constexpr (std::is_same_v<in_type_, e3m2_t> && dispatch)
|
|
210
|
+
nk_dots_symmetric_e3m2(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
|
|
211
|
+
row_count);
|
|
212
|
+
else if constexpr (std::is_same_v<in_type_, u4x2_t> && dispatch)
|
|
213
|
+
nk_dots_symmetric_u4(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
|
|
214
|
+
row_count);
|
|
215
|
+
else if constexpr (std::is_same_v<in_type_, i4x2_t> && dispatch)
|
|
216
|
+
nk_dots_symmetric_i4(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
|
|
217
|
+
row_count);
|
|
218
|
+
else {
|
|
219
|
+
std::size_t depth_values = divide_round_up(depth, dimensions_per_value<in_type_>());
|
|
220
|
+
char const *a_bytes = reinterpret_cast<char const *>(a);
|
|
221
|
+
char *c_bytes = reinterpret_cast<char *>(c);
|
|
222
|
+
std::size_t row_end = row_start + row_count < n_vectors ? row_start + row_count : n_vectors;
|
|
223
|
+
|
|
224
|
+
for (std::size_t i = row_start; i < row_end; i++) {
|
|
225
|
+
in_type_ const *a_i = reinterpret_cast<in_type_ const *>(a_bytes + i * a_stride_in_bytes);
|
|
226
|
+
result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
|
|
227
|
+
for (std::size_t j = 0; j < n_vectors; j++) {
|
|
228
|
+
in_type_ const *a_j = reinterpret_cast<in_type_ const *>(a_bytes + j * a_stride_in_bytes);
|
|
229
|
+
result_type_ sum {};
|
|
230
|
+
for (std::size_t l = 0; l < depth_values; l++) sum = fma(a_i[l], a_j[l], sum);
|
|
231
|
+
c_row[j] = sum;
|
|
232
|
+
}
|
|
233
|
+
}
|
|
234
|
+
}
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
/**
|
|
238
|
+
* @brief Symmetric Hamming distance matrix: C[i,j] = hamming(A[i], A[j])
|
|
239
|
+
* @param[in] a Input matrix (n_vectors x depth)
|
|
240
|
+
* @param[in] n_vectors Number of vectors
|
|
241
|
+
* @param[in] depth Number of dimensions per vector
|
|
242
|
+
* @param[in] a_stride_in_bytes Row stride in bytes
|
|
243
|
+
* @param[out] c Output matrix (n_vectors x n_vectors)
|
|
244
|
+
* @param[in] c_stride_in_bytes Output row stride in bytes
|
|
245
|
+
* @param[in] row_start Starting row index (default 0)
|
|
246
|
+
* @param[in] row_count Number of rows to compute (default all)
|
|
247
|
+
*
|
|
248
|
+
* Computes Hamming distances between all pairs of binary vectors.
|
|
249
|
+
* For u1x8_t inputs, distances are exact bit counts (u32_t outputs).
|
|
250
|
+
*
|
|
251
|
+
* @tparam in_type_ Input element type (u1x8_t)
|
|
252
|
+
* @tparam result_type_ Output type (u32_t for Hamming distances)
|
|
253
|
+
* @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
|
|
254
|
+
*/
|
|
255
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::hamming_result_t,
|
|
256
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
257
|
+
void hammings_symmetric(in_type_ const *a, std::size_t n_vectors, std::size_t depth, std::size_t a_stride_in_bytes,
|
|
258
|
+
result_type_ *c, std::size_t c_stride_in_bytes, std::size_t row_start = 0,
|
|
259
|
+
std::size_t row_count = std::numeric_limits<std::size_t>::max()) noexcept {
|
|
260
|
+
if (row_count == std::numeric_limits<std::size_t>::max()) row_count = n_vectors;
|
|
261
|
+
constexpr bool dispatch = allow_simd_ == prefer_simd_k &&
|
|
262
|
+
std::is_same_v<result_type_, typename in_type_::hamming_result_t>;
|
|
263
|
+
|
|
264
|
+
if constexpr (std::is_same_v<in_type_, u1x8_t> && dispatch)
|
|
265
|
+
nk_hammings_symmetric_u1(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
|
|
266
|
+
row_count);
|
|
267
|
+
else {
|
|
268
|
+
using raw_t = typename in_type_::raw_t;
|
|
269
|
+
std::size_t depth_bytes = divide_round_up(depth, 8);
|
|
270
|
+
char const *a_bytes = reinterpret_cast<char const *>(a);
|
|
271
|
+
char *c_bytes = reinterpret_cast<char *>(c);
|
|
272
|
+
std::size_t row_end = row_start + row_count < n_vectors ? row_start + row_count : n_vectors;
|
|
273
|
+
|
|
274
|
+
for (std::size_t i = row_start; i < row_end; i++) {
|
|
275
|
+
raw_t const *a_i = reinterpret_cast<raw_t const *>(a_bytes + i * a_stride_in_bytes);
|
|
276
|
+
result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
|
|
277
|
+
|
|
278
|
+
for (std::size_t j = 0; j < n_vectors; j++) {
|
|
279
|
+
raw_t const *a_j = reinterpret_cast<raw_t const *>(a_bytes + j * a_stride_in_bytes);
|
|
280
|
+
typename result_type_::raw_t distance = 0;
|
|
281
|
+
for (std::size_t b = 0; b < depth_bytes; b++) {
|
|
282
|
+
auto xor_val = a_i[b] ^ a_j[b];
|
|
283
|
+
distance += std::popcount(static_cast<unsigned>(xor_val));
|
|
284
|
+
}
|
|
285
|
+
c_row[j] = result_type_::from_raw(distance);
|
|
286
|
+
}
|
|
287
|
+
}
|
|
288
|
+
}
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
/**
|
|
292
|
+
* @brief Computes Hamming distances between rows of A and columns of packed B.
|
|
293
|
+
* @param[in] a Pointer to the first matrix (m x k).
|
|
294
|
+
* @param[in] b_packed Pointer to the packed second matrix (k x n).
|
|
295
|
+
* @param[out] c Pointer to the output matrix (m x n).
|
|
296
|
+
* @param[in] row_count Number of rows in A (m).
|
|
297
|
+
* @param[in] column_count Number of columns in B (n).
|
|
298
|
+
* @param[in] depth Depth dimension in bits (k).
|
|
299
|
+
* @param[in] a_stride_in_bytes Stride between consecutive rows of A in bytes.
|
|
300
|
+
* @param[in] c_stride_in_bytes Stride between consecutive rows of C in bytes.
|
|
301
|
+
*
|
|
302
|
+
* Computes Hamming distances between binary vectors using optimized packed format.
|
|
303
|
+
* For u1x8_t inputs, distances are exact bit counts (u32_t outputs).
|
|
304
|
+
*
|
|
305
|
+
* @tparam in_type_ Input element type (u1x8_t)
|
|
306
|
+
* @tparam result_type_ Output type (u32_t for Hamming distances)
|
|
307
|
+
* @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
|
|
308
|
+
*/
|
|
309
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::hamming_result_t,
|
|
310
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
311
|
+
void hammings_packed(in_type_ const *a, void const *b_packed, result_type_ *c, std::size_t row_count,
|
|
312
|
+
std::size_t column_count, std::size_t depth, std::size_t a_stride_in_bytes = 0,
|
|
313
|
+
std::size_t c_stride_in_bytes = 0) noexcept {
|
|
314
|
+
// Compute default strides
|
|
315
|
+
if (!a_stride_in_bytes) a_stride_in_bytes = divide_round_up(depth, 8) * sizeof(in_type_);
|
|
316
|
+
if (!c_stride_in_bytes) c_stride_in_bytes = column_count * sizeof(result_type_);
|
|
317
|
+
|
|
318
|
+
constexpr bool dispatch = allow_simd_ == prefer_simd_k &&
|
|
319
|
+
std::is_same_v<result_type_, typename in_type_::hamming_result_t>;
|
|
320
|
+
|
|
321
|
+
if constexpr (std::is_same_v<in_type_, u1x8_t> && dispatch) {
|
|
322
|
+
nk_hammings_packed_u1(reinterpret_cast<nk_u1x8_t const *>(a), b_packed, reinterpret_cast<nk_u32_t *>(c),
|
|
323
|
+
row_count, column_count, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
324
|
+
}
|
|
325
|
+
else {
|
|
326
|
+
// Scalar fallback: extract pointer and stride from b_packed, then compute directly
|
|
327
|
+
in_type_ const *b;
|
|
328
|
+
size_t b_stride_in_bytes;
|
|
329
|
+
char const *b_packed_bytes = reinterpret_cast<char const *>(b_packed);
|
|
330
|
+
std::memcpy(&b, b_packed_bytes, sizeof(void *));
|
|
331
|
+
std::memcpy(&b_stride_in_bytes, b_packed_bytes + sizeof(void *), sizeof(size_t));
|
|
332
|
+
|
|
333
|
+
// Compute Hamming distances using unpacked matrices
|
|
334
|
+
char const *a_bytes = reinterpret_cast<char const *>(a);
|
|
335
|
+
char const *b_bytes = reinterpret_cast<char const *>(b);
|
|
336
|
+
char *c_bytes = reinterpret_cast<char *>(c);
|
|
337
|
+
std::size_t depth_bytes = divide_round_up(depth, 8);
|
|
338
|
+
|
|
339
|
+
for (std::size_t i = 0; i < row_count; i++) {
|
|
340
|
+
typename in_type_::raw_t const *a_row = reinterpret_cast<typename in_type_::raw_t const *>(
|
|
341
|
+
a_bytes + i * a_stride_in_bytes);
|
|
342
|
+
result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
|
|
343
|
+
|
|
344
|
+
for (std::size_t j = 0; j < column_count; j++) {
|
|
345
|
+
typename in_type_::raw_t const *b_row = reinterpret_cast<typename in_type_::raw_t const *>(
|
|
346
|
+
b_bytes + j * b_stride_in_bytes);
|
|
347
|
+
|
|
348
|
+
// Compute Hamming distance: XOR then popcount
|
|
349
|
+
typename result_type_::raw_t distance = 0;
|
|
350
|
+
for (std::size_t byte_idx = 0; byte_idx < depth_bytes; byte_idx++) {
|
|
351
|
+
auto xor_val = a_row[byte_idx] ^ b_row[byte_idx];
|
|
352
|
+
distance += std::popcount(static_cast<unsigned>(xor_val));
|
|
353
|
+
}
|
|
354
|
+
c_row[j] = result_type_::from_raw(distance);
|
|
355
|
+
}
|
|
356
|
+
}
|
|
357
|
+
}
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
/**
|
|
361
|
+
* @brief Symmetric Jaccard distance matrix: C[i,j] = jaccard(A[i], A[j])
|
|
362
|
+
*/
|
|
363
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::jaccard_result_t,
|
|
364
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
365
|
+
void jaccards_symmetric(in_type_ const *a, std::size_t n_vectors, std::size_t depth, std::size_t a_stride_in_bytes,
|
|
366
|
+
result_type_ *c, std::size_t c_stride_in_bytes, std::size_t row_start = 0,
|
|
367
|
+
std::size_t row_count = std::numeric_limits<std::size_t>::max()) noexcept {
|
|
368
|
+
if (row_count == std::numeric_limits<std::size_t>::max()) row_count = n_vectors;
|
|
369
|
+
constexpr bool dispatch = allow_simd_ == prefer_simd_k &&
|
|
370
|
+
std::is_same_v<result_type_, typename in_type_::jaccard_result_t>;
|
|
371
|
+
|
|
372
|
+
if constexpr (std::is_same_v<in_type_, u1x8_t> && dispatch)
|
|
373
|
+
nk_jaccards_symmetric_u1(&a->raw_, n_vectors, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
|
|
374
|
+
row_count);
|
|
375
|
+
else {
|
|
376
|
+
using raw_t = typename in_type_::raw_t;
|
|
377
|
+
std::size_t depth_bytes = divide_round_up(depth, 8);
|
|
378
|
+
char const *a_bytes = reinterpret_cast<char const *>(a);
|
|
379
|
+
char *c_bytes = reinterpret_cast<char *>(c);
|
|
380
|
+
std::size_t row_end = row_start + row_count < n_vectors ? row_start + row_count : n_vectors;
|
|
381
|
+
|
|
382
|
+
for (std::size_t i = row_start; i < row_end; i++) {
|
|
383
|
+
raw_t const *a_i = reinterpret_cast<raw_t const *>(a_bytes + i * a_stride_in_bytes);
|
|
384
|
+
result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
|
|
385
|
+
|
|
386
|
+
for (std::size_t j = 0; j < n_vectors; j++) {
|
|
387
|
+
raw_t const *a_j = reinterpret_cast<raw_t const *>(a_bytes + j * a_stride_in_bytes);
|
|
388
|
+
unsigned intersection = 0, union_ = 0;
|
|
389
|
+
for (std::size_t b = 0; b < depth_bytes; b++) {
|
|
390
|
+
intersection += std::popcount(static_cast<unsigned>(a_i[b] & a_j[b]));
|
|
391
|
+
union_ += std::popcount(static_cast<unsigned>(a_i[b] | a_j[b]));
|
|
392
|
+
}
|
|
393
|
+
c_row[j] = result_type_::from_raw(union_ ? 1.0f - static_cast<float>(intersection) / union_ : 0.0f);
|
|
394
|
+
}
|
|
395
|
+
}
|
|
396
|
+
}
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
/**
|
|
400
|
+
* @brief Computes Jaccard distances between rows of A and columns of packed B.
|
|
401
|
+
*/
|
|
402
|
+
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::jaccard_result_t,
|
|
403
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
404
|
+
void jaccards_packed(in_type_ const *a, void const *b_packed, result_type_ *c, std::size_t row_count,
|
|
405
|
+
std::size_t column_count, std::size_t depth, std::size_t a_stride_in_bytes = 0,
|
|
406
|
+
std::size_t c_stride_in_bytes = 0) noexcept {
|
|
407
|
+
if (!a_stride_in_bytes) a_stride_in_bytes = divide_round_up(depth, 8) * sizeof(in_type_);
|
|
408
|
+
if (!c_stride_in_bytes) c_stride_in_bytes = column_count * sizeof(result_type_);
|
|
409
|
+
|
|
410
|
+
constexpr bool dispatch = allow_simd_ == prefer_simd_k &&
|
|
411
|
+
std::is_same_v<result_type_, typename in_type_::jaccard_result_t>;
|
|
412
|
+
|
|
413
|
+
if constexpr (std::is_same_v<in_type_, u1x8_t> && dispatch) {
|
|
414
|
+
nk_jaccards_packed_u1(reinterpret_cast<nk_u1x8_t const *>(a), b_packed, reinterpret_cast<nk_f32_t *>(c),
|
|
415
|
+
row_count, column_count, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
416
|
+
}
|
|
417
|
+
else {
|
|
418
|
+
// Scalar fallback: extract pointer and stride from b_packed, then compute directly
|
|
419
|
+
in_type_ const *b;
|
|
420
|
+
size_t b_stride_in_bytes;
|
|
421
|
+
char const *b_packed_bytes = reinterpret_cast<char const *>(b_packed);
|
|
422
|
+
std::memcpy(&b, b_packed_bytes, sizeof(void *));
|
|
423
|
+
std::memcpy(&b_stride_in_bytes, b_packed_bytes + sizeof(void *), sizeof(size_t));
|
|
424
|
+
|
|
425
|
+
char const *a_bytes = reinterpret_cast<char const *>(a);
|
|
426
|
+
char const *b_bytes = reinterpret_cast<char const *>(b);
|
|
427
|
+
char *c_bytes = reinterpret_cast<char *>(c);
|
|
428
|
+
std::size_t depth_bytes = divide_round_up(depth, 8);
|
|
429
|
+
|
|
430
|
+
for (std::size_t i = 0; i < row_count; i++) {
|
|
431
|
+
typename in_type_::raw_t const *a_row = reinterpret_cast<typename in_type_::raw_t const *>(
|
|
432
|
+
a_bytes + i * a_stride_in_bytes);
|
|
433
|
+
result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
|
|
434
|
+
|
|
435
|
+
for (std::size_t j = 0; j < column_count; j++) {
|
|
436
|
+
typename in_type_::raw_t const *b_row = reinterpret_cast<typename in_type_::raw_t const *>(
|
|
437
|
+
b_bytes + j * b_stride_in_bytes);
|
|
438
|
+
unsigned intersection = 0, union_ = 0;
|
|
439
|
+
for (std::size_t byte_idx = 0; byte_idx < depth_bytes; byte_idx++) {
|
|
440
|
+
intersection += std::popcount(static_cast<unsigned>(a_row[byte_idx] & b_row[byte_idx]));
|
|
441
|
+
union_ += std::popcount(static_cast<unsigned>(a_row[byte_idx] | b_row[byte_idx]));
|
|
442
|
+
}
|
|
443
|
+
c_row[j] = result_type_::from_raw(union_ ? 1.0f - static_cast<float>(intersection) / union_ : 0.0f);
|
|
444
|
+
}
|
|
445
|
+
}
|
|
446
|
+
}
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
} // namespace ashvardanian::numkong
|
|
450
|
+
|
|
451
|
+
#include "numkong/tensor.hpp"
|
|
452
|
+
|
|
453
|
+
namespace ashvardanian::numkong {
|
|
454
|
+
|
|
455
|
+
#pragma region - Concept-Constrained Symmetric Dot Products
|
|
456
|
+
|
|
457
|
+
/** @brief C = A × Aᵀ where C[i,j] = ⟨A[i], A[j]⟩. */
|
|
458
|
+
template <numeric_dtype value_type_, const_matrix_of<value_type_> input_matrix_,
|
|
459
|
+
mutable_matrix_of<typename value_type_::dot_result_t> output_matrix_>
|
|
460
|
+
bool dots_symmetric(input_matrix_ const &input, output_matrix_ &&output) noexcept {
|
|
461
|
+
std::size_t num_vectors = input.extent(0);
|
|
462
|
+
if (output.extent(0) != num_vectors || output.extent(1) != num_vectors) return false;
|
|
463
|
+
numkong::dots_symmetric<value_type_>(input.data(), num_vectors, input.extent(1),
|
|
464
|
+
static_cast<std::size_t>(input.stride_bytes(0)), output.data(),
|
|
465
|
+
static_cast<std::size_t>(output.stride_bytes(0)));
|
|
466
|
+
return true;
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
/** @brief Partitioned symmetric dot products for parallel row-range work. */
|
|
470
|
+
template <numeric_dtype value_type_, const_matrix_of<value_type_> input_matrix_,
|
|
471
|
+
mutable_matrix_of<typename value_type_::dot_result_t> output_matrix_>
|
|
472
|
+
bool dots_symmetric(input_matrix_ const &input, output_matrix_ output, std::size_t row_start,
|
|
473
|
+
std::size_t row_count) noexcept {
|
|
474
|
+
std::size_t num_vectors = input.extent(0);
|
|
475
|
+
if (output.extent(0) != num_vectors || output.extent(1) != num_vectors) return false;
|
|
476
|
+
numkong::dots_symmetric<value_type_>(input.data(), num_vectors, input.extent(1),
|
|
477
|
+
static_cast<std::size_t>(input.stride_bytes(0)), output.data(),
|
|
478
|
+
static_cast<std::size_t>(output.stride_bytes(0)), row_start, row_count);
|
|
479
|
+
return true;
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
/** @brief Allocating symmetric dot products: C = A × Aᵀ. */
|
|
483
|
+
template <numeric_dtype value_type_, const_matrix_of<value_type_> input_matrix_,
|
|
484
|
+
typename allocator_type_ = aligned_allocator<typename value_type_::dot_result_t>>
|
|
485
|
+
matrix<typename value_type_::dot_result_t, allocator_type_> try_dots_symmetric(input_matrix_ const &input) noexcept {
|
|
486
|
+
using result_t = typename value_type_::dot_result_t;
|
|
487
|
+
using out_tensor_t = matrix<result_t, allocator_type_>;
|
|
488
|
+
if (input.empty()) return out_tensor_t {};
|
|
489
|
+
std::size_t num_vectors = input.extent(0);
|
|
490
|
+
auto result = out_tensor_t::try_zeros({num_vectors, num_vectors});
|
|
491
|
+
if (result.empty()) return result;
|
|
492
|
+
if (!dots_symmetric<value_type_>(input, result.span())) return out_tensor_t {};
|
|
493
|
+
return result;
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
/** @brief Symmetric Hamming distances: C[i,j] = hamming(A[i], A[j]). */
|
|
497
|
+
template <numeric_dtype value_type_, const_matrix_of<value_type_> input_matrix_,
|
|
498
|
+
mutable_matrix_of<typename value_type_::hamming_result_t> output_matrix_>
|
|
499
|
+
bool hammings_symmetric(input_matrix_ const &input, output_matrix_ &&output) noexcept {
|
|
500
|
+
std::size_t num_vectors = input.extent(0);
|
|
501
|
+
if (output.extent(0) != num_vectors || output.extent(1) != num_vectors) return false;
|
|
502
|
+
numkong::hammings_symmetric<value_type_>(input.data(), num_vectors, input.extent(1),
|
|
503
|
+
static_cast<std::size_t>(input.stride_bytes(0)), output.data(),
|
|
504
|
+
static_cast<std::size_t>(output.stride_bytes(0)));
|
|
505
|
+
return true;
|
|
506
|
+
}
|
|
507
|
+
|
|
508
|
+
/** @brief Allocating symmetric Hamming distances. */
|
|
509
|
+
template <numeric_dtype value_type_, const_matrix_of<value_type_> input_matrix_,
|
|
510
|
+
typename allocator_type_ = aligned_allocator<typename value_type_::hamming_result_t>>
|
|
511
|
+
matrix<typename value_type_::hamming_result_t, allocator_type_> try_hammings_symmetric(
|
|
512
|
+
input_matrix_ const &input) noexcept {
|
|
513
|
+
using result_t = typename value_type_::hamming_result_t;
|
|
514
|
+
using out_tensor_t = matrix<result_t, allocator_type_>;
|
|
515
|
+
if (input.empty()) return out_tensor_t {};
|
|
516
|
+
std::size_t num_vectors = input.extent(0);
|
|
517
|
+
auto result = out_tensor_t::try_zeros({num_vectors, num_vectors});
|
|
518
|
+
if (result.empty()) return result;
|
|
519
|
+
if (!hammings_symmetric<value_type_>(input, result.span())) return out_tensor_t {};
|
|
520
|
+
return result;
|
|
521
|
+
}
|
|
522
|
+
|
|
523
|
+
/** @brief Symmetric Jaccard distances: C[i,j] = jaccard(A[i], A[j]). */
|
|
524
|
+
template <numeric_dtype value_type_, const_matrix_of<value_type_> input_matrix_,
|
|
525
|
+
mutable_matrix_of<typename value_type_::jaccard_result_t> output_matrix_>
|
|
526
|
+
bool jaccards_symmetric(input_matrix_ const &input, output_matrix_ &&output) noexcept {
|
|
527
|
+
std::size_t num_vectors = input.extent(0);
|
|
528
|
+
if (output.extent(0) != num_vectors || output.extent(1) != num_vectors) return false;
|
|
529
|
+
numkong::jaccards_symmetric<value_type_>(input.data(), num_vectors, input.extent(1),
|
|
530
|
+
static_cast<std::size_t>(input.stride_bytes(0)), output.data(),
|
|
531
|
+
static_cast<std::size_t>(output.stride_bytes(0)));
|
|
532
|
+
return true;
|
|
533
|
+
}
|
|
534
|
+
|
|
535
|
+
/** @brief Allocating symmetric Jaccard distances. */
|
|
536
|
+
template <numeric_dtype value_type_, const_matrix_of<value_type_> input_matrix_,
|
|
537
|
+
typename allocator_type_ = aligned_allocator<typename value_type_::jaccard_result_t>>
|
|
538
|
+
matrix<typename value_type_::jaccard_result_t, allocator_type_> try_jaccards_symmetric(
|
|
539
|
+
input_matrix_ const &input) noexcept {
|
|
540
|
+
using result_t = typename value_type_::jaccard_result_t;
|
|
541
|
+
using out_tensor_t = matrix<result_t, allocator_type_>;
|
|
542
|
+
if (input.empty()) return out_tensor_t {};
|
|
543
|
+
std::size_t num_vectors = input.extent(0);
|
|
544
|
+
auto result = out_tensor_t::try_zeros({num_vectors, num_vectors});
|
|
545
|
+
if (result.empty()) return result;
|
|
546
|
+
if (!jaccards_symmetric<value_type_>(input, result.span())) return out_tensor_t {};
|
|
547
|
+
return result;
|
|
548
|
+
}
|
|
549
|
+
|
|
550
|
+
#pragma endregion - Concept - Constrained Symmetric Dot Products
|
|
551
|
+
|
|
552
|
+
#pragma region - Concept-Constrained Packed Dot Products
|
|
553
|
+
|
|
554
|
+
/** @brief Packed dot products: C = A × B_packedᵀ. */
|
|
555
|
+
template <numeric_dtype value_type_, packed_matrix_like packed_type_, const_matrix_of<value_type_> input_matrix_,
|
|
556
|
+
mutable_matrix_of<typename value_type_::dot_result_t> output_matrix_>
|
|
557
|
+
bool dots_packed(input_matrix_ const &a, packed_type_ const &packed_b, output_matrix_ &&c) noexcept {
|
|
558
|
+
if (packed_b.empty() || a.rank() < 2 || c.rank() < 2) return false;
|
|
559
|
+
if (a.extent(1) != packed_b.depth()) return false;
|
|
560
|
+
if (c.extent(0) != a.extent(0) || c.extent(1) != packed_b.rows()) return false;
|
|
561
|
+
numkong::dots_packed<value_type_>(a.data(), packed_b.data(), c.data(), a.extent(0), packed_b.rows(),
|
|
562
|
+
packed_b.depth(), static_cast<std::size_t>(a.stride_bytes(0)),
|
|
563
|
+
static_cast<std::size_t>(c.stride_bytes(0)));
|
|
564
|
+
return true;
|
|
565
|
+
}
|
|
566
|
+
|
|
567
|
+
/** @brief Allocating packed dot products: C = A × B_packedᵀ. */
|
|
568
|
+
template <numeric_dtype value_type_, packed_matrix_like packed_type_, const_matrix_of<value_type_> input_matrix_,
|
|
569
|
+
typename allocator_type_ = aligned_allocator<typename value_type_::dot_result_t>>
|
|
570
|
+
matrix<typename value_type_::dot_result_t, allocator_type_> try_dots_packed(input_matrix_ const &a,
|
|
571
|
+
packed_type_ const &packed_b) noexcept {
|
|
572
|
+
using result_t = typename value_type_::dot_result_t;
|
|
573
|
+
using out_t = matrix<result_t, allocator_type_>;
|
|
574
|
+
if (packed_b.empty() || a.rank() < 2) return out_t {};
|
|
575
|
+
auto c = out_t::try_empty({a.extent(0), packed_b.rows()});
|
|
576
|
+
if (c.empty()) return c;
|
|
577
|
+
if (!dots_packed<value_type_>(a, packed_b, c.as_matrix_span())) return out_t {};
|
|
578
|
+
return c;
|
|
579
|
+
}
|
|
580
|
+
|
|
581
|
+
/** @brief Packed Hamming distances: C = hamming(A, B_packed). */
|
|
582
|
+
template <numeric_dtype value_type_, packed_matrix_like packed_type_, const_matrix_of<value_type_> input_matrix_,
|
|
583
|
+
mutable_matrix_of<typename value_type_::hamming_result_t> output_matrix_>
|
|
584
|
+
bool hammings_packed(input_matrix_ const &a, packed_type_ const &packed_b, output_matrix_ &&c) noexcept {
|
|
585
|
+
if (packed_b.empty() || a.rank() < 2 || c.rank() < 2) return false;
|
|
586
|
+
if (a.extent(1) != packed_b.depth()) return false;
|
|
587
|
+
if (c.extent(0) != a.extent(0) || c.extent(1) != packed_b.rows()) return false;
|
|
588
|
+
numkong::hammings_packed<value_type_>(a.data(), packed_b.data(), c.data(), a.extent(0), packed_b.rows(),
|
|
589
|
+
packed_b.depth(), static_cast<std::size_t>(a.stride_bytes(0)),
|
|
590
|
+
static_cast<std::size_t>(c.stride_bytes(0)));
|
|
591
|
+
return true;
|
|
592
|
+
}
|
|
593
|
+
|
|
594
|
+
/** @brief Allocating packed Hamming distances. */
|
|
595
|
+
template <numeric_dtype value_type_, packed_matrix_like packed_type_, const_matrix_of<value_type_> input_matrix_,
|
|
596
|
+
typename allocator_type_ = aligned_allocator<typename value_type_::hamming_result_t>>
|
|
597
|
+
matrix<typename value_type_::hamming_result_t, allocator_type_> try_hammings_packed(
|
|
598
|
+
input_matrix_ const &a, packed_type_ const &packed_b) noexcept {
|
|
599
|
+
using result_t = typename value_type_::hamming_result_t;
|
|
600
|
+
using out_t = matrix<result_t, allocator_type_>;
|
|
601
|
+
if (packed_b.empty() || a.rank() < 2) return out_t {};
|
|
602
|
+
auto c = out_t::try_empty({a.extent(0), packed_b.rows()});
|
|
603
|
+
if (c.empty()) return c;
|
|
604
|
+
if (!hammings_packed<value_type_>(a, packed_b, c.as_matrix_span())) return out_t {};
|
|
605
|
+
return c;
|
|
606
|
+
}
|
|
607
|
+
|
|
608
|
+
/** @brief Packed Jaccard distances: C = jaccard(A, B_packed). */
|
|
609
|
+
template <numeric_dtype value_type_, packed_matrix_like packed_type_, const_matrix_of<value_type_> input_matrix_,
|
|
610
|
+
mutable_matrix_of<typename value_type_::jaccard_result_t> output_matrix_>
|
|
611
|
+
bool jaccards_packed(input_matrix_ const &a, packed_type_ const &packed_b, output_matrix_ &&c) noexcept {
|
|
612
|
+
if (packed_b.empty() || a.rank() < 2 || c.rank() < 2) return false;
|
|
613
|
+
if (a.extent(1) != packed_b.depth()) return false;
|
|
614
|
+
if (c.extent(0) != a.extent(0) || c.extent(1) != packed_b.rows()) return false;
|
|
615
|
+
numkong::jaccards_packed<value_type_>(a.data(), packed_b.data(), c.data(), a.extent(0), packed_b.rows(),
|
|
616
|
+
packed_b.depth(), static_cast<std::size_t>(a.stride_bytes(0)),
|
|
617
|
+
static_cast<std::size_t>(c.stride_bytes(0)));
|
|
618
|
+
return true;
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
/** @brief Allocating packed Jaccard distances. */
|
|
622
|
+
template <numeric_dtype value_type_, packed_matrix_like packed_type_, const_matrix_of<value_type_> input_matrix_,
|
|
623
|
+
typename allocator_type_ = aligned_allocator<typename value_type_::jaccard_result_t>>
|
|
624
|
+
matrix<typename value_type_::jaccard_result_t, allocator_type_> try_jaccards_packed(
|
|
625
|
+
input_matrix_ const &a, packed_type_ const &packed_b) noexcept {
|
|
626
|
+
using result_t = typename value_type_::jaccard_result_t;
|
|
627
|
+
using out_t = matrix<result_t, allocator_type_>;
|
|
628
|
+
if (packed_b.empty() || a.rank() < 2) return out_t {};
|
|
629
|
+
auto c = out_t::try_empty({a.extent(0), packed_b.rows()});
|
|
630
|
+
if (c.empty()) return c;
|
|
631
|
+
if (!jaccards_packed<value_type_>(a, packed_b, c.as_matrix_span())) return out_t {};
|
|
632
|
+
return c;
|
|
633
|
+
}
|
|
634
|
+
|
|
635
|
+
#pragma endregion - Concept - Constrained Packed Dot Products
|
|
636
|
+
|
|
637
|
+
} // namespace ashvardanian::numkong
|
|
638
|
+
|
|
639
|
+
#endif // NK_DOTS_HPP
|