numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -184,7 +184,7 @@ NK_INTERNAL void nk_euclideans_row_u32dots_sapphireamx_(nk_f32_t *results, nk_u3
|
|
|
184
184
|
}
|
|
185
185
|
}
|
|
186
186
|
|
|
187
|
-
#pragma endregion
|
|
187
|
+
#pragma endregion Row Finalize Helpers
|
|
188
188
|
|
|
189
189
|
#pragma region BF16 Packed
|
|
190
190
|
|
|
@@ -234,11 +234,11 @@ NK_PUBLIC void nk_euclideans_packed_bf16_sapphireamx( //
|
|
|
234
234
|
c_stride_elements);
|
|
235
235
|
}
|
|
236
236
|
|
|
237
|
-
#pragma endregion
|
|
237
|
+
#pragma endregion BF16 Packed
|
|
238
238
|
|
|
239
239
|
#pragma region BF16 Symmetric
|
|
240
240
|
|
|
241
|
-
NK_INTERNAL void nk_angulars_symmetric_bf16_sapphireamx_finalize_(nk_bf16_t const *vectors, nk_size_t
|
|
241
|
+
NK_INTERNAL void nk_angulars_symmetric_bf16_sapphireamx_finalize_(nk_bf16_t const *vectors, nk_size_t vectors_count,
|
|
242
242
|
nk_size_t depth, nk_size_t stride_elements,
|
|
243
243
|
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
244
244
|
nk_size_t row_start, nk_size_t row_count) {
|
|
@@ -249,8 +249,8 @@ NK_INTERNAL void nk_angulars_symmetric_bf16_sapphireamx_finalize_(nk_bf16_t cons
|
|
|
249
249
|
|
|
250
250
|
// Phase 2: 256-column chunks with cached norms
|
|
251
251
|
nk_f32_t column_norms_cache[256];
|
|
252
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
253
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
252
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
253
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
254
254
|
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
255
255
|
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_bf16_(vectors + col * stride_elements, depth);
|
|
256
256
|
|
|
@@ -267,18 +267,18 @@ NK_INTERNAL void nk_angulars_symmetric_bf16_sapphireamx_finalize_(nk_bf16_t cons
|
|
|
267
267
|
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
268
268
|
}
|
|
269
269
|
|
|
270
|
-
NK_PUBLIC void nk_angulars_symmetric_bf16_sapphireamx(
|
|
271
|
-
nk_bf16_t const *vectors, nk_size_t
|
|
272
|
-
nk_f32_t *result, nk_size_t
|
|
273
|
-
nk_size_t const stride_elements =
|
|
274
|
-
nk_size_t const result_stride_elements =
|
|
275
|
-
nk_dots_symmetric_bf16_sapphireamx(vectors,
|
|
276
|
-
row_count);
|
|
277
|
-
nk_angulars_symmetric_bf16_sapphireamx_finalize_(vectors,
|
|
270
|
+
NK_PUBLIC void nk_angulars_symmetric_bf16_sapphireamx( //
|
|
271
|
+
nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
272
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
273
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_bf16_t);
|
|
274
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
275
|
+
nk_dots_symmetric_bf16_sapphireamx(vectors, vectors_count, depth, stride_in_bytes, (nk_f32_t *)result,
|
|
276
|
+
result_stride_in_bytes, row_start, row_count);
|
|
277
|
+
nk_angulars_symmetric_bf16_sapphireamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
278
278
|
result_stride_elements, row_start, row_count);
|
|
279
279
|
}
|
|
280
280
|
|
|
281
|
-
NK_INTERNAL void nk_euclideans_symmetric_bf16_sapphireamx_finalize_(nk_bf16_t const *vectors, nk_size_t
|
|
281
|
+
NK_INTERNAL void nk_euclideans_symmetric_bf16_sapphireamx_finalize_(nk_bf16_t const *vectors, nk_size_t vectors_count,
|
|
282
282
|
nk_size_t depth, nk_size_t stride_elements,
|
|
283
283
|
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
284
284
|
nk_size_t row_start, nk_size_t row_count) {
|
|
@@ -289,8 +289,8 @@ NK_INTERNAL void nk_euclideans_symmetric_bf16_sapphireamx_finalize_(nk_bf16_t co
|
|
|
289
289
|
|
|
290
290
|
// Phase 2: 256-column chunks with cached norms
|
|
291
291
|
nk_f32_t column_norms_cache[256];
|
|
292
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
293
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
292
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
293
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
294
294
|
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
295
295
|
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_bf16_(vectors + col * stride_elements, depth);
|
|
296
296
|
|
|
@@ -307,20 +307,20 @@ NK_INTERNAL void nk_euclideans_symmetric_bf16_sapphireamx_finalize_(nk_bf16_t co
|
|
|
307
307
|
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
308
308
|
}
|
|
309
309
|
|
|
310
|
-
NK_PUBLIC void nk_euclideans_symmetric_bf16_sapphireamx(
|
|
311
|
-
nk_bf16_t const *vectors, nk_size_t
|
|
312
|
-
nk_f32_t *result, nk_size_t
|
|
313
|
-
nk_size_t const stride_elements =
|
|
314
|
-
nk_size_t const result_stride_elements =
|
|
315
|
-
nk_dots_symmetric_bf16_sapphireamx(vectors,
|
|
316
|
-
row_count);
|
|
317
|
-
nk_euclideans_symmetric_bf16_sapphireamx_finalize_(vectors,
|
|
310
|
+
NK_PUBLIC void nk_euclideans_symmetric_bf16_sapphireamx( //
|
|
311
|
+
nk_bf16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
312
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
313
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_bf16_t);
|
|
314
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
315
|
+
nk_dots_symmetric_bf16_sapphireamx(vectors, vectors_count, depth, stride_in_bytes, (nk_f32_t *)result,
|
|
316
|
+
result_stride_in_bytes, row_start, row_count);
|
|
317
|
+
nk_euclideans_symmetric_bf16_sapphireamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
318
318
|
result_stride_elements, row_start, row_count);
|
|
319
319
|
}
|
|
320
320
|
|
|
321
|
-
#pragma endregion
|
|
321
|
+
#pragma endregion BF16 Symmetric
|
|
322
322
|
|
|
323
|
-
#pragma region
|
|
323
|
+
#pragma region I8 Packed
|
|
324
324
|
|
|
325
325
|
NK_INTERNAL void nk_angulars_packed_i8_sapphireamx_finalize_(nk_i8_t const *a, void const *b_packed, nk_f32_t *c,
|
|
326
326
|
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
@@ -369,11 +369,11 @@ NK_PUBLIC void nk_euclideans_packed_i8_sapphireamx( //
|
|
|
369
369
|
c_stride_elements);
|
|
370
370
|
}
|
|
371
371
|
|
|
372
|
-
#pragma endregion
|
|
372
|
+
#pragma endregion I8 Packed
|
|
373
373
|
|
|
374
|
-
#pragma region
|
|
374
|
+
#pragma region I8 Symmetric
|
|
375
375
|
|
|
376
|
-
NK_INTERNAL void nk_angulars_symmetric_i8_sapphireamx_finalize_(nk_i8_t const *vectors, nk_size_t
|
|
376
|
+
NK_INTERNAL void nk_angulars_symmetric_i8_sapphireamx_finalize_(nk_i8_t const *vectors, nk_size_t vectors_count,
|
|
377
377
|
nk_size_t depth, nk_size_t stride_elements,
|
|
378
378
|
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
379
379
|
nk_size_t row_start, nk_size_t row_count) {
|
|
@@ -385,8 +385,8 @@ NK_INTERNAL void nk_angulars_symmetric_i8_sapphireamx_finalize_(nk_i8_t const *v
|
|
|
385
385
|
|
|
386
386
|
// Phase 2: 256-column chunks with cached norms
|
|
387
387
|
nk_u32_t column_norms_cache[256];
|
|
388
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
389
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
388
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
389
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
390
390
|
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
391
391
|
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i8_(vectors + col * stride_elements, depth);
|
|
392
392
|
|
|
@@ -404,18 +404,18 @@ NK_INTERNAL void nk_angulars_symmetric_i8_sapphireamx_finalize_(nk_i8_t const *v
|
|
|
404
404
|
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
405
405
|
}
|
|
406
406
|
|
|
407
|
-
NK_PUBLIC void nk_angulars_symmetric_i8_sapphireamx(
|
|
408
|
-
nk_i8_t const *vectors, nk_size_t
|
|
409
|
-
nk_f32_t *result, nk_size_t
|
|
410
|
-
nk_size_t const stride_elements =
|
|
411
|
-
nk_size_t const result_stride_elements =
|
|
412
|
-
nk_dots_symmetric_i8_sapphireamx(vectors,
|
|
413
|
-
row_count);
|
|
414
|
-
nk_angulars_symmetric_i8_sapphireamx_finalize_(vectors,
|
|
407
|
+
NK_PUBLIC void nk_angulars_symmetric_i8_sapphireamx( //
|
|
408
|
+
nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
409
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
410
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_i8_t);
|
|
411
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
412
|
+
nk_dots_symmetric_i8_sapphireamx(vectors, vectors_count, depth, stride_in_bytes, (nk_i32_t *)result,
|
|
413
|
+
result_stride_in_bytes, row_start, row_count);
|
|
414
|
+
nk_angulars_symmetric_i8_sapphireamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
415
415
|
result_stride_elements, row_start, row_count);
|
|
416
416
|
}
|
|
417
417
|
|
|
418
|
-
NK_INTERNAL void nk_euclideans_symmetric_i8_sapphireamx_finalize_(nk_i8_t const *vectors, nk_size_t
|
|
418
|
+
NK_INTERNAL void nk_euclideans_symmetric_i8_sapphireamx_finalize_(nk_i8_t const *vectors, nk_size_t vectors_count,
|
|
419
419
|
nk_size_t depth, nk_size_t stride_elements,
|
|
420
420
|
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
421
421
|
nk_size_t row_start, nk_size_t row_count) {
|
|
@@ -427,8 +427,8 @@ NK_INTERNAL void nk_euclideans_symmetric_i8_sapphireamx_finalize_(nk_i8_t const
|
|
|
427
427
|
|
|
428
428
|
// Phase 2: 256-column chunks with cached norms
|
|
429
429
|
nk_u32_t column_norms_cache[256];
|
|
430
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
431
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
430
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
431
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
432
432
|
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
433
433
|
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i8_(vectors + col * stride_elements, depth);
|
|
434
434
|
|
|
@@ -446,20 +446,20 @@ NK_INTERNAL void nk_euclideans_symmetric_i8_sapphireamx_finalize_(nk_i8_t const
|
|
|
446
446
|
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
447
447
|
}
|
|
448
448
|
|
|
449
|
-
NK_PUBLIC void nk_euclideans_symmetric_i8_sapphireamx(
|
|
450
|
-
nk_i8_t const *vectors, nk_size_t
|
|
451
|
-
nk_f32_t *result, nk_size_t
|
|
452
|
-
nk_size_t const stride_elements =
|
|
453
|
-
nk_size_t const result_stride_elements =
|
|
454
|
-
nk_dots_symmetric_i8_sapphireamx(vectors,
|
|
455
|
-
row_count);
|
|
456
|
-
nk_euclideans_symmetric_i8_sapphireamx_finalize_(vectors,
|
|
449
|
+
NK_PUBLIC void nk_euclideans_symmetric_i8_sapphireamx( //
|
|
450
|
+
nk_i8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
451
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
452
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_i8_t);
|
|
453
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
454
|
+
nk_dots_symmetric_i8_sapphireamx(vectors, vectors_count, depth, stride_in_bytes, (nk_i32_t *)result,
|
|
455
|
+
result_stride_in_bytes, row_start, row_count);
|
|
456
|
+
nk_euclideans_symmetric_i8_sapphireamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
457
457
|
result_stride_elements, row_start, row_count);
|
|
458
458
|
}
|
|
459
459
|
|
|
460
|
-
#pragma endregion
|
|
460
|
+
#pragma endregion I8 Symmetric
|
|
461
461
|
|
|
462
|
-
#pragma region
|
|
462
|
+
#pragma region U8 Packed
|
|
463
463
|
|
|
464
464
|
NK_INTERNAL void nk_angulars_packed_u8_sapphireamx_finalize_(nk_u8_t const *a, void const *b_packed, nk_f32_t *c,
|
|
465
465
|
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
@@ -508,11 +508,11 @@ NK_PUBLIC void nk_euclideans_packed_u8_sapphireamx( //
|
|
|
508
508
|
c_stride_elements);
|
|
509
509
|
}
|
|
510
510
|
|
|
511
|
-
#pragma endregion
|
|
511
|
+
#pragma endregion U8 Packed
|
|
512
512
|
|
|
513
|
-
#pragma region
|
|
513
|
+
#pragma region U8 Symmetric
|
|
514
514
|
|
|
515
|
-
NK_INTERNAL void nk_angulars_symmetric_u8_sapphireamx_finalize_(nk_u8_t const *vectors, nk_size_t
|
|
515
|
+
NK_INTERNAL void nk_angulars_symmetric_u8_sapphireamx_finalize_(nk_u8_t const *vectors, nk_size_t vectors_count,
|
|
516
516
|
nk_size_t depth, nk_size_t stride_elements,
|
|
517
517
|
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
518
518
|
nk_size_t row_start, nk_size_t row_count) {
|
|
@@ -524,8 +524,8 @@ NK_INTERNAL void nk_angulars_symmetric_u8_sapphireamx_finalize_(nk_u8_t const *v
|
|
|
524
524
|
|
|
525
525
|
// Phase 2: 256-column chunks with cached norms
|
|
526
526
|
nk_u32_t column_norms_cache[256];
|
|
527
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
528
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
527
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
528
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
529
529
|
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
530
530
|
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u8_(vectors + col * stride_elements, depth);
|
|
531
531
|
|
|
@@ -543,18 +543,18 @@ NK_INTERNAL void nk_angulars_symmetric_u8_sapphireamx_finalize_(nk_u8_t const *v
|
|
|
543
543
|
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
544
544
|
}
|
|
545
545
|
|
|
546
|
-
NK_PUBLIC void nk_angulars_symmetric_u8_sapphireamx(
|
|
547
|
-
nk_u8_t const *vectors, nk_size_t
|
|
548
|
-
nk_f32_t *result, nk_size_t
|
|
549
|
-
nk_size_t const stride_elements =
|
|
550
|
-
nk_size_t const result_stride_elements =
|
|
551
|
-
nk_dots_symmetric_u8_sapphireamx(vectors,
|
|
552
|
-
row_count);
|
|
553
|
-
nk_angulars_symmetric_u8_sapphireamx_finalize_(vectors,
|
|
546
|
+
NK_PUBLIC void nk_angulars_symmetric_u8_sapphireamx( //
|
|
547
|
+
nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
548
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
549
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_u8_t);
|
|
550
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
551
|
+
nk_dots_symmetric_u8_sapphireamx(vectors, vectors_count, depth, stride_in_bytes, (nk_u32_t *)result,
|
|
552
|
+
result_stride_in_bytes, row_start, row_count);
|
|
553
|
+
nk_angulars_symmetric_u8_sapphireamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
554
554
|
result_stride_elements, row_start, row_count);
|
|
555
555
|
}
|
|
556
556
|
|
|
557
|
-
NK_INTERNAL void nk_euclideans_symmetric_u8_sapphireamx_finalize_(nk_u8_t const *vectors, nk_size_t
|
|
557
|
+
NK_INTERNAL void nk_euclideans_symmetric_u8_sapphireamx_finalize_(nk_u8_t const *vectors, nk_size_t vectors_count,
|
|
558
558
|
nk_size_t depth, nk_size_t stride_elements,
|
|
559
559
|
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
560
560
|
nk_size_t row_start, nk_size_t row_count) {
|
|
@@ -566,8 +566,8 @@ NK_INTERNAL void nk_euclideans_symmetric_u8_sapphireamx_finalize_(nk_u8_t const
|
|
|
566
566
|
|
|
567
567
|
// Phase 2: 256-column chunks with cached norms
|
|
568
568
|
nk_u32_t column_norms_cache[256];
|
|
569
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
570
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
569
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
570
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
571
571
|
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
572
572
|
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u8_(vectors + col * stride_elements, depth);
|
|
573
573
|
|
|
@@ -585,18 +585,18 @@ NK_INTERNAL void nk_euclideans_symmetric_u8_sapphireamx_finalize_(nk_u8_t const
|
|
|
585
585
|
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
586
586
|
}
|
|
587
587
|
|
|
588
|
-
NK_PUBLIC void nk_euclideans_symmetric_u8_sapphireamx(
|
|
589
|
-
nk_u8_t const *vectors, nk_size_t
|
|
590
|
-
nk_f32_t *result, nk_size_t
|
|
591
|
-
nk_size_t const stride_elements =
|
|
592
|
-
nk_size_t const result_stride_elements =
|
|
593
|
-
nk_dots_symmetric_u8_sapphireamx(vectors,
|
|
594
|
-
row_count);
|
|
595
|
-
nk_euclideans_symmetric_u8_sapphireamx_finalize_(vectors,
|
|
588
|
+
NK_PUBLIC void nk_euclideans_symmetric_u8_sapphireamx( //
|
|
589
|
+
nk_u8_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
590
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
591
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_u8_t);
|
|
592
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
593
|
+
nk_dots_symmetric_u8_sapphireamx(vectors, vectors_count, depth, stride_in_bytes, (nk_u32_t *)result,
|
|
594
|
+
result_stride_in_bytes, row_start, row_count);
|
|
595
|
+
nk_euclideans_symmetric_u8_sapphireamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
596
596
|
result_stride_elements, row_start, row_count);
|
|
597
597
|
}
|
|
598
598
|
|
|
599
|
-
#pragma endregion
|
|
599
|
+
#pragma endregion U8 Symmetric
|
|
600
600
|
|
|
601
601
|
#pragma region E4M3 Packed
|
|
602
602
|
|
|
@@ -646,7 +646,7 @@ NK_PUBLIC void nk_euclideans_packed_e4m3_sapphireamx( //
|
|
|
646
646
|
c_stride_elements);
|
|
647
647
|
}
|
|
648
648
|
|
|
649
|
-
#pragma endregion
|
|
649
|
+
#pragma endregion E4M3 Packed
|
|
650
650
|
|
|
651
651
|
#pragma region E5M2 Packed
|
|
652
652
|
|
|
@@ -696,11 +696,11 @@ NK_PUBLIC void nk_euclideans_packed_e5m2_sapphireamx( //
|
|
|
696
696
|
c_stride_elements);
|
|
697
697
|
}
|
|
698
698
|
|
|
699
|
-
#pragma endregion
|
|
699
|
+
#pragma endregion E5M2 Packed
|
|
700
700
|
|
|
701
701
|
#pragma region E5M2 Symmetric
|
|
702
702
|
|
|
703
|
-
NK_INTERNAL void nk_angulars_symmetric_e5m2_sapphireamx_finalize_(nk_e5m2_t const *vectors, nk_size_t
|
|
703
|
+
NK_INTERNAL void nk_angulars_symmetric_e5m2_sapphireamx_finalize_(nk_e5m2_t const *vectors, nk_size_t vectors_count,
|
|
704
704
|
nk_size_t depth, nk_size_t stride_elements,
|
|
705
705
|
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
706
706
|
nk_size_t row_start, nk_size_t row_count) {
|
|
@@ -711,8 +711,8 @@ NK_INTERNAL void nk_angulars_symmetric_e5m2_sapphireamx_finalize_(nk_e5m2_t cons
|
|
|
711
711
|
|
|
712
712
|
// Phase 2: 256-column chunks with cached norms
|
|
713
713
|
nk_f32_t column_norms_cache[256];
|
|
714
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
715
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
714
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
715
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
716
716
|
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
717
717
|
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e5m2_(vectors + col * stride_elements, depth);
|
|
718
718
|
|
|
@@ -729,18 +729,18 @@ NK_INTERNAL void nk_angulars_symmetric_e5m2_sapphireamx_finalize_(nk_e5m2_t cons
|
|
|
729
729
|
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
730
730
|
}
|
|
731
731
|
|
|
732
|
-
NK_PUBLIC void nk_angulars_symmetric_e5m2_sapphireamx(
|
|
733
|
-
nk_e5m2_t const *vectors, nk_size_t
|
|
734
|
-
nk_f32_t *result, nk_size_t
|
|
735
|
-
nk_size_t const stride_elements =
|
|
736
|
-
nk_size_t const result_stride_elements =
|
|
737
|
-
nk_dots_symmetric_e5m2_sapphireamx(vectors,
|
|
738
|
-
row_count);
|
|
739
|
-
nk_angulars_symmetric_e5m2_sapphireamx_finalize_(vectors,
|
|
732
|
+
NK_PUBLIC void nk_angulars_symmetric_e5m2_sapphireamx( //
|
|
733
|
+
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
734
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
735
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e5m2_t);
|
|
736
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
737
|
+
nk_dots_symmetric_e5m2_sapphireamx(vectors, vectors_count, depth, stride_in_bytes, (nk_f32_t *)result,
|
|
738
|
+
result_stride_in_bytes, row_start, row_count);
|
|
739
|
+
nk_angulars_symmetric_e5m2_sapphireamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
740
740
|
result_stride_elements, row_start, row_count);
|
|
741
741
|
}
|
|
742
742
|
|
|
743
|
-
NK_INTERNAL void nk_euclideans_symmetric_e5m2_sapphireamx_finalize_(nk_e5m2_t const *vectors, nk_size_t
|
|
743
|
+
NK_INTERNAL void nk_euclideans_symmetric_e5m2_sapphireamx_finalize_(nk_e5m2_t const *vectors, nk_size_t vectors_count,
|
|
744
744
|
nk_size_t depth, nk_size_t stride_elements,
|
|
745
745
|
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
746
746
|
nk_size_t row_start, nk_size_t row_count) {
|
|
@@ -751,8 +751,8 @@ NK_INTERNAL void nk_euclideans_symmetric_e5m2_sapphireamx_finalize_(nk_e5m2_t co
|
|
|
751
751
|
|
|
752
752
|
// Phase 2: 256-column chunks with cached norms
|
|
753
753
|
nk_f32_t column_norms_cache[256];
|
|
754
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
755
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
754
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
755
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
756
756
|
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
757
757
|
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e5m2_(vectors + col * stride_elements, depth);
|
|
758
758
|
|
|
@@ -769,22 +769,22 @@ NK_INTERNAL void nk_euclideans_symmetric_e5m2_sapphireamx_finalize_(nk_e5m2_t co
|
|
|
769
769
|
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
770
770
|
}
|
|
771
771
|
|
|
772
|
-
NK_PUBLIC void nk_euclideans_symmetric_e5m2_sapphireamx(
|
|
773
|
-
nk_e5m2_t const *vectors, nk_size_t
|
|
774
|
-
nk_f32_t *result, nk_size_t
|
|
775
|
-
nk_size_t const stride_elements =
|
|
776
|
-
nk_size_t const result_stride_elements =
|
|
777
|
-
nk_dots_symmetric_e5m2_sapphireamx(vectors,
|
|
778
|
-
row_count);
|
|
779
|
-
nk_euclideans_symmetric_e5m2_sapphireamx_finalize_(vectors,
|
|
772
|
+
NK_PUBLIC void nk_euclideans_symmetric_e5m2_sapphireamx( //
|
|
773
|
+
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
774
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
775
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e5m2_t);
|
|
776
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
777
|
+
nk_dots_symmetric_e5m2_sapphireamx(vectors, vectors_count, depth, stride_in_bytes, (nk_f32_t *)result,
|
|
778
|
+
result_stride_in_bytes, row_start, row_count);
|
|
779
|
+
nk_euclideans_symmetric_e5m2_sapphireamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
780
780
|
result_stride_elements, row_start, row_count);
|
|
781
781
|
}
|
|
782
782
|
|
|
783
|
-
#pragma endregion
|
|
783
|
+
#pragma endregion E5M2 Symmetric
|
|
784
784
|
|
|
785
785
|
#pragma region E4M3 Symmetric
|
|
786
786
|
|
|
787
|
-
NK_INTERNAL void nk_angulars_symmetric_e4m3_sapphireamx_finalize_(nk_e4m3_t const *vectors, nk_size_t
|
|
787
|
+
NK_INTERNAL void nk_angulars_symmetric_e4m3_sapphireamx_finalize_(nk_e4m3_t const *vectors, nk_size_t vectors_count,
|
|
788
788
|
nk_size_t depth, nk_size_t stride_elements,
|
|
789
789
|
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
790
790
|
nk_size_t row_start, nk_size_t row_count) {
|
|
@@ -795,8 +795,8 @@ NK_INTERNAL void nk_angulars_symmetric_e4m3_sapphireamx_finalize_(nk_e4m3_t cons
|
|
|
795
795
|
|
|
796
796
|
// Phase 2: 256-column chunks with cached norms
|
|
797
797
|
nk_f32_t column_norms_cache[256];
|
|
798
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
799
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
798
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
799
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
800
800
|
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
801
801
|
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e4m3_(vectors + col * stride_elements, depth);
|
|
802
802
|
|
|
@@ -813,18 +813,18 @@ NK_INTERNAL void nk_angulars_symmetric_e4m3_sapphireamx_finalize_(nk_e4m3_t cons
|
|
|
813
813
|
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
814
814
|
}
|
|
815
815
|
|
|
816
|
-
NK_PUBLIC void nk_angulars_symmetric_e4m3_sapphireamx(
|
|
817
|
-
nk_e4m3_t const *vectors, nk_size_t
|
|
818
|
-
nk_f32_t *result, nk_size_t
|
|
819
|
-
nk_size_t const stride_elements =
|
|
820
|
-
nk_size_t const result_stride_elements =
|
|
821
|
-
nk_dots_symmetric_e4m3_sapphireamx(vectors,
|
|
822
|
-
row_count);
|
|
823
|
-
nk_angulars_symmetric_e4m3_sapphireamx_finalize_(vectors,
|
|
816
|
+
NK_PUBLIC void nk_angulars_symmetric_e4m3_sapphireamx( //
|
|
817
|
+
nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
818
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
819
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e4m3_t);
|
|
820
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
821
|
+
nk_dots_symmetric_e4m3_sapphireamx(vectors, vectors_count, depth, stride_in_bytes, (nk_f32_t *)result,
|
|
822
|
+
result_stride_in_bytes, row_start, row_count);
|
|
823
|
+
nk_angulars_symmetric_e4m3_sapphireamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
824
824
|
result_stride_elements, row_start, row_count);
|
|
825
825
|
}
|
|
826
826
|
|
|
827
|
-
NK_INTERNAL void nk_euclideans_symmetric_e4m3_sapphireamx_finalize_(nk_e4m3_t const *vectors, nk_size_t
|
|
827
|
+
NK_INTERNAL void nk_euclideans_symmetric_e4m3_sapphireamx_finalize_(nk_e4m3_t const *vectors, nk_size_t vectors_count,
|
|
828
828
|
nk_size_t depth, nk_size_t stride_elements,
|
|
829
829
|
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
830
830
|
nk_size_t row_start, nk_size_t row_count) {
|
|
@@ -835,8 +835,8 @@ NK_INTERNAL void nk_euclideans_symmetric_e4m3_sapphireamx_finalize_(nk_e4m3_t co
|
|
|
835
835
|
|
|
836
836
|
// Phase 2: 256-column chunks with cached norms
|
|
837
837
|
nk_f32_t column_norms_cache[256];
|
|
838
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
839
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
838
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
839
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
840
840
|
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
841
841
|
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e4m3_(vectors + col * stride_elements, depth);
|
|
842
842
|
|
|
@@ -853,18 +853,18 @@ NK_INTERNAL void nk_euclideans_symmetric_e4m3_sapphireamx_finalize_(nk_e4m3_t co
|
|
|
853
853
|
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
854
854
|
}
|
|
855
855
|
|
|
856
|
-
NK_PUBLIC void nk_euclideans_symmetric_e4m3_sapphireamx(
|
|
857
|
-
nk_e4m3_t const *vectors, nk_size_t
|
|
858
|
-
nk_f32_t *result, nk_size_t
|
|
859
|
-
nk_size_t const stride_elements =
|
|
860
|
-
nk_size_t const result_stride_elements =
|
|
861
|
-
nk_dots_symmetric_e4m3_sapphireamx(vectors,
|
|
862
|
-
row_count);
|
|
863
|
-
nk_euclideans_symmetric_e4m3_sapphireamx_finalize_(vectors,
|
|
856
|
+
NK_PUBLIC void nk_euclideans_symmetric_e4m3_sapphireamx( //
|
|
857
|
+
nk_e4m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
858
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
859
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e4m3_t);
|
|
860
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
861
|
+
nk_dots_symmetric_e4m3_sapphireamx(vectors, vectors_count, depth, stride_in_bytes, (nk_f32_t *)result,
|
|
862
|
+
result_stride_in_bytes, row_start, row_count);
|
|
863
|
+
nk_euclideans_symmetric_e4m3_sapphireamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
864
864
|
result_stride_elements, row_start, row_count);
|
|
865
865
|
}
|
|
866
866
|
|
|
867
|
-
#pragma endregion
|
|
867
|
+
#pragma endregion E4M3 Symmetric
|
|
868
868
|
|
|
869
869
|
#pragma region E2M3 Packed
|
|
870
870
|
|
|
@@ -914,11 +914,11 @@ NK_PUBLIC void nk_euclideans_packed_e2m3_sapphireamx( //
|
|
|
914
914
|
c_stride_elements);
|
|
915
915
|
}
|
|
916
916
|
|
|
917
|
-
#pragma endregion
|
|
917
|
+
#pragma endregion E2M3 Packed
|
|
918
918
|
|
|
919
919
|
#pragma region E2M3 Symmetric
|
|
920
920
|
|
|
921
|
-
NK_INTERNAL void nk_angulars_symmetric_e2m3_sapphireamx_finalize_(nk_e2m3_t const *vectors, nk_size_t
|
|
921
|
+
NK_INTERNAL void nk_angulars_symmetric_e2m3_sapphireamx_finalize_(nk_e2m3_t const *vectors, nk_size_t vectors_count,
|
|
922
922
|
nk_size_t depth, nk_size_t stride_elements,
|
|
923
923
|
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
924
924
|
nk_size_t row_start, nk_size_t row_count) {
|
|
@@ -929,8 +929,8 @@ NK_INTERNAL void nk_angulars_symmetric_e2m3_sapphireamx_finalize_(nk_e2m3_t cons
|
|
|
929
929
|
|
|
930
930
|
// Phase 2: 256-column chunks with cached norms
|
|
931
931
|
nk_f32_t column_norms_cache[256];
|
|
932
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
933
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
932
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
933
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
934
934
|
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
935
935
|
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e2m3_(vectors + col * stride_elements, depth);
|
|
936
936
|
|
|
@@ -947,18 +947,18 @@ NK_INTERNAL void nk_angulars_symmetric_e2m3_sapphireamx_finalize_(nk_e2m3_t cons
|
|
|
947
947
|
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
948
948
|
}
|
|
949
949
|
|
|
950
|
-
NK_PUBLIC void nk_angulars_symmetric_e2m3_sapphireamx(
|
|
951
|
-
nk_e2m3_t const *vectors, nk_size_t
|
|
952
|
-
nk_f32_t *result, nk_size_t
|
|
953
|
-
nk_size_t const stride_elements =
|
|
954
|
-
nk_size_t const result_stride_elements =
|
|
955
|
-
nk_dots_symmetric_e2m3_sapphireamx(vectors,
|
|
956
|
-
row_count);
|
|
957
|
-
nk_angulars_symmetric_e2m3_sapphireamx_finalize_(vectors,
|
|
950
|
+
NK_PUBLIC void nk_angulars_symmetric_e2m3_sapphireamx( //
|
|
951
|
+
nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
952
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
953
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e2m3_t);
|
|
954
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
955
|
+
nk_dots_symmetric_e2m3_sapphireamx(vectors, vectors_count, depth, stride_in_bytes, (nk_f32_t *)result,
|
|
956
|
+
result_stride_in_bytes, row_start, row_count);
|
|
957
|
+
nk_angulars_symmetric_e2m3_sapphireamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
958
958
|
result_stride_elements, row_start, row_count);
|
|
959
959
|
}
|
|
960
960
|
|
|
961
|
-
NK_INTERNAL void nk_euclideans_symmetric_e2m3_sapphireamx_finalize_(nk_e2m3_t const *vectors, nk_size_t
|
|
961
|
+
NK_INTERNAL void nk_euclideans_symmetric_e2m3_sapphireamx_finalize_(nk_e2m3_t const *vectors, nk_size_t vectors_count,
|
|
962
962
|
nk_size_t depth, nk_size_t stride_elements,
|
|
963
963
|
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
964
964
|
nk_size_t row_start, nk_size_t row_count) {
|
|
@@ -969,8 +969,8 @@ NK_INTERNAL void nk_euclideans_symmetric_e2m3_sapphireamx_finalize_(nk_e2m3_t co
|
|
|
969
969
|
|
|
970
970
|
// Phase 2: 256-column chunks with cached norms
|
|
971
971
|
nk_f32_t column_norms_cache[256];
|
|
972
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
973
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
972
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
973
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
974
974
|
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
975
975
|
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e2m3_(vectors + col * stride_elements, depth);
|
|
976
976
|
|
|
@@ -987,18 +987,18 @@ NK_INTERNAL void nk_euclideans_symmetric_e2m3_sapphireamx_finalize_(nk_e2m3_t co
|
|
|
987
987
|
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
988
988
|
}
|
|
989
989
|
|
|
990
|
-
NK_PUBLIC void nk_euclideans_symmetric_e2m3_sapphireamx(
|
|
991
|
-
nk_e2m3_t const *vectors, nk_size_t
|
|
992
|
-
nk_f32_t *result, nk_size_t
|
|
993
|
-
nk_size_t const stride_elements =
|
|
994
|
-
nk_size_t const result_stride_elements =
|
|
995
|
-
nk_dots_symmetric_e2m3_sapphireamx(vectors,
|
|
996
|
-
row_count);
|
|
997
|
-
nk_euclideans_symmetric_e2m3_sapphireamx_finalize_(vectors,
|
|
990
|
+
NK_PUBLIC void nk_euclideans_symmetric_e2m3_sapphireamx( //
|
|
991
|
+
nk_e2m3_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
992
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
993
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e2m3_t);
|
|
994
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
995
|
+
nk_dots_symmetric_e2m3_sapphireamx(vectors, vectors_count, depth, stride_in_bytes, (nk_f32_t *)result,
|
|
996
|
+
result_stride_in_bytes, row_start, row_count);
|
|
997
|
+
nk_euclideans_symmetric_e2m3_sapphireamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
998
998
|
result_stride_elements, row_start, row_count);
|
|
999
999
|
}
|
|
1000
1000
|
|
|
1001
|
-
#pragma endregion
|
|
1001
|
+
#pragma endregion E2M3 Symmetric
|
|
1002
1002
|
|
|
1003
1003
|
#pragma region E3M2 Packed
|
|
1004
1004
|
|
|
@@ -1048,11 +1048,11 @@ NK_PUBLIC void nk_euclideans_packed_e3m2_sapphireamx( //
|
|
|
1048
1048
|
c_stride_elements);
|
|
1049
1049
|
}
|
|
1050
1050
|
|
|
1051
|
-
#pragma endregion
|
|
1051
|
+
#pragma endregion E3M2 Packed
|
|
1052
1052
|
|
|
1053
1053
|
#pragma region E3M2 Symmetric
|
|
1054
1054
|
|
|
1055
|
-
NK_INTERNAL void nk_angulars_symmetric_e3m2_sapphireamx_finalize_(nk_e3m2_t const *vectors, nk_size_t
|
|
1055
|
+
NK_INTERNAL void nk_angulars_symmetric_e3m2_sapphireamx_finalize_(nk_e3m2_t const *vectors, nk_size_t vectors_count,
|
|
1056
1056
|
nk_size_t depth, nk_size_t stride_elements,
|
|
1057
1057
|
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
1058
1058
|
nk_size_t row_start, nk_size_t row_count) {
|
|
@@ -1063,8 +1063,8 @@ NK_INTERNAL void nk_angulars_symmetric_e3m2_sapphireamx_finalize_(nk_e3m2_t cons
|
|
|
1063
1063
|
|
|
1064
1064
|
// Phase 2: 256-column chunks with cached norms
|
|
1065
1065
|
nk_f32_t column_norms_cache[256];
|
|
1066
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1067
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
1066
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
1067
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1068
1068
|
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
1069
1069
|
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e3m2_(vectors + col * stride_elements, depth);
|
|
1070
1070
|
|
|
@@ -1081,18 +1081,18 @@ NK_INTERNAL void nk_angulars_symmetric_e3m2_sapphireamx_finalize_(nk_e3m2_t cons
|
|
|
1081
1081
|
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
1082
1082
|
}
|
|
1083
1083
|
|
|
1084
|
-
NK_PUBLIC void nk_angulars_symmetric_e3m2_sapphireamx(
|
|
1085
|
-
nk_e3m2_t const *vectors, nk_size_t
|
|
1086
|
-
nk_f32_t *result, nk_size_t
|
|
1087
|
-
nk_size_t const stride_elements =
|
|
1088
|
-
nk_size_t const result_stride_elements =
|
|
1089
|
-
nk_dots_symmetric_e3m2_sapphireamx(vectors,
|
|
1090
|
-
row_count);
|
|
1091
|
-
nk_angulars_symmetric_e3m2_sapphireamx_finalize_(vectors,
|
|
1084
|
+
NK_PUBLIC void nk_angulars_symmetric_e3m2_sapphireamx( //
|
|
1085
|
+
nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1086
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1087
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e3m2_t);
|
|
1088
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1089
|
+
nk_dots_symmetric_e3m2_sapphireamx(vectors, vectors_count, depth, stride_in_bytes, (nk_f32_t *)result,
|
|
1090
|
+
result_stride_in_bytes, row_start, row_count);
|
|
1091
|
+
nk_angulars_symmetric_e3m2_sapphireamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
1092
1092
|
result_stride_elements, row_start, row_count);
|
|
1093
1093
|
}
|
|
1094
1094
|
|
|
1095
|
-
NK_INTERNAL void nk_euclideans_symmetric_e3m2_sapphireamx_finalize_(nk_e3m2_t const *vectors, nk_size_t
|
|
1095
|
+
NK_INTERNAL void nk_euclideans_symmetric_e3m2_sapphireamx_finalize_(nk_e3m2_t const *vectors, nk_size_t vectors_count,
|
|
1096
1096
|
nk_size_t depth, nk_size_t stride_elements,
|
|
1097
1097
|
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
1098
1098
|
nk_size_t row_start, nk_size_t row_count) {
|
|
@@ -1103,8 +1103,8 @@ NK_INTERNAL void nk_euclideans_symmetric_e3m2_sapphireamx_finalize_(nk_e3m2_t co
|
|
|
1103
1103
|
|
|
1104
1104
|
// Phase 2: 256-column chunks with cached norms
|
|
1105
1105
|
nk_f32_t column_norms_cache[256];
|
|
1106
|
-
for (nk_size_t chunk_start = 0; chunk_start <
|
|
1107
|
-
nk_size_t chunk_end = chunk_start + 256 <
|
|
1106
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
1107
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
1108
1108
|
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
1109
1109
|
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e3m2_(vectors + col * stride_elements, depth);
|
|
1110
1110
|
|
|
@@ -1121,18 +1121,18 @@ NK_INTERNAL void nk_euclideans_symmetric_e3m2_sapphireamx_finalize_(nk_e3m2_t co
|
|
|
1121
1121
|
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
1122
1122
|
}
|
|
1123
1123
|
|
|
1124
|
-
NK_PUBLIC void nk_euclideans_symmetric_e3m2_sapphireamx(
|
|
1125
|
-
nk_e3m2_t const *vectors, nk_size_t
|
|
1126
|
-
nk_f32_t *result, nk_size_t
|
|
1127
|
-
nk_size_t const stride_elements =
|
|
1128
|
-
nk_size_t const result_stride_elements =
|
|
1129
|
-
nk_dots_symmetric_e3m2_sapphireamx(vectors,
|
|
1130
|
-
row_count);
|
|
1131
|
-
nk_euclideans_symmetric_e3m2_sapphireamx_finalize_(vectors,
|
|
1124
|
+
NK_PUBLIC void nk_euclideans_symmetric_e3m2_sapphireamx( //
|
|
1125
|
+
nk_e3m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
1126
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1127
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_e3m2_t);
|
|
1128
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
1129
|
+
nk_dots_symmetric_e3m2_sapphireamx(vectors, vectors_count, depth, stride_in_bytes, (nk_f32_t *)result,
|
|
1130
|
+
result_stride_in_bytes, row_start, row_count);
|
|
1131
|
+
nk_euclideans_symmetric_e3m2_sapphireamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
1132
1132
|
result_stride_elements, row_start, row_count);
|
|
1133
1133
|
}
|
|
1134
1134
|
|
|
1135
|
-
#pragma endregion
|
|
1135
|
+
#pragma endregion E3M2 Symmetric
|
|
1136
1136
|
|
|
1137
1137
|
#if defined(__clang__)
|
|
1138
1138
|
#pragma clang attribute pop
|