numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -70,7 +70,7 @@ extern "C" {
|
|
|
70
70
|
"avx512fp16", "f16c", "fma", "bmi", "bmi2", "amx-tile", "amx-bf16", "amx-int8")
|
|
71
71
|
#endif
|
|
72
72
|
|
|
73
|
-
#pragma region
|
|
73
|
+
#pragma region I8 Header
|
|
74
74
|
|
|
75
75
|
/**
|
|
76
76
|
* i8 packed buffer header for AMX coarse+refine MaxSim (64 bytes).
|
|
@@ -92,9 +92,9 @@ typedef struct {
|
|
|
92
92
|
|
|
93
93
|
NK_STATIC_ASSERT(sizeof(nk_maxsim_sapphireamx_i8_header_t) == 64, nk_maxsim_sapphireamx_i8_header_must_be_64_bytes);
|
|
94
94
|
|
|
95
|
-
#pragma endregion
|
|
95
|
+
#pragma endregion I8 Header
|
|
96
96
|
|
|
97
|
-
#pragma region
|
|
97
|
+
#pragma region F32 Floats
|
|
98
98
|
|
|
99
99
|
NK_PUBLIC nk_size_t nk_maxsim_packed_size_f32_sapphireamx(nk_size_t vector_count, nk_size_t depth) {
|
|
100
100
|
nk_size_t column_tile_count = nk_size_divide_round_up_(vector_count, 16);
|
|
@@ -108,7 +108,7 @@ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f32_sapphireamx(nk_size_t vector_count
|
|
|
108
108
|
}
|
|
109
109
|
|
|
110
110
|
NK_PUBLIC void nk_maxsim_pack_f32_sapphireamx( //
|
|
111
|
-
nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t
|
|
111
|
+
nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
|
|
112
112
|
|
|
113
113
|
nk_size_t column_tile_count = nk_size_divide_round_up_(vector_count, 16);
|
|
114
114
|
nk_size_t depth_tile_count = nk_size_divide_round_up_(depth, 64);
|
|
@@ -147,7 +147,7 @@ NK_PUBLIC void nk_maxsim_pack_f32_sapphireamx( //
|
|
|
147
147
|
|
|
148
148
|
// Quantize vectors and scatter into A-side tiles, copy originals, compute inverse norms
|
|
149
149
|
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
150
|
-
nk_f32_t const *source_vector = (nk_f32_t const *)((char const *)vectors + vector_index *
|
|
150
|
+
nk_f32_t const *source_vector = (nk_f32_t const *)((char const *)vectors + vector_index * stride_in_bytes);
|
|
151
151
|
|
|
152
152
|
// Pass 1: find absmax and norm_squared
|
|
153
153
|
nk_f32_t absmax_f32 = 0.0f;
|
|
@@ -347,9 +347,9 @@ NK_PUBLIC void nk_maxsim_packed_f32_sapphireamx( //
|
|
|
347
347
|
*result = total_angular_distance_f64;
|
|
348
348
|
}
|
|
349
349
|
|
|
350
|
-
#pragma endregion
|
|
350
|
+
#pragma endregion F32 Floats
|
|
351
351
|
|
|
352
|
-
#pragma region
|
|
352
|
+
#pragma region F16 Floats
|
|
353
353
|
|
|
354
354
|
NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_sapphireamx(nk_size_t vector_count, nk_size_t depth) {
|
|
355
355
|
nk_size_t column_tile_count = nk_size_divide_round_up_(vector_count, 16);
|
|
@@ -363,7 +363,7 @@ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_sapphireamx(nk_size_t vector_count
|
|
|
363
363
|
}
|
|
364
364
|
|
|
365
365
|
NK_PUBLIC void nk_maxsim_pack_f16_sapphireamx( //
|
|
366
|
-
nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t
|
|
366
|
+
nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
|
|
367
367
|
|
|
368
368
|
nk_size_t column_tile_count = nk_size_divide_round_up_(vector_count, 16);
|
|
369
369
|
nk_size_t depth_tile_count = nk_size_divide_round_up_(depth, 64);
|
|
@@ -401,7 +401,7 @@ NK_PUBLIC void nk_maxsim_pack_f16_sapphireamx( //
|
|
|
401
401
|
}
|
|
402
402
|
|
|
403
403
|
// Quantize vectors and scatter into A-side tiles, copy originals, compute inverse norms
|
|
404
|
-
nk_size_t const stride_elements =
|
|
404
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
|
|
405
405
|
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
406
406
|
nk_f16_t const *source_vector = vectors + vector_index * stride_elements;
|
|
407
407
|
|
|
@@ -602,9 +602,9 @@ NK_PUBLIC void nk_maxsim_packed_f16_sapphireamx( //
|
|
|
602
602
|
*result = (nk_f32_t)total_angular_distance_f64;
|
|
603
603
|
}
|
|
604
604
|
|
|
605
|
-
#pragma endregion
|
|
605
|
+
#pragma endregion F16 Floats
|
|
606
606
|
|
|
607
|
-
#pragma region
|
|
607
|
+
#pragma region BF16 Floats
|
|
608
608
|
|
|
609
609
|
/**
|
|
610
610
|
* BF16 packed buffer header for AMX fused MaxSim (64 bytes).
|
|
@@ -635,10 +635,10 @@ NK_PUBLIC nk_size_t nk_maxsim_packed_size_bf16_sapphireamx(nk_size_t vector_coun
|
|
|
635
635
|
}
|
|
636
636
|
|
|
637
637
|
NK_PUBLIC void nk_maxsim_pack_bf16_sapphireamx( //
|
|
638
|
-
nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t
|
|
638
|
+
nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
|
|
639
639
|
|
|
640
640
|
nk_size_t const tile_bytes = 1024;
|
|
641
|
-
nk_size_t const stride_elements =
|
|
641
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_bf16_t);
|
|
642
642
|
nk_size_t column_tile_count = nk_size_divide_round_up_(vector_count, 16);
|
|
643
643
|
nk_size_t depth_tile_count = nk_size_divide_round_up_(depth, 32);
|
|
644
644
|
|
|
@@ -860,7 +860,7 @@ NK_PUBLIC void nk_maxsim_packed_bf16_sapphireamx( //
|
|
|
860
860
|
*result = (nk_f32_t)total_angular_distance_f64;
|
|
861
861
|
}
|
|
862
862
|
|
|
863
|
-
#pragma endregion
|
|
863
|
+
#pragma endregion BF16 Floats
|
|
864
864
|
|
|
865
865
|
#if defined(__clang__)
|
|
866
866
|
#pragma clang attribute pop
|
|
@@ -234,7 +234,7 @@ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f32_serial(nk_size_t vector_count, nk_
|
|
|
234
234
|
}
|
|
235
235
|
|
|
236
236
|
NK_PUBLIC void nk_maxsim_pack_bf16_serial( //
|
|
237
|
-
nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t
|
|
237
|
+
nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
|
|
238
238
|
|
|
239
239
|
nk_size_t const element_bytes = sizeof(nk_bf16_t);
|
|
240
240
|
nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 1, element_bytes);
|
|
@@ -246,7 +246,7 @@ NK_PUBLIC void nk_maxsim_pack_bf16_serial( //
|
|
|
246
246
|
nk_size_t const original_stride = header->original_stride_bytes;
|
|
247
247
|
|
|
248
248
|
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
249
|
-
char const *source_row = (char const *)vectors + vector_index *
|
|
249
|
+
char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
|
|
250
250
|
nk_f32_t norm_sq;
|
|
251
251
|
nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f,
|
|
252
252
|
(nk_maxsim_to_f32_t)nk_bf16_to_f32_serial,
|
|
@@ -260,7 +260,7 @@ NK_PUBLIC void nk_maxsim_pack_bf16_serial( //
|
|
|
260
260
|
}
|
|
261
261
|
|
|
262
262
|
NK_PUBLIC void nk_maxsim_pack_f32_serial( //
|
|
263
|
-
nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t
|
|
263
|
+
nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
|
|
264
264
|
|
|
265
265
|
nk_size_t const element_bytes = sizeof(nk_f32_t);
|
|
266
266
|
nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 1, element_bytes);
|
|
@@ -272,7 +272,7 @@ NK_PUBLIC void nk_maxsim_pack_f32_serial( //
|
|
|
272
272
|
nk_size_t const original_stride = header->original_stride_bytes;
|
|
273
273
|
|
|
274
274
|
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
275
|
-
char const *source_row = (char const *)vectors + vector_index *
|
|
275
|
+
char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
|
|
276
276
|
nk_f32_t norm_sq;
|
|
277
277
|
nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f, nk_f32_to_f32_,
|
|
278
278
|
&quantized_i8[vector_index * depth_i8_padded], &metadata[vector_index], &norm_sq);
|
|
@@ -289,7 +289,7 @@ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_serial(nk_size_t vector_count, nk_
|
|
|
289
289
|
}
|
|
290
290
|
|
|
291
291
|
NK_PUBLIC void nk_maxsim_pack_f16_serial( //
|
|
292
|
-
nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t
|
|
292
|
+
nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
|
|
293
293
|
|
|
294
294
|
nk_size_t const element_bytes = sizeof(nk_f16_t);
|
|
295
295
|
nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 1, element_bytes);
|
|
@@ -301,7 +301,7 @@ NK_PUBLIC void nk_maxsim_pack_f16_serial( //
|
|
|
301
301
|
nk_size_t const original_stride = header->original_stride_bytes;
|
|
302
302
|
|
|
303
303
|
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
304
|
-
char const *source_row = (char const *)vectors + vector_index *
|
|
304
|
+
char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
|
|
305
305
|
nk_f32_t norm_sq;
|
|
306
306
|
nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f,
|
|
307
307
|
(nk_maxsim_to_f32_t)nk_f16_to_f32_serial,
|