numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,490 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SWAR-accelerated MaxSim (ColBERT late-interaction) for SIMD-free CPUs.
|
|
3
|
+
* @file include/numkong/maxsim/serial.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 17, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/maxsim.h
|
|
8
|
+
*
|
|
9
|
+
* Defines the packed buffer header and per-vector metadata structures used by all MaxSim ISA backends,
|
|
10
|
+
* plus scalar reference implementations for correctness validation.
|
|
11
|
+
*
|
|
12
|
+
* MaxSim computes: result = Σᵢ minⱼ angular(qᵢ, dⱼ) — angular distance late-interaction scoring.
|
|
13
|
+
*
|
|
14
|
+
* Strategy: coarse i8-quantized screening with running argmax (dot as proxy for argmin angular),
|
|
15
|
+
* then full-precision refinement of the winning (query, document) pairs via existing nk_dot_* primitives,
|
|
16
|
+
* finalized with angular distance: 1 - dot / sqrt(||q||² × ||d||²).
|
|
17
|
+
*
|
|
18
|
+
* @section packed_layout Packed Buffer Layout
|
|
19
|
+
*
|
|
20
|
+
* [Header 64B] [i8 vectors, 64B-aligned] [metadata, 64B-aligned] [originals row-major, 64B-aligned]
|
|
21
|
+
*
|
|
22
|
+
* - i8 region: row-major with padded depth for SIMD alignment
|
|
23
|
+
* - Metadata region: vector_count x 12 bytes (scale + sum + norm_squared per vector)
|
|
24
|
+
* - Originals region: row-major bf16 or f32, stride padded to 64B for nk_dot_* calls
|
|
25
|
+
*/
|
|
26
|
+
#ifndef NK_MAXSIM_SERIAL_H
|
|
27
|
+
#define NK_MAXSIM_SERIAL_H
|
|
28
|
+
|
|
29
|
+
#include "numkong/types.h"
|
|
30
|
+
#include "numkong/cast/serial.h" // `nk_bf16_to_f32_serial`
|
|
31
|
+
#include "numkong/dot.h" // `nk_dot_bf16`, `nk_dot_f32`, `nk_dot_f16`
|
|
32
|
+
#include "numkong/spatial/serial.h" // `nk_f32_rsqrt_serial`
|
|
33
|
+
|
|
34
|
+
#if defined(__cplusplus)
|
|
35
|
+
extern "C" {
|
|
36
|
+
#endif
|
|
37
|
+
|
|
38
|
+
/**
|
|
39
|
+
* @brief Packed buffer header (64 bytes, cache-line aligned).
|
|
40
|
+
* Stored at the beginning of every maxsim packed buffer.
|
|
41
|
+
*/
|
|
42
|
+
typedef struct {
|
|
43
|
+
nk_u32_t vector_count; ///< Number of vectors packed
|
|
44
|
+
nk_u32_t depth_dimensions; ///< Logical depth (number of elements per vector)
|
|
45
|
+
nk_u32_t depth_i8_padded; ///< Padded i8 depth in bytes (SIMD-aligned)
|
|
46
|
+
nk_u32_t original_element_bytes; ///< 2 for bf16, 4 for f32
|
|
47
|
+
nk_u32_t offset_i8_data; ///< Byte offset from buffer start to i8 region
|
|
48
|
+
nk_u32_t offset_metadata; ///< Byte offset from buffer start to metadata region
|
|
49
|
+
nk_u32_t offset_original_data; ///< Byte offset from buffer start to originals region
|
|
50
|
+
nk_u32_t original_stride_bytes; ///< Row stride in bytes for originals region
|
|
51
|
+
nk_u32_t reserved[8]; ///< Padding to 64 bytes
|
|
52
|
+
} nk_maxsim_packed_header_t;
|
|
53
|
+
|
|
54
|
+
NK_STATIC_ASSERT(sizeof(nk_maxsim_packed_header_t) == 64, nk_maxsim_packed_header_must_be_64_bytes);
|
|
55
|
+
|
|
56
|
+
/**
|
|
57
|
+
* @brief Per-vector quantization metadata (12 bytes).
|
|
58
|
+
* Stored in the metadata region of the packed buffer, one per vector.
|
|
59
|
+
*/
|
|
60
|
+
typedef struct {
|
|
61
|
+
nk_f32_t scale_f32; ///< Quantization scale: absmax / range_limit
|
|
62
|
+
nk_i32_t sum_i8_i32; ///< Sum of all i8 quantized elements (for VPDPBUSD/VPMADDUBSW bias correction)
|
|
63
|
+
nk_f32_t inverse_norm_f32; ///< 1/sqrt(||v||^2), 0 if zero-vector — precomputed for angular finalization
|
|
64
|
+
} nk_maxsim_vector_metadata_t;
|
|
65
|
+
|
|
66
|
+
NK_STATIC_ASSERT(sizeof(nk_maxsim_vector_metadata_t) == 12, nk_maxsim_vector_metadata_must_be_12_bytes);
|
|
67
|
+
|
|
68
|
+
/**
|
|
69
|
+
* @brief Conversion function pointer type for element-to-f32 conversion.
|
|
70
|
+
* Each conversion reads one element from `source` and writes one f32 to `destination`.
|
|
71
|
+
*/
|
|
72
|
+
typedef void (*nk_maxsim_to_f32_t)(void const *source, nk_f32_t *destination);
|
|
73
|
+
|
|
74
|
+
/** @brief Identity conversion for f32 sources — just a typed memcpy. */
|
|
75
|
+
NK_INTERNAL void nk_f32_to_f32_(void const *source, nk_f32_t *destination) { *destination = *(nk_f32_t const *)source; }
|
|
76
|
+
|
|
77
|
+
/**
|
|
78
|
+
* @brief Fills the packed buffer header and returns the padded i8 depth.
|
|
79
|
+
* Consolidates header/offset computation duplicated in every pack function.
|
|
80
|
+
*/
|
|
81
|
+
NK_INTERNAL nk_size_t nk_maxsim_packed_header_setup_( //
|
|
82
|
+
void *packed, nk_size_t vector_count, nk_size_t depth, //
|
|
83
|
+
nk_size_t depth_simd_dimensions, nk_size_t original_element_bytes) {
|
|
84
|
+
|
|
85
|
+
nk_size_t depth_i8_padded = nk_size_round_up_to_multiple_(depth, depth_simd_dimensions);
|
|
86
|
+
if ((depth_i8_padded & (depth_i8_padded - 1)) == 0 && depth_i8_padded > 0) depth_i8_padded += depth_simd_dimensions;
|
|
87
|
+
|
|
88
|
+
nk_size_t const header_size = sizeof(nk_maxsim_packed_header_t);
|
|
89
|
+
nk_size_t const i8_region_size = nk_size_round_up_to_multiple_(vector_count * depth_i8_padded, 64);
|
|
90
|
+
nk_size_t const metadata_region_size = nk_size_round_up_to_multiple_(
|
|
91
|
+
vector_count * sizeof(nk_maxsim_vector_metadata_t), 64);
|
|
92
|
+
nk_size_t const original_stride = nk_size_round_up_to_multiple_(depth * original_element_bytes, 64);
|
|
93
|
+
|
|
94
|
+
nk_maxsim_packed_header_t *header = (nk_maxsim_packed_header_t *)packed;
|
|
95
|
+
header->vector_count = (nk_u32_t)vector_count;
|
|
96
|
+
header->depth_dimensions = (nk_u32_t)depth;
|
|
97
|
+
header->depth_i8_padded = (nk_u32_t)depth_i8_padded;
|
|
98
|
+
header->original_element_bytes = (nk_u32_t)original_element_bytes;
|
|
99
|
+
header->offset_i8_data = (nk_u32_t)header_size;
|
|
100
|
+
header->offset_metadata = (nk_u32_t)(header_size + i8_region_size);
|
|
101
|
+
header->offset_original_data = (nk_u32_t)(header_size + i8_region_size + metadata_region_size);
|
|
102
|
+
header->original_stride_bytes = (nk_u32_t)original_stride;
|
|
103
|
+
for (nk_size_t reserved_index = 0; reserved_index < 8; reserved_index++) header->reserved[reserved_index] = 0;
|
|
104
|
+
|
|
105
|
+
return depth_i8_padded;
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
/**
|
|
109
|
+
* @brief Quantizes a single source vector to i8, computing metadata.
|
|
110
|
+
* Iterates element-by-element, calling the conversion callback for each f32 value.
|
|
111
|
+
* No temp buffer needed — works for arbitrary depth.
|
|
112
|
+
*/
|
|
113
|
+
NK_INTERNAL void nk_maxsim_quantize_vector_( //
|
|
114
|
+
void const *source_vector, nk_size_t element_bytes, nk_size_t depth, //
|
|
115
|
+
nk_size_t depth_i8_padded, nk_f32_t scale_limit, //
|
|
116
|
+
nk_maxsim_to_f32_t convert_to_f32, //
|
|
117
|
+
nk_i8_t *destination_i8, nk_maxsim_vector_metadata_t *metadata, //
|
|
118
|
+
nk_f32_t *norm_squared_ptr) {
|
|
119
|
+
|
|
120
|
+
char const *source_bytes = (char const *)source_vector;
|
|
121
|
+
|
|
122
|
+
// Pass 1: Find absmax, compute norm_squared
|
|
123
|
+
nk_f32_t absmax_f32 = 0.0f;
|
|
124
|
+
nk_f32_t norm_squared_f32 = 0.0f;
|
|
125
|
+
for (nk_size_t dim_index = 0; dim_index < depth; dim_index++) {
|
|
126
|
+
nk_f32_t value_f32;
|
|
127
|
+
convert_to_f32(source_bytes + dim_index * element_bytes, &value_f32);
|
|
128
|
+
nk_f32_t abs_value = nk_f32_abs_(value_f32);
|
|
129
|
+
if (abs_value > absmax_f32) absmax_f32 = abs_value;
|
|
130
|
+
norm_squared_f32 += value_f32 * value_f32;
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
nk_f32_t scale_f32 = absmax_f32 / scale_limit;
|
|
134
|
+
if (scale_f32 == 0.0f) scale_f32 = 1.0f;
|
|
135
|
+
|
|
136
|
+
// Pass 2: Quantize to i8 and compute sum
|
|
137
|
+
nk_i32_t sum_quantized_i32 = 0;
|
|
138
|
+
for (nk_size_t dim_index = 0; dim_index < depth; dim_index++) {
|
|
139
|
+
nk_f32_t value_f32;
|
|
140
|
+
convert_to_f32(source_bytes + dim_index * element_bytes, &value_f32);
|
|
141
|
+
nk_f32_t scaled = value_f32 / scale_f32;
|
|
142
|
+
nk_i32_t quantized_value;
|
|
143
|
+
if (scaled >= 0.0f) quantized_value = (nk_i32_t)(scaled + 0.5f);
|
|
144
|
+
else quantized_value = (nk_i32_t)(scaled - 0.5f);
|
|
145
|
+
if (quantized_value > (nk_i32_t)scale_limit) quantized_value = (nk_i32_t)scale_limit;
|
|
146
|
+
if (quantized_value < -(nk_i32_t)scale_limit) quantized_value = -(nk_i32_t)scale_limit;
|
|
147
|
+
|
|
148
|
+
destination_i8[dim_index] = (nk_i8_t)quantized_value;
|
|
149
|
+
sum_quantized_i32 += quantized_value;
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
// Zero-pad remaining bytes
|
|
153
|
+
for (nk_size_t dim_index = depth; dim_index < depth_i8_padded; dim_index++) destination_i8[dim_index] = 0;
|
|
154
|
+
|
|
155
|
+
metadata->scale_f32 = scale_f32;
|
|
156
|
+
metadata->sum_i8_i32 = sum_quantized_i32;
|
|
157
|
+
*norm_squared_ptr = norm_squared_f32;
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
/**
|
|
161
|
+
* @brief Region pointers extracted from two packed buffers.
|
|
162
|
+
* Eliminates ~15 lines of boilerplate per compute function.
|
|
163
|
+
*/
|
|
164
|
+
typedef struct {
|
|
165
|
+
nk_size_t depth_i8_padded;
|
|
166
|
+
nk_i8_t const *query_quantized;
|
|
167
|
+
nk_i8_t const *document_quantized;
|
|
168
|
+
nk_maxsim_vector_metadata_t const *query_metadata;
|
|
169
|
+
nk_maxsim_vector_metadata_t const *document_metadata;
|
|
170
|
+
char const *query_originals;
|
|
171
|
+
char const *document_originals;
|
|
172
|
+
nk_size_t query_original_stride;
|
|
173
|
+
nk_size_t document_original_stride;
|
|
174
|
+
} nk_maxsim_packed_regions_t;
|
|
175
|
+
|
|
176
|
+
NK_INTERNAL nk_maxsim_packed_regions_t nk_maxsim_extract_packed_regions_( //
|
|
177
|
+
void const *query_packed, void const *document_packed) {
|
|
178
|
+
|
|
179
|
+
nk_maxsim_packed_header_t const *query_header = (nk_maxsim_packed_header_t const *)query_packed;
|
|
180
|
+
nk_maxsim_packed_header_t const *document_header = (nk_maxsim_packed_header_t const *)document_packed;
|
|
181
|
+
|
|
182
|
+
nk_maxsim_packed_regions_t regions;
|
|
183
|
+
regions.depth_i8_padded = query_header->depth_i8_padded;
|
|
184
|
+
regions.query_quantized = (nk_i8_t const *)((char const *)query_packed + query_header->offset_i8_data);
|
|
185
|
+
regions.document_quantized = (nk_i8_t const *)((char const *)document_packed + document_header->offset_i8_data);
|
|
186
|
+
regions.query_metadata = (nk_maxsim_vector_metadata_t const *)((char const *)query_packed +
|
|
187
|
+
query_header->offset_metadata);
|
|
188
|
+
regions.document_metadata = (nk_maxsim_vector_metadata_t const *)((char const *)document_packed +
|
|
189
|
+
document_header->offset_metadata);
|
|
190
|
+
regions.query_originals = (char const *)query_packed + query_header->offset_original_data;
|
|
191
|
+
regions.document_originals = (char const *)document_packed + document_header->offset_original_data;
|
|
192
|
+
regions.query_original_stride = query_header->original_stride_bytes;
|
|
193
|
+
regions.document_original_stride = document_header->original_stride_bytes;
|
|
194
|
+
return regions;
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
/**
|
|
198
|
+
* @brief Computes padded i8 depth and total packed buffer size for maxsim.
|
|
199
|
+
*
|
|
200
|
+
* Layout: header + i8 data (64B-aligned) + metadata (64B-aligned) + originals (64B-aligned)
|
|
201
|
+
*
|
|
202
|
+
* @param vector_count Number of vectors to pack.
|
|
203
|
+
* @param depth Number of elements per vector.
|
|
204
|
+
* @param original_element_bytes Size of each original element (2 for bf16, 4 for f32).
|
|
205
|
+
* @param depth_simd_dimensions SIMD width for i8 depth padding (1 for serial).
|
|
206
|
+
*/
|
|
207
|
+
NK_INTERNAL nk_size_t nk_maxsim_packed_size_( //
|
|
208
|
+
nk_size_t vector_count, nk_size_t depth, //
|
|
209
|
+
nk_size_t original_element_bytes, nk_size_t depth_simd_dimensions) {
|
|
210
|
+
|
|
211
|
+
// Step 1: Pad i8 depth to SIMD width
|
|
212
|
+
nk_size_t depth_i8_padded = nk_size_round_up_to_multiple_(depth, depth_simd_dimensions);
|
|
213
|
+
|
|
214
|
+
// Step 2: Break power-of-2 strides for cache associativity
|
|
215
|
+
if ((depth_i8_padded & (depth_i8_padded - 1)) == 0 && depth_i8_padded > 0) depth_i8_padded += depth_simd_dimensions;
|
|
216
|
+
|
|
217
|
+
// Step 3: Calculate region sizes
|
|
218
|
+
nk_size_t const header_size = sizeof(nk_maxsim_packed_header_t);
|
|
219
|
+
nk_size_t const i8_region_size = nk_size_round_up_to_multiple_(vector_count * depth_i8_padded, 64);
|
|
220
|
+
nk_size_t const metadata_region_size = nk_size_round_up_to_multiple_(
|
|
221
|
+
vector_count * sizeof(nk_maxsim_vector_metadata_t), 64);
|
|
222
|
+
nk_size_t const original_stride = nk_size_round_up_to_multiple_(depth * original_element_bytes, 64);
|
|
223
|
+
nk_size_t const originals_region_size = vector_count * original_stride;
|
|
224
|
+
|
|
225
|
+
return header_size + i8_region_size + metadata_region_size + originals_region_size;
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
NK_PUBLIC nk_size_t nk_maxsim_packed_size_bf16_serial(nk_size_t vector_count, nk_size_t depth) {
|
|
229
|
+
return nk_maxsim_packed_size_(vector_count, depth, sizeof(nk_bf16_t), 1);
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
NK_PUBLIC nk_size_t nk_maxsim_packed_size_f32_serial(nk_size_t vector_count, nk_size_t depth) {
|
|
233
|
+
return nk_maxsim_packed_size_(vector_count, depth, sizeof(nk_f32_t), 1);
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
NK_PUBLIC void nk_maxsim_pack_bf16_serial( //
|
|
237
|
+
nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
|
|
238
|
+
|
|
239
|
+
nk_size_t const element_bytes = sizeof(nk_bf16_t);
|
|
240
|
+
nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 1, element_bytes);
|
|
241
|
+
|
|
242
|
+
nk_maxsim_packed_header_t const *header = (nk_maxsim_packed_header_t const *)packed;
|
|
243
|
+
nk_i8_t *quantized_i8 = (nk_i8_t *)((char *)packed + header->offset_i8_data);
|
|
244
|
+
nk_maxsim_vector_metadata_t *metadata = (nk_maxsim_vector_metadata_t *)((char *)packed + header->offset_metadata);
|
|
245
|
+
char *originals = (char *)packed + header->offset_original_data;
|
|
246
|
+
nk_size_t const original_stride = header->original_stride_bytes;
|
|
247
|
+
|
|
248
|
+
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
249
|
+
char const *source_row = (char const *)vectors + vector_index * stride;
|
|
250
|
+
nk_f32_t norm_sq;
|
|
251
|
+
nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f,
|
|
252
|
+
(nk_maxsim_to_f32_t)nk_bf16_to_f32_serial,
|
|
253
|
+
&quantized_i8[vector_index * depth_i8_padded], &metadata[vector_index], &norm_sq);
|
|
254
|
+
metadata[vector_index].inverse_norm_f32 = norm_sq > 0.0f ? nk_f32_rsqrt_serial(norm_sq) : 0.0f;
|
|
255
|
+
char *destination_original = originals + vector_index * original_stride;
|
|
256
|
+
nk_copy_bytes_(destination_original, source_row, depth * element_bytes);
|
|
257
|
+
for (nk_size_t byte_index = depth * element_bytes; byte_index < original_stride; byte_index++)
|
|
258
|
+
destination_original[byte_index] = 0;
|
|
259
|
+
}
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
NK_PUBLIC void nk_maxsim_pack_f32_serial( //
|
|
263
|
+
nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
|
|
264
|
+
|
|
265
|
+
nk_size_t const element_bytes = sizeof(nk_f32_t);
|
|
266
|
+
nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 1, element_bytes);
|
|
267
|
+
|
|
268
|
+
nk_maxsim_packed_header_t const *header = (nk_maxsim_packed_header_t const *)packed;
|
|
269
|
+
nk_i8_t *quantized_i8 = (nk_i8_t *)((char *)packed + header->offset_i8_data);
|
|
270
|
+
nk_maxsim_vector_metadata_t *metadata = (nk_maxsim_vector_metadata_t *)((char *)packed + header->offset_metadata);
|
|
271
|
+
char *originals = (char *)packed + header->offset_original_data;
|
|
272
|
+
nk_size_t const original_stride = header->original_stride_bytes;
|
|
273
|
+
|
|
274
|
+
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
275
|
+
char const *source_row = (char const *)vectors + vector_index * stride;
|
|
276
|
+
nk_f32_t norm_sq;
|
|
277
|
+
nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f, nk_f32_to_f32_,
|
|
278
|
+
&quantized_i8[vector_index * depth_i8_padded], &metadata[vector_index], &norm_sq);
|
|
279
|
+
metadata[vector_index].inverse_norm_f32 = norm_sq > 0.0f ? nk_f32_rsqrt_serial(norm_sq) : 0.0f;
|
|
280
|
+
char *destination_original = originals + vector_index * original_stride;
|
|
281
|
+
nk_copy_bytes_(destination_original, source_row, depth * element_bytes);
|
|
282
|
+
for (nk_size_t byte_index = depth * element_bytes; byte_index < original_stride; byte_index++)
|
|
283
|
+
destination_original[byte_index] = 0;
|
|
284
|
+
}
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_serial(nk_size_t vector_count, nk_size_t depth) {
|
|
288
|
+
return nk_maxsim_packed_size_(vector_count, depth, sizeof(nk_f16_t), 1);
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
NK_PUBLIC void nk_maxsim_pack_f16_serial( //
|
|
292
|
+
nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride, void *packed) {
|
|
293
|
+
|
|
294
|
+
nk_size_t const element_bytes = sizeof(nk_f16_t);
|
|
295
|
+
nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 1, element_bytes);
|
|
296
|
+
|
|
297
|
+
nk_maxsim_packed_header_t const *header = (nk_maxsim_packed_header_t const *)packed;
|
|
298
|
+
nk_i8_t *quantized_i8 = (nk_i8_t *)((char *)packed + header->offset_i8_data);
|
|
299
|
+
nk_maxsim_vector_metadata_t *metadata = (nk_maxsim_vector_metadata_t *)((char *)packed + header->offset_metadata);
|
|
300
|
+
char *originals = (char *)packed + header->offset_original_data;
|
|
301
|
+
nk_size_t const original_stride = header->original_stride_bytes;
|
|
302
|
+
|
|
303
|
+
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
304
|
+
char const *source_row = (char const *)vectors + vector_index * stride;
|
|
305
|
+
nk_f32_t norm_sq;
|
|
306
|
+
nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f,
|
|
307
|
+
(nk_maxsim_to_f32_t)nk_f16_to_f32_serial,
|
|
308
|
+
&quantized_i8[vector_index * depth_i8_padded], &metadata[vector_index], &norm_sq);
|
|
309
|
+
metadata[vector_index].inverse_norm_f32 = norm_sq > 0.0f ? nk_f32_rsqrt_serial(norm_sq) : 0.0f;
|
|
310
|
+
char *destination_original = originals + vector_index * original_stride;
|
|
311
|
+
nk_copy_bytes_(destination_original, source_row, depth * element_bytes);
|
|
312
|
+
for (nk_size_t byte_index = depth * element_bytes; byte_index < original_stride; byte_index++)
|
|
313
|
+
destination_original[byte_index] = 0;
|
|
314
|
+
}
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
/**
|
|
318
|
+
* @brief Dtype-agnostic coarse i8 argmax kernel for the serial backend.
|
|
319
|
+
* Produces per-query best document indices using signed i8×i8 dot products.
|
|
320
|
+
* No bias correction needed — serial uses native signed×signed multiplication.
|
|
321
|
+
*/
|
|
322
|
+
NK_INTERNAL void nk_maxsim_coarse_argmax_serial_( //
|
|
323
|
+
nk_i8_t const *query_i8, nk_i8_t const *document_i8, nk_size_t query_count, nk_size_t document_count,
|
|
324
|
+
nk_size_t depth_i8_padded, nk_u32_t *best_document_indices) {
|
|
325
|
+
|
|
326
|
+
// Primary path: 4-query grouping
|
|
327
|
+
nk_size_t query_block_start_index = 0;
|
|
328
|
+
for (; query_block_start_index + 4 <= query_count; query_block_start_index += 4) {
|
|
329
|
+
nk_i32_t running_max_i32[4] = {NK_I32_MIN, NK_I32_MIN, NK_I32_MIN, NK_I32_MIN};
|
|
330
|
+
nk_u32_t running_argmax_u32[4] = {0, 0, 0, 0};
|
|
331
|
+
|
|
332
|
+
for (nk_size_t document_index = 0; document_index < document_count; document_index++) {
|
|
333
|
+
nk_i8_t const *document_i8_row = document_i8 + document_index * depth_i8_padded;
|
|
334
|
+
nk_i32_t accumulator_i32[4] = {0, 0, 0, 0};
|
|
335
|
+
|
|
336
|
+
for (nk_size_t dim_index = 0; dim_index < depth_i8_padded; dim_index++) {
|
|
337
|
+
nk_i32_t document_value = (nk_i32_t)document_i8_row[dim_index];
|
|
338
|
+
accumulator_i32[0] += (nk_i32_t)query_i8[(query_block_start_index + 0) * depth_i8_padded + dim_index] *
|
|
339
|
+
document_value;
|
|
340
|
+
accumulator_i32[1] += (nk_i32_t)query_i8[(query_block_start_index + 1) * depth_i8_padded + dim_index] *
|
|
341
|
+
document_value;
|
|
342
|
+
accumulator_i32[2] += (nk_i32_t)query_i8[(query_block_start_index + 2) * depth_i8_padded + dim_index] *
|
|
343
|
+
document_value;
|
|
344
|
+
accumulator_i32[3] += (nk_i32_t)query_i8[(query_block_start_index + 3) * depth_i8_padded + dim_index] *
|
|
345
|
+
document_value;
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
for (nk_size_t query_tile_index = 0; query_tile_index < 4; query_tile_index++) {
|
|
349
|
+
if (accumulator_i32[query_tile_index] > running_max_i32[query_tile_index]) {
|
|
350
|
+
running_max_i32[query_tile_index] = accumulator_i32[query_tile_index];
|
|
351
|
+
running_argmax_u32[query_tile_index] = (nk_u32_t)document_index;
|
|
352
|
+
}
|
|
353
|
+
}
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
for (nk_size_t query_tile_index = 0; query_tile_index < 4; query_tile_index++)
|
|
357
|
+
best_document_indices[query_block_start_index + query_tile_index] = running_argmax_u32[query_tile_index];
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
// Edge path: remaining 1-3 queries
|
|
361
|
+
for (nk_size_t query_index = query_block_start_index; query_index < query_count; query_index++) {
|
|
362
|
+
nk_i8_t const *query_i8_row = query_i8 + query_index * depth_i8_padded;
|
|
363
|
+
nk_i32_t running_max_i32 = NK_I32_MIN;
|
|
364
|
+
nk_u32_t running_argmax_u32 = 0;
|
|
365
|
+
|
|
366
|
+
for (nk_size_t document_index = 0; document_index < document_count; document_index++) {
|
|
367
|
+
nk_i8_t const *document_i8_row = document_i8 + document_index * depth_i8_padded;
|
|
368
|
+
nk_i32_t accumulator_i32 = 0;
|
|
369
|
+
|
|
370
|
+
for (nk_size_t dim_index = 0; dim_index < depth_i8_padded; dim_index++)
|
|
371
|
+
accumulator_i32 += (nk_i32_t)query_i8_row[dim_index] * (nk_i32_t)document_i8_row[dim_index];
|
|
372
|
+
|
|
373
|
+
if (accumulator_i32 > running_max_i32) {
|
|
374
|
+
running_max_i32 = accumulator_i32;
|
|
375
|
+
running_argmax_u32 = (nk_u32_t)document_index;
|
|
376
|
+
}
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
best_document_indices[query_index] = running_argmax_u32;
|
|
380
|
+
}
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
NK_PUBLIC void nk_maxsim_packed_bf16_serial( //
|
|
384
|
+
void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
|
|
385
|
+
nk_size_t depth, nk_f32_t *result) {
|
|
386
|
+
|
|
387
|
+
nk_maxsim_packed_regions_t regions = nk_maxsim_extract_packed_regions_(query_packed, document_packed);
|
|
388
|
+
nk_f64_t total_angular_distance = 0.0;
|
|
389
|
+
|
|
390
|
+
for (nk_size_t chunk_start = 0; chunk_start < query_count; chunk_start += 256) {
|
|
391
|
+
nk_size_t chunk_size = query_count - chunk_start < 256 ? query_count - chunk_start : 256;
|
|
392
|
+
nk_u32_t best_document_indices[256];
|
|
393
|
+
|
|
394
|
+
nk_maxsim_coarse_argmax_serial_(regions.query_quantized + chunk_start * regions.depth_i8_padded,
|
|
395
|
+
regions.document_quantized, chunk_size, document_count, regions.depth_i8_padded,
|
|
396
|
+
best_document_indices);
|
|
397
|
+
|
|
398
|
+
for (nk_size_t query_index = 0; query_index < chunk_size; query_index++) {
|
|
399
|
+
nk_u32_t best_document_index = best_document_indices[query_index];
|
|
400
|
+
nk_f32_t dot_result;
|
|
401
|
+
nk_dot_bf16((nk_bf16_t const *)(regions.query_originals +
|
|
402
|
+
(chunk_start + query_index) * regions.query_original_stride),
|
|
403
|
+
(nk_bf16_t const *)(regions.document_originals +
|
|
404
|
+
best_document_index * regions.document_original_stride),
|
|
405
|
+
depth, &dot_result);
|
|
406
|
+
nk_f32_t cosine = dot_result * regions.query_metadata[chunk_start + query_index].inverse_norm_f32 *
|
|
407
|
+
regions.document_metadata[best_document_index].inverse_norm_f32;
|
|
408
|
+
nk_f32_t angular = 1.0f - cosine;
|
|
409
|
+
if (angular < 0.0f) angular = 0.0f;
|
|
410
|
+
total_angular_distance += (nk_f64_t)angular;
|
|
411
|
+
}
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
*result = (nk_f32_t)total_angular_distance;
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
NK_PUBLIC void nk_maxsim_packed_f32_serial( //
|
|
418
|
+
void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
|
|
419
|
+
nk_size_t depth, nk_f64_t *result) {
|
|
420
|
+
|
|
421
|
+
nk_maxsim_packed_regions_t regions = nk_maxsim_extract_packed_regions_(query_packed, document_packed);
|
|
422
|
+
nk_f64_t total_angular_distance = 0.0;
|
|
423
|
+
|
|
424
|
+
for (nk_size_t chunk_start = 0; chunk_start < query_count; chunk_start += 256) {
|
|
425
|
+
nk_size_t chunk_size = query_count - chunk_start < 256 ? query_count - chunk_start : 256;
|
|
426
|
+
nk_u32_t best_document_indices[256];
|
|
427
|
+
|
|
428
|
+
nk_maxsim_coarse_argmax_serial_(regions.query_quantized + chunk_start * regions.depth_i8_padded,
|
|
429
|
+
regions.document_quantized, chunk_size, document_count, regions.depth_i8_padded,
|
|
430
|
+
best_document_indices);
|
|
431
|
+
|
|
432
|
+
for (nk_size_t query_index = 0; query_index < chunk_size; query_index++) {
|
|
433
|
+
nk_u32_t best_document_index = best_document_indices[query_index];
|
|
434
|
+
nk_f64_t dot_result;
|
|
435
|
+
nk_dot_f32(
|
|
436
|
+
(nk_f32_t const *)(regions.query_originals +
|
|
437
|
+
(chunk_start + query_index) * regions.query_original_stride),
|
|
438
|
+
(nk_f32_t const *)(regions.document_originals + best_document_index * regions.document_original_stride),
|
|
439
|
+
depth, &dot_result);
|
|
440
|
+
nk_f64_t cosine = dot_result *
|
|
441
|
+
(nk_f64_t)regions.query_metadata[chunk_start + query_index].inverse_norm_f32 *
|
|
442
|
+
(nk_f64_t)regions.document_metadata[best_document_index].inverse_norm_f32;
|
|
443
|
+
nk_f64_t angular = 1.0 - cosine;
|
|
444
|
+
if (angular < 0.0) angular = 0.0;
|
|
445
|
+
total_angular_distance += angular;
|
|
446
|
+
}
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
*result = total_angular_distance;
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
NK_PUBLIC void nk_maxsim_packed_f16_serial( //
|
|
453
|
+
void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
|
|
454
|
+
nk_size_t depth, nk_f32_t *result) {
|
|
455
|
+
|
|
456
|
+
nk_maxsim_packed_regions_t regions = nk_maxsim_extract_packed_regions_(query_packed, document_packed);
|
|
457
|
+
nk_f64_t total_angular_distance = 0.0;
|
|
458
|
+
|
|
459
|
+
for (nk_size_t chunk_start = 0; chunk_start < query_count; chunk_start += 256) {
|
|
460
|
+
nk_size_t chunk_size = query_count - chunk_start < 256 ? query_count - chunk_start : 256;
|
|
461
|
+
nk_u32_t best_document_indices[256];
|
|
462
|
+
|
|
463
|
+
nk_maxsim_coarse_argmax_serial_(regions.query_quantized + chunk_start * regions.depth_i8_padded,
|
|
464
|
+
regions.document_quantized, chunk_size, document_count, regions.depth_i8_padded,
|
|
465
|
+
best_document_indices);
|
|
466
|
+
|
|
467
|
+
for (nk_size_t query_index = 0; query_index < chunk_size; query_index++) {
|
|
468
|
+
nk_u32_t best_document_index = best_document_indices[query_index];
|
|
469
|
+
nk_f32_t dot_result;
|
|
470
|
+
nk_dot_f16(
|
|
471
|
+
(nk_f16_t const *)(regions.query_originals +
|
|
472
|
+
(chunk_start + query_index) * regions.query_original_stride),
|
|
473
|
+
(nk_f16_t const *)(regions.document_originals + best_document_index * regions.document_original_stride),
|
|
474
|
+
depth, &dot_result);
|
|
475
|
+
nk_f32_t cosine = dot_result * regions.query_metadata[chunk_start + query_index].inverse_norm_f32 *
|
|
476
|
+
regions.document_metadata[best_document_index].inverse_norm_f32;
|
|
477
|
+
nk_f32_t angular = 1.0f - cosine;
|
|
478
|
+
if (angular < 0.0f) angular = 0.0f;
|
|
479
|
+
total_angular_distance += (nk_f64_t)angular;
|
|
480
|
+
}
|
|
481
|
+
}
|
|
482
|
+
|
|
483
|
+
*result = (nk_f32_t)total_angular_distance;
|
|
484
|
+
}
|
|
485
|
+
|
|
486
|
+
#if defined(__cplusplus)
|
|
487
|
+
} // extern "C"
|
|
488
|
+
#endif
|
|
489
|
+
|
|
490
|
+
#endif // NK_MAXSIM_SERIAL_H
|