numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -78,15 +78,15 @@
|
|
|
78
78
|
* but only execute once per point-pair. The polynomial trig approximations use FMA chains.
|
|
79
79
|
* Note: ZMM sqrt is faster on Genoa (15c) than Ice Lake (19c) due to better 512-bit support.
|
|
80
80
|
*
|
|
81
|
-
* Intrinsic
|
|
82
|
-
* _mm256_sqrt_ps
|
|
83
|
-
* _mm256_sqrt_pd
|
|
84
|
-
* _mm512_sqrt_ps
|
|
85
|
-
* _mm512_sqrt_pd
|
|
86
|
-
* _mm256_div_ps
|
|
87
|
-
* _mm256_div_pd
|
|
88
|
-
* _mm256_fmadd_ps
|
|
89
|
-
* _mm256_fmadd_pd
|
|
81
|
+
* Intrinsic Instruction Icelake Genoa
|
|
82
|
+
* _mm256_sqrt_ps VSQRTPS (YMM, YMM) 12cy @ p0 15cy @ p01
|
|
83
|
+
* _mm256_sqrt_pd VSQRTPD (YMM, YMM) 13cy @ p0 21cy @ p01
|
|
84
|
+
* _mm512_sqrt_ps VSQRTPS (ZMM, ZMM) 19cy @ p0+p0+p05 15cy @ p01
|
|
85
|
+
* _mm512_sqrt_pd VSQRTPD (ZMM, ZMM) 23cy @ p0+p0+p05 21cy @ p01
|
|
86
|
+
* _mm256_div_ps VDIVPS (YMM, YMM, YMM) 11cy @ p0 11cy @ p01
|
|
87
|
+
* _mm256_div_pd VDIVPD (YMM, YMM, YMM) 13cy @ p0 13cy @ p01
|
|
88
|
+
* _mm256_fmadd_ps VFMADD231PS (YMM, YMM, YMM) 4cy @ p01 4cy @ p01
|
|
89
|
+
* _mm256_fmadd_pd VFMADD231PD (YMM, YMM, YMM) 4cy @ p01 4cy @ p01
|
|
90
90
|
*
|
|
91
91
|
* @section arm_instructions Relevant ARM NEON/SVE Instructions
|
|
92
92
|
*
|
|
@@ -94,21 +94,21 @@
|
|
|
94
94
|
* acceptable since sqrt only appears once per distance calculation. FMA chains for trig
|
|
95
95
|
* polynomial evaluation pipeline well across all 4 V-units.
|
|
96
96
|
*
|
|
97
|
-
* Intrinsic
|
|
98
|
-
* vfmaq_f32
|
|
99
|
-
* vfmaq_f64
|
|
100
|
-
* vsqrtq_f32
|
|
101
|
-
* vsqrtq_f64
|
|
97
|
+
* Intrinsic Instruction M1 Firestorm Graviton 3 Graviton 4
|
|
98
|
+
* vfmaq_f32 FMLA.S (vec) 4cy @ V0123 4cy @ V0123 4cy @ V0123
|
|
99
|
+
* vfmaq_f64 FMLA.D (vec) 4cy @ V0123 4cy @ V0123 4cy @ V0123
|
|
100
|
+
* vsqrtq_f32 FSQRT.S (vec) 10cy @ V02 10cy @ V02 9cy @ V02
|
|
101
|
+
* vsqrtq_f64 FSQRT.D (vec) 13cy @ V02 16cy @ V02 16cy @ V02
|
|
102
102
|
*
|
|
103
103
|
* @section references References
|
|
104
104
|
*
|
|
105
|
-
* - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
|
|
105
|
+
* - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html
|
|
106
106
|
* - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
|
|
107
107
|
* - Earth Ellipsoid: https://en.wikipedia.org/wiki/Earth_ellipsoid
|
|
108
108
|
* - Oblate Spheroid Geodesic: https://mathworld.wolfram.com/OblateSpheroidGeodesic.html
|
|
109
|
-
* - Staging experiments: https://github.com/ashvardanian/HaversineMathKong
|
|
110
109
|
* - Speeding up atan2f by 50x: https://mazzo.li/posts/vectorized-atan2.html
|
|
111
|
-
* - Simplifying the GNU C Sine Function:
|
|
110
|
+
* - Simplifying the GNU C Sine Function:
|
|
111
|
+
* https://web.archive.org/web/20230605051610/https://www.awelm.com/posts/simplifying-the-gnu-c-sine-function/
|
|
112
112
|
*
|
|
113
113
|
*/
|
|
114
114
|
#ifndef NK_GEOSPATIAL_H
|
|
@@ -26,7 +26,7 @@
|
|
|
26
26
|
|
|
27
27
|
namespace ashvardanian::numkong {
|
|
28
28
|
|
|
29
|
-
#pragma region
|
|
29
|
+
#pragma region Packing Utilities
|
|
30
30
|
|
|
31
31
|
/**
|
|
32
32
|
* @brief Estimates the memory requirements for packed B matrix.
|
|
@@ -155,9 +155,9 @@ NK_PUBLIC void maxsim_pack(typename in_type_::raw_t const *vectors, std::size_t
|
|
|
155
155
|
}
|
|
156
156
|
}
|
|
157
157
|
|
|
158
|
-
#pragma endregion
|
|
158
|
+
#pragma endregion Packing Utilities
|
|
159
159
|
|
|
160
|
-
#pragma region
|
|
160
|
+
#pragma region Packed Containers
|
|
161
161
|
|
|
162
162
|
/**
|
|
163
163
|
* @brief Owning, move-only, pre-packed matrix for efficient GEMM.
|
|
@@ -329,7 +329,7 @@ class packed_maxsim {
|
|
|
329
329
|
std::size_t size_bytes() const noexcept { return size_bytes_; }
|
|
330
330
|
};
|
|
331
331
|
|
|
332
|
-
#pragma endregion
|
|
332
|
+
#pragma endregion Packed Containers
|
|
333
333
|
|
|
334
334
|
} // namespace ashvardanian::numkong
|
|
335
335
|
|
|
@@ -4,21 +4,21 @@ NumKong implements ColBERT-style late-interaction scoring: the MaxSim score sums
|
|
|
4
4
|
|
|
5
5
|
MaxSim score:
|
|
6
6
|
|
|
7
|
-
|
|
7
|
+
$$
|
|
8
8
|
\text{MaxSim}(Q, D) = \sum_{i=0}^{m-1} \min_{j=0}^{n-1} \text{angular}(q_i, d_j)
|
|
9
|
-
|
|
9
|
+
$$
|
|
10
10
|
|
|
11
11
|
Coarse screening finds the best document via i8 dot products as a proxy for argmin angular:
|
|
12
12
|
|
|
13
|
-
|
|
13
|
+
$$
|
|
14
14
|
j^* = \arg\max_j \text{dot}_{\text{i8}}(q_i, d_j)
|
|
15
|
-
|
|
15
|
+
$$
|
|
16
16
|
|
|
17
17
|
Full-precision refinement:
|
|
18
18
|
|
|
19
|
-
|
|
19
|
+
$$
|
|
20
20
|
\text{angular}(q_i, d_{j^*}) = 1 - \frac{\text{dot}(q_i, d_{j^*})}{\|q_i\| \cdot \|d_{j^*}\|}
|
|
21
|
-
|
|
21
|
+
$$
|
|
22
22
|
|
|
23
23
|
Reformulating as Python pseudocode:
|
|
24
24
|
|
|
@@ -46,7 +46,7 @@ def maxsim(queries: np.ndarray, documents: np.ndarray) -> float:
|
|
|
46
46
|
|
|
47
47
|
## Optimizations
|
|
48
48
|
|
|
49
|
-
### Dual Pre-Packing
|
|
49
|
+
### Dual Pre-Packing
|
|
50
50
|
|
|
51
51
|
`nk_maxsim_packed_bf16_sme`, `nk_maxsim_packed_f32_sme` benefit from having _both_ query and document matrices pre-packed into identical contiguous formats, unlike the `nk_dots_packed_*` family where only B is pre-packed and A is accessed with arbitrary stride.
|
|
52
52
|
In the dots GEMM, one ZA tile must be reserved for A-side staging (loading unpacked A rows into the tile array), leaving 3 ZA tiles for accumulation.
|
|
@@ -172,16 +172,16 @@ Measured with Wasmtime v42 (Cranelift backend).
|
|
|
172
172
|
|
|
173
173
|
#### WASM
|
|
174
174
|
|
|
175
|
-
Measured with Wasmtime
|
|
175
|
+
Measured with Wasmtime v43 (Cranelift backend).
|
|
176
176
|
|
|
177
177
|
| Kernel | 256³ | 1024³ | 4096³ |
|
|
178
178
|
| :---------------------------------- | -----------------------: | -----------------------: | -----------------------: |
|
|
179
179
|
| __f32__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
180
|
-
| `nk_maxsim_packed_f32_serial` |
|
|
181
|
-
| `nk_maxsim_packed_f32_v128relaxed` |
|
|
180
|
+
| `nk_maxsim_packed_f32_serial` | 33.7 gso/s, 46.8K ulp | 35.0 gso/s, 46.8K ulp | 35.8 gso/s, 46.8K ulp |
|
|
181
|
+
| `nk_maxsim_packed_f32_v128relaxed` | 88.5 gso/s, 46.0K ulp | 98.1 gso/s, 46.0K ulp | 82.7 gso/s, 46.0K ulp |
|
|
182
182
|
| __bf16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
183
|
-
| `nk_maxsim_packed_bf16_serial` |
|
|
184
|
-
| `nk_maxsim_packed_bf16_v128relaxed` |
|
|
183
|
+
| `nk_maxsim_packed_bf16_serial` | 34.4 gso/s, 49.2K ulp | 35.1 gso/s, 49.2K ulp | 35.7 gso/s, 49.2K ulp |
|
|
184
|
+
| `nk_maxsim_packed_bf16_v128relaxed` | 92.3 gso/s, 49.4K ulp | 100 gso/s, 49.4K ulp | 83.2 gso/s, 49.4K ulp |
|
|
185
185
|
| __f16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
186
|
-
| `nk_maxsim_packed_f16_serial` |
|
|
187
|
-
| `nk_maxsim_packed_f16_v128relaxed` |
|
|
186
|
+
| `nk_maxsim_packed_f16_serial` | 33.8 gso/s, 49.5K ulp | 35.0 gso/s, 49.5K ulp | 35.7 gso/s, 49.5K ulp |
|
|
187
|
+
| `nk_maxsim_packed_f16_v128relaxed` | 87.0 gso/s, 49.3K ulp | 95.8 gso/s, 49.3K ulp | 82.3 gso/s, 49.3K ulp |
|
|
@@ -57,7 +57,7 @@ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_alder(nk_size_t vector_count, nk_s
|
|
|
57
57
|
}
|
|
58
58
|
|
|
59
59
|
NK_PUBLIC void nk_maxsim_pack_bf16_alder( //
|
|
60
|
-
nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t
|
|
60
|
+
nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
|
|
61
61
|
|
|
62
62
|
nk_size_t const element_bytes = sizeof(nk_bf16_t);
|
|
63
63
|
nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 32, element_bytes);
|
|
@@ -69,7 +69,7 @@ NK_PUBLIC void nk_maxsim_pack_bf16_alder( //
|
|
|
69
69
|
nk_size_t const original_stride = header->original_stride_bytes;
|
|
70
70
|
|
|
71
71
|
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
72
|
-
char const *source_row = (char const *)vectors + vector_index *
|
|
72
|
+
char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
|
|
73
73
|
nk_f32_t norm_sq;
|
|
74
74
|
nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f,
|
|
75
75
|
(nk_maxsim_to_f32_t)nk_bf16_to_f32_serial,
|
|
@@ -83,7 +83,7 @@ NK_PUBLIC void nk_maxsim_pack_bf16_alder( //
|
|
|
83
83
|
}
|
|
84
84
|
|
|
85
85
|
NK_PUBLIC void nk_maxsim_pack_f32_alder( //
|
|
86
|
-
nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t
|
|
86
|
+
nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
|
|
87
87
|
|
|
88
88
|
nk_size_t const element_bytes = sizeof(nk_f32_t);
|
|
89
89
|
nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 32, element_bytes);
|
|
@@ -95,7 +95,7 @@ NK_PUBLIC void nk_maxsim_pack_f32_alder( //
|
|
|
95
95
|
nk_size_t const original_stride = header->original_stride_bytes;
|
|
96
96
|
|
|
97
97
|
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
98
|
-
char const *source_row = (char const *)vectors + vector_index *
|
|
98
|
+
char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
|
|
99
99
|
nk_f32_t norm_sq;
|
|
100
100
|
nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f, nk_f32_to_f32_,
|
|
101
101
|
&quantized_i8[vector_index * depth_i8_padded], &metadata[vector_index], &norm_sq);
|
|
@@ -108,7 +108,7 @@ NK_PUBLIC void nk_maxsim_pack_f32_alder( //
|
|
|
108
108
|
}
|
|
109
109
|
|
|
110
110
|
NK_PUBLIC void nk_maxsim_pack_f16_alder( //
|
|
111
|
-
nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t
|
|
111
|
+
nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
|
|
112
112
|
|
|
113
113
|
nk_size_t const element_bytes = sizeof(nk_f16_t);
|
|
114
114
|
nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 32, element_bytes);
|
|
@@ -120,7 +120,7 @@ NK_PUBLIC void nk_maxsim_pack_f16_alder( //
|
|
|
120
120
|
nk_size_t const original_stride = header->original_stride_bytes;
|
|
121
121
|
|
|
122
122
|
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
123
|
-
char const *source_row = (char const *)vectors + vector_index *
|
|
123
|
+
char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
|
|
124
124
|
nk_f32_t norm_sq;
|
|
125
125
|
nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f,
|
|
126
126
|
(nk_maxsim_to_f32_t)nk_f16_to_f32_haswell,
|
|
@@ -9,8 +9,8 @@
|
|
|
9
9
|
* Uses AVX-512 VNNI (VPDPBUSD) for coarse i8 screening via icelake.h, and VDPBF16PS for bf16 refinement.
|
|
10
10
|
* f32/f16 MaxSim variants live in icelake.h — this file only provides bf16 pack and compute.
|
|
11
11
|
*
|
|
12
|
-
* Intrinsic
|
|
13
|
-
* _mm512_dpbf16_ps
|
|
12
|
+
* Intrinsic Instruction Genoa
|
|
13
|
+
* _mm512_dpbf16_ps VDPBF16PS 6cy @ p01
|
|
14
14
|
*/
|
|
15
15
|
#ifndef NK_MAXSIM_GENOA_H
|
|
16
16
|
#define NK_MAXSIM_GENOA_H
|
|
@@ -41,7 +41,7 @@ NK_PUBLIC nk_size_t nk_maxsim_packed_size_bf16_genoa(nk_size_t vector_count, nk_
|
|
|
41
41
|
}
|
|
42
42
|
|
|
43
43
|
NK_PUBLIC void nk_maxsim_pack_bf16_genoa( //
|
|
44
|
-
nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t
|
|
44
|
+
nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
|
|
45
45
|
|
|
46
46
|
nk_size_t const element_bytes = sizeof(nk_bf16_t);
|
|
47
47
|
nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 64, element_bytes);
|
|
@@ -53,7 +53,7 @@ NK_PUBLIC void nk_maxsim_pack_bf16_genoa( //
|
|
|
53
53
|
nk_size_t const original_stride = header->original_stride_bytes;
|
|
54
54
|
|
|
55
55
|
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
56
|
-
char const *source_row = (char const *)vectors + vector_index *
|
|
56
|
+
char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
|
|
57
57
|
nk_f32_t norm_sq;
|
|
58
58
|
nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f,
|
|
59
59
|
(nk_maxsim_to_f32_t)nk_bf16_to_f32_serial,
|
|
@@ -49,7 +49,7 @@ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_haswell(nk_size_t vector_count, nk
|
|
|
49
49
|
}
|
|
50
50
|
|
|
51
51
|
NK_PUBLIC void nk_maxsim_pack_bf16_haswell( //
|
|
52
|
-
nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t
|
|
52
|
+
nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
|
|
53
53
|
|
|
54
54
|
nk_size_t const element_bytes = sizeof(nk_bf16_t);
|
|
55
55
|
nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 32, element_bytes);
|
|
@@ -61,7 +61,7 @@ NK_PUBLIC void nk_maxsim_pack_bf16_haswell( //
|
|
|
61
61
|
nk_size_t const original_stride = header->original_stride_bytes;
|
|
62
62
|
|
|
63
63
|
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
64
|
-
char const *source_row = (char const *)vectors + vector_index *
|
|
64
|
+
char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
|
|
65
65
|
nk_f32_t norm_sq;
|
|
66
66
|
nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 79.0f,
|
|
67
67
|
(nk_maxsim_to_f32_t)nk_bf16_to_f32_serial,
|
|
@@ -75,7 +75,7 @@ NK_PUBLIC void nk_maxsim_pack_bf16_haswell( //
|
|
|
75
75
|
}
|
|
76
76
|
|
|
77
77
|
NK_PUBLIC void nk_maxsim_pack_f32_haswell( //
|
|
78
|
-
nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t
|
|
78
|
+
nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
|
|
79
79
|
|
|
80
80
|
nk_size_t const element_bytes = sizeof(nk_f32_t);
|
|
81
81
|
nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 32, element_bytes);
|
|
@@ -87,7 +87,7 @@ NK_PUBLIC void nk_maxsim_pack_f32_haswell( //
|
|
|
87
87
|
nk_size_t const original_stride = header->original_stride_bytes;
|
|
88
88
|
|
|
89
89
|
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
90
|
-
char const *source_row = (char const *)vectors + vector_index *
|
|
90
|
+
char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
|
|
91
91
|
nk_f32_t norm_sq;
|
|
92
92
|
nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 79.0f, nk_f32_to_f32_,
|
|
93
93
|
&quantized_i8[vector_index * depth_i8_padded], &metadata[vector_index], &norm_sq);
|
|
@@ -100,7 +100,7 @@ NK_PUBLIC void nk_maxsim_pack_f32_haswell( //
|
|
|
100
100
|
}
|
|
101
101
|
|
|
102
102
|
NK_PUBLIC void nk_maxsim_pack_f16_haswell( //
|
|
103
|
-
nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t
|
|
103
|
+
nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
|
|
104
104
|
|
|
105
105
|
nk_size_t const element_bytes = sizeof(nk_f16_t);
|
|
106
106
|
nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 32, element_bytes);
|
|
@@ -112,7 +112,7 @@ NK_PUBLIC void nk_maxsim_pack_f16_haswell( //
|
|
|
112
112
|
nk_size_t const original_stride = header->original_stride_bytes;
|
|
113
113
|
|
|
114
114
|
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
115
|
-
char const *source_row = (char const *)vectors + vector_index *
|
|
115
|
+
char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
|
|
116
116
|
nk_f32_t norm_sq;
|
|
117
117
|
nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 79.0f,
|
|
118
118
|
(nk_maxsim_to_f32_t)nk_f16_to_f32_haswell,
|
|
@@ -9,15 +9,15 @@
|
|
|
9
9
|
* Uses AVX-512 VNNI (VPDPBUSD) for coarse i8 screening. The coarse argmax kernel and reduce helper
|
|
10
10
|
* are shared with genoa.h — genoa.h imports them from this file for its bf16 compute path.
|
|
11
11
|
*
|
|
12
|
-
* VPDPBUSD computes 4 groups of (u8
|
|
12
|
+
* VPDPBUSD computes 4 groups of (u8 × i8) → i32 per 128-bit lane, processing 64 i8 pairs
|
|
13
13
|
* per ZMM register operation. Bias correction via XOR with 0x80 converts signed queries
|
|
14
14
|
* to unsigned, then subtracts 128 * sum(document_i8) after the depth loop.
|
|
15
15
|
*
|
|
16
|
-
* 4x4 register tiling: 4 queries
|
|
16
|
+
* 4x4 register tiling: 4 queries × 4 documents = 16 ZMM accumulators per depth loop.
|
|
17
17
|
* Each document load is amortized across 4 VPDPBUSDs, and each query load across 4 documents.
|
|
18
18
|
*
|
|
19
|
-
* Intrinsic
|
|
20
|
-
* _mm512_dpbusd_epi32
|
|
19
|
+
* Intrinsic Instruction Icelake Genoa
|
|
20
|
+
* _mm512_dpbusd_epi32 VPDPBUSD 5cy @ p0 4cy @ p01
|
|
21
21
|
*/
|
|
22
22
|
#ifndef NK_MAXSIM_ICELAKE_H
|
|
23
23
|
#define NK_MAXSIM_ICELAKE_H
|
|
@@ -44,14 +44,14 @@ extern "C" {
|
|
|
44
44
|
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512vnni", "f16c", "fma", "bmi", "bmi2")
|
|
45
45
|
#endif
|
|
46
46
|
|
|
47
|
-
#pragma region
|
|
47
|
+
#pragma region F32 Floats
|
|
48
48
|
|
|
49
49
|
NK_PUBLIC nk_size_t nk_maxsim_packed_size_f32_icelake(nk_size_t vector_count, nk_size_t depth) {
|
|
50
50
|
return nk_maxsim_packed_size_(vector_count, depth, sizeof(nk_f32_t), 64);
|
|
51
51
|
}
|
|
52
52
|
|
|
53
53
|
NK_PUBLIC void nk_maxsim_pack_f32_icelake( //
|
|
54
|
-
nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t
|
|
54
|
+
nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
|
|
55
55
|
|
|
56
56
|
nk_size_t const element_bytes = sizeof(nk_f32_t);
|
|
57
57
|
nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 64, element_bytes);
|
|
@@ -63,7 +63,7 @@ NK_PUBLIC void nk_maxsim_pack_f32_icelake( //
|
|
|
63
63
|
nk_size_t const original_stride = header->original_stride_bytes;
|
|
64
64
|
|
|
65
65
|
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
66
|
-
char const *source_row = (char const *)vectors + vector_index *
|
|
66
|
+
char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
|
|
67
67
|
nk_f32_t norm_sq;
|
|
68
68
|
nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f, nk_f32_to_f32_,
|
|
69
69
|
&quantized_i8[vector_index * depth_i8_padded], &metadata[vector_index], &norm_sq);
|
|
@@ -75,16 +75,16 @@ NK_PUBLIC void nk_maxsim_pack_f32_icelake( //
|
|
|
75
75
|
}
|
|
76
76
|
}
|
|
77
77
|
|
|
78
|
-
#pragma endregion
|
|
78
|
+
#pragma endregion F32 Floats
|
|
79
79
|
|
|
80
|
-
#pragma region
|
|
80
|
+
#pragma region F16 Floats
|
|
81
81
|
|
|
82
82
|
NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_icelake(nk_size_t vector_count, nk_size_t depth) {
|
|
83
83
|
return nk_maxsim_packed_size_(vector_count, depth, sizeof(nk_f16_t), 64);
|
|
84
84
|
}
|
|
85
85
|
|
|
86
86
|
NK_PUBLIC void nk_maxsim_pack_f16_icelake( //
|
|
87
|
-
nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t
|
|
87
|
+
nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
|
|
88
88
|
|
|
89
89
|
nk_size_t const element_bytes = sizeof(nk_f16_t);
|
|
90
90
|
nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 64, element_bytes);
|
|
@@ -96,7 +96,7 @@ NK_PUBLIC void nk_maxsim_pack_f16_icelake( //
|
|
|
96
96
|
nk_size_t const original_stride = header->original_stride_bytes;
|
|
97
97
|
|
|
98
98
|
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
99
|
-
char const *source_row = (char const *)vectors + vector_index *
|
|
99
|
+
char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
|
|
100
100
|
nk_f32_t norm_sq;
|
|
101
101
|
nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f,
|
|
102
102
|
(nk_maxsim_to_f32_t)nk_f16_to_f32_haswell,
|
|
@@ -109,7 +109,7 @@ NK_PUBLIC void nk_maxsim_pack_f16_icelake( //
|
|
|
109
109
|
}
|
|
110
110
|
}
|
|
111
111
|
|
|
112
|
-
#pragma endregion
|
|
112
|
+
#pragma endregion F16 Floats
|
|
113
113
|
|
|
114
114
|
#pragma region Coarse Argmax
|
|
115
115
|
|
|
@@ -117,7 +117,7 @@ NK_PUBLIC void nk_maxsim_pack_f16_icelake( //
|
|
|
117
117
|
NK_INTERNAL __m128i nk_maxsim_reduce_i32x16x4_icelake_( //
|
|
118
118
|
__m512i accumulator_a_i32x16, __m512i accumulator_b_i32x16, //
|
|
119
119
|
__m512i accumulator_c_i32x16, __m512i accumulator_d_i32x16) {
|
|
120
|
-
// Step 1: 16
|
|
120
|
+
// Step 1: 16 → 8 (extract high 256-bit half and add to low half)
|
|
121
121
|
__m256i sum_a_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(accumulator_a_i32x16),
|
|
122
122
|
_mm512_extracti32x8_epi32(accumulator_a_i32x16, 1));
|
|
123
123
|
__m256i sum_b_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(accumulator_b_i32x16),
|
|
@@ -126,12 +126,12 @@ NK_INTERNAL __m128i nk_maxsim_reduce_i32x16x4_icelake_( //
|
|
|
126
126
|
_mm512_extracti32x8_epi32(accumulator_c_i32x16, 1));
|
|
127
127
|
__m256i sum_d_i32x8 = _mm256_add_epi32(_mm512_castsi512_si256(accumulator_d_i32x16),
|
|
128
128
|
_mm512_extracti32x8_epi32(accumulator_d_i32x16, 1));
|
|
129
|
-
// Step 2: 8
|
|
129
|
+
// Step 2: 8 → 4 (extract high 128-bit half and add to low half)
|
|
130
130
|
__m128i sum_a_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_a_i32x8), _mm256_extracti128_si256(sum_a_i32x8, 1));
|
|
131
131
|
__m128i sum_b_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_b_i32x8), _mm256_extracti128_si256(sum_b_i32x8, 1));
|
|
132
132
|
__m128i sum_c_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_c_i32x8), _mm256_extracti128_si256(sum_c_i32x8, 1));
|
|
133
133
|
__m128i sum_d_i32x4 = _mm_add_epi32(_mm256_castsi256_si128(sum_d_i32x8), _mm256_extracti128_si256(sum_d_i32x8, 1));
|
|
134
|
-
// Step 3: 4x4 transpose + reduce
|
|
134
|
+
// Step 3: 4x4 transpose + reduce → [sum_a, sum_b, sum_c, sum_d]
|
|
135
135
|
__m128i transpose_ab_low_i32x4 = _mm_unpacklo_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
136
136
|
__m128i transpose_cd_low_i32x4 = _mm_unpacklo_epi32(sum_c_i32x4, sum_d_i32x4);
|
|
137
137
|
__m128i transpose_ab_high_i32x4 = _mm_unpackhi_epi32(sum_a_i32x4, sum_b_i32x4);
|
|
@@ -258,7 +258,7 @@ NK_INTERNAL void nk_maxsim_coarse_argmax_icelake_( //
|
|
|
258
258
|
query_2_coarse_dots_i32x4 = _mm_sub_epi32(query_2_coarse_dots_i32x4, bias_correction_i32x4);
|
|
259
259
|
query_3_coarse_dots_i32x4 = _mm_sub_epi32(query_3_coarse_dots_i32x4, bias_correction_i32x4);
|
|
260
260
|
|
|
261
|
-
// 4x4 transpose: [query][doc]
|
|
261
|
+
// 4x4 transpose: [query][doc] → [doc][query] for vectorized argmax
|
|
262
262
|
__m128i transpose_queries_01_low_i32x4 = _mm_unpacklo_epi32(query_0_coarse_dots_i32x4,
|
|
263
263
|
query_1_coarse_dots_i32x4);
|
|
264
264
|
__m128i transpose_queries_23_low_i32x4 = _mm_unpacklo_epi32(query_2_coarse_dots_i32x4,
|
|
@@ -390,7 +390,7 @@ NK_INTERNAL void nk_maxsim_coarse_argmax_icelake_( //
|
|
|
390
390
|
}
|
|
391
391
|
}
|
|
392
392
|
|
|
393
|
-
#pragma endregion
|
|
393
|
+
#pragma endregion Coarse Argmax
|
|
394
394
|
|
|
395
395
|
#pragma region Compute Functions
|
|
396
396
|
|
|
@@ -463,7 +463,7 @@ NK_PUBLIC void nk_maxsim_packed_f16_icelake( //
|
|
|
463
463
|
*result = (nk_f32_t)total_angular_distance;
|
|
464
464
|
}
|
|
465
465
|
|
|
466
|
-
#pragma endregion
|
|
466
|
+
#pragma endregion Compute Functions
|
|
467
467
|
|
|
468
468
|
#if defined(__clang__)
|
|
469
469
|
#pragma clang attribute pop
|
|
@@ -46,7 +46,7 @@ NK_PUBLIC nk_size_t nk_maxsim_packed_size_f16_neonsdot(nk_size_t vector_count, n
|
|
|
46
46
|
}
|
|
47
47
|
|
|
48
48
|
NK_PUBLIC void nk_maxsim_pack_bf16_neonsdot( //
|
|
49
|
-
nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t
|
|
49
|
+
nk_bf16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
|
|
50
50
|
|
|
51
51
|
nk_size_t const element_bytes = sizeof(nk_bf16_t);
|
|
52
52
|
nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 16, element_bytes);
|
|
@@ -58,7 +58,7 @@ NK_PUBLIC void nk_maxsim_pack_bf16_neonsdot( //
|
|
|
58
58
|
nk_size_t const original_stride = header->original_stride_bytes;
|
|
59
59
|
|
|
60
60
|
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
61
|
-
char const *source_row = (char const *)vectors + vector_index *
|
|
61
|
+
char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
|
|
62
62
|
nk_f32_t norm_sq;
|
|
63
63
|
nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f,
|
|
64
64
|
(nk_maxsim_to_f32_t)nk_bf16_to_f32_serial,
|
|
@@ -72,7 +72,7 @@ NK_PUBLIC void nk_maxsim_pack_bf16_neonsdot( //
|
|
|
72
72
|
}
|
|
73
73
|
|
|
74
74
|
NK_PUBLIC void nk_maxsim_pack_f32_neonsdot( //
|
|
75
|
-
nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t
|
|
75
|
+
nk_f32_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
|
|
76
76
|
|
|
77
77
|
nk_size_t const element_bytes = sizeof(nk_f32_t);
|
|
78
78
|
nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 16, element_bytes);
|
|
@@ -84,7 +84,7 @@ NK_PUBLIC void nk_maxsim_pack_f32_neonsdot( //
|
|
|
84
84
|
nk_size_t const original_stride = header->original_stride_bytes;
|
|
85
85
|
|
|
86
86
|
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
87
|
-
char const *source_row = (char const *)vectors + vector_index *
|
|
87
|
+
char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
|
|
88
88
|
nk_f32_t norm_sq;
|
|
89
89
|
nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f, nk_f32_to_f32_,
|
|
90
90
|
&quantized_i8[vector_index * depth_i8_padded], &metadata[vector_index], &norm_sq);
|
|
@@ -97,7 +97,7 @@ NK_PUBLIC void nk_maxsim_pack_f32_neonsdot( //
|
|
|
97
97
|
}
|
|
98
98
|
|
|
99
99
|
NK_PUBLIC void nk_maxsim_pack_f16_neonsdot( //
|
|
100
|
-
nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t
|
|
100
|
+
nk_f16_t const *vectors, nk_size_t vector_count, nk_size_t depth, nk_size_t stride_in_bytes, void *packed) {
|
|
101
101
|
|
|
102
102
|
nk_size_t const element_bytes = sizeof(nk_f16_t);
|
|
103
103
|
nk_size_t depth_i8_padded = nk_maxsim_packed_header_setup_(packed, vector_count, depth, 16, element_bytes);
|
|
@@ -109,7 +109,7 @@ NK_PUBLIC void nk_maxsim_pack_f16_neonsdot( //
|
|
|
109
109
|
nk_size_t const original_stride = header->original_stride_bytes;
|
|
110
110
|
|
|
111
111
|
for (nk_size_t vector_index = 0; vector_index < vector_count; vector_index++) {
|
|
112
|
-
char const *source_row = (char const *)vectors + vector_index *
|
|
112
|
+
char const *source_row = (char const *)vectors + vector_index * stride_in_bytes;
|
|
113
113
|
nk_f32_t norm_sq;
|
|
114
114
|
nk_maxsim_quantize_vector_(source_row, element_bytes, depth, depth_i8_padded, 127.0f,
|
|
115
115
|
(nk_maxsim_to_f32_t)nk_f16_to_f32_neon,
|
|
@@ -149,39 +149,39 @@ NK_INTERNAL void nk_maxsim_coarse_argmax_neonsdot_(
|
|
|
149
149
|
// Depth loop: 16 bytes per step
|
|
150
150
|
for (nk_size_t depth_index = 0; depth_index < depth_i8_padded; depth_index += 16) {
|
|
151
151
|
int8x16_t query_i8x16_0 = vld1q_s8(
|
|
152
|
-
(
|
|
152
|
+
(nk_i8_t const *)(query_i8 + (query_block_start_index + 0) * depth_i8_padded + depth_index));
|
|
153
153
|
int8x16_t query_i8x16_1 = vld1q_s8(
|
|
154
|
-
(
|
|
154
|
+
(nk_i8_t const *)(query_i8 + (query_block_start_index + 1) * depth_i8_padded + depth_index));
|
|
155
155
|
int8x16_t query_i8x16_2 = vld1q_s8(
|
|
156
|
-
(
|
|
156
|
+
(nk_i8_t const *)(query_i8 + (query_block_start_index + 2) * depth_i8_padded + depth_index));
|
|
157
157
|
int8x16_t query_i8x16_3 = vld1q_s8(
|
|
158
|
-
(
|
|
158
|
+
(nk_i8_t const *)(query_i8 + (query_block_start_index + 3) * depth_i8_padded + depth_index));
|
|
159
159
|
|
|
160
160
|
int8x16_t document_i8x16;
|
|
161
161
|
|
|
162
162
|
document_i8x16 = vld1q_s8(
|
|
163
|
-
(
|
|
163
|
+
(nk_i8_t const *)(document_i8 + (document_block_start_index + 0) * depth_i8_padded + depth_index));
|
|
164
164
|
accumulator_tiles_i32x4[0][0] = vdotq_s32(accumulator_tiles_i32x4[0][0], query_i8x16_0, document_i8x16);
|
|
165
165
|
accumulator_tiles_i32x4[1][0] = vdotq_s32(accumulator_tiles_i32x4[1][0], query_i8x16_1, document_i8x16);
|
|
166
166
|
accumulator_tiles_i32x4[2][0] = vdotq_s32(accumulator_tiles_i32x4[2][0], query_i8x16_2, document_i8x16);
|
|
167
167
|
accumulator_tiles_i32x4[3][0] = vdotq_s32(accumulator_tiles_i32x4[3][0], query_i8x16_3, document_i8x16);
|
|
168
168
|
|
|
169
169
|
document_i8x16 = vld1q_s8(
|
|
170
|
-
(
|
|
170
|
+
(nk_i8_t const *)(document_i8 + (document_block_start_index + 1) * depth_i8_padded + depth_index));
|
|
171
171
|
accumulator_tiles_i32x4[0][1] = vdotq_s32(accumulator_tiles_i32x4[0][1], query_i8x16_0, document_i8x16);
|
|
172
172
|
accumulator_tiles_i32x4[1][1] = vdotq_s32(accumulator_tiles_i32x4[1][1], query_i8x16_1, document_i8x16);
|
|
173
173
|
accumulator_tiles_i32x4[2][1] = vdotq_s32(accumulator_tiles_i32x4[2][1], query_i8x16_2, document_i8x16);
|
|
174
174
|
accumulator_tiles_i32x4[3][1] = vdotq_s32(accumulator_tiles_i32x4[3][1], query_i8x16_3, document_i8x16);
|
|
175
175
|
|
|
176
176
|
document_i8x16 = vld1q_s8(
|
|
177
|
-
(
|
|
177
|
+
(nk_i8_t const *)(document_i8 + (document_block_start_index + 2) * depth_i8_padded + depth_index));
|
|
178
178
|
accumulator_tiles_i32x4[0][2] = vdotq_s32(accumulator_tiles_i32x4[0][2], query_i8x16_0, document_i8x16);
|
|
179
179
|
accumulator_tiles_i32x4[1][2] = vdotq_s32(accumulator_tiles_i32x4[1][2], query_i8x16_1, document_i8x16);
|
|
180
180
|
accumulator_tiles_i32x4[2][2] = vdotq_s32(accumulator_tiles_i32x4[2][2], query_i8x16_2, document_i8x16);
|
|
181
181
|
accumulator_tiles_i32x4[3][2] = vdotq_s32(accumulator_tiles_i32x4[3][2], query_i8x16_3, document_i8x16);
|
|
182
182
|
|
|
183
183
|
document_i8x16 = vld1q_s8(
|
|
184
|
-
(
|
|
184
|
+
(nk_i8_t const *)(document_i8 + (document_block_start_index + 3) * depth_i8_padded + depth_index));
|
|
185
185
|
accumulator_tiles_i32x4[0][3] = vdotq_s32(accumulator_tiles_i32x4[0][3], query_i8x16_0, document_i8x16);
|
|
186
186
|
accumulator_tiles_i32x4[1][3] = vdotq_s32(accumulator_tiles_i32x4[1][3], query_i8x16_1, document_i8x16);
|
|
187
187
|
accumulator_tiles_i32x4[2][3] = vdotq_s32(accumulator_tiles_i32x4[2][3], query_i8x16_2, document_i8x16);
|
|
@@ -211,27 +211,27 @@ NK_INTERNAL void nk_maxsim_coarse_argmax_neonsdot_(
|
|
|
211
211
|
int32x4_t accumulator_i32x4_3 = vdupq_n_s32(0);
|
|
212
212
|
|
|
213
213
|
for (nk_size_t depth_index = 0; depth_index < depth_i8_padded; depth_index += 16) {
|
|
214
|
-
int8x16_t document_i8x16 = vld1q_s8((
|
|
214
|
+
int8x16_t document_i8x16 = vld1q_s8((nk_i8_t const *)(document_i8_row + depth_index));
|
|
215
215
|
|
|
216
216
|
accumulator_i32x4_0 = vdotq_s32(
|
|
217
217
|
accumulator_i32x4_0,
|
|
218
218
|
vld1q_s8(
|
|
219
|
-
(
|
|
219
|
+
(nk_i8_t const *)(query_i8 + (query_block_start_index + 0) * depth_i8_padded + depth_index)),
|
|
220
220
|
document_i8x16);
|
|
221
221
|
accumulator_i32x4_1 = vdotq_s32(
|
|
222
222
|
accumulator_i32x4_1,
|
|
223
223
|
vld1q_s8(
|
|
224
|
-
(
|
|
224
|
+
(nk_i8_t const *)(query_i8 + (query_block_start_index + 1) * depth_i8_padded + depth_index)),
|
|
225
225
|
document_i8x16);
|
|
226
226
|
accumulator_i32x4_2 = vdotq_s32(
|
|
227
227
|
accumulator_i32x4_2,
|
|
228
228
|
vld1q_s8(
|
|
229
|
-
(
|
|
229
|
+
(nk_i8_t const *)(query_i8 + (query_block_start_index + 2) * depth_i8_padded + depth_index)),
|
|
230
230
|
document_i8x16);
|
|
231
231
|
accumulator_i32x4_3 = vdotq_s32(
|
|
232
232
|
accumulator_i32x4_3,
|
|
233
233
|
vld1q_s8(
|
|
234
|
-
(
|
|
234
|
+
(nk_i8_t const *)(query_i8 + (query_block_start_index + 3) * depth_i8_padded + depth_index)),
|
|
235
235
|
document_i8x16);
|
|
236
236
|
}
|
|
237
237
|
|
|
@@ -260,8 +260,8 @@ NK_INTERNAL void nk_maxsim_coarse_argmax_neonsdot_(
|
|
|
260
260
|
int32x4_t accumulator_i32x4 = vdupq_n_s32(0);
|
|
261
261
|
|
|
262
262
|
for (nk_size_t depth_index = 0; depth_index < depth_i8_padded; depth_index += 16) {
|
|
263
|
-
int8x16_t query_i8x16 = vld1q_s8((
|
|
264
|
-
int8x16_t document_i8x16 = vld1q_s8((
|
|
263
|
+
int8x16_t query_i8x16 = vld1q_s8((nk_i8_t const *)(query_i8_row + depth_index));
|
|
264
|
+
int8x16_t document_i8x16 = vld1q_s8((nk_i8_t const *)(document_i8_row + depth_index));
|
|
265
265
|
accumulator_i32x4 = vdotq_s32(accumulator_i32x4, query_i8x16, document_i8x16);
|
|
266
266
|
}
|
|
267
267
|
|