numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
package/include/numkong/dots.hpp
CHANGED
|
@@ -160,7 +160,7 @@ void dots_packed(in_type_ const *a, void const *b_packed, result_type_ *c, size_
|
|
|
160
160
|
/**
|
|
161
161
|
* @brief Symmetric dot products: C = A × Aᵀ where C[i,j] = ⟨A[i], A[j]⟩
|
|
162
162
|
* @param[in] a Matrix A [n x k] (n vectors of dimension k)
|
|
163
|
-
* @param[in]
|
|
163
|
+
* @param[in] vectors_count Number of vectors (n)
|
|
164
164
|
* @param[in] depth Dimension of each vector (k)
|
|
165
165
|
* @param[in] a_stride_in_bytes Stride between vectors in A
|
|
166
166
|
* @param[out] c Output matrix C [n x n]
|
|
@@ -172,59 +172,59 @@ void dots_packed(in_type_ const *a, void const *b_packed, result_type_ *c, size_
|
|
|
172
172
|
*/
|
|
173
173
|
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::dot_result_t,
|
|
174
174
|
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
175
|
-
void dots_symmetric(in_type_ const *a, std::size_t
|
|
175
|
+
void dots_symmetric(in_type_ const *a, std::size_t vectors_count, std::size_t depth, std::size_t a_stride_in_bytes,
|
|
176
176
|
result_type_ *c, std::size_t c_stride_in_bytes, std::size_t row_start = 0,
|
|
177
177
|
std::size_t row_count = std::numeric_limits<std::size_t>::max()) noexcept {
|
|
178
|
-
if (row_count == std::numeric_limits<std::size_t>::max()) row_count =
|
|
178
|
+
if (row_count == std::numeric_limits<std::size_t>::max()) row_count = vectors_count;
|
|
179
179
|
constexpr bool dispatch = allow_simd_ == prefer_simd_k &&
|
|
180
180
|
std::is_same_v<result_type_, typename in_type_::dot_result_t>;
|
|
181
181
|
|
|
182
182
|
if constexpr (std::is_same_v<in_type_, f64_t> && dispatch)
|
|
183
|
-
nk_dots_symmetric_f64(&a->raw_,
|
|
183
|
+
nk_dots_symmetric_f64(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
|
|
184
184
|
row_count);
|
|
185
185
|
else if constexpr (std::is_same_v<in_type_, f32_t> && dispatch)
|
|
186
|
-
nk_dots_symmetric_f32(&a->raw_,
|
|
186
|
+
nk_dots_symmetric_f32(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
|
|
187
187
|
row_count);
|
|
188
188
|
else if constexpr (std::is_same_v<in_type_, f16_t> && dispatch)
|
|
189
|
-
nk_dots_symmetric_f16(&a->raw_,
|
|
189
|
+
nk_dots_symmetric_f16(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
|
|
190
190
|
row_count);
|
|
191
191
|
else if constexpr (std::is_same_v<in_type_, bf16_t> && dispatch)
|
|
192
|
-
nk_dots_symmetric_bf16(&a->raw_,
|
|
193
|
-
row_count);
|
|
192
|
+
nk_dots_symmetric_bf16(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
193
|
+
row_start, row_count);
|
|
194
194
|
else if constexpr (std::is_same_v<in_type_, i8_t> && dispatch)
|
|
195
|
-
nk_dots_symmetric_i8(&a->raw_,
|
|
195
|
+
nk_dots_symmetric_i8(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
|
|
196
196
|
row_count);
|
|
197
197
|
else if constexpr (std::is_same_v<in_type_, u8_t> && dispatch)
|
|
198
|
-
nk_dots_symmetric_u8(&a->raw_,
|
|
198
|
+
nk_dots_symmetric_u8(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
|
|
199
199
|
row_count);
|
|
200
200
|
else if constexpr (std::is_same_v<in_type_, e4m3_t> && dispatch)
|
|
201
|
-
nk_dots_symmetric_e4m3(&a->raw_,
|
|
202
|
-
row_count);
|
|
201
|
+
nk_dots_symmetric_e4m3(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
202
|
+
row_start, row_count);
|
|
203
203
|
else if constexpr (std::is_same_v<in_type_, e5m2_t> && dispatch)
|
|
204
|
-
nk_dots_symmetric_e5m2(&a->raw_,
|
|
205
|
-
row_count);
|
|
204
|
+
nk_dots_symmetric_e5m2(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
205
|
+
row_start, row_count);
|
|
206
206
|
else if constexpr (std::is_same_v<in_type_, e2m3_t> && dispatch)
|
|
207
|
-
nk_dots_symmetric_e2m3(&a->raw_,
|
|
208
|
-
row_count);
|
|
207
|
+
nk_dots_symmetric_e2m3(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
208
|
+
row_start, row_count);
|
|
209
209
|
else if constexpr (std::is_same_v<in_type_, e3m2_t> && dispatch)
|
|
210
|
-
nk_dots_symmetric_e3m2(&a->raw_,
|
|
211
|
-
row_count);
|
|
210
|
+
nk_dots_symmetric_e3m2(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
211
|
+
row_start, row_count);
|
|
212
212
|
else if constexpr (std::is_same_v<in_type_, u4x2_t> && dispatch)
|
|
213
|
-
nk_dots_symmetric_u4(&a->raw_,
|
|
213
|
+
nk_dots_symmetric_u4(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
|
|
214
214
|
row_count);
|
|
215
215
|
else if constexpr (std::is_same_v<in_type_, i4x2_t> && dispatch)
|
|
216
|
-
nk_dots_symmetric_i4(&a->raw_,
|
|
216
|
+
nk_dots_symmetric_i4(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes, row_start,
|
|
217
217
|
row_count);
|
|
218
218
|
else {
|
|
219
219
|
std::size_t depth_values = divide_round_up(depth, dimensions_per_value<in_type_>());
|
|
220
220
|
char const *a_bytes = reinterpret_cast<char const *>(a);
|
|
221
221
|
char *c_bytes = reinterpret_cast<char *>(c);
|
|
222
|
-
std::size_t row_end = row_start + row_count <
|
|
222
|
+
std::size_t row_end = row_start + row_count < vectors_count ? row_start + row_count : vectors_count;
|
|
223
223
|
|
|
224
224
|
for (std::size_t i = row_start; i < row_end; i++) {
|
|
225
225
|
in_type_ const *a_i = reinterpret_cast<in_type_ const *>(a_bytes + i * a_stride_in_bytes);
|
|
226
226
|
result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
|
|
227
|
-
for (std::size_t j = 0; j <
|
|
227
|
+
for (std::size_t j = 0; j < vectors_count; j++) {
|
|
228
228
|
in_type_ const *a_j = reinterpret_cast<in_type_ const *>(a_bytes + j * a_stride_in_bytes);
|
|
229
229
|
result_type_ sum {};
|
|
230
230
|
for (std::size_t l = 0; l < depth_values; l++) sum = fma(a_i[l], a_j[l], sum);
|
|
@@ -236,11 +236,11 @@ void dots_symmetric(in_type_ const *a, std::size_t n_vectors, std::size_t depth,
|
|
|
236
236
|
|
|
237
237
|
/**
|
|
238
238
|
* @brief Symmetric Hamming distance matrix: C[i,j] = hamming(A[i], A[j])
|
|
239
|
-
* @param[in] a Input matrix (
|
|
240
|
-
* @param[in]
|
|
239
|
+
* @param[in] a Input matrix (vectors_count x depth)
|
|
240
|
+
* @param[in] vectors_count Number of vectors
|
|
241
241
|
* @param[in] depth Number of dimensions per vector
|
|
242
242
|
* @param[in] a_stride_in_bytes Row stride in bytes
|
|
243
|
-
* @param[out] c Output matrix (
|
|
243
|
+
* @param[out] c Output matrix (vectors_count x vectors_count)
|
|
244
244
|
* @param[in] c_stride_in_bytes Output row stride in bytes
|
|
245
245
|
* @param[in] row_start Starting row index (default 0)
|
|
246
246
|
* @param[in] row_count Number of rows to compute (default all)
|
|
@@ -254,28 +254,28 @@ void dots_symmetric(in_type_ const *a, std::size_t n_vectors, std::size_t depth,
|
|
|
254
254
|
*/
|
|
255
255
|
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::hamming_result_t,
|
|
256
256
|
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
257
|
-
void hammings_symmetric(in_type_ const *a, std::size_t
|
|
257
|
+
void hammings_symmetric(in_type_ const *a, std::size_t vectors_count, std::size_t depth, std::size_t a_stride_in_bytes,
|
|
258
258
|
result_type_ *c, std::size_t c_stride_in_bytes, std::size_t row_start = 0,
|
|
259
259
|
std::size_t row_count = std::numeric_limits<std::size_t>::max()) noexcept {
|
|
260
|
-
if (row_count == std::numeric_limits<std::size_t>::max()) row_count =
|
|
260
|
+
if (row_count == std::numeric_limits<std::size_t>::max()) row_count = vectors_count;
|
|
261
261
|
constexpr bool dispatch = allow_simd_ == prefer_simd_k &&
|
|
262
262
|
std::is_same_v<result_type_, typename in_type_::hamming_result_t>;
|
|
263
263
|
|
|
264
264
|
if constexpr (std::is_same_v<in_type_, u1x8_t> && dispatch)
|
|
265
|
-
nk_hammings_symmetric_u1(&a->raw_,
|
|
266
|
-
row_count);
|
|
265
|
+
nk_hammings_symmetric_u1(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
266
|
+
row_start, row_count);
|
|
267
267
|
else {
|
|
268
268
|
using raw_t = typename in_type_::raw_t;
|
|
269
269
|
std::size_t depth_bytes = divide_round_up(depth, 8);
|
|
270
270
|
char const *a_bytes = reinterpret_cast<char const *>(a);
|
|
271
271
|
char *c_bytes = reinterpret_cast<char *>(c);
|
|
272
|
-
std::size_t row_end = row_start + row_count <
|
|
272
|
+
std::size_t row_end = row_start + row_count < vectors_count ? row_start + row_count : vectors_count;
|
|
273
273
|
|
|
274
274
|
for (std::size_t i = row_start; i < row_end; i++) {
|
|
275
275
|
raw_t const *a_i = reinterpret_cast<raw_t const *>(a_bytes + i * a_stride_in_bytes);
|
|
276
276
|
result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
|
|
277
277
|
|
|
278
|
-
for (std::size_t j = 0; j <
|
|
278
|
+
for (std::size_t j = 0; j < vectors_count; j++) {
|
|
279
279
|
raw_t const *a_j = reinterpret_cast<raw_t const *>(a_bytes + j * a_stride_in_bytes);
|
|
280
280
|
typename result_type_::raw_t distance = 0;
|
|
281
281
|
for (std::size_t b = 0; b < depth_bytes; b++) {
|
|
@@ -362,35 +362,36 @@ void hammings_packed(in_type_ const *a, void const *b_packed, result_type_ *c, s
|
|
|
362
362
|
*/
|
|
363
363
|
template <numeric_dtype in_type_, numeric_dtype result_type_ = typename in_type_::jaccard_result_t,
|
|
364
364
|
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
365
|
-
void jaccards_symmetric(in_type_ const *a, std::size_t
|
|
365
|
+
void jaccards_symmetric(in_type_ const *a, std::size_t vectors_count, std::size_t depth, std::size_t a_stride_in_bytes,
|
|
366
366
|
result_type_ *c, std::size_t c_stride_in_bytes, std::size_t row_start = 0,
|
|
367
367
|
std::size_t row_count = std::numeric_limits<std::size_t>::max()) noexcept {
|
|
368
|
-
if (row_count == std::numeric_limits<std::size_t>::max()) row_count =
|
|
368
|
+
if (row_count == std::numeric_limits<std::size_t>::max()) row_count = vectors_count;
|
|
369
369
|
constexpr bool dispatch = allow_simd_ == prefer_simd_k &&
|
|
370
370
|
std::is_same_v<result_type_, typename in_type_::jaccard_result_t>;
|
|
371
371
|
|
|
372
372
|
if constexpr (std::is_same_v<in_type_, u1x8_t> && dispatch)
|
|
373
|
-
nk_jaccards_symmetric_u1(&a->raw_,
|
|
374
|
-
row_count);
|
|
373
|
+
nk_jaccards_symmetric_u1(&a->raw_, vectors_count, depth, a_stride_in_bytes, &c->raw_, c_stride_in_bytes,
|
|
374
|
+
row_start, row_count);
|
|
375
375
|
else {
|
|
376
376
|
using raw_t = typename in_type_::raw_t;
|
|
377
377
|
std::size_t depth_bytes = divide_round_up(depth, 8);
|
|
378
378
|
char const *a_bytes = reinterpret_cast<char const *>(a);
|
|
379
379
|
char *c_bytes = reinterpret_cast<char *>(c);
|
|
380
|
-
std::size_t row_end = row_start + row_count <
|
|
380
|
+
std::size_t row_end = row_start + row_count < vectors_count ? row_start + row_count : vectors_count;
|
|
381
381
|
|
|
382
382
|
for (std::size_t i = row_start; i < row_end; i++) {
|
|
383
383
|
raw_t const *a_i = reinterpret_cast<raw_t const *>(a_bytes + i * a_stride_in_bytes);
|
|
384
384
|
result_type_ *c_row = reinterpret_cast<result_type_ *>(c_bytes + i * c_stride_in_bytes);
|
|
385
385
|
|
|
386
|
-
for (std::size_t j = 0; j <
|
|
386
|
+
for (std::size_t j = 0; j < vectors_count; j++) {
|
|
387
387
|
raw_t const *a_j = reinterpret_cast<raw_t const *>(a_bytes + j * a_stride_in_bytes);
|
|
388
388
|
unsigned intersection = 0, union_ = 0;
|
|
389
389
|
for (std::size_t b = 0; b < depth_bytes; b++) {
|
|
390
390
|
intersection += std::popcount(static_cast<unsigned>(a_i[b] & a_j[b]));
|
|
391
391
|
union_ += std::popcount(static_cast<unsigned>(a_i[b] | a_j[b]));
|
|
392
392
|
}
|
|
393
|
-
c_row[j] = result_type_::from_raw(
|
|
393
|
+
c_row[j] = result_type_::from_raw(
|
|
394
|
+
union_ ? 1.0f - static_cast<float>(intersection) / static_cast<float>(union_) : 0.0f);
|
|
394
395
|
}
|
|
395
396
|
}
|
|
396
397
|
}
|
|
@@ -440,7 +441,8 @@ void jaccards_packed(in_type_ const *a, void const *b_packed, result_type_ *c, s
|
|
|
440
441
|
intersection += std::popcount(static_cast<unsigned>(a_row[byte_idx] & b_row[byte_idx]));
|
|
441
442
|
union_ += std::popcount(static_cast<unsigned>(a_row[byte_idx] | b_row[byte_idx]));
|
|
442
443
|
}
|
|
443
|
-
c_row[j] = result_type_::from_raw(
|
|
444
|
+
c_row[j] = result_type_::from_raw(
|
|
445
|
+
union_ ? 1.0f - static_cast<float>(intersection) / static_cast<float>(union_) : 0.0f);
|
|
444
446
|
}
|
|
445
447
|
}
|
|
446
448
|
}
|
|
@@ -452,7 +454,7 @@ void jaccards_packed(in_type_ const *a, void const *b_packed, result_type_ *c, s
|
|
|
452
454
|
|
|
453
455
|
namespace ashvardanian::numkong {
|
|
454
456
|
|
|
455
|
-
#pragma region
|
|
457
|
+
#pragma region Concept Constrained Symmetric Dot Products
|
|
456
458
|
|
|
457
459
|
/** @brief C = A × Aᵀ where C[i,j] = ⟨A[i], A[j]⟩. */
|
|
458
460
|
template <numeric_dtype value_type_, const_matrix_of<value_type_> input_matrix_,
|
|
@@ -547,9 +549,9 @@ matrix<typename value_type_::jaccard_result_t, allocator_type_> try_jaccards_sym
|
|
|
547
549
|
return result;
|
|
548
550
|
}
|
|
549
551
|
|
|
550
|
-
#pragma endregion
|
|
552
|
+
#pragma endregion Concept Constrained Symmetric Dot Products
|
|
551
553
|
|
|
552
|
-
#pragma region
|
|
554
|
+
#pragma region Concept Constrained Packed Dot Products
|
|
553
555
|
|
|
554
556
|
/** @brief Packed dot products: C = A × B_packedᵀ. */
|
|
555
557
|
template <numeric_dtype value_type_, packed_matrix_like packed_type_, const_matrix_of<value_type_> input_matrix_,
|
|
@@ -632,7 +634,7 @@ matrix<typename value_type_::jaccard_result_t, allocator_type_> try_jaccards_pac
|
|
|
632
634
|
return c;
|
|
633
635
|
}
|
|
634
636
|
|
|
635
|
-
#pragma endregion
|
|
637
|
+
#pragma endregion Concept Constrained Packed Dot Products
|
|
636
638
|
|
|
637
639
|
} // namespace ashvardanian::numkong
|
|
638
640
|
|