numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,1149 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief Batched Spatial Distances for Sapphire Rapids (AMX) with AVX-512 Finalization.
|
|
3
|
+
* @file include/numkong/spatials/sapphireamx.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 23, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/spatials.h
|
|
8
|
+
*/
|
|
9
|
+
#ifndef NK_SPATIALS_SAPPHIREAMX_H
|
|
10
|
+
#define NK_SPATIALS_SAPPHIREAMX_H
|
|
11
|
+
|
|
12
|
+
#if NK_TARGET_X86_
|
|
13
|
+
#if NK_TARGET_SAPPHIREAMX
|
|
14
|
+
|
|
15
|
+
#include "numkong/spatial/skylake.h"
|
|
16
|
+
#include "numkong/spatial/serial.h"
|
|
17
|
+
#include "numkong/dots/sapphireamx.h"
|
|
18
|
+
|
|
19
|
+
#if defined(__cplusplus)
|
|
20
|
+
extern "C" {
|
|
21
|
+
#endif
|
|
22
|
+
|
|
23
|
+
#if defined(__clang__)
|
|
24
|
+
#pragma clang attribute push( \
|
|
25
|
+
__attribute__((target( \
|
|
26
|
+
"avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,avx512vbmi,f16c,fma,bmi,bmi2,amx-tile,amx-bf16,amx-int8"))), \
|
|
27
|
+
apply_to = function)
|
|
28
|
+
#elif defined(__GNUC__)
|
|
29
|
+
#pragma GCC push_options
|
|
30
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "avx512vbmi", "f16c", "fma", \
|
|
31
|
+
"bmi", "bmi2", "amx-tile", "amx-bf16", "amx-int8")
|
|
32
|
+
#endif
|
|
33
|
+
|
|
34
|
+
#pragma region Row Finalize Helpers
|
|
35
|
+
|
|
36
|
+
NK_INTERNAL void nk_angulars_row_f32dots_sapphireamx_(nk_f32_t *results, nk_f32_t const *norms, nk_f32_t query_norm_sq,
|
|
37
|
+
nk_size_t count) {
|
|
38
|
+
__m512 query_norm_sq_f32x16 = _mm512_set1_ps(query_norm_sq);
|
|
39
|
+
nk_size_t i = 0;
|
|
40
|
+
for (; i + 16 <= count; i += 16) {
|
|
41
|
+
__m512 dots_f32x16 = _mm512_loadu_ps(results + i);
|
|
42
|
+
__m512 norms_f32x16 = _mm512_loadu_ps(norms + i);
|
|
43
|
+
__m512 products_f32x16 = _mm512_mul_ps(query_norm_sq_f32x16, norms_f32x16);
|
|
44
|
+
__m512 rsqrt_f32x16 = nk_rsqrt_f32x16_skylake_(products_f32x16);
|
|
45
|
+
__m512 normalized_f32x16 = _mm512_mul_ps(dots_f32x16, rsqrt_f32x16);
|
|
46
|
+
__m512 angular_f32x16 = _mm512_sub_ps(_mm512_set1_ps(1.0f), normalized_f32x16);
|
|
47
|
+
_mm512_storeu_ps(results + i, _mm512_max_ps(angular_f32x16, _mm512_setzero_ps()));
|
|
48
|
+
}
|
|
49
|
+
if (i < count) {
|
|
50
|
+
__mmask16 tail = (__mmask16)((1u << (count - i)) - 1);
|
|
51
|
+
__m512 dots_f32x16 = _mm512_maskz_loadu_ps(tail, results + i);
|
|
52
|
+
__m512 norms_f32x16 = _mm512_maskz_loadu_ps(tail, norms + i);
|
|
53
|
+
__m512 products_f32x16 = _mm512_mul_ps(query_norm_sq_f32x16, norms_f32x16);
|
|
54
|
+
__m512 rsqrt_f32x16 = nk_rsqrt_f32x16_skylake_(products_f32x16);
|
|
55
|
+
__m512 normalized_f32x16 = _mm512_mul_ps(dots_f32x16, rsqrt_f32x16);
|
|
56
|
+
__m512 angular_f32x16 = _mm512_sub_ps(_mm512_set1_ps(1.0f), normalized_f32x16);
|
|
57
|
+
_mm512_mask_storeu_ps(results + i, tail, _mm512_max_ps(angular_f32x16, _mm512_setzero_ps()));
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
NK_INTERNAL void nk_euclideans_row_f32dots_sapphireamx_(nk_f32_t *results, nk_f32_t const *norms,
|
|
62
|
+
nk_f32_t query_norm_sq, nk_size_t count) {
|
|
63
|
+
__m512 query_norm_sq_f32x16 = _mm512_set1_ps(query_norm_sq);
|
|
64
|
+
__m512 two_f32x16 = _mm512_set1_ps(2.0f);
|
|
65
|
+
nk_size_t i = 0;
|
|
66
|
+
for (; i + 16 <= count; i += 16) {
|
|
67
|
+
__m512 dots_f32x16 = _mm512_loadu_ps(results + i);
|
|
68
|
+
__m512 norms_f32x16 = _mm512_loadu_ps(norms + i);
|
|
69
|
+
__m512 sum_norms_f32x16 = _mm512_add_ps(query_norm_sq_f32x16, norms_f32x16);
|
|
70
|
+
__m512 dist_sq_f32x16 = _mm512_fnmadd_ps(two_f32x16, dots_f32x16, sum_norms_f32x16);
|
|
71
|
+
dist_sq_f32x16 = _mm512_max_ps(dist_sq_f32x16, _mm512_setzero_ps());
|
|
72
|
+
_mm512_storeu_ps(results + i, _mm512_sqrt_ps(dist_sq_f32x16));
|
|
73
|
+
}
|
|
74
|
+
if (i < count) {
|
|
75
|
+
__mmask16 tail = (__mmask16)((1u << (count - i)) - 1);
|
|
76
|
+
__m512 dots_f32x16 = _mm512_maskz_loadu_ps(tail, results + i);
|
|
77
|
+
__m512 norms_f32x16 = _mm512_maskz_loadu_ps(tail, norms + i);
|
|
78
|
+
__m512 sum_norms_f32x16 = _mm512_add_ps(query_norm_sq_f32x16, norms_f32x16);
|
|
79
|
+
__m512 dist_sq_f32x16 = _mm512_fnmadd_ps(two_f32x16, dots_f32x16, sum_norms_f32x16);
|
|
80
|
+
dist_sq_f32x16 = _mm512_max_ps(dist_sq_f32x16, _mm512_setzero_ps());
|
|
81
|
+
_mm512_mask_storeu_ps(results + i, tail, _mm512_sqrt_ps(dist_sq_f32x16));
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
NK_INTERNAL void nk_angulars_row_i32dots_sapphireamx_(nk_f32_t *results, nk_u32_t const *norms, nk_f32_t query_norm_sq,
|
|
86
|
+
nk_size_t count) {
|
|
87
|
+
nk_i32_t *results_i32 = (nk_i32_t *)results;
|
|
88
|
+
__m512 query_norm_sq_f32x16 = _mm512_set1_ps(query_norm_sq);
|
|
89
|
+
nk_size_t i = 0;
|
|
90
|
+
for (; i + 16 <= count; i += 16) {
|
|
91
|
+
__m512 dots_f32x16 = _mm512_cvtepi32_ps(_mm512_loadu_si512(results_i32 + i));
|
|
92
|
+
__m512 norms_f32x16 = _mm512_cvtepu32_ps(_mm512_loadu_si512((__m512i const *)(norms + i)));
|
|
93
|
+
__m512 products_f32x16 = _mm512_mul_ps(query_norm_sq_f32x16, norms_f32x16);
|
|
94
|
+
__m512 rsqrt_f32x16 = nk_rsqrt_f32x16_skylake_(products_f32x16);
|
|
95
|
+
__m512 normalized_f32x16 = _mm512_mul_ps(dots_f32x16, rsqrt_f32x16);
|
|
96
|
+
__m512 angular_f32x16 = _mm512_sub_ps(_mm512_set1_ps(1.0f), normalized_f32x16);
|
|
97
|
+
_mm512_storeu_ps(results + i, _mm512_max_ps(angular_f32x16, _mm512_setzero_ps()));
|
|
98
|
+
}
|
|
99
|
+
if (i < count) {
|
|
100
|
+
__mmask16 tail = (__mmask16)((1u << (count - i)) - 1);
|
|
101
|
+
__m512 dots_f32x16 = _mm512_cvtepi32_ps(_mm512_maskz_loadu_epi32(tail, results_i32 + i));
|
|
102
|
+
__m512 norms_f32x16 = _mm512_cvtepu32_ps(_mm512_maskz_loadu_epi32(tail, norms + i));
|
|
103
|
+
__m512 products_f32x16 = _mm512_mul_ps(query_norm_sq_f32x16, norms_f32x16);
|
|
104
|
+
__m512 rsqrt_f32x16 = nk_rsqrt_f32x16_skylake_(products_f32x16);
|
|
105
|
+
__m512 normalized_f32x16 = _mm512_mul_ps(dots_f32x16, rsqrt_f32x16);
|
|
106
|
+
__m512 angular_f32x16 = _mm512_sub_ps(_mm512_set1_ps(1.0f), normalized_f32x16);
|
|
107
|
+
_mm512_mask_storeu_ps(results + i, tail, _mm512_max_ps(angular_f32x16, _mm512_setzero_ps()));
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
NK_INTERNAL void nk_euclideans_row_i32dots_sapphireamx_(nk_f32_t *results, nk_u32_t const *norms,
|
|
112
|
+
nk_f32_t query_norm_sq, nk_size_t count) {
|
|
113
|
+
nk_i32_t *results_i32 = (nk_i32_t *)results;
|
|
114
|
+
__m512 query_norm_sq_f32x16 = _mm512_set1_ps(query_norm_sq);
|
|
115
|
+
__m512 two_f32x16 = _mm512_set1_ps(2.0f);
|
|
116
|
+
nk_size_t i = 0;
|
|
117
|
+
for (; i + 16 <= count; i += 16) {
|
|
118
|
+
__m512 dots_f32x16 = _mm512_cvtepi32_ps(_mm512_loadu_si512(results_i32 + i));
|
|
119
|
+
__m512 norms_f32x16 = _mm512_cvtepu32_ps(_mm512_loadu_si512((__m512i const *)(norms + i)));
|
|
120
|
+
__m512 sum_norms_f32x16 = _mm512_add_ps(query_norm_sq_f32x16, norms_f32x16);
|
|
121
|
+
__m512 dist_sq_f32x16 = _mm512_fnmadd_ps(two_f32x16, dots_f32x16, sum_norms_f32x16);
|
|
122
|
+
dist_sq_f32x16 = _mm512_max_ps(dist_sq_f32x16, _mm512_setzero_ps());
|
|
123
|
+
_mm512_storeu_ps(results + i, _mm512_sqrt_ps(dist_sq_f32x16));
|
|
124
|
+
}
|
|
125
|
+
if (i < count) {
|
|
126
|
+
__mmask16 tail = (__mmask16)((1u << (count - i)) - 1);
|
|
127
|
+
__m512 dots_f32x16 = _mm512_cvtepi32_ps(_mm512_maskz_loadu_epi32(tail, results_i32 + i));
|
|
128
|
+
__m512 norms_f32x16 = _mm512_cvtepu32_ps(_mm512_maskz_loadu_epi32(tail, norms + i));
|
|
129
|
+
__m512 sum_norms_f32x16 = _mm512_add_ps(query_norm_sq_f32x16, norms_f32x16);
|
|
130
|
+
__m512 dist_sq_f32x16 = _mm512_fnmadd_ps(two_f32x16, dots_f32x16, sum_norms_f32x16);
|
|
131
|
+
dist_sq_f32x16 = _mm512_max_ps(dist_sq_f32x16, _mm512_setzero_ps());
|
|
132
|
+
_mm512_mask_storeu_ps(results + i, tail, _mm512_sqrt_ps(dist_sq_f32x16));
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
NK_INTERNAL void nk_angulars_row_u32dots_sapphireamx_(nk_f32_t *results, nk_u32_t const *norms, nk_f32_t query_norm_sq,
|
|
137
|
+
nk_size_t count) {
|
|
138
|
+
nk_u32_t *results_u32 = (nk_u32_t *)results;
|
|
139
|
+
__m512 query_norm_sq_f32x16 = _mm512_set1_ps(query_norm_sq);
|
|
140
|
+
nk_size_t i = 0;
|
|
141
|
+
for (; i + 16 <= count; i += 16) {
|
|
142
|
+
__m512 dots_f32x16 = _mm512_cvtepu32_ps(_mm512_loadu_si512((__m512i const *)(results_u32 + i)));
|
|
143
|
+
__m512 norms_f32x16 = _mm512_cvtepu32_ps(_mm512_loadu_si512((__m512i const *)(norms + i)));
|
|
144
|
+
__m512 products_f32x16 = _mm512_mul_ps(query_norm_sq_f32x16, norms_f32x16);
|
|
145
|
+
__m512 rsqrt_f32x16 = nk_rsqrt_f32x16_skylake_(products_f32x16);
|
|
146
|
+
__m512 normalized_f32x16 = _mm512_mul_ps(dots_f32x16, rsqrt_f32x16);
|
|
147
|
+
__m512 angular_f32x16 = _mm512_sub_ps(_mm512_set1_ps(1.0f), normalized_f32x16);
|
|
148
|
+
_mm512_storeu_ps(results + i, _mm512_max_ps(angular_f32x16, _mm512_setzero_ps()));
|
|
149
|
+
}
|
|
150
|
+
if (i < count) {
|
|
151
|
+
__mmask16 tail = (__mmask16)((1u << (count - i)) - 1);
|
|
152
|
+
__m512 dots_f32x16 = _mm512_cvtepu32_ps(_mm512_maskz_loadu_epi32(tail, results_u32 + i));
|
|
153
|
+
__m512 norms_f32x16 = _mm512_cvtepu32_ps(_mm512_maskz_loadu_epi32(tail, norms + i));
|
|
154
|
+
__m512 products_f32x16 = _mm512_mul_ps(query_norm_sq_f32x16, norms_f32x16);
|
|
155
|
+
__m512 rsqrt_f32x16 = nk_rsqrt_f32x16_skylake_(products_f32x16);
|
|
156
|
+
__m512 normalized_f32x16 = _mm512_mul_ps(dots_f32x16, rsqrt_f32x16);
|
|
157
|
+
__m512 angular_f32x16 = _mm512_sub_ps(_mm512_set1_ps(1.0f), normalized_f32x16);
|
|
158
|
+
_mm512_mask_storeu_ps(results + i, tail, _mm512_max_ps(angular_f32x16, _mm512_setzero_ps()));
|
|
159
|
+
}
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
NK_INTERNAL void nk_euclideans_row_u32dots_sapphireamx_(nk_f32_t *results, nk_u32_t const *norms,
|
|
163
|
+
nk_f32_t query_norm_sq, nk_size_t count) {
|
|
164
|
+
nk_u32_t *results_u32 = (nk_u32_t *)results;
|
|
165
|
+
__m512 query_norm_sq_f32x16 = _mm512_set1_ps(query_norm_sq);
|
|
166
|
+
__m512 two_f32x16 = _mm512_set1_ps(2.0f);
|
|
167
|
+
nk_size_t i = 0;
|
|
168
|
+
for (; i + 16 <= count; i += 16) {
|
|
169
|
+
__m512 dots_f32x16 = _mm512_cvtepu32_ps(_mm512_loadu_si512((__m512i const *)(results_u32 + i)));
|
|
170
|
+
__m512 norms_f32x16 = _mm512_cvtepu32_ps(_mm512_loadu_si512((__m512i const *)(norms + i)));
|
|
171
|
+
__m512 sum_norms_f32x16 = _mm512_add_ps(query_norm_sq_f32x16, norms_f32x16);
|
|
172
|
+
__m512 dist_sq_f32x16 = _mm512_fnmadd_ps(two_f32x16, dots_f32x16, sum_norms_f32x16);
|
|
173
|
+
dist_sq_f32x16 = _mm512_max_ps(dist_sq_f32x16, _mm512_setzero_ps());
|
|
174
|
+
_mm512_storeu_ps(results + i, _mm512_sqrt_ps(dist_sq_f32x16));
|
|
175
|
+
}
|
|
176
|
+
if (i < count) {
|
|
177
|
+
__mmask16 tail = (__mmask16)((1u << (count - i)) - 1);
|
|
178
|
+
__m512 dots_f32x16 = _mm512_cvtepu32_ps(_mm512_maskz_loadu_epi32(tail, results_u32 + i));
|
|
179
|
+
__m512 norms_f32x16 = _mm512_cvtepu32_ps(_mm512_maskz_loadu_epi32(tail, norms + i));
|
|
180
|
+
__m512 sum_norms_f32x16 = _mm512_add_ps(query_norm_sq_f32x16, norms_f32x16);
|
|
181
|
+
__m512 dist_sq_f32x16 = _mm512_fnmadd_ps(two_f32x16, dots_f32x16, sum_norms_f32x16);
|
|
182
|
+
dist_sq_f32x16 = _mm512_max_ps(dist_sq_f32x16, _mm512_setzero_ps());
|
|
183
|
+
_mm512_mask_storeu_ps(results + i, tail, _mm512_sqrt_ps(dist_sq_f32x16));
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
#pragma endregion // Row Finalize Helpers
|
|
188
|
+
|
|
189
|
+
#pragma region BF16 Packed
|
|
190
|
+
|
|
191
|
+
NK_INTERNAL void nk_angulars_packed_bf16_sapphireamx_finalize_(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c,
|
|
192
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
193
|
+
nk_size_t a_stride_elements,
|
|
194
|
+
nk_size_t c_stride_elements) {
|
|
195
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
196
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_byte_offset);
|
|
197
|
+
for (nk_size_t row = 0; row < rows; row++) {
|
|
198
|
+
nk_f32_t query_norm_sq = nk_dots_reduce_sumsq_bf16_(a + row * a_stride_elements, depth);
|
|
199
|
+
nk_angulars_row_f32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
|
|
200
|
+
}
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
NK_PUBLIC void nk_angulars_packed_bf16_sapphireamx( //
|
|
204
|
+
nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
205
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
206
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
207
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_bf16_t);
|
|
208
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
209
|
+
nk_dots_packed_bf16_sapphireamx(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
210
|
+
nk_angulars_packed_bf16_sapphireamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
211
|
+
c_stride_elements);
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
NK_INTERNAL void nk_euclideans_packed_bf16_sapphireamx_finalize_(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c,
|
|
215
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
216
|
+
nk_size_t a_stride_elements,
|
|
217
|
+
nk_size_t c_stride_elements) {
|
|
218
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
219
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_byte_offset);
|
|
220
|
+
for (nk_size_t row = 0; row < rows; row++) {
|
|
221
|
+
nk_f32_t query_norm_sq = nk_dots_reduce_sumsq_bf16_(a + row * a_stride_elements, depth);
|
|
222
|
+
nk_euclideans_row_f32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
NK_PUBLIC void nk_euclideans_packed_bf16_sapphireamx( //
|
|
227
|
+
nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
228
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
229
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
230
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_bf16_t);
|
|
231
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
232
|
+
nk_dots_packed_bf16_sapphireamx(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
233
|
+
nk_euclideans_packed_bf16_sapphireamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
234
|
+
c_stride_elements);
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
#pragma endregion // BF16 Packed
|
|
238
|
+
|
|
239
|
+
#pragma region BF16 Symmetric
|
|
240
|
+
|
|
241
|
+
NK_INTERNAL void nk_angulars_symmetric_bf16_sapphireamx_finalize_(nk_bf16_t const *vectors, nk_size_t n_vectors,
|
|
242
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
243
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
244
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
245
|
+
|
|
246
|
+
// Phase 1: Cache row norms on diagonal
|
|
247
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++)
|
|
248
|
+
result[row * result_stride_elements + row] = nk_dots_reduce_sumsq_bf16_(vectors + row * stride_elements, depth);
|
|
249
|
+
|
|
250
|
+
// Phase 2: 256-column chunks with cached norms
|
|
251
|
+
nk_f32_t column_norms_cache[256];
|
|
252
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
253
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
254
|
+
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
255
|
+
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_bf16_(vectors + col * stride_elements, depth);
|
|
256
|
+
|
|
257
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) {
|
|
258
|
+
nk_f32_t *r_row = result + row * result_stride_elements;
|
|
259
|
+
nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
|
|
260
|
+
if (col_start >= chunk_end) continue;
|
|
261
|
+
nk_angulars_row_f32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
|
|
262
|
+
r_row[row], chunk_end - col_start);
|
|
263
|
+
}
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
// Phase 3: Zero diagonal
|
|
267
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
NK_PUBLIC void nk_angulars_symmetric_bf16_sapphireamx( //
|
|
271
|
+
nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
272
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
273
|
+
nk_size_t const stride_elements = stride / sizeof(nk_bf16_t);
|
|
274
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
275
|
+
nk_dots_symmetric_bf16_sapphireamx(vectors, n_vectors, depth, stride, (nk_f32_t *)result, result_stride, row_start,
|
|
276
|
+
row_count);
|
|
277
|
+
nk_angulars_symmetric_bf16_sapphireamx_finalize_(vectors, n_vectors, depth, stride_elements, result,
|
|
278
|
+
result_stride_elements, row_start, row_count);
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
NK_INTERNAL void nk_euclideans_symmetric_bf16_sapphireamx_finalize_(nk_bf16_t const *vectors, nk_size_t n_vectors,
|
|
282
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
283
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
284
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
285
|
+
|
|
286
|
+
// Phase 1: Cache row norms on diagonal
|
|
287
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++)
|
|
288
|
+
result[row * result_stride_elements + row] = nk_dots_reduce_sumsq_bf16_(vectors + row * stride_elements, depth);
|
|
289
|
+
|
|
290
|
+
// Phase 2: 256-column chunks with cached norms
|
|
291
|
+
nk_f32_t column_norms_cache[256];
|
|
292
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
293
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
294
|
+
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
295
|
+
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_bf16_(vectors + col * stride_elements, depth);
|
|
296
|
+
|
|
297
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) {
|
|
298
|
+
nk_f32_t *r_row = result + row * result_stride_elements;
|
|
299
|
+
nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
|
|
300
|
+
if (col_start >= chunk_end) continue;
|
|
301
|
+
nk_euclideans_row_f32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
|
|
302
|
+
r_row[row], chunk_end - col_start);
|
|
303
|
+
}
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
// Phase 3: Zero diagonal
|
|
307
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
NK_PUBLIC void nk_euclideans_symmetric_bf16_sapphireamx( //
|
|
311
|
+
nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
312
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
313
|
+
nk_size_t const stride_elements = stride / sizeof(nk_bf16_t);
|
|
314
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
315
|
+
nk_dots_symmetric_bf16_sapphireamx(vectors, n_vectors, depth, stride, (nk_f32_t *)result, result_stride, row_start,
|
|
316
|
+
row_count);
|
|
317
|
+
nk_euclideans_symmetric_bf16_sapphireamx_finalize_(vectors, n_vectors, depth, stride_elements, result,
|
|
318
|
+
result_stride_elements, row_start, row_count);
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
#pragma endregion // BF16 Symmetric
|
|
322
|
+
|
|
323
|
+
#pragma region Signed 8-bit Integer Packed
|
|
324
|
+
|
|
325
|
+
NK_INTERNAL void nk_angulars_packed_i8_sapphireamx_finalize_(nk_i8_t const *a, void const *b_packed, nk_f32_t *c,
|
|
326
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
327
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
328
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
329
|
+
nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_byte_offset);
|
|
330
|
+
for (nk_size_t row = 0; row < rows; row++) {
|
|
331
|
+
nk_f32_t query_norm_sq = (nk_f32_t)nk_dots_reduce_sumsq_i8_(a + row * a_stride_elements, depth);
|
|
332
|
+
nk_angulars_row_i32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
|
|
333
|
+
}
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
NK_PUBLIC void nk_angulars_packed_i8_sapphireamx( //
|
|
337
|
+
nk_i8_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
338
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
339
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
340
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i8_t);
|
|
341
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
342
|
+
nk_dots_packed_i8_sapphireamx(a, b_packed, (nk_i32_t *)c, rows, columns, depth, a_stride_in_bytes,
|
|
343
|
+
c_stride_in_bytes);
|
|
344
|
+
nk_angulars_packed_i8_sapphireamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
345
|
+
c_stride_elements);
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
NK_INTERNAL void nk_euclideans_packed_i8_sapphireamx_finalize_(nk_i8_t const *a, void const *b_packed, nk_f32_t *c,
|
|
349
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
350
|
+
nk_size_t a_stride_elements,
|
|
351
|
+
nk_size_t c_stride_elements) {
|
|
352
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
353
|
+
nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_byte_offset);
|
|
354
|
+
for (nk_size_t row = 0; row < rows; row++) {
|
|
355
|
+
nk_f32_t query_norm_sq = (nk_f32_t)nk_dots_reduce_sumsq_i8_(a + row * a_stride_elements, depth);
|
|
356
|
+
nk_euclideans_row_i32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
|
|
357
|
+
}
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
NK_PUBLIC void nk_euclideans_packed_i8_sapphireamx( //
|
|
361
|
+
nk_i8_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
362
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
363
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
364
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_i8_t);
|
|
365
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
366
|
+
nk_dots_packed_i8_sapphireamx(a, b_packed, (nk_i32_t *)c, rows, columns, depth, a_stride_in_bytes,
|
|
367
|
+
c_stride_in_bytes);
|
|
368
|
+
nk_euclideans_packed_i8_sapphireamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
369
|
+
c_stride_elements);
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
#pragma endregion // Signed 8-bit Integer Packed
|
|
373
|
+
|
|
374
|
+
#pragma region Signed 8-bit Integer Symmetric
|
|
375
|
+
|
|
376
|
+
NK_INTERNAL void nk_angulars_symmetric_i8_sapphireamx_finalize_(nk_i8_t const *vectors, nk_size_t n_vectors,
|
|
377
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
378
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
379
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
380
|
+
|
|
381
|
+
// Phase 1: Cache row norms on diagonal (stored as u32 reinterpreted in f32 slot)
|
|
382
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++)
|
|
383
|
+
((nk_u32_t *)(result + row * result_stride_elements))[row] = nk_dots_reduce_sumsq_i8_(
|
|
384
|
+
vectors + row * stride_elements, depth);
|
|
385
|
+
|
|
386
|
+
// Phase 2: 256-column chunks with cached norms
|
|
387
|
+
nk_u32_t column_norms_cache[256];
|
|
388
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
389
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
390
|
+
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
391
|
+
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i8_(vectors + col * stride_elements, depth);
|
|
392
|
+
|
|
393
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) {
|
|
394
|
+
nk_f32_t *r_row = result + row * result_stride_elements;
|
|
395
|
+
nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
|
|
396
|
+
if (col_start >= chunk_end) continue;
|
|
397
|
+
nk_f32_t query_norm_sq_f32 = (nk_f32_t)((nk_u32_t *)r_row)[row];
|
|
398
|
+
nk_angulars_row_i32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
|
|
399
|
+
query_norm_sq_f32, chunk_end - col_start);
|
|
400
|
+
}
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
// Phase 3: Zero diagonal
|
|
404
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
NK_PUBLIC void nk_angulars_symmetric_i8_sapphireamx( //
|
|
408
|
+
nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
409
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
410
|
+
nk_size_t const stride_elements = stride / sizeof(nk_i8_t);
|
|
411
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
412
|
+
nk_dots_symmetric_i8_sapphireamx(vectors, n_vectors, depth, stride, (nk_i32_t *)result, result_stride, row_start,
|
|
413
|
+
row_count);
|
|
414
|
+
nk_angulars_symmetric_i8_sapphireamx_finalize_(vectors, n_vectors, depth, stride_elements, result,
|
|
415
|
+
result_stride_elements, row_start, row_count);
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
NK_INTERNAL void nk_euclideans_symmetric_i8_sapphireamx_finalize_(nk_i8_t const *vectors, nk_size_t n_vectors,
|
|
419
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
420
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
421
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
422
|
+
|
|
423
|
+
// Phase 1: Cache row norms on diagonal (stored as u32 reinterpreted in f32 slot)
|
|
424
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++)
|
|
425
|
+
((nk_u32_t *)(result + row * result_stride_elements))[row] = nk_dots_reduce_sumsq_i8_(
|
|
426
|
+
vectors + row * stride_elements, depth);
|
|
427
|
+
|
|
428
|
+
// Phase 2: 256-column chunks with cached norms
|
|
429
|
+
nk_u32_t column_norms_cache[256];
|
|
430
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
431
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
432
|
+
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
433
|
+
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_i8_(vectors + col * stride_elements, depth);
|
|
434
|
+
|
|
435
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) {
|
|
436
|
+
nk_f32_t *r_row = result + row * result_stride_elements;
|
|
437
|
+
nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
|
|
438
|
+
if (col_start >= chunk_end) continue;
|
|
439
|
+
nk_f32_t query_norm_sq_f32 = (nk_f32_t)((nk_u32_t *)r_row)[row];
|
|
440
|
+
nk_euclideans_row_i32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
|
|
441
|
+
query_norm_sq_f32, chunk_end - col_start);
|
|
442
|
+
}
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
// Phase 3: Zero diagonal
|
|
446
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
NK_PUBLIC void nk_euclideans_symmetric_i8_sapphireamx( //
|
|
450
|
+
nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
451
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
452
|
+
nk_size_t const stride_elements = stride / sizeof(nk_i8_t);
|
|
453
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
454
|
+
nk_dots_symmetric_i8_sapphireamx(vectors, n_vectors, depth, stride, (nk_i32_t *)result, result_stride, row_start,
|
|
455
|
+
row_count);
|
|
456
|
+
nk_euclideans_symmetric_i8_sapphireamx_finalize_(vectors, n_vectors, depth, stride_elements, result,
|
|
457
|
+
result_stride_elements, row_start, row_count);
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
#pragma endregion // Signed 8-bit Integer Symmetric
|
|
461
|
+
|
|
462
|
+
#pragma region Unsigned 8-bit Integer Packed
|
|
463
|
+
|
|
464
|
+
NK_INTERNAL void nk_angulars_packed_u8_sapphireamx_finalize_(nk_u8_t const *a, void const *b_packed, nk_f32_t *c,
|
|
465
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
466
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
467
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
468
|
+
nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_byte_offset);
|
|
469
|
+
for (nk_size_t row = 0; row < rows; row++) {
|
|
470
|
+
nk_f32_t query_norm_sq = (nk_f32_t)nk_dots_reduce_sumsq_u8_(a + row * a_stride_elements, depth);
|
|
471
|
+
nk_angulars_row_u32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
|
|
472
|
+
}
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
NK_PUBLIC void nk_angulars_packed_u8_sapphireamx( //
|
|
476
|
+
nk_u8_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
477
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
478
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
479
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u8_t);
|
|
480
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
481
|
+
nk_dots_packed_u8_sapphireamx(a, b_packed, (nk_u32_t *)c, rows, columns, depth, a_stride_in_bytes,
|
|
482
|
+
c_stride_in_bytes);
|
|
483
|
+
nk_angulars_packed_u8_sapphireamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
484
|
+
c_stride_elements);
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
NK_INTERNAL void nk_euclideans_packed_u8_sapphireamx_finalize_(nk_u8_t const *a, void const *b_packed, nk_f32_t *c,
|
|
488
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
489
|
+
nk_size_t a_stride_elements,
|
|
490
|
+
nk_size_t c_stride_elements) {
|
|
491
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
492
|
+
nk_u32_t const *b_norms = (nk_u32_t const *)((char const *)b_packed + header->norms_byte_offset);
|
|
493
|
+
for (nk_size_t row = 0; row < rows; row++) {
|
|
494
|
+
nk_f32_t query_norm_sq = (nk_f32_t)nk_dots_reduce_sumsq_u8_(a + row * a_stride_elements, depth);
|
|
495
|
+
nk_euclideans_row_u32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
|
|
496
|
+
}
|
|
497
|
+
}
|
|
498
|
+
|
|
499
|
+
NK_PUBLIC void nk_euclideans_packed_u8_sapphireamx( //
|
|
500
|
+
nk_u8_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
501
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
502
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
503
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_u8_t);
|
|
504
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
505
|
+
nk_dots_packed_u8_sapphireamx(a, b_packed, (nk_u32_t *)c, rows, columns, depth, a_stride_in_bytes,
|
|
506
|
+
c_stride_in_bytes);
|
|
507
|
+
nk_euclideans_packed_u8_sapphireamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
508
|
+
c_stride_elements);
|
|
509
|
+
}
|
|
510
|
+
|
|
511
|
+
#pragma endregion // Unsigned 8-bit Integer Packed
|
|
512
|
+
|
|
513
|
+
#pragma region Unsigned 8-bit Integer Symmetric
|
|
514
|
+
|
|
515
|
+
NK_INTERNAL void nk_angulars_symmetric_u8_sapphireamx_finalize_(nk_u8_t const *vectors, nk_size_t n_vectors,
|
|
516
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
517
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
518
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
519
|
+
|
|
520
|
+
// Phase 1: Cache row norms on diagonal (stored as u32 reinterpreted in f32 slot)
|
|
521
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++)
|
|
522
|
+
((nk_u32_t *)(result + row * result_stride_elements))[row] = nk_dots_reduce_sumsq_u8_(
|
|
523
|
+
vectors + row * stride_elements, depth);
|
|
524
|
+
|
|
525
|
+
// Phase 2: 256-column chunks with cached norms
|
|
526
|
+
nk_u32_t column_norms_cache[256];
|
|
527
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
528
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
529
|
+
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
530
|
+
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u8_(vectors + col * stride_elements, depth);
|
|
531
|
+
|
|
532
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) {
|
|
533
|
+
nk_f32_t *r_row = result + row * result_stride_elements;
|
|
534
|
+
nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
|
|
535
|
+
if (col_start >= chunk_end) continue;
|
|
536
|
+
nk_f32_t query_norm_sq_f32 = (nk_f32_t)((nk_u32_t *)r_row)[row];
|
|
537
|
+
nk_angulars_row_u32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
|
|
538
|
+
query_norm_sq_f32, chunk_end - col_start);
|
|
539
|
+
}
|
|
540
|
+
}
|
|
541
|
+
|
|
542
|
+
// Phase 3: Zero diagonal
|
|
543
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
544
|
+
}
|
|
545
|
+
|
|
546
|
+
NK_PUBLIC void nk_angulars_symmetric_u8_sapphireamx( //
|
|
547
|
+
nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
548
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
549
|
+
nk_size_t const stride_elements = stride / sizeof(nk_u8_t);
|
|
550
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
551
|
+
nk_dots_symmetric_u8_sapphireamx(vectors, n_vectors, depth, stride, (nk_u32_t *)result, result_stride, row_start,
|
|
552
|
+
row_count);
|
|
553
|
+
nk_angulars_symmetric_u8_sapphireamx_finalize_(vectors, n_vectors, depth, stride_elements, result,
|
|
554
|
+
result_stride_elements, row_start, row_count);
|
|
555
|
+
}
|
|
556
|
+
|
|
557
|
+
NK_INTERNAL void nk_euclideans_symmetric_u8_sapphireamx_finalize_(nk_u8_t const *vectors, nk_size_t n_vectors,
|
|
558
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
559
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
560
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
561
|
+
|
|
562
|
+
// Phase 1: Cache row norms on diagonal (stored as u32 reinterpreted in f32 slot)
|
|
563
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++)
|
|
564
|
+
((nk_u32_t *)(result + row * result_stride_elements))[row] = nk_dots_reduce_sumsq_u8_(
|
|
565
|
+
vectors + row * stride_elements, depth);
|
|
566
|
+
|
|
567
|
+
// Phase 2: 256-column chunks with cached norms
|
|
568
|
+
nk_u32_t column_norms_cache[256];
|
|
569
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
570
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
571
|
+
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
572
|
+
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_u8_(vectors + col * stride_elements, depth);
|
|
573
|
+
|
|
574
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) {
|
|
575
|
+
nk_f32_t *r_row = result + row * result_stride_elements;
|
|
576
|
+
nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
|
|
577
|
+
if (col_start >= chunk_end) continue;
|
|
578
|
+
nk_f32_t query_norm_sq_f32 = (nk_f32_t)((nk_u32_t *)r_row)[row];
|
|
579
|
+
nk_euclideans_row_u32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
|
|
580
|
+
query_norm_sq_f32, chunk_end - col_start);
|
|
581
|
+
}
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
// Phase 3: Zero diagonal
|
|
585
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
586
|
+
}
|
|
587
|
+
|
|
588
|
+
NK_PUBLIC void nk_euclideans_symmetric_u8_sapphireamx( //
|
|
589
|
+
nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
590
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
591
|
+
nk_size_t const stride_elements = stride / sizeof(nk_u8_t);
|
|
592
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
593
|
+
nk_dots_symmetric_u8_sapphireamx(vectors, n_vectors, depth, stride, (nk_u32_t *)result, result_stride, row_start,
|
|
594
|
+
row_count);
|
|
595
|
+
nk_euclideans_symmetric_u8_sapphireamx_finalize_(vectors, n_vectors, depth, stride_elements, result,
|
|
596
|
+
result_stride_elements, row_start, row_count);
|
|
597
|
+
}
|
|
598
|
+
|
|
599
|
+
#pragma endregion // Unsigned 8-bit Integer Symmetric
|
|
600
|
+
|
|
601
|
+
#pragma region E4M3 Packed
|
|
602
|
+
|
|
603
|
+
NK_INTERNAL void nk_angulars_packed_e4m3_sapphireamx_finalize_(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c,
|
|
604
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
605
|
+
nk_size_t a_stride_elements,
|
|
606
|
+
nk_size_t c_stride_elements) {
|
|
607
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
608
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_byte_offset);
|
|
609
|
+
for (nk_size_t row = 0; row < rows; row++) {
|
|
610
|
+
nk_f32_t query_norm_sq = nk_dots_reduce_sumsq_e4m3_(a + row * a_stride_elements, depth);
|
|
611
|
+
nk_angulars_row_f32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
|
|
612
|
+
}
|
|
613
|
+
}
|
|
614
|
+
|
|
615
|
+
NK_PUBLIC void nk_angulars_packed_e4m3_sapphireamx( //
|
|
616
|
+
nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
617
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
618
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
619
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e4m3_t);
|
|
620
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
621
|
+
nk_dots_packed_e4m3_sapphireamx(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
622
|
+
nk_angulars_packed_e4m3_sapphireamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
623
|
+
c_stride_elements);
|
|
624
|
+
}
|
|
625
|
+
|
|
626
|
+
NK_INTERNAL void nk_euclideans_packed_e4m3_sapphireamx_finalize_(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c,
|
|
627
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
628
|
+
nk_size_t a_stride_elements,
|
|
629
|
+
nk_size_t c_stride_elements) {
|
|
630
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
631
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_byte_offset);
|
|
632
|
+
for (nk_size_t row = 0; row < rows; row++) {
|
|
633
|
+
nk_f32_t query_norm_sq = nk_dots_reduce_sumsq_e4m3_(a + row * a_stride_elements, depth);
|
|
634
|
+
nk_euclideans_row_f32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
|
|
635
|
+
}
|
|
636
|
+
}
|
|
637
|
+
|
|
638
|
+
NK_PUBLIC void nk_euclideans_packed_e4m3_sapphireamx( //
|
|
639
|
+
nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
640
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
641
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
642
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e4m3_t);
|
|
643
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
644
|
+
nk_dots_packed_e4m3_sapphireamx(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
645
|
+
nk_euclideans_packed_e4m3_sapphireamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
646
|
+
c_stride_elements);
|
|
647
|
+
}
|
|
648
|
+
|
|
649
|
+
#pragma endregion // E4M3 Packed
|
|
650
|
+
|
|
651
|
+
#pragma region E5M2 Packed
|
|
652
|
+
|
|
653
|
+
NK_INTERNAL void nk_angulars_packed_e5m2_sapphireamx_finalize_(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
654
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
655
|
+
nk_size_t a_stride_elements,
|
|
656
|
+
nk_size_t c_stride_elements) {
|
|
657
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
658
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_byte_offset);
|
|
659
|
+
for (nk_size_t row = 0; row < rows; row++) {
|
|
660
|
+
nk_f32_t query_norm_sq = nk_dots_reduce_sumsq_e5m2_(a + row * a_stride_elements, depth);
|
|
661
|
+
nk_angulars_row_f32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
|
|
662
|
+
}
|
|
663
|
+
}
|
|
664
|
+
|
|
665
|
+
NK_PUBLIC void nk_angulars_packed_e5m2_sapphireamx( //
|
|
666
|
+
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
667
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
668
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
669
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e5m2_t);
|
|
670
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
671
|
+
nk_dots_packed_e5m2_sapphireamx(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
672
|
+
nk_angulars_packed_e5m2_sapphireamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
673
|
+
c_stride_elements);
|
|
674
|
+
}
|
|
675
|
+
|
|
676
|
+
NK_INTERNAL void nk_euclideans_packed_e5m2_sapphireamx_finalize_(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
677
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
678
|
+
nk_size_t a_stride_elements,
|
|
679
|
+
nk_size_t c_stride_elements) {
|
|
680
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
681
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_byte_offset);
|
|
682
|
+
for (nk_size_t row = 0; row < rows; row++) {
|
|
683
|
+
nk_f32_t query_norm_sq = nk_dots_reduce_sumsq_e5m2_(a + row * a_stride_elements, depth);
|
|
684
|
+
nk_euclideans_row_f32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
|
|
685
|
+
}
|
|
686
|
+
}
|
|
687
|
+
|
|
688
|
+
NK_PUBLIC void nk_euclideans_packed_e5m2_sapphireamx( //
|
|
689
|
+
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
690
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
691
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
692
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e5m2_t);
|
|
693
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
694
|
+
nk_dots_packed_e5m2_sapphireamx(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
695
|
+
nk_euclideans_packed_e5m2_sapphireamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
696
|
+
c_stride_elements);
|
|
697
|
+
}
|
|
698
|
+
|
|
699
|
+
#pragma endregion // E5M2 Packed
|
|
700
|
+
|
|
701
|
+
#pragma region E5M2 Symmetric
|
|
702
|
+
|
|
703
|
+
NK_INTERNAL void nk_angulars_symmetric_e5m2_sapphireamx_finalize_(nk_e5m2_t const *vectors, nk_size_t n_vectors,
|
|
704
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
705
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
706
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
707
|
+
|
|
708
|
+
// Phase 1: Cache row norms on diagonal
|
|
709
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++)
|
|
710
|
+
result[row * result_stride_elements + row] = nk_dots_reduce_sumsq_e5m2_(vectors + row * stride_elements, depth);
|
|
711
|
+
|
|
712
|
+
// Phase 2: 256-column chunks with cached norms
|
|
713
|
+
nk_f32_t column_norms_cache[256];
|
|
714
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
715
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
716
|
+
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
717
|
+
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e5m2_(vectors + col * stride_elements, depth);
|
|
718
|
+
|
|
719
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) {
|
|
720
|
+
nk_f32_t *r_row = result + row * result_stride_elements;
|
|
721
|
+
nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
|
|
722
|
+
if (col_start >= chunk_end) continue;
|
|
723
|
+
nk_angulars_row_f32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
|
|
724
|
+
r_row[row], chunk_end - col_start);
|
|
725
|
+
}
|
|
726
|
+
}
|
|
727
|
+
|
|
728
|
+
// Phase 3: Zero diagonal
|
|
729
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
730
|
+
}
|
|
731
|
+
|
|
732
|
+
NK_PUBLIC void nk_angulars_symmetric_e5m2_sapphireamx( //
|
|
733
|
+
nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
734
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
735
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e5m2_t);
|
|
736
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
737
|
+
nk_dots_symmetric_e5m2_sapphireamx(vectors, n_vectors, depth, stride, (nk_f32_t *)result, result_stride, row_start,
|
|
738
|
+
row_count);
|
|
739
|
+
nk_angulars_symmetric_e5m2_sapphireamx_finalize_(vectors, n_vectors, depth, stride_elements, result,
|
|
740
|
+
result_stride_elements, row_start, row_count);
|
|
741
|
+
}
|
|
742
|
+
|
|
743
|
+
NK_INTERNAL void nk_euclideans_symmetric_e5m2_sapphireamx_finalize_(nk_e5m2_t const *vectors, nk_size_t n_vectors,
|
|
744
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
745
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
746
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
747
|
+
|
|
748
|
+
// Phase 1: Cache row norms on diagonal
|
|
749
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++)
|
|
750
|
+
result[row * result_stride_elements + row] = nk_dots_reduce_sumsq_e5m2_(vectors + row * stride_elements, depth);
|
|
751
|
+
|
|
752
|
+
// Phase 2: 256-column chunks with cached norms
|
|
753
|
+
nk_f32_t column_norms_cache[256];
|
|
754
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
755
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
756
|
+
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
757
|
+
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e5m2_(vectors + col * stride_elements, depth);
|
|
758
|
+
|
|
759
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) {
|
|
760
|
+
nk_f32_t *r_row = result + row * result_stride_elements;
|
|
761
|
+
nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
|
|
762
|
+
if (col_start >= chunk_end) continue;
|
|
763
|
+
nk_euclideans_row_f32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
|
|
764
|
+
r_row[row], chunk_end - col_start);
|
|
765
|
+
}
|
|
766
|
+
}
|
|
767
|
+
|
|
768
|
+
// Phase 3: Zero diagonal
|
|
769
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
770
|
+
}
|
|
771
|
+
|
|
772
|
+
NK_PUBLIC void nk_euclideans_symmetric_e5m2_sapphireamx( //
|
|
773
|
+
nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
774
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
775
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e5m2_t);
|
|
776
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
777
|
+
nk_dots_symmetric_e5m2_sapphireamx(vectors, n_vectors, depth, stride, (nk_f32_t *)result, result_stride, row_start,
|
|
778
|
+
row_count);
|
|
779
|
+
nk_euclideans_symmetric_e5m2_sapphireamx_finalize_(vectors, n_vectors, depth, stride_elements, result,
|
|
780
|
+
result_stride_elements, row_start, row_count);
|
|
781
|
+
}
|
|
782
|
+
|
|
783
|
+
#pragma endregion // E5M2 Symmetric
|
|
784
|
+
|
|
785
|
+
#pragma region E4M3 Symmetric
|
|
786
|
+
|
|
787
|
+
NK_INTERNAL void nk_angulars_symmetric_e4m3_sapphireamx_finalize_(nk_e4m3_t const *vectors, nk_size_t n_vectors,
|
|
788
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
789
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
790
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
791
|
+
|
|
792
|
+
// Phase 1: Cache row norms on diagonal
|
|
793
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++)
|
|
794
|
+
result[row * result_stride_elements + row] = nk_dots_reduce_sumsq_e4m3_(vectors + row * stride_elements, depth);
|
|
795
|
+
|
|
796
|
+
// Phase 2: 256-column chunks with cached norms
|
|
797
|
+
nk_f32_t column_norms_cache[256];
|
|
798
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
799
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
800
|
+
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
801
|
+
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e4m3_(vectors + col * stride_elements, depth);
|
|
802
|
+
|
|
803
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) {
|
|
804
|
+
nk_f32_t *r_row = result + row * result_stride_elements;
|
|
805
|
+
nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
|
|
806
|
+
if (col_start >= chunk_end) continue;
|
|
807
|
+
nk_angulars_row_f32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
|
|
808
|
+
r_row[row], chunk_end - col_start);
|
|
809
|
+
}
|
|
810
|
+
}
|
|
811
|
+
|
|
812
|
+
// Phase 3: Zero diagonal
|
|
813
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
814
|
+
}
|
|
815
|
+
|
|
816
|
+
NK_PUBLIC void nk_angulars_symmetric_e4m3_sapphireamx( //
|
|
817
|
+
nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
818
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
819
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e4m3_t);
|
|
820
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
821
|
+
nk_dots_symmetric_e4m3_sapphireamx(vectors, n_vectors, depth, stride, (nk_f32_t *)result, result_stride, row_start,
|
|
822
|
+
row_count);
|
|
823
|
+
nk_angulars_symmetric_e4m3_sapphireamx_finalize_(vectors, n_vectors, depth, stride_elements, result,
|
|
824
|
+
result_stride_elements, row_start, row_count);
|
|
825
|
+
}
|
|
826
|
+
|
|
827
|
+
NK_INTERNAL void nk_euclideans_symmetric_e4m3_sapphireamx_finalize_(nk_e4m3_t const *vectors, nk_size_t n_vectors,
|
|
828
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
829
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
830
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
831
|
+
|
|
832
|
+
// Phase 1: Cache row norms on diagonal
|
|
833
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++)
|
|
834
|
+
result[row * result_stride_elements + row] = nk_dots_reduce_sumsq_e4m3_(vectors + row * stride_elements, depth);
|
|
835
|
+
|
|
836
|
+
// Phase 2: 256-column chunks with cached norms
|
|
837
|
+
nk_f32_t column_norms_cache[256];
|
|
838
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
839
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
840
|
+
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
841
|
+
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e4m3_(vectors + col * stride_elements, depth);
|
|
842
|
+
|
|
843
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) {
|
|
844
|
+
nk_f32_t *r_row = result + row * result_stride_elements;
|
|
845
|
+
nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
|
|
846
|
+
if (col_start >= chunk_end) continue;
|
|
847
|
+
nk_euclideans_row_f32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
|
|
848
|
+
r_row[row], chunk_end - col_start);
|
|
849
|
+
}
|
|
850
|
+
}
|
|
851
|
+
|
|
852
|
+
// Phase 3: Zero diagonal
|
|
853
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
854
|
+
}
|
|
855
|
+
|
|
856
|
+
NK_PUBLIC void nk_euclideans_symmetric_e4m3_sapphireamx( //
|
|
857
|
+
nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
858
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
859
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e4m3_t);
|
|
860
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
861
|
+
nk_dots_symmetric_e4m3_sapphireamx(vectors, n_vectors, depth, stride, (nk_f32_t *)result, result_stride, row_start,
|
|
862
|
+
row_count);
|
|
863
|
+
nk_euclideans_symmetric_e4m3_sapphireamx_finalize_(vectors, n_vectors, depth, stride_elements, result,
|
|
864
|
+
result_stride_elements, row_start, row_count);
|
|
865
|
+
}
|
|
866
|
+
|
|
867
|
+
#pragma endregion // E4M3 Symmetric
|
|
868
|
+
|
|
869
|
+
#pragma region E2M3 Packed
|
|
870
|
+
|
|
871
|
+
NK_INTERNAL void nk_angulars_packed_e2m3_sapphireamx_finalize_(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c,
|
|
872
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
873
|
+
nk_size_t a_stride_elements,
|
|
874
|
+
nk_size_t c_stride_elements) {
|
|
875
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
876
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_byte_offset);
|
|
877
|
+
for (nk_size_t row = 0; row < rows; row++) {
|
|
878
|
+
nk_f32_t query_norm_sq = nk_dots_reduce_sumsq_e2m3_(a + row * a_stride_elements, depth);
|
|
879
|
+
nk_angulars_row_f32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
|
|
880
|
+
}
|
|
881
|
+
}
|
|
882
|
+
|
|
883
|
+
NK_PUBLIC void nk_angulars_packed_e2m3_sapphireamx( //
|
|
884
|
+
nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
885
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
886
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
887
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e2m3_t);
|
|
888
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
889
|
+
nk_dots_packed_e2m3_sapphireamx(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
890
|
+
nk_angulars_packed_e2m3_sapphireamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
891
|
+
c_stride_elements);
|
|
892
|
+
}
|
|
893
|
+
|
|
894
|
+
NK_INTERNAL void nk_euclideans_packed_e2m3_sapphireamx_finalize_(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c,
|
|
895
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
896
|
+
nk_size_t a_stride_elements,
|
|
897
|
+
nk_size_t c_stride_elements) {
|
|
898
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
899
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_byte_offset);
|
|
900
|
+
for (nk_size_t row = 0; row < rows; row++) {
|
|
901
|
+
nk_f32_t query_norm_sq = nk_dots_reduce_sumsq_e2m3_(a + row * a_stride_elements, depth);
|
|
902
|
+
nk_euclideans_row_f32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
|
|
903
|
+
}
|
|
904
|
+
}
|
|
905
|
+
|
|
906
|
+
NK_PUBLIC void nk_euclideans_packed_e2m3_sapphireamx( //
|
|
907
|
+
nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
908
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
909
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
910
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e2m3_t);
|
|
911
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
912
|
+
nk_dots_packed_e2m3_sapphireamx(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
913
|
+
nk_euclideans_packed_e2m3_sapphireamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
914
|
+
c_stride_elements);
|
|
915
|
+
}
|
|
916
|
+
|
|
917
|
+
#pragma endregion // E2M3 Packed
|
|
918
|
+
|
|
919
|
+
#pragma region E2M3 Symmetric
|
|
920
|
+
|
|
921
|
+
NK_INTERNAL void nk_angulars_symmetric_e2m3_sapphireamx_finalize_(nk_e2m3_t const *vectors, nk_size_t n_vectors,
|
|
922
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
923
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
924
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
925
|
+
|
|
926
|
+
// Phase 1: Cache row norms on diagonal
|
|
927
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++)
|
|
928
|
+
result[row * result_stride_elements + row] = nk_dots_reduce_sumsq_e2m3_(vectors + row * stride_elements, depth);
|
|
929
|
+
|
|
930
|
+
// Phase 2: 256-column chunks with cached norms
|
|
931
|
+
nk_f32_t column_norms_cache[256];
|
|
932
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
933
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
934
|
+
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
935
|
+
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e2m3_(vectors + col * stride_elements, depth);
|
|
936
|
+
|
|
937
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) {
|
|
938
|
+
nk_f32_t *r_row = result + row * result_stride_elements;
|
|
939
|
+
nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
|
|
940
|
+
if (col_start >= chunk_end) continue;
|
|
941
|
+
nk_angulars_row_f32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
|
|
942
|
+
r_row[row], chunk_end - col_start);
|
|
943
|
+
}
|
|
944
|
+
}
|
|
945
|
+
|
|
946
|
+
// Phase 3: Zero diagonal
|
|
947
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
948
|
+
}
|
|
949
|
+
|
|
950
|
+
NK_PUBLIC void nk_angulars_symmetric_e2m3_sapphireamx( //
|
|
951
|
+
nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
952
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
953
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e2m3_t);
|
|
954
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
955
|
+
nk_dots_symmetric_e2m3_sapphireamx(vectors, n_vectors, depth, stride, (nk_f32_t *)result, result_stride, row_start,
|
|
956
|
+
row_count);
|
|
957
|
+
nk_angulars_symmetric_e2m3_sapphireamx_finalize_(vectors, n_vectors, depth, stride_elements, result,
|
|
958
|
+
result_stride_elements, row_start, row_count);
|
|
959
|
+
}
|
|
960
|
+
|
|
961
|
+
NK_INTERNAL void nk_euclideans_symmetric_e2m3_sapphireamx_finalize_(nk_e2m3_t const *vectors, nk_size_t n_vectors,
|
|
962
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
963
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
964
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
965
|
+
|
|
966
|
+
// Phase 1: Cache row norms on diagonal
|
|
967
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++)
|
|
968
|
+
result[row * result_stride_elements + row] = nk_dots_reduce_sumsq_e2m3_(vectors + row * stride_elements, depth);
|
|
969
|
+
|
|
970
|
+
// Phase 2: 256-column chunks with cached norms
|
|
971
|
+
nk_f32_t column_norms_cache[256];
|
|
972
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
973
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
974
|
+
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
975
|
+
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e2m3_(vectors + col * stride_elements, depth);
|
|
976
|
+
|
|
977
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) {
|
|
978
|
+
nk_f32_t *r_row = result + row * result_stride_elements;
|
|
979
|
+
nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
|
|
980
|
+
if (col_start >= chunk_end) continue;
|
|
981
|
+
nk_euclideans_row_f32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
|
|
982
|
+
r_row[row], chunk_end - col_start);
|
|
983
|
+
}
|
|
984
|
+
}
|
|
985
|
+
|
|
986
|
+
// Phase 3: Zero diagonal
|
|
987
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
988
|
+
}
|
|
989
|
+
|
|
990
|
+
NK_PUBLIC void nk_euclideans_symmetric_e2m3_sapphireamx( //
|
|
991
|
+
nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
992
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
993
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e2m3_t);
|
|
994
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
995
|
+
nk_dots_symmetric_e2m3_sapphireamx(vectors, n_vectors, depth, stride, (nk_f32_t *)result, result_stride, row_start,
|
|
996
|
+
row_count);
|
|
997
|
+
nk_euclideans_symmetric_e2m3_sapphireamx_finalize_(vectors, n_vectors, depth, stride_elements, result,
|
|
998
|
+
result_stride_elements, row_start, row_count);
|
|
999
|
+
}
|
|
1000
|
+
|
|
1001
|
+
#pragma endregion // E2M3 Symmetric
|
|
1002
|
+
|
|
1003
|
+
#pragma region E3M2 Packed
|
|
1004
|
+
|
|
1005
|
+
NK_INTERNAL void nk_angulars_packed_e3m2_sapphireamx_finalize_(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1006
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1007
|
+
nk_size_t a_stride_elements,
|
|
1008
|
+
nk_size_t c_stride_elements) {
|
|
1009
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
1010
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_byte_offset);
|
|
1011
|
+
for (nk_size_t row = 0; row < rows; row++) {
|
|
1012
|
+
nk_f32_t query_norm_sq = nk_dots_reduce_sumsq_e3m2_(a + row * a_stride_elements, depth);
|
|
1013
|
+
nk_angulars_row_f32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
|
|
1014
|
+
}
|
|
1015
|
+
}
|
|
1016
|
+
|
|
1017
|
+
NK_PUBLIC void nk_angulars_packed_e3m2_sapphireamx( //
|
|
1018
|
+
nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1019
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1020
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1021
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e3m2_t);
|
|
1022
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1023
|
+
nk_dots_packed_e3m2_sapphireamx(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
1024
|
+
nk_angulars_packed_e3m2_sapphireamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
1025
|
+
c_stride_elements);
|
|
1026
|
+
}
|
|
1027
|
+
|
|
1028
|
+
NK_INTERNAL void nk_euclideans_packed_e3m2_sapphireamx_finalize_(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
1029
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1030
|
+
nk_size_t a_stride_elements,
|
|
1031
|
+
nk_size_t c_stride_elements) {
|
|
1032
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
1033
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_byte_offset);
|
|
1034
|
+
for (nk_size_t row = 0; row < rows; row++) {
|
|
1035
|
+
nk_f32_t query_norm_sq = nk_dots_reduce_sumsq_e3m2_(a + row * a_stride_elements, depth);
|
|
1036
|
+
nk_euclideans_row_f32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
|
|
1037
|
+
}
|
|
1038
|
+
}
|
|
1039
|
+
|
|
1040
|
+
NK_PUBLIC void nk_euclideans_packed_e3m2_sapphireamx( //
|
|
1041
|
+
nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
1042
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
1043
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1044
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_e3m2_t);
|
|
1045
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
1046
|
+
nk_dots_packed_e3m2_sapphireamx(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
1047
|
+
nk_euclideans_packed_e3m2_sapphireamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
1048
|
+
c_stride_elements);
|
|
1049
|
+
}
|
|
1050
|
+
|
|
1051
|
+
#pragma endregion // E3M2 Packed
|
|
1052
|
+
|
|
1053
|
+
#pragma region E3M2 Symmetric
|
|
1054
|
+
|
|
1055
|
+
NK_INTERNAL void nk_angulars_symmetric_e3m2_sapphireamx_finalize_(nk_e3m2_t const *vectors, nk_size_t n_vectors,
|
|
1056
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
1057
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
1058
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
1059
|
+
|
|
1060
|
+
// Phase 1: Cache row norms on diagonal
|
|
1061
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++)
|
|
1062
|
+
result[row * result_stride_elements + row] = nk_dots_reduce_sumsq_e3m2_(vectors + row * stride_elements, depth);
|
|
1063
|
+
|
|
1064
|
+
// Phase 2: 256-column chunks with cached norms
|
|
1065
|
+
nk_f32_t column_norms_cache[256];
|
|
1066
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1067
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1068
|
+
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
1069
|
+
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e3m2_(vectors + col * stride_elements, depth);
|
|
1070
|
+
|
|
1071
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) {
|
|
1072
|
+
nk_f32_t *r_row = result + row * result_stride_elements;
|
|
1073
|
+
nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
|
|
1074
|
+
if (col_start >= chunk_end) continue;
|
|
1075
|
+
nk_angulars_row_f32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
|
|
1076
|
+
r_row[row], chunk_end - col_start);
|
|
1077
|
+
}
|
|
1078
|
+
}
|
|
1079
|
+
|
|
1080
|
+
// Phase 3: Zero diagonal
|
|
1081
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
1082
|
+
}
|
|
1083
|
+
|
|
1084
|
+
NK_PUBLIC void nk_angulars_symmetric_e3m2_sapphireamx( //
|
|
1085
|
+
nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1086
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1087
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e3m2_t);
|
|
1088
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1089
|
+
nk_dots_symmetric_e3m2_sapphireamx(vectors, n_vectors, depth, stride, (nk_f32_t *)result, result_stride, row_start,
|
|
1090
|
+
row_count);
|
|
1091
|
+
nk_angulars_symmetric_e3m2_sapphireamx_finalize_(vectors, n_vectors, depth, stride_elements, result,
|
|
1092
|
+
result_stride_elements, row_start, row_count);
|
|
1093
|
+
}
|
|
1094
|
+
|
|
1095
|
+
NK_INTERNAL void nk_euclideans_symmetric_e3m2_sapphireamx_finalize_(nk_e3m2_t const *vectors, nk_size_t n_vectors,
|
|
1096
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
1097
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
1098
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
1099
|
+
|
|
1100
|
+
// Phase 1: Cache row norms on diagonal
|
|
1101
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++)
|
|
1102
|
+
result[row * result_stride_elements + row] = nk_dots_reduce_sumsq_e3m2_(vectors + row * stride_elements, depth);
|
|
1103
|
+
|
|
1104
|
+
// Phase 2: 256-column chunks with cached norms
|
|
1105
|
+
nk_f32_t column_norms_cache[256];
|
|
1106
|
+
for (nk_size_t chunk_start = 0; chunk_start < n_vectors; chunk_start += 256) {
|
|
1107
|
+
nk_size_t chunk_end = chunk_start + 256 < n_vectors ? chunk_start + 256 : n_vectors;
|
|
1108
|
+
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
1109
|
+
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e3m2_(vectors + col * stride_elements, depth);
|
|
1110
|
+
|
|
1111
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) {
|
|
1112
|
+
nk_f32_t *r_row = result + row * result_stride_elements;
|
|
1113
|
+
nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
|
|
1114
|
+
if (col_start >= chunk_end) continue;
|
|
1115
|
+
nk_euclideans_row_f32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
|
|
1116
|
+
r_row[row], chunk_end - col_start);
|
|
1117
|
+
}
|
|
1118
|
+
}
|
|
1119
|
+
|
|
1120
|
+
// Phase 3: Zero diagonal
|
|
1121
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
1122
|
+
}
|
|
1123
|
+
|
|
1124
|
+
NK_PUBLIC void nk_euclideans_symmetric_e3m2_sapphireamx( //
|
|
1125
|
+
nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, //
|
|
1126
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) {
|
|
1127
|
+
nk_size_t const stride_elements = stride / sizeof(nk_e3m2_t);
|
|
1128
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1129
|
+
nk_dots_symmetric_e3m2_sapphireamx(vectors, n_vectors, depth, stride, (nk_f32_t *)result, result_stride, row_start,
|
|
1130
|
+
row_count);
|
|
1131
|
+
nk_euclideans_symmetric_e3m2_sapphireamx_finalize_(vectors, n_vectors, depth, stride_elements, result,
|
|
1132
|
+
result_stride_elements, row_start, row_count);
|
|
1133
|
+
}
|
|
1134
|
+
|
|
1135
|
+
#pragma endregion // E3M2 Symmetric
|
|
1136
|
+
|
|
1137
|
+
#if defined(__clang__)
|
|
1138
|
+
#pragma clang attribute pop
|
|
1139
|
+
#elif defined(__GNUC__)
|
|
1140
|
+
#pragma GCC pop_options
|
|
1141
|
+
#endif
|
|
1142
|
+
|
|
1143
|
+
#if defined(__cplusplus)
|
|
1144
|
+
} // extern "C"
|
|
1145
|
+
#endif
|
|
1146
|
+
|
|
1147
|
+
#endif // NK_TARGET_SAPPHIREAMX
|
|
1148
|
+
#endif // NK_TARGET_X86_
|
|
1149
|
+
#endif // NK_SPATIALS_SAPPHIREAMX_H
|