numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -19,8 +19,8 @@ namespace ashvardanian::numkong {
|
|
|
19
19
|
|
|
20
20
|
/**
|
|
21
21
|
* @brief Symmetric angular distance matrix: C[i,j] = angular(A[i], A[j])
|
|
22
|
-
* @param[in] a Matrix A [
|
|
23
|
-
* @param[in]
|
|
22
|
+
* @param[in] a Matrix A [vectors_count x depth]
|
|
23
|
+
* @param[in] vectors_count Number of vectors (n)
|
|
24
24
|
* @param[in] depth Dimension of each vector (k)
|
|
25
25
|
* @param[in] a_stride_in_bytes Stride between vectors in A
|
|
26
26
|
* @param[out] c Output matrix C [n x n]
|
|
@@ -34,59 +34,59 @@ namespace ashvardanian::numkong {
|
|
|
34
34
|
*/
|
|
35
35
|
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::angular_result_t,
|
|
36
36
|
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
37
|
-
void angulars_symmetric(in_type_ const *a, std::size_t
|
|
37
|
+
void angulars_symmetric(in_type_ const *a, std::size_t vectors_count, std::size_t depth, std::size_t a_stride_in_bytes,
|
|
38
38
|
result_type_ *c, std::size_t c_stride_in_bytes, std::size_t row_start = 0,
|
|
39
39
|
std::size_t row_count = std::numeric_limits<std::size_t>::max()) noexcept {
|
|
40
|
-
if (row_count == std::numeric_limits<std::size_t>::max()) row_count =
|
|
40
|
+
if (row_count == std::numeric_limits<std::size_t>::max()) row_count = vectors_count;
|
|
41
41
|
constexpr bool dispatch = allow_simd_ == prefer_simd_k &&
|
|
42
42
|
std::is_same_v<result_type_, typename in_type_::angular_result_t>;
|
|
43
43
|
|
|
44
44
|
if constexpr (std::is_same_v<in_type_, f64_t> && dispatch)
|
|
45
|
-
nk_angulars_symmetric_f64(&a->raw_,
|
|
46
|
-
row_count);
|
|
45
|
+
nk_angulars_symmetric_f64(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
46
|
+
row_start, row_count);
|
|
47
47
|
else if constexpr (std::is_same_v<in_type_, f32_t> && dispatch)
|
|
48
|
-
nk_angulars_symmetric_f32(&a->raw_,
|
|
49
|
-
row_count);
|
|
48
|
+
nk_angulars_symmetric_f32(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
49
|
+
row_start, row_count);
|
|
50
50
|
else if constexpr (std::is_same_v<in_type_, f16_t> && dispatch)
|
|
51
|
-
nk_angulars_symmetric_f16(&a->raw_,
|
|
52
|
-
row_count);
|
|
51
|
+
nk_angulars_symmetric_f16(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
52
|
+
row_start, row_count);
|
|
53
53
|
else if constexpr (std::is_same_v<in_type_, bf16_t> && dispatch)
|
|
54
|
-
nk_angulars_symmetric_bf16(&a->raw_,
|
|
54
|
+
nk_angulars_symmetric_bf16(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
55
55
|
row_start, row_count);
|
|
56
56
|
else if constexpr (std::is_same_v<in_type_, e4m3_t> && dispatch)
|
|
57
|
-
nk_angulars_symmetric_e4m3(&a->raw_,
|
|
57
|
+
nk_angulars_symmetric_e4m3(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
58
58
|
row_start, row_count);
|
|
59
59
|
else if constexpr (std::is_same_v<in_type_, e5m2_t> && dispatch)
|
|
60
|
-
nk_angulars_symmetric_e5m2(&a->raw_,
|
|
60
|
+
nk_angulars_symmetric_e5m2(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
61
61
|
row_start, row_count);
|
|
62
62
|
else if constexpr (std::is_same_v<in_type_, e2m3_t> && dispatch)
|
|
63
|
-
nk_angulars_symmetric_e2m3(&a->raw_,
|
|
63
|
+
nk_angulars_symmetric_e2m3(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
64
64
|
row_start, row_count);
|
|
65
65
|
else if constexpr (std::is_same_v<in_type_, e3m2_t> && dispatch)
|
|
66
|
-
nk_angulars_symmetric_e3m2(&a->raw_,
|
|
66
|
+
nk_angulars_symmetric_e3m2(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
67
67
|
row_start, row_count);
|
|
68
68
|
else if constexpr (std::is_same_v<in_type_, i8_t> && dispatch)
|
|
69
|
-
nk_angulars_symmetric_i8(&a->raw_,
|
|
70
|
-
row_count);
|
|
69
|
+
nk_angulars_symmetric_i8(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
70
|
+
row_start, row_count);
|
|
71
71
|
else if constexpr (std::is_same_v<in_type_, u8_t> && dispatch)
|
|
72
|
-
nk_angulars_symmetric_u8(&a->raw_,
|
|
73
|
-
row_count);
|
|
72
|
+
nk_angulars_symmetric_u8(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
73
|
+
row_start, row_count);
|
|
74
74
|
else if constexpr (std::is_same_v<in_type_, i4x2_t> && dispatch)
|
|
75
|
-
nk_angulars_symmetric_i4(&a->raw_,
|
|
76
|
-
row_count);
|
|
75
|
+
nk_angulars_symmetric_i4(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
76
|
+
row_start, row_count);
|
|
77
77
|
else if constexpr (std::is_same_v<in_type_, u4x2_t> && dispatch)
|
|
78
|
-
nk_angulars_symmetric_u4(&a->raw_,
|
|
79
|
-
row_count);
|
|
78
|
+
nk_angulars_symmetric_u4(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
79
|
+
row_start, row_count);
|
|
80
80
|
else {
|
|
81
81
|
std::size_t depth_values = divide_round_up(depth, dimensions_per_value<in_type_>());
|
|
82
82
|
char const *a_bytes = reinterpret_cast<char const *>(a);
|
|
83
83
|
char *c_bytes = reinterpret_cast<char *>(c);
|
|
84
|
-
std::size_t row_end = row_start + row_count <
|
|
84
|
+
std::size_t row_end = row_start + row_count < vectors_count ? row_start + row_count : vectors_count;
|
|
85
85
|
|
|
86
86
|
for (std::size_t i = row_start; i < row_end; i++) {
|
|
87
87
|
in_type_ const *a_i = reinterpret_cast<in_type_ const *>(a_bytes + i * a_stride_in_bytes);
|
|
88
88
|
result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
|
|
89
|
-
for (std::size_t j = 0; j <
|
|
89
|
+
for (std::size_t j = 0; j < vectors_count; j++) {
|
|
90
90
|
in_type_ const *a_j = reinterpret_cast<in_type_ const *>(a_bytes + j * a_stride_in_bytes);
|
|
91
91
|
result_type_ ab {}, aa {}, bb {};
|
|
92
92
|
for (std::size_t l = 0; l < depth_values; l++) {
|
|
@@ -104,8 +104,8 @@ void angulars_symmetric(in_type_ const *a, std::size_t n_vectors, std::size_t de
|
|
|
104
104
|
|
|
105
105
|
/**
|
|
106
106
|
* @brief Symmetric Euclidean distance matrix: C[i,j] = euclidean(A[i], A[j])
|
|
107
|
-
* @param[in] a Matrix A [
|
|
108
|
-
* @param[in]
|
|
107
|
+
* @param[in] a Matrix A [vectors_count x depth]
|
|
108
|
+
* @param[in] vectors_count Number of vectors (n)
|
|
109
109
|
* @param[in] depth Dimension of each vector (k)
|
|
110
110
|
* @param[in] a_stride_in_bytes Stride between vectors in A
|
|
111
111
|
* @param[out] c Output matrix C [n x n]
|
|
@@ -119,59 +119,60 @@ void angulars_symmetric(in_type_ const *a, std::size_t n_vectors, std::size_t de
|
|
|
119
119
|
*/
|
|
120
120
|
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::euclidean_result_t,
|
|
121
121
|
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
122
|
-
void euclideans_symmetric(in_type_ const *a, std::size_t
|
|
123
|
-
result_type_ *c, std::size_t c_stride_in_bytes,
|
|
122
|
+
void euclideans_symmetric(in_type_ const *a, std::size_t vectors_count, std::size_t depth,
|
|
123
|
+
std::size_t a_stride_in_bytes, result_type_ *c, std::size_t c_stride_in_bytes,
|
|
124
|
+
std::size_t row_start = 0,
|
|
124
125
|
std::size_t row_count = std::numeric_limits<std::size_t>::max()) noexcept {
|
|
125
|
-
if (row_count == std::numeric_limits<std::size_t>::max()) row_count =
|
|
126
|
+
if (row_count == std::numeric_limits<std::size_t>::max()) row_count = vectors_count;
|
|
126
127
|
constexpr bool dispatch = allow_simd_ == prefer_simd_k &&
|
|
127
128
|
std::is_same_v<result_type_, typename in_type_::euclidean_result_t>;
|
|
128
129
|
|
|
129
130
|
if constexpr (std::is_same_v<in_type_, f64_t> && dispatch)
|
|
130
|
-
nk_euclideans_symmetric_f64(&a->raw_,
|
|
131
|
+
nk_euclideans_symmetric_f64(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
131
132
|
row_start, row_count);
|
|
132
133
|
else if constexpr (std::is_same_v<in_type_, f32_t> && dispatch)
|
|
133
|
-
nk_euclideans_symmetric_f32(&a->raw_,
|
|
134
|
+
nk_euclideans_symmetric_f32(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
134
135
|
row_start, row_count);
|
|
135
136
|
else if constexpr (std::is_same_v<in_type_, f16_t> && dispatch)
|
|
136
|
-
nk_euclideans_symmetric_f16(&a->raw_,
|
|
137
|
+
nk_euclideans_symmetric_f16(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
137
138
|
row_start, row_count);
|
|
138
139
|
else if constexpr (std::is_same_v<in_type_, bf16_t> && dispatch)
|
|
139
|
-
nk_euclideans_symmetric_bf16(&a->raw_,
|
|
140
|
+
nk_euclideans_symmetric_bf16(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
140
141
|
row_start, row_count);
|
|
141
142
|
else if constexpr (std::is_same_v<in_type_, e4m3_t> && dispatch)
|
|
142
|
-
nk_euclideans_symmetric_e4m3(&a->raw_,
|
|
143
|
+
nk_euclideans_symmetric_e4m3(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
143
144
|
row_start, row_count);
|
|
144
145
|
else if constexpr (std::is_same_v<in_type_, e5m2_t> && dispatch)
|
|
145
|
-
nk_euclideans_symmetric_e5m2(&a->raw_,
|
|
146
|
+
nk_euclideans_symmetric_e5m2(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
146
147
|
row_start, row_count);
|
|
147
148
|
else if constexpr (std::is_same_v<in_type_, e2m3_t> && dispatch)
|
|
148
|
-
nk_euclideans_symmetric_e2m3(&a->raw_,
|
|
149
|
+
nk_euclideans_symmetric_e2m3(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
149
150
|
row_start, row_count);
|
|
150
151
|
else if constexpr (std::is_same_v<in_type_, e3m2_t> && dispatch)
|
|
151
|
-
nk_euclideans_symmetric_e3m2(&a->raw_,
|
|
152
|
+
nk_euclideans_symmetric_e3m2(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
152
153
|
row_start, row_count);
|
|
153
154
|
else if constexpr (std::is_same_v<in_type_, i8_t> && dispatch)
|
|
154
|
-
nk_euclideans_symmetric_i8(&a->raw_,
|
|
155
|
+
nk_euclideans_symmetric_i8(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
155
156
|
row_start, row_count);
|
|
156
157
|
else if constexpr (std::is_same_v<in_type_, u8_t> && dispatch)
|
|
157
|
-
nk_euclideans_symmetric_u8(&a->raw_,
|
|
158
|
+
nk_euclideans_symmetric_u8(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
158
159
|
row_start, row_count);
|
|
159
160
|
else if constexpr (std::is_same_v<in_type_, i4x2_t> && dispatch)
|
|
160
|
-
nk_euclideans_symmetric_i4(&a->raw_,
|
|
161
|
+
nk_euclideans_symmetric_i4(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
161
162
|
row_start, row_count);
|
|
162
163
|
else if constexpr (std::is_same_v<in_type_, u4x2_t> && dispatch)
|
|
163
|
-
nk_euclideans_symmetric_u4(&a->raw_,
|
|
164
|
+
nk_euclideans_symmetric_u4(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
164
165
|
row_start, row_count);
|
|
165
166
|
else {
|
|
166
167
|
std::size_t depth_values = divide_round_up(depth, dimensions_per_value<in_type_>());
|
|
167
168
|
char const *a_bytes = reinterpret_cast<char const *>(a);
|
|
168
169
|
char *c_bytes = reinterpret_cast<char *>(c);
|
|
169
|
-
std::size_t row_end = row_start + row_count <
|
|
170
|
+
std::size_t row_end = row_start + row_count < vectors_count ? row_start + row_count : vectors_count;
|
|
170
171
|
|
|
171
172
|
for (std::size_t i = row_start; i < row_end; i++) {
|
|
172
173
|
in_type_ const *a_i = reinterpret_cast<in_type_ const *>(a_bytes + i * a_stride_in_bytes);
|
|
173
174
|
result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
|
|
174
|
-
for (std::size_t j = 0; j <
|
|
175
|
+
for (std::size_t j = 0; j < vectors_count; j++) {
|
|
175
176
|
in_type_ const *a_j = reinterpret_cast<in_type_ const *>(a_bytes + j * a_stride_in_bytes);
|
|
176
177
|
result_type_ sum {};
|
|
177
178
|
for (std::size_t l = 0; l < depth_values; l++) sum = fdsa(a_i[l], a_j[l], sum);
|
|
@@ -361,7 +362,7 @@ void euclideans_packed(in_type_ const *a, void const *b_packed, result_type_ *c,
|
|
|
361
362
|
|
|
362
363
|
namespace ashvardanian::numkong {
|
|
363
364
|
|
|
364
|
-
#pragma region
|
|
365
|
+
#pragma region Concept Constrained Symmetric Spatial Distances
|
|
365
366
|
|
|
366
367
|
/** @brief Symmetric angular distances: C[i,j] = angular(A[i], A[j]). */
|
|
367
368
|
template <numeric_dtype value_type_, const_matrix_of<value_type_> input_matrix_,
|
|
@@ -443,9 +444,9 @@ bool euclideans_symmetric(input_matrix_ const &input, output_matrix_ &&output, s
|
|
|
443
444
|
return true;
|
|
444
445
|
}
|
|
445
446
|
|
|
446
|
-
#pragma endregion
|
|
447
|
+
#pragma endregion Concept Constrained Symmetric Spatial Distances
|
|
447
448
|
|
|
448
|
-
#pragma region
|
|
449
|
+
#pragma region Concept Constrained Packed Spatial Distances
|
|
449
450
|
|
|
450
451
|
/** @brief Packed angular distances: C = angular(A, B_packed). */
|
|
451
452
|
template <numeric_dtype value_type_, packed_matrix_like packed_type_, const_matrix_of<value_type_> input_matrix_,
|
|
@@ -501,7 +502,7 @@ matrix<typename value_type_::euclidean_result_t, allocator_type_> try_euclideans
|
|
|
501
502
|
return c;
|
|
502
503
|
}
|
|
503
504
|
|
|
504
|
-
#pragma endregion
|
|
505
|
+
#pragma endregion Concept Constrained Packed Spatial Distances
|
|
505
506
|
|
|
506
507
|
} // namespace ashvardanian::numkong
|
|
507
508
|
|