numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,1425 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Spatial Similarity Measures.
|
|
3
|
+
* @file include/numkong/spatial.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 14, 2023
|
|
6
|
+
*
|
|
7
|
+
* Contains following similarity measures:
|
|
8
|
+
*
|
|
9
|
+
* - L2 (Euclidean) regular and squared distance
|
|
10
|
+
* - Cosine (Angular) distance - @b not similarity!
|
|
11
|
+
*
|
|
12
|
+
* For dtypes:
|
|
13
|
+
*
|
|
14
|
+
* - f64: 64-bit IEEE floating point numbers → 64-bit floats
|
|
15
|
+
* - f32: 32-bit IEEE floating point numbers → 64-bit floats
|
|
16
|
+
* - f16: 16-bit IEEE floating point numbers → 32-bit floats
|
|
17
|
+
* - bf16: 16-bit brain floating point numbers → 32-bit floats
|
|
18
|
+
* - e4m3: 8-bit e4m3 floating point numbers → 32-bit floats
|
|
19
|
+
* - e5m2: 8-bit e5m2 floating point numbers → 32-bit floats
|
|
20
|
+
* - e2m3: 8-bit e2m3 floating point numbers (MX) → 32-bit floats
|
|
21
|
+
* - e3m2: 8-bit e3m2 floating point numbers (MX) → 32-bit floats
|
|
22
|
+
* - i8: 8-bit signed integers → 32-bit floats
|
|
23
|
+
* - u8: 8-bit unsigned integers → 32-bit floats
|
|
24
|
+
* - i4: 4-bit signed integers (packed pairs) → 32-bit floats
|
|
25
|
+
* - u4: 4-bit unsigned integers (packed pairs) → 32-bit floats
|
|
26
|
+
*
|
|
27
|
+
* For hardware architectures:
|
|
28
|
+
*
|
|
29
|
+
* - Arm: NEON, NEON+F16, NEON+BF16, NEON+SDOT, SVE, SVE+F16, SVE+BF16
|
|
30
|
+
* - x86: Haswell, Skylake, Ice Lake, Genoa, Sapphire Rapids, Sierra Forest
|
|
31
|
+
* - RISC-V: RVV, RVV+BF16, RVV+HALF
|
|
32
|
+
* - WASM: V128Relaxed
|
|
33
|
+
*
|
|
34
|
+
* @section numerical_stability Numerical Stability
|
|
35
|
+
*
|
|
36
|
+
* Serial kernels use compensated summation for dot, a_norm_sq, b_norm_sq — O(1) error growth regardless of vector
|
|
37
|
+
* dimension. `f32` public outputs widen to `f64`, so widened paths use `f64` arithmetic and `sqrt64`.
|
|
38
|
+
* Angular finalization uses rsqrt via magic constant + 3 Newton-Raphson iterations (f32,
|
|
39
|
+
* ~34.9 correct bits) or 4 iterations (f64, ~69.3 correct bits), then clamps result ≥ 0.
|
|
40
|
+
* L2 uses conditional `dist_sq > 0 ? sqrt(dist_sq) : 0` to avoid NaN from rounding.
|
|
41
|
+
* Integer types (i8/u8/i4/u4) accumulate squared differences in i32 — overflows at
|
|
42
|
+
* n > 2^31/65,025 ≈ 33K for i8 (max diff² = 255²). Output is cast to f32.
|
|
43
|
+
*
|
|
44
|
+
* @section streaming_api Streaming API
|
|
45
|
+
*
|
|
46
|
+
* Angular and L2 distances can be computed from a single dot-product stream and precomputed magnitudes.
|
|
47
|
+
* The streaming helpers operate on 512-bit blocks (`nk_b512_vec_t`) and only accumulate $A*B$.
|
|
48
|
+
* Finalization takes the magnitudes of the full vectors (L2 norms) and computes the distance.
|
|
49
|
+
* Let the following be computed over the full vectors:
|
|
50
|
+
*
|
|
51
|
+
* ab = Σᵢ (aᵢ × bᵢ)
|
|
52
|
+
* ‖a‖ = √(Σᵢ aᵢ²)
|
|
53
|
+
* ‖b‖ = √(Σᵢ bᵢ²)
|
|
54
|
+
*
|
|
55
|
+
* Finalization formulas:
|
|
56
|
+
*
|
|
57
|
+
* angular(a, b) = 1 − ab / (‖a‖ × ‖b‖)
|
|
58
|
+
* l2(a, b) = √( ‖a‖² + ‖b‖² − 2 × ab )
|
|
59
|
+
*
|
|
60
|
+
* The angular distance is clamped to ≥ 0, with a 0 result when both norms are zero and a 1 result when $ab$ is zero.
|
|
61
|
+
* L2 clamps the argument of the square root at 0 to avoid negative values from rounding.
|
|
62
|
+
*
|
|
63
|
+
* @code{.c}
|
|
64
|
+
* nk_b512_vec_t a_block, b_block;
|
|
65
|
+
* nk_f32_t a_norm = ..., b_norm = ...; // Precomputed L2 norms of full vectors
|
|
66
|
+
* nk_angular_f32x8_state_haswell_t state; // Often equivalent to dot-product state
|
|
67
|
+
* nk_angular_f32x8_init_haswell(&state);
|
|
68
|
+
* nk_angular_f32x8_update_haswell(&state, a_block, b_block);
|
|
69
|
+
* nk_angular_f32x8_finalize_haswell(&state, a_norm, b_norm, &distance);
|
|
70
|
+
* @endcode
|
|
71
|
+
*
|
|
72
|
+
* @section rsqrt_notes Reciprocal Square Root and Newton-Raphson Notes
|
|
73
|
+
*
|
|
74
|
+
* Angular distance normalization uses reciprocal square roots to avoid the
|
|
75
|
+
* latency of full sqrt/div pipelines. We refine the rsqrt estimate with one
|
|
76
|
+
* (x86) or two (Arm NEON) Newton-Raphson iterations to reduce error.
|
|
77
|
+
*
|
|
78
|
+
* Relevant instructions and caveats:
|
|
79
|
+
*
|
|
80
|
+
* Intrinsic Instruction Notes
|
|
81
|
+
* _mm_rsqrt_ps VRSQRTPS fast approx; refine with NR
|
|
82
|
+
* _mm_maskz_rsqrt14_pd VRSQRT14PD higher-precision approx; MSVC masked-only
|
|
83
|
+
* _mm_sqrt_ps/_mm_sqrt_pd VSQRTPS/VSQRTPD higher latency, sqrt/div unit
|
|
84
|
+
*
|
|
85
|
+
* Latency/port notes (rule of thumb):
|
|
86
|
+
* - On Intel client cores, sqrt/rsqrt execute on the divide/sqrt unit (often
|
|
87
|
+
* port 0) and can bottleneck tight loops.
|
|
88
|
+
* - NR refinement uses mul/FMA ports and amortizes well when `ab` is reduced
|
|
89
|
+
* to a scalar and reused for finalization.
|
|
90
|
+
* - Arm NEON `rsqrt` accuracy is coarse; we apply two refinement steps to keep
|
|
91
|
+
* angular distance error bounded.
|
|
92
|
+
*
|
|
93
|
+
* @section x86_instructions Relevant x86 Instructions
|
|
94
|
+
*
|
|
95
|
+
* AVX2 lacks signed 8-bit dot products, so Haswell widens to i16 and uses VPMADDWD.
|
|
96
|
+
* AVX-512 VNNI replaces that with VPDPWSSD. BF16 uses VDPBF16PS where available to avoid
|
|
97
|
+
* convert+FMA sequences; if the ISA lacks it, we fall back to f32 FMA in the AVX2/serial:
|
|
98
|
+
*
|
|
99
|
+
* Intrinsic Instruction Ice Genoa
|
|
100
|
+
* _mm256_fmadd_ps VFMADD231PS (YMM, YMM, YMM) 4c @ p01 4c @ p01
|
|
101
|
+
* _mm256_fmadd_pd VFMADD231PD (YMM, YMM, YMM) 4c @ p01 4c @ p01
|
|
102
|
+
* _mm256_madd_epi16 VPMADDWD (YMM, YMM, YMM) 5c @ p01 3c @ p01
|
|
103
|
+
* _mm512_dpwssd_epi32 VPDPWSSD (ZMM, K, ZMM, ZMM) 5c @ p05 4c @ p01
|
|
104
|
+
* _mm512_dpbf16_ps VDPBF16PS (ZMM, K, ZMM, ZMM) n/a 6c @ p01
|
|
105
|
+
* _mm_rsqrt_ps VRSQRTPS (XMM, XMM) 5c @ p0 4c @ p01
|
|
106
|
+
* _mm_maskz_rsqrt14_pd VRSQRT14PD (XMM, K, XMM) 4c @ p0 5c @ p01
|
|
107
|
+
* _mm_sqrt_ps VSQRTPS (XMM, XMM) 12c @ p0 15c @ p01
|
|
108
|
+
*
|
|
109
|
+
* @section arm_instructions Relevant Arm Instructions
|
|
110
|
+
*
|
|
111
|
+
* The NEON/SVE kernels in this header are structured around FMLA/SDOT/BFDOT loops,
|
|
112
|
+
* which is why we avoid mul+add splits and keep reductions to scalars before square roots.
|
|
113
|
+
* Dot-product kernels for i8/u8 are only built when the "dotprod+i8mm" target is enabled;
|
|
114
|
+
* otherwise we rely on the serial backends. BF16 kernels are enabled only with BF16 dot
|
|
115
|
+
* instructions skipping `vbfmlal` and `vbfmlalt` alternatives to limit shuffle overhead
|
|
116
|
+
* and code complexity.
|
|
117
|
+
*
|
|
118
|
+
* Intrinsic Instruction M1 Firestorm
|
|
119
|
+
* vfmaq_f32 FMLA.S (vec) 4c / 4c
|
|
120
|
+
* vfmaq_f64 FMLA.D (vec) 4c / 4c
|
|
121
|
+
* vdotq_s32 SDOT.B (vec) 3c / 4c
|
|
122
|
+
* vbfdotq_f32 BFDOT (vec) n/a
|
|
123
|
+
* vrsqrteq_f32 FRSQRTE.S (vec) 3c / 1c
|
|
124
|
+
* vrsqrtsq_f32 FRSQRTS.S (vec) 4c / 4c
|
|
125
|
+
* vsqrtq_f32 FSQRT.S (vec) 10c / 0.5c
|
|
126
|
+
*
|
|
127
|
+
* @section references References
|
|
128
|
+
*
|
|
129
|
+
* - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
|
|
130
|
+
* - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
|
|
131
|
+
*
|
|
132
|
+
*/
|
|
133
|
+
#ifndef NK_SPATIAL_H
|
|
134
|
+
#define NK_SPATIAL_H
|
|
135
|
+
|
|
136
|
+
#include "numkong/types.h"
|
|
137
|
+
|
|
138
|
+
#if defined(__cplusplus)
|
|
139
|
+
extern "C" {
|
|
140
|
+
#endif
|
|
141
|
+
|
|
142
|
+
/**
|
|
143
|
+
* @brief L2 (Euclidean) distance between two vectors.
|
|
144
|
+
*
|
|
145
|
+
* @param[in] a The first vector.
|
|
146
|
+
* @param[in] b The second vector.
|
|
147
|
+
* @param[in] n The number of elements in each vector.
|
|
148
|
+
* @param[out] result The output distance value.
|
|
149
|
+
*
|
|
150
|
+
* @note The output distance value is non-negative.
|
|
151
|
+
* @note The output distance value is zero if and only if the two vectors are identical.
|
|
152
|
+
*/
|
|
153
|
+
NK_DYNAMIC void nk_euclidean_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
154
|
+
/** @copydoc nk_euclidean_f64 */
|
|
155
|
+
NK_DYNAMIC void nk_euclidean_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
156
|
+
/** @copydoc nk_euclidean_f64 */
|
|
157
|
+
NK_DYNAMIC void nk_euclidean_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
158
|
+
/** @copydoc nk_euclidean_f64 */
|
|
159
|
+
NK_DYNAMIC void nk_euclidean_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
160
|
+
/** @copydoc nk_euclidean_f64 */
|
|
161
|
+
NK_DYNAMIC void nk_euclidean_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
162
|
+
/** @copydoc nk_euclidean_f64 */
|
|
163
|
+
NK_DYNAMIC void nk_euclidean_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
164
|
+
/** @copydoc nk_euclidean_f64 */
|
|
165
|
+
NK_DYNAMIC void nk_euclidean_e2m3(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
166
|
+
/** @copydoc nk_euclidean_f64 */
|
|
167
|
+
NK_DYNAMIC void nk_euclidean_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
168
|
+
/** @copydoc nk_euclidean_f64 */
|
|
169
|
+
NK_DYNAMIC void nk_euclidean_i8(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
170
|
+
/** @copydoc nk_euclidean_f64 */
|
|
171
|
+
NK_DYNAMIC void nk_euclidean_u8(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
172
|
+
/** @copydoc nk_euclidean_f64 */
|
|
173
|
+
NK_DYNAMIC void nk_euclidean_i4(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
174
|
+
/** @copydoc nk_euclidean_f64 */
|
|
175
|
+
NK_DYNAMIC void nk_euclidean_u4(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
176
|
+
|
|
177
|
+
/**
|
|
178
|
+
* @brief Squared L2 (Euclidean) distance between two vectors.
|
|
179
|
+
*
|
|
180
|
+
* @param[in] a The first vector.
|
|
181
|
+
* @param[in] b The second vector.
|
|
182
|
+
* @param[in] n The number of elements in each vector.
|
|
183
|
+
* @param[out] result The output distance value.
|
|
184
|
+
*
|
|
185
|
+
* @note The output distance value is non-negative.
|
|
186
|
+
* @note The output distance value is zero if and only if the two vectors are identical.
|
|
187
|
+
*/
|
|
188
|
+
NK_DYNAMIC void nk_sqeuclidean_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
189
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
190
|
+
NK_DYNAMIC void nk_sqeuclidean_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
191
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
192
|
+
NK_DYNAMIC void nk_sqeuclidean_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
193
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
194
|
+
NK_DYNAMIC void nk_sqeuclidean_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
195
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
196
|
+
NK_DYNAMIC void nk_sqeuclidean_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
197
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
198
|
+
NK_DYNAMIC void nk_sqeuclidean_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
199
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
200
|
+
NK_DYNAMIC void nk_sqeuclidean_e2m3(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
201
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
202
|
+
NK_DYNAMIC void nk_sqeuclidean_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
203
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
204
|
+
NK_DYNAMIC void nk_sqeuclidean_i8(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
205
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
206
|
+
NK_DYNAMIC void nk_sqeuclidean_u8(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
207
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
208
|
+
NK_DYNAMIC void nk_sqeuclidean_i4(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_u32_t *result);
|
|
209
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
210
|
+
NK_DYNAMIC void nk_sqeuclidean_u4(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result);
|
|
211
|
+
|
|
212
|
+
/**
|
|
213
|
+
* @brief Angular (cosine) distance between two vectors.
|
|
214
|
+
*
|
|
215
|
+
* @param[in] a The first vector.
|
|
216
|
+
* @param[in] b The second vector.
|
|
217
|
+
* @param[in] n The number of elements in each vector.
|
|
218
|
+
* @param[out] result The output distance value.
|
|
219
|
+
*
|
|
220
|
+
* @note The output distance value is non-negative.
|
|
221
|
+
* @note The output distance value is zero if and only if the two vectors are identical.
|
|
222
|
+
*/
|
|
223
|
+
NK_DYNAMIC void nk_angular_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
224
|
+
/** @copydoc nk_angular_f64 */
|
|
225
|
+
NK_DYNAMIC void nk_angular_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
226
|
+
/** @copydoc nk_angular_f64 */
|
|
227
|
+
NK_DYNAMIC void nk_angular_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
228
|
+
/** @copydoc nk_angular_f64 */
|
|
229
|
+
NK_DYNAMIC void nk_angular_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
230
|
+
/** @copydoc nk_angular_f64 */
|
|
231
|
+
NK_DYNAMIC void nk_angular_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
232
|
+
/** @copydoc nk_angular_f64 */
|
|
233
|
+
NK_DYNAMIC void nk_angular_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
234
|
+
/** @copydoc nk_angular_f64 */
|
|
235
|
+
NK_DYNAMIC void nk_angular_e2m3(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
236
|
+
/** @copydoc nk_angular_f64 */
|
|
237
|
+
NK_DYNAMIC void nk_angular_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
238
|
+
/** @copydoc nk_angular_f64 */
|
|
239
|
+
NK_DYNAMIC void nk_angular_i8(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
240
|
+
/** @copydoc nk_angular_f64 */
|
|
241
|
+
NK_DYNAMIC void nk_angular_u8(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
242
|
+
/** @copydoc nk_angular_f64 */
|
|
243
|
+
NK_DYNAMIC void nk_angular_i4(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
244
|
+
/** @copydoc nk_angular_f64 */
|
|
245
|
+
NK_DYNAMIC void nk_angular_u4(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
246
|
+
|
|
247
|
+
/* Serial backends for all numeric types.
|
|
248
|
+
* By default they use 32-bit arithmetic, unless the arguments themselves contain 64-bit floats.
|
|
249
|
+
*/
|
|
250
|
+
/** @copydoc nk_euclidean_f64 */
|
|
251
|
+
NK_PUBLIC void nk_euclidean_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
252
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
253
|
+
NK_PUBLIC void nk_sqeuclidean_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
254
|
+
/** @copydoc nk_angular_f64 */
|
|
255
|
+
NK_PUBLIC void nk_angular_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
256
|
+
/** @copydoc nk_euclidean_f64 */
|
|
257
|
+
NK_PUBLIC void nk_euclidean_f32_serial(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
258
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
259
|
+
NK_PUBLIC void nk_sqeuclidean_f32_serial(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
260
|
+
/** @copydoc nk_angular_f64 */
|
|
261
|
+
NK_PUBLIC void nk_angular_f32_serial(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
262
|
+
/** @copydoc nk_euclidean_f64 */
|
|
263
|
+
NK_PUBLIC void nk_euclidean_f16_serial(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
264
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
265
|
+
NK_PUBLIC void nk_sqeuclidean_f16_serial(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
266
|
+
/** @copydoc nk_angular_f64 */
|
|
267
|
+
NK_PUBLIC void nk_angular_f16_serial(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
268
|
+
/** @copydoc nk_euclidean_f64 */
|
|
269
|
+
NK_PUBLIC void nk_euclidean_bf16_serial(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
270
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
271
|
+
NK_PUBLIC void nk_sqeuclidean_bf16_serial(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
272
|
+
/** @copydoc nk_angular_f64 */
|
|
273
|
+
NK_PUBLIC void nk_angular_bf16_serial(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
274
|
+
/** @copydoc nk_euclidean_f64 */
|
|
275
|
+
NK_PUBLIC void nk_euclidean_e4m3_serial(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
276
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
277
|
+
NK_PUBLIC void nk_sqeuclidean_e4m3_serial(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
278
|
+
/** @copydoc nk_angular_f64 */
|
|
279
|
+
NK_PUBLIC void nk_angular_e4m3_serial(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
280
|
+
/** @copydoc nk_euclidean_f64 */
|
|
281
|
+
NK_PUBLIC void nk_euclidean_e5m2_serial(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
282
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
283
|
+
NK_PUBLIC void nk_sqeuclidean_e5m2_serial(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
284
|
+
/** @copydoc nk_angular_f64 */
|
|
285
|
+
NK_PUBLIC void nk_angular_e5m2_serial(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
286
|
+
/** @copydoc nk_euclidean_f64 */
|
|
287
|
+
NK_PUBLIC void nk_euclidean_i8_serial(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
288
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
289
|
+
NK_PUBLIC void nk_sqeuclidean_i8_serial(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
290
|
+
/** @copydoc nk_angular_f64 */
|
|
291
|
+
NK_PUBLIC void nk_angular_i8_serial(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
292
|
+
/** @copydoc nk_euclidean_f64 */
|
|
293
|
+
NK_PUBLIC void nk_euclidean_u8_serial(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
294
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
295
|
+
NK_PUBLIC void nk_sqeuclidean_u8_serial(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
296
|
+
/** @copydoc nk_angular_f64 */
|
|
297
|
+
NK_PUBLIC void nk_angular_u8_serial(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
298
|
+
|
|
299
|
+
/** @copydoc nk_euclidean_f64 */
|
|
300
|
+
NK_PUBLIC void nk_euclidean_i4_serial(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
301
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
302
|
+
NK_PUBLIC void nk_sqeuclidean_i4_serial(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_u32_t *result);
|
|
303
|
+
/** @copydoc nk_angular_f64 */
|
|
304
|
+
NK_PUBLIC void nk_angular_i4_serial(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
305
|
+
/** @copydoc nk_euclidean_f64 */
|
|
306
|
+
NK_PUBLIC void nk_euclidean_u4_serial(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
307
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
308
|
+
NK_PUBLIC void nk_sqeuclidean_u4_serial(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result);
|
|
309
|
+
/** @copydoc nk_angular_f64 */
|
|
310
|
+
NK_PUBLIC void nk_angular_u4_serial(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
311
|
+
|
|
312
|
+
/* SIMD-powered backends for Arm NEON, mostly using 32-bit arithmetic over 128-bit words.
|
|
313
|
+
* By far the most portable backend, covering most Arm v8 devices, over a billion phones, and almost all
|
|
314
|
+
* server CPUs produced before 2023.
|
|
315
|
+
*/
|
|
316
|
+
#if NK_TARGET_NEON
|
|
317
|
+
/** @copydoc nk_euclidean_f64 */
|
|
318
|
+
NK_PUBLIC void nk_euclidean_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
319
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
320
|
+
NK_PUBLIC void nk_sqeuclidean_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
321
|
+
/** @copydoc nk_angular_f64 */
|
|
322
|
+
NK_PUBLIC void nk_angular_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
323
|
+
/** @copydoc nk_euclidean_f64 */
|
|
324
|
+
NK_PUBLIC void nk_euclidean_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
325
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
326
|
+
NK_PUBLIC void nk_sqeuclidean_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
327
|
+
/** @copydoc nk_angular_f64 */
|
|
328
|
+
NK_PUBLIC void nk_angular_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
329
|
+
/** @copydoc nk_euclidean_f64 */
|
|
330
|
+
NK_PUBLIC void nk_euclidean_bf16_neon(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
331
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
332
|
+
NK_PUBLIC void nk_sqeuclidean_bf16_neon(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
333
|
+
/** @copydoc nk_angular_f64 */
|
|
334
|
+
NK_PUBLIC void nk_angular_bf16_neon(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
335
|
+
#endif // NK_TARGET_NEON
|
|
336
|
+
|
|
337
|
+
#if NK_TARGET_NEONHALF
|
|
338
|
+
/** @copydoc nk_euclidean_f64 */
|
|
339
|
+
NK_PUBLIC void nk_euclidean_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
340
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
341
|
+
NK_PUBLIC void nk_sqeuclidean_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
342
|
+
/** @copydoc nk_angular_f64 */
|
|
343
|
+
NK_PUBLIC void nk_angular_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
344
|
+
#endif // NK_TARGET_NEONHALF
|
|
345
|
+
|
|
346
|
+
#if NK_TARGET_NEONBFDOT
|
|
347
|
+
/** @copydoc nk_euclidean_f64 */
|
|
348
|
+
NK_PUBLIC void nk_euclidean_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
349
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
350
|
+
NK_PUBLIC void nk_sqeuclidean_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
351
|
+
/** @copydoc nk_angular_f64 */
|
|
352
|
+
NK_PUBLIC void nk_angular_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
353
|
+
#endif // NK_TARGET_NEONBFDOT
|
|
354
|
+
|
|
355
|
+
#if NK_TARGET_NEONSDOT
|
|
356
|
+
/** @copydoc nk_euclidean_f64 */
|
|
357
|
+
NK_PUBLIC void nk_euclidean_i8_neonsdot(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
358
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
359
|
+
NK_PUBLIC void nk_sqeuclidean_i8_neonsdot(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
360
|
+
/** @copydoc nk_angular_f64 */
|
|
361
|
+
NK_PUBLIC void nk_angular_i8_neonsdot(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
362
|
+
/** @copydoc nk_euclidean_f64 */
|
|
363
|
+
NK_PUBLIC void nk_euclidean_u8_neonsdot(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
364
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
365
|
+
NK_PUBLIC void nk_sqeuclidean_u8_neonsdot(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
366
|
+
/** @copydoc nk_angular_f64 */
|
|
367
|
+
NK_PUBLIC void nk_angular_u8_neonsdot(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
368
|
+
#endif // NK_TARGET_NEONSDOT
|
|
369
|
+
|
|
370
|
+
/* SIMD-powered backends for Arm SVE, mostly using 32-bit arithmetic over variable-length platform-defined word sizes.
|
|
371
|
+
* Designed for Arm Graviton 3, Microsoft Cobalt, as well as Nvidia Grace and newer Ampere Altra CPUs.
|
|
372
|
+
*/
|
|
373
|
+
#if NK_TARGET_SVE
|
|
374
|
+
/** @copydoc nk_euclidean_f64 */
|
|
375
|
+
NK_PUBLIC void nk_euclidean_f32_sve(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
376
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
377
|
+
NK_PUBLIC void nk_sqeuclidean_f32_sve(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
378
|
+
/** @copydoc nk_angular_f64 */
|
|
379
|
+
NK_PUBLIC void nk_angular_f32_sve(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
380
|
+
/** @copydoc nk_euclidean_f64 */
|
|
381
|
+
NK_PUBLIC void nk_euclidean_f64_sve(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
382
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
383
|
+
NK_PUBLIC void nk_sqeuclidean_f64_sve(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
384
|
+
/** @copydoc nk_angular_f64 */
|
|
385
|
+
NK_PUBLIC void nk_angular_f64_sve(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
386
|
+
#endif // NK_TARGET_SVE
|
|
387
|
+
|
|
388
|
+
#if NK_TARGET_SVEHALF
|
|
389
|
+
/** @copydoc nk_euclidean_f64 */
|
|
390
|
+
NK_PUBLIC void nk_euclidean_f16_svehalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
391
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
392
|
+
NK_PUBLIC void nk_sqeuclidean_f16_svehalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
393
|
+
/** @copydoc nk_angular_f64 */
|
|
394
|
+
NK_PUBLIC void nk_angular_f16_svehalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
395
|
+
#endif // NK_TARGET_SVEHALF
|
|
396
|
+
|
|
397
|
+
#if NK_TARGET_SVEBFDOT
|
|
398
|
+
/** @copydoc nk_euclidean_f64 */
|
|
399
|
+
NK_PUBLIC void nk_euclidean_bf16_svebfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
400
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
401
|
+
NK_PUBLIC void nk_sqeuclidean_bf16_svebfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
402
|
+
/** @copydoc nk_angular_f64 */
|
|
403
|
+
NK_PUBLIC void nk_angular_bf16_svebfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
404
|
+
#endif // NK_TARGET_SVEBFDOT
|
|
405
|
+
|
|
406
|
+
/* SIMD-powered backends for AVX2 CPUs of Haswell generation and newer, using 32-bit arithmetic over 256-bit words.
|
|
407
|
+
* First demonstrated in 2011, at least one Haswell-based processor was still being sold in 2022 — the Pentium G3420.
|
|
408
|
+
* Practically all modern x86 CPUs support AVX2, FMA, and F16C, making it a perfect baseline for SIMD algorithms.
|
|
409
|
+
* On other hand, there is no need to implement AVX2 versions of `f32` and `f64` functions, as those are
|
|
410
|
+
* properly vectorized by recent compilers.
|
|
411
|
+
*/
|
|
412
|
+
#if NK_TARGET_HASWELL
|
|
413
|
+
/** @copydoc nk_euclidean_f64 */
|
|
414
|
+
NK_PUBLIC void nk_euclidean_i8_haswell(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
415
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
416
|
+
NK_PUBLIC void nk_sqeuclidean_i8_haswell(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
417
|
+
/** @copydoc nk_angular_f64 */
|
|
418
|
+
NK_PUBLIC void nk_angular_i8_haswell(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
419
|
+
/** @copydoc nk_euclidean_f64 */
|
|
420
|
+
NK_PUBLIC void nk_euclidean_u8_haswell(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
421
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
422
|
+
NK_PUBLIC void nk_sqeuclidean_u8_haswell(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
423
|
+
/** @copydoc nk_angular_f64 */
|
|
424
|
+
NK_PUBLIC void nk_angular_u8_haswell(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
425
|
+
/** @copydoc nk_euclidean_f64 */
|
|
426
|
+
NK_PUBLIC void nk_euclidean_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
427
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
428
|
+
NK_PUBLIC void nk_sqeuclidean_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
429
|
+
/** @copydoc nk_angular_f64 */
|
|
430
|
+
NK_PUBLIC void nk_angular_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
431
|
+
/** @copydoc nk_euclidean_f64 */
|
|
432
|
+
NK_PUBLIC void nk_euclidean_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
433
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
434
|
+
NK_PUBLIC void nk_sqeuclidean_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
435
|
+
/** @copydoc nk_angular_f64 */
|
|
436
|
+
NK_PUBLIC void nk_angular_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
437
|
+
/** @copydoc nk_euclidean_f64 */
|
|
438
|
+
NK_PUBLIC void nk_euclidean_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
439
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
440
|
+
NK_PUBLIC void nk_sqeuclidean_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
441
|
+
/** @copydoc nk_angular_f64 */
|
|
442
|
+
NK_PUBLIC void nk_angular_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
443
|
+
/** @copydoc nk_euclidean_f64 */
|
|
444
|
+
NK_PUBLIC void nk_euclidean_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
445
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
446
|
+
NK_PUBLIC void nk_sqeuclidean_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
447
|
+
/** @copydoc nk_angular_f64 */
|
|
448
|
+
NK_PUBLIC void nk_angular_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
449
|
+
#endif // NK_TARGET_HASWELL
|
|
450
|
+
|
|
451
|
+
/* SIMD-powered backends for AVX512 CPUs of Skylake generation and newer, using 32-bit arithmetic over 512-bit words.
|
|
452
|
+
* Skylake was launched in 2015, and discontinued in 2019. Skylake had support for F, CD, VL, DQ, and BW extensions,
|
|
453
|
+
* as well as masked operations. This is enough to supersede auto-vectorization on `f32` and `f64` types.
|
|
454
|
+
*
|
|
455
|
+
* Sadly, we can't effectively interleave different kinds of arithmetic instructions to utilize more ports:
|
|
456
|
+
*
|
|
457
|
+
* > Like Intel server architectures since Skylake-X, SPR cores feature two 512-bit FMA units, and organize them in a
|
|
458
|
+
* similar fashion. > One 512-bit FMA unit is created by fusing two 256-bit ones on port 0 and port 1. The other is
|
|
459
|
+
* added to port 5, as a server-specific > core extension. The FMA units on port 0 and 1 are configured into
|
|
460
|
+
* 2×256-bit or 1×512-bit mode depending on whether 512-bit FMA > instructions are present in the scheduler. That
|
|
461
|
+
* means a mix of 256-bit and 512-bit FMA instructions will not achieve higher IPC > than executing 512-bit
|
|
462
|
+
* instructions alone.
|
|
463
|
+
*
|
|
464
|
+
* Source: https://chipsandcheese.com/p/a-peek-at-sapphire-rapids
|
|
465
|
+
*/
|
|
466
|
+
#if NK_TARGET_SKYLAKE
|
|
467
|
+
/** @copydoc nk_euclidean_f64 */
|
|
468
|
+
NK_PUBLIC void nk_euclidean_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
469
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
470
|
+
NK_PUBLIC void nk_sqeuclidean_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
471
|
+
/** @copydoc nk_angular_f64 */
|
|
472
|
+
NK_PUBLIC void nk_angular_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
473
|
+
/** @copydoc nk_euclidean_f64 */
|
|
474
|
+
NK_PUBLIC void nk_euclidean_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
475
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
476
|
+
NK_PUBLIC void nk_sqeuclidean_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
477
|
+
/** @copydoc nk_angular_f64 */
|
|
478
|
+
NK_PUBLIC void nk_angular_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
479
|
+
/** @copydoc nk_euclidean_f64 */
|
|
480
|
+
NK_PUBLIC void nk_euclidean_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
481
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
482
|
+
NK_PUBLIC void nk_sqeuclidean_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
483
|
+
/** @copydoc nk_angular_f64 */
|
|
484
|
+
NK_PUBLIC void nk_angular_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
485
|
+
/** @copydoc nk_euclidean_f64 */
|
|
486
|
+
NK_PUBLIC void nk_euclidean_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
487
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
488
|
+
NK_PUBLIC void nk_sqeuclidean_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
489
|
+
/** @copydoc nk_angular_f64 */
|
|
490
|
+
NK_PUBLIC void nk_angular_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
491
|
+
/** @copydoc nk_euclidean_f64 */
|
|
492
|
+
NK_PUBLIC void nk_euclidean_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
493
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
494
|
+
NK_PUBLIC void nk_sqeuclidean_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
495
|
+
/** @copydoc nk_angular_f64 */
|
|
496
|
+
NK_PUBLIC void nk_angular_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
497
|
+
#endif // NK_TARGET_SKYLAKE
|
|
498
|
+
|
|
499
|
+
/* SIMD-powered backends for AVX512 CPUs of Ice Lake generation and newer, using mixed arithmetic over 512-bit words.
|
|
500
|
+
* Ice Lake added VNNI, VPOPCNTDQ, IFMA, VBMI, VAES, GFNI, VBMI2, BITALG, VPCLMULQDQ, and other extensions for integral
|
|
501
|
+
* operations. Sapphire Rapids added tiled matrix operations, but we are most interested in the new mixed-precision FMA
|
|
502
|
+
* instructions.
|
|
503
|
+
*/
|
|
504
|
+
#if NK_TARGET_ICELAKE
|
|
505
|
+
/** @copydoc nk_euclidean_f64 */
|
|
506
|
+
NK_PUBLIC void nk_euclidean_i4_icelake(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
507
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
508
|
+
NK_PUBLIC void nk_sqeuclidean_i4_icelake(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_u32_t *result);
|
|
509
|
+
/** @copydoc nk_angular_f64 */
|
|
510
|
+
NK_PUBLIC void nk_angular_i4_icelake(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
511
|
+
/** @copydoc nk_euclidean_f64 */
|
|
512
|
+
NK_PUBLIC void nk_euclidean_u4_icelake(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
513
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
514
|
+
NK_PUBLIC void nk_sqeuclidean_u4_icelake(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result);
|
|
515
|
+
/** @copydoc nk_angular_f64 */
|
|
516
|
+
NK_PUBLIC void nk_angular_u4_icelake(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
517
|
+
/** @copydoc nk_euclidean_f64 */
|
|
518
|
+
NK_PUBLIC void nk_euclidean_i8_icelake(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
519
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
520
|
+
NK_PUBLIC void nk_sqeuclidean_i8_icelake(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
521
|
+
/** @copydoc nk_angular_f64 */
|
|
522
|
+
NK_PUBLIC void nk_angular_i8_icelake(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
523
|
+
/** @copydoc nk_euclidean_f64 */
|
|
524
|
+
NK_PUBLIC void nk_euclidean_u8_icelake(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
525
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
526
|
+
NK_PUBLIC void nk_sqeuclidean_u8_icelake(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
527
|
+
/** @copydoc nk_angular_f64 */
|
|
528
|
+
NK_PUBLIC void nk_angular_u8_icelake(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
529
|
+
#endif // NK_TARGET_ICELAKE
|
|
530
|
+
|
|
531
|
+
#if NK_TARGET_GENOA
|
|
532
|
+
/** @copydoc nk_euclidean_f64 */
|
|
533
|
+
NK_PUBLIC void nk_euclidean_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
534
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
535
|
+
NK_PUBLIC void nk_sqeuclidean_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
536
|
+
/** @copydoc nk_angular_f64 */
|
|
537
|
+
NK_PUBLIC void nk_angular_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
538
|
+
/** @copydoc nk_euclidean_f64 */
|
|
539
|
+
NK_PUBLIC void nk_euclidean_e4m3_genoa(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
540
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
541
|
+
NK_PUBLIC void nk_sqeuclidean_e4m3_genoa(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
542
|
+
/** @copydoc nk_angular_f64 */
|
|
543
|
+
NK_PUBLIC void nk_angular_e4m3_genoa(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
544
|
+
/** @copydoc nk_euclidean_f64 */
|
|
545
|
+
NK_PUBLIC void nk_euclidean_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
546
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
547
|
+
NK_PUBLIC void nk_sqeuclidean_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
548
|
+
/** @copydoc nk_angular_f64 */
|
|
549
|
+
NK_PUBLIC void nk_angular_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
550
|
+
#endif // NK_TARGET_GENOA
|
|
551
|
+
|
|
552
|
+
#if NK_TARGET_SAPPHIRE
|
|
553
|
+
/** @copydoc nk_euclidean_f64 */
|
|
554
|
+
NK_PUBLIC void nk_euclidean_e4m3_sapphire(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
555
|
+
/** @copydoc nk_euclidean_f64 */
|
|
556
|
+
NK_PUBLIC void nk_sqeuclidean_e4m3_sapphire(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
557
|
+
/** @copydoc nk_euclidean_f64 */
|
|
558
|
+
NK_PUBLIC void nk_sqeuclidean_e2m3_sapphire(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
559
|
+
/** @copydoc nk_euclidean_f64 */
|
|
560
|
+
NK_PUBLIC void nk_sqeuclidean_e3m2_sapphire(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
561
|
+
/** @copydoc nk_euclidean_f64 */
|
|
562
|
+
NK_PUBLIC void nk_euclidean_e2m3_sapphire(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
563
|
+
/** @copydoc nk_euclidean_f64 */
|
|
564
|
+
NK_PUBLIC void nk_euclidean_e3m2_sapphire(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
565
|
+
/** @copydoc nk_angular_f64 */
|
|
566
|
+
NK_PUBLIC void nk_angular_e2m3_sapphire(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
567
|
+
/** @copydoc nk_angular_f64 */
|
|
568
|
+
NK_PUBLIC void nk_angular_e3m2_sapphire(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
569
|
+
#endif // NK_TARGET_SAPPHIRE
|
|
570
|
+
|
|
571
|
+
/* SIMD-powered backends for AVX-INT8-VNNI extensions on Xeon 6 CPUs, including Sierra Forest and Granite Rapids.
|
|
572
|
+
* The packs many "efficiency" cores into a single socket, avoiding heavy 512-bit operations, and focusing on
|
|
573
|
+
* 256-bit ones.
|
|
574
|
+
*/
|
|
575
|
+
#if NK_TARGET_SIERRA
|
|
576
|
+
/** @copydoc nk_angular_f64 */
|
|
577
|
+
NK_PUBLIC void nk_angular_i8_sierra(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
578
|
+
/** @copydoc nk_euclidean_f64 */
|
|
579
|
+
NK_PUBLIC void nk_euclidean_i8_sierra(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
580
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
581
|
+
NK_PUBLIC void nk_sqeuclidean_i8_sierra(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
582
|
+
/** @copydoc nk_angular_f64 */
|
|
583
|
+
NK_PUBLIC void nk_angular_u8_sierra(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
584
|
+
/** @copydoc nk_euclidean_f64 */
|
|
585
|
+
NK_PUBLIC void nk_euclidean_u8_sierra(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
586
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
587
|
+
NK_PUBLIC void nk_sqeuclidean_u8_sierra(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
588
|
+
/** @copydoc nk_angular_f64 */
|
|
589
|
+
NK_PUBLIC void nk_angular_e2m3_sierra(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
590
|
+
/** @copydoc nk_euclidean_f64 */
|
|
591
|
+
NK_PUBLIC void nk_euclidean_e2m3_sierra(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
592
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
593
|
+
NK_PUBLIC void nk_sqeuclidean_e2m3_sierra(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
594
|
+
#endif // NK_TARGET_SIERRA
|
|
595
|
+
|
|
596
|
+
#if NK_TARGET_ALDER
|
|
597
|
+
/** @copydoc nk_angular_f64 */
|
|
598
|
+
NK_PUBLIC void nk_angular_i8_alder(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
599
|
+
/** @copydoc nk_euclidean_f64 */
|
|
600
|
+
NK_PUBLIC void nk_euclidean_i8_alder(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
601
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
602
|
+
NK_PUBLIC void nk_sqeuclidean_i8_alder(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
603
|
+
/** @copydoc nk_angular_f64 */
|
|
604
|
+
NK_PUBLIC void nk_angular_u8_alder(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
605
|
+
/** @copydoc nk_euclidean_f64 */
|
|
606
|
+
NK_PUBLIC void nk_euclidean_u8_alder(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
607
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
608
|
+
NK_PUBLIC void nk_sqeuclidean_u8_alder(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
609
|
+
/** @copydoc nk_angular_f64 */
|
|
610
|
+
NK_PUBLIC void nk_angular_e2m3_alder(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
611
|
+
/** @copydoc nk_euclidean_f64 */
|
|
612
|
+
NK_PUBLIC void nk_euclidean_e2m3_alder(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
613
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
614
|
+
NK_PUBLIC void nk_sqeuclidean_e2m3_alder(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
615
|
+
/** @copydoc nk_angular_f64 */
|
|
616
|
+
NK_PUBLIC void nk_angular_e3m2_alder(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
617
|
+
/** @copydoc nk_euclidean_f64 */
|
|
618
|
+
NK_PUBLIC void nk_euclidean_e3m2_alder(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
619
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
620
|
+
NK_PUBLIC void nk_sqeuclidean_e3m2_alder(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
621
|
+
#endif // NK_TARGET_ALDER
|
|
622
|
+
|
|
623
|
+
#if NK_TARGET_V128RELAXED
|
|
624
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
625
|
+
NK_PUBLIC void nk_sqeuclidean_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
626
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
627
|
+
NK_PUBLIC void nk_sqeuclidean_f64_v128relaxed(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
628
|
+
/** @copydoc nk_euclidean_f64 */
|
|
629
|
+
NK_PUBLIC void nk_euclidean_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
630
|
+
/** @copydoc nk_euclidean_f64 */
|
|
631
|
+
NK_PUBLIC void nk_euclidean_f64_v128relaxed(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
632
|
+
/** @copydoc nk_angular_f64 */
|
|
633
|
+
NK_PUBLIC void nk_angular_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
634
|
+
/** @copydoc nk_angular_f64 */
|
|
635
|
+
NK_PUBLIC void nk_angular_f64_v128relaxed(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
636
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
637
|
+
NK_PUBLIC void nk_sqeuclidean_f16_v128relaxed(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
638
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
639
|
+
NK_PUBLIC void nk_sqeuclidean_bf16_v128relaxed(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
640
|
+
/** @copydoc nk_euclidean_f64 */
|
|
641
|
+
NK_PUBLIC void nk_euclidean_f16_v128relaxed(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
642
|
+
/** @copydoc nk_euclidean_f64 */
|
|
643
|
+
NK_PUBLIC void nk_euclidean_bf16_v128relaxed(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
644
|
+
/** @copydoc nk_angular_f64 */
|
|
645
|
+
NK_PUBLIC void nk_angular_f16_v128relaxed(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
646
|
+
/** @copydoc nk_angular_f64 */
|
|
647
|
+
NK_PUBLIC void nk_angular_bf16_v128relaxed(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
648
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
649
|
+
NK_PUBLIC void nk_sqeuclidean_u8_v128relaxed(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
650
|
+
/** @copydoc nk_euclidean_f64 */
|
|
651
|
+
NK_PUBLIC void nk_euclidean_u8_v128relaxed(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
652
|
+
/** @copydoc nk_angular_f64 */
|
|
653
|
+
NK_PUBLIC void nk_angular_u8_v128relaxed(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
654
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
655
|
+
NK_PUBLIC void nk_sqeuclidean_i8_v128relaxed(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
656
|
+
/** @copydoc nk_euclidean_f64 */
|
|
657
|
+
NK_PUBLIC void nk_euclidean_i8_v128relaxed(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
658
|
+
/** @copydoc nk_angular_f64 */
|
|
659
|
+
NK_PUBLIC void nk_angular_i8_v128relaxed(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
660
|
+
#endif // NK_TARGET_V128RELAXED
|
|
661
|
+
|
|
662
|
+
/* SIMD-powered backends for RISC-V Vector extension, using scalable vector arithmetic.
|
|
663
|
+
* Designed for SiFive, T-Head, and other RISC-V processors with the V extension.
|
|
664
|
+
*/
|
|
665
|
+
#if NK_TARGET_RVV
|
|
666
|
+
/** @copydoc nk_euclidean_f64 */
|
|
667
|
+
NK_PUBLIC void nk_euclidean_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
668
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
669
|
+
NK_PUBLIC void nk_sqeuclidean_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
670
|
+
/** @copydoc nk_angular_f64 */
|
|
671
|
+
NK_PUBLIC void nk_angular_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
672
|
+
/** @copydoc nk_euclidean_f64 */
|
|
673
|
+
NK_PUBLIC void nk_euclidean_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
674
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
675
|
+
NK_PUBLIC void nk_sqeuclidean_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
676
|
+
/** @copydoc nk_angular_f64 */
|
|
677
|
+
NK_PUBLIC void nk_angular_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
678
|
+
/** @copydoc nk_euclidean_f64 */
|
|
679
|
+
NK_PUBLIC void nk_euclidean_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
680
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
681
|
+
NK_PUBLIC void nk_sqeuclidean_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
682
|
+
/** @copydoc nk_angular_f64 */
|
|
683
|
+
NK_PUBLIC void nk_angular_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
684
|
+
/** @copydoc nk_euclidean_f64 */
|
|
685
|
+
NK_PUBLIC void nk_euclidean_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
686
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
687
|
+
NK_PUBLIC void nk_sqeuclidean_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
688
|
+
/** @copydoc nk_angular_f64 */
|
|
689
|
+
NK_PUBLIC void nk_angular_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
690
|
+
/** @copydoc nk_euclidean_f64 */
|
|
691
|
+
NK_PUBLIC void nk_euclidean_e4m3_rvv(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
692
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
693
|
+
NK_PUBLIC void nk_sqeuclidean_e4m3_rvv(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
694
|
+
/** @copydoc nk_angular_f64 */
|
|
695
|
+
NK_PUBLIC void nk_angular_e4m3_rvv(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
696
|
+
/** @copydoc nk_euclidean_f64 */
|
|
697
|
+
NK_PUBLIC void nk_euclidean_e5m2_rvv(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
698
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
699
|
+
NK_PUBLIC void nk_sqeuclidean_e5m2_rvv(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
700
|
+
/** @copydoc nk_angular_f64 */
|
|
701
|
+
NK_PUBLIC void nk_angular_e5m2_rvv(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
702
|
+
/** @copydoc nk_euclidean_f64 */
|
|
703
|
+
NK_PUBLIC void nk_euclidean_i8_rvv(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
704
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
705
|
+
NK_PUBLIC void nk_sqeuclidean_i8_rvv(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
706
|
+
/** @copydoc nk_angular_f64 */
|
|
707
|
+
NK_PUBLIC void nk_angular_i8_rvv(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
708
|
+
/** @copydoc nk_euclidean_f64 */
|
|
709
|
+
NK_PUBLIC void nk_euclidean_u8_rvv(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
710
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
711
|
+
NK_PUBLIC void nk_sqeuclidean_u8_rvv(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
712
|
+
/** @copydoc nk_angular_f64 */
|
|
713
|
+
NK_PUBLIC void nk_angular_u8_rvv(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result);
|
|
714
|
+
/** @copydoc nk_euclidean_f64 */
|
|
715
|
+
NK_PUBLIC void nk_euclidean_i4_rvv(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
716
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
717
|
+
NK_PUBLIC void nk_sqeuclidean_i4_rvv(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_u32_t *result);
|
|
718
|
+
/** @copydoc nk_angular_f64 */
|
|
719
|
+
NK_PUBLIC void nk_angular_i4_rvv(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
720
|
+
/** @copydoc nk_euclidean_f64 */
|
|
721
|
+
NK_PUBLIC void nk_euclidean_u4_rvv(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
722
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
723
|
+
NK_PUBLIC void nk_sqeuclidean_u4_rvv(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result);
|
|
724
|
+
/** @copydoc nk_angular_f64 */
|
|
725
|
+
NK_PUBLIC void nk_angular_u4_rvv(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
726
|
+
#endif // NK_TARGET_RVV
|
|
727
|
+
|
|
728
|
+
#if NK_TARGET_RVVHALF
|
|
729
|
+
/** @copydoc nk_euclidean_f64 */
|
|
730
|
+
NK_PUBLIC void nk_euclidean_f16_rvvhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
731
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
732
|
+
NK_PUBLIC void nk_sqeuclidean_f16_rvvhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
733
|
+
/** @copydoc nk_angular_f64 */
|
|
734
|
+
NK_PUBLIC void nk_angular_f16_rvvhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
735
|
+
#endif // NK_TARGET_RVVHALF
|
|
736
|
+
|
|
737
|
+
#if NK_TARGET_RVVBF16
|
|
738
|
+
/** @copydoc nk_euclidean_f64 */
|
|
739
|
+
NK_PUBLIC void nk_euclidean_bf16_rvvbf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
740
|
+
/** @copydoc nk_sqeuclidean_f64 */
|
|
741
|
+
NK_PUBLIC void nk_sqeuclidean_bf16_rvvbf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
742
|
+
/** @copydoc nk_angular_f64 */
|
|
743
|
+
NK_PUBLIC void nk_angular_bf16_rvvbf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
744
|
+
#endif // NK_TARGET_RVVBF16
|
|
745
|
+
|
|
746
|
+
/** @brief Returns the output dtype for L2 (Euclidean) distance. */
|
|
747
|
+
NK_INTERNAL nk_dtype_t nk_euclidean_output_dtype(nk_dtype_t dtype) {
|
|
748
|
+
switch (dtype) {
|
|
749
|
+
case nk_f64_k: return nk_f64_k;
|
|
750
|
+
case nk_f32_k: return nk_f64_k;
|
|
751
|
+
case nk_f16_k: return nk_f32_k;
|
|
752
|
+
case nk_bf16_k: return nk_f32_k;
|
|
753
|
+
case nk_e4m3_k: return nk_f32_k;
|
|
754
|
+
case nk_e5m2_k: return nk_f32_k;
|
|
755
|
+
case nk_e2m3_k: return nk_f32_k;
|
|
756
|
+
case nk_e3m2_k: return nk_f32_k;
|
|
757
|
+
case nk_i8_k: return nk_f32_k;
|
|
758
|
+
case nk_u8_k: return nk_f32_k;
|
|
759
|
+
case nk_i4_k: return nk_f32_k;
|
|
760
|
+
case nk_u4_k: return nk_f32_k;
|
|
761
|
+
default: return nk_dtype_unknown_k;
|
|
762
|
+
}
|
|
763
|
+
}
|
|
764
|
+
|
|
765
|
+
/** @brief Returns the output dtype for L2 squared distance. */
|
|
766
|
+
NK_INTERNAL nk_dtype_t nk_sqeuclidean_output_dtype(nk_dtype_t dtype) {
|
|
767
|
+
switch (dtype) {
|
|
768
|
+
case nk_f64_k: return nk_f64_k;
|
|
769
|
+
case nk_f32_k: return nk_f64_k;
|
|
770
|
+
case nk_f16_k: return nk_f32_k;
|
|
771
|
+
case nk_bf16_k: return nk_f32_k;
|
|
772
|
+
case nk_e4m3_k: return nk_f32_k;
|
|
773
|
+
case nk_e5m2_k: return nk_f32_k;
|
|
774
|
+
case nk_e2m3_k: return nk_f32_k;
|
|
775
|
+
case nk_e3m2_k: return nk_f32_k;
|
|
776
|
+
case nk_i8_k: return nk_u32_k;
|
|
777
|
+
case nk_u8_k: return nk_u32_k;
|
|
778
|
+
case nk_i4_k: return nk_u32_k;
|
|
779
|
+
case nk_u4_k: return nk_u32_k;
|
|
780
|
+
default: return nk_dtype_unknown_k;
|
|
781
|
+
}
|
|
782
|
+
}
|
|
783
|
+
|
|
784
|
+
/** @brief Returns the output dtype for angular/cosine distance. */
|
|
785
|
+
NK_INTERNAL nk_dtype_t nk_angular_output_dtype(nk_dtype_t dtype) {
|
|
786
|
+
switch (dtype) {
|
|
787
|
+
case nk_f64_k: return nk_f64_k;
|
|
788
|
+
case nk_f32_k: return nk_f64_k;
|
|
789
|
+
case nk_f16_k: return nk_f32_k;
|
|
790
|
+
case nk_bf16_k: return nk_f32_k;
|
|
791
|
+
case nk_e4m3_k: return nk_f32_k;
|
|
792
|
+
case nk_e5m2_k: return nk_f32_k;
|
|
793
|
+
case nk_e2m3_k: return nk_f32_k;
|
|
794
|
+
case nk_e3m2_k: return nk_f32_k;
|
|
795
|
+
case nk_i8_k: return nk_f32_k;
|
|
796
|
+
case nk_u8_k: return nk_f32_k;
|
|
797
|
+
case nk_i4_k: return nk_f32_k;
|
|
798
|
+
case nk_u4_k: return nk_f32_k;
|
|
799
|
+
default: return nk_dtype_unknown_k;
|
|
800
|
+
}
|
|
801
|
+
}
|
|
802
|
+
|
|
803
|
+
#if defined(__cplusplus)
|
|
804
|
+
} // extern "C"
|
|
805
|
+
#endif
|
|
806
|
+
|
|
807
|
+
#include "numkong/spatial/serial.h"
|
|
808
|
+
#include "numkong/spatial/neon.h"
|
|
809
|
+
#include "numkong/spatial/neonhalf.h"
|
|
810
|
+
#include "numkong/spatial/neonbfdot.h"
|
|
811
|
+
#include "numkong/spatial/neonsdot.h"
|
|
812
|
+
#include "numkong/spatial/sve.h"
|
|
813
|
+
#include "numkong/spatial/svehalf.h"
|
|
814
|
+
#include "numkong/spatial/svebfdot.h"
|
|
815
|
+
#include "numkong/spatial/haswell.h"
|
|
816
|
+
#include "numkong/spatial/skylake.h"
|
|
817
|
+
#include "numkong/spatial/genoa.h"
|
|
818
|
+
#include "numkong/spatial/sapphire.h"
|
|
819
|
+
#include "numkong/spatial/icelake.h"
|
|
820
|
+
#include "numkong/spatial/alder.h"
|
|
821
|
+
#include "numkong/spatial/sierra.h"
|
|
822
|
+
#include "numkong/spatial/rvv.h"
|
|
823
|
+
#include "numkong/spatial/rvvhalf.h"
|
|
824
|
+
#include "numkong/spatial/rvvbf16.h"
|
|
825
|
+
#include "numkong/spatial/v128relaxed.h"
|
|
826
|
+
|
|
827
|
+
#if defined(__cplusplus)
|
|
828
|
+
extern "C" {
|
|
829
|
+
#endif
|
|
830
|
+
|
|
831
|
+
#if !NK_DYNAMIC_DISPATCH
|
|
832
|
+
|
|
833
|
+
NK_PUBLIC void nk_euclidean_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
834
|
+
#if NK_TARGET_V128RELAXED
|
|
835
|
+
nk_euclidean_f64_v128relaxed(a, b, n, result);
|
|
836
|
+
#elif NK_TARGET_RVV
|
|
837
|
+
nk_euclidean_f64_rvv(a, b, n, result);
|
|
838
|
+
#elif NK_TARGET_SVE
|
|
839
|
+
nk_euclidean_f64_sve(a, b, n, result);
|
|
840
|
+
#elif NK_TARGET_NEON
|
|
841
|
+
nk_euclidean_f64_neon(a, b, n, result);
|
|
842
|
+
#elif NK_TARGET_SKYLAKE
|
|
843
|
+
nk_euclidean_f64_skylake(a, b, n, result);
|
|
844
|
+
#elif NK_TARGET_HASWELL
|
|
845
|
+
nk_euclidean_f64_haswell(a, b, n, result);
|
|
846
|
+
#else
|
|
847
|
+
nk_euclidean_f64_serial(a, b, n, result);
|
|
848
|
+
#endif
|
|
849
|
+
}
|
|
850
|
+
|
|
851
|
+
NK_PUBLIC void nk_sqeuclidean_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
852
|
+
#if NK_TARGET_V128RELAXED
|
|
853
|
+
nk_sqeuclidean_f64_v128relaxed(a, b, n, result);
|
|
854
|
+
#elif NK_TARGET_RVV
|
|
855
|
+
nk_sqeuclidean_f64_rvv(a, b, n, result);
|
|
856
|
+
#elif NK_TARGET_SVE
|
|
857
|
+
nk_sqeuclidean_f64_sve(a, b, n, result);
|
|
858
|
+
#elif NK_TARGET_NEON
|
|
859
|
+
nk_sqeuclidean_f64_neon(a, b, n, result);
|
|
860
|
+
#elif NK_TARGET_SKYLAKE
|
|
861
|
+
nk_sqeuclidean_f64_skylake(a, b, n, result);
|
|
862
|
+
#elif NK_TARGET_HASWELL
|
|
863
|
+
nk_sqeuclidean_f64_haswell(a, b, n, result);
|
|
864
|
+
#else
|
|
865
|
+
nk_sqeuclidean_f64_serial(a, b, n, result);
|
|
866
|
+
#endif
|
|
867
|
+
}
|
|
868
|
+
|
|
869
|
+
NK_PUBLIC void nk_angular_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
870
|
+
#if NK_TARGET_V128RELAXED
|
|
871
|
+
nk_angular_f64_v128relaxed(a, b, n, result);
|
|
872
|
+
#elif NK_TARGET_RVV
|
|
873
|
+
nk_angular_f64_rvv(a, b, n, result);
|
|
874
|
+
#elif NK_TARGET_SVE
|
|
875
|
+
nk_angular_f64_sve(a, b, n, result);
|
|
876
|
+
#elif NK_TARGET_NEON
|
|
877
|
+
nk_angular_f64_neon(a, b, n, result);
|
|
878
|
+
#elif NK_TARGET_SKYLAKE
|
|
879
|
+
nk_angular_f64_skylake(a, b, n, result);
|
|
880
|
+
#elif NK_TARGET_HASWELL
|
|
881
|
+
nk_angular_f64_haswell(a, b, n, result);
|
|
882
|
+
#else
|
|
883
|
+
nk_angular_f64_serial(a, b, n, result);
|
|
884
|
+
#endif
|
|
885
|
+
}
|
|
886
|
+
|
|
887
|
+
NK_PUBLIC void nk_euclidean_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
888
|
+
#if NK_TARGET_V128RELAXED
|
|
889
|
+
nk_euclidean_f32_v128relaxed(a, b, n, result);
|
|
890
|
+
#elif NK_TARGET_RVV
|
|
891
|
+
nk_euclidean_f32_rvv(a, b, n, result);
|
|
892
|
+
#elif NK_TARGET_SVE
|
|
893
|
+
nk_euclidean_f32_sve(a, b, n, result);
|
|
894
|
+
#elif NK_TARGET_NEON
|
|
895
|
+
nk_euclidean_f32_neon(a, b, n, result);
|
|
896
|
+
#elif NK_TARGET_SKYLAKE
|
|
897
|
+
nk_euclidean_f32_skylake(a, b, n, result);
|
|
898
|
+
#elif NK_TARGET_HASWELL
|
|
899
|
+
nk_euclidean_f32_haswell(a, b, n, result);
|
|
900
|
+
#else
|
|
901
|
+
nk_euclidean_f32_serial(a, b, n, result);
|
|
902
|
+
#endif
|
|
903
|
+
}
|
|
904
|
+
|
|
905
|
+
NK_PUBLIC void nk_sqeuclidean_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
906
|
+
#if NK_TARGET_V128RELAXED
|
|
907
|
+
nk_sqeuclidean_f32_v128relaxed(a, b, n, result);
|
|
908
|
+
#elif NK_TARGET_RVV
|
|
909
|
+
nk_sqeuclidean_f32_rvv(a, b, n, result);
|
|
910
|
+
#elif NK_TARGET_SVE
|
|
911
|
+
nk_sqeuclidean_f32_sve(a, b, n, result);
|
|
912
|
+
#elif NK_TARGET_NEON
|
|
913
|
+
nk_sqeuclidean_f32_neon(a, b, n, result);
|
|
914
|
+
#elif NK_TARGET_SKYLAKE
|
|
915
|
+
nk_sqeuclidean_f32_skylake(a, b, n, result);
|
|
916
|
+
#elif NK_TARGET_HASWELL
|
|
917
|
+
nk_sqeuclidean_f32_haswell(a, b, n, result);
|
|
918
|
+
#else
|
|
919
|
+
nk_sqeuclidean_f32_serial(a, b, n, result);
|
|
920
|
+
#endif
|
|
921
|
+
}
|
|
922
|
+
|
|
923
|
+
NK_PUBLIC void nk_angular_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
924
|
+
#if NK_TARGET_V128RELAXED
|
|
925
|
+
nk_angular_f32_v128relaxed(a, b, n, result);
|
|
926
|
+
#elif NK_TARGET_RVV
|
|
927
|
+
nk_angular_f32_rvv(a, b, n, result);
|
|
928
|
+
#elif NK_TARGET_SVE
|
|
929
|
+
nk_angular_f32_sve(a, b, n, result);
|
|
930
|
+
#elif NK_TARGET_NEON
|
|
931
|
+
nk_angular_f32_neon(a, b, n, result);
|
|
932
|
+
#elif NK_TARGET_SKYLAKE
|
|
933
|
+
nk_angular_f32_skylake(a, b, n, result);
|
|
934
|
+
#elif NK_TARGET_HASWELL
|
|
935
|
+
nk_angular_f32_haswell(a, b, n, result);
|
|
936
|
+
#else
|
|
937
|
+
nk_angular_f32_serial(a, b, n, result);
|
|
938
|
+
#endif
|
|
939
|
+
}
|
|
940
|
+
|
|
941
|
+
NK_PUBLIC void nk_euclidean_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
942
|
+
#if NK_TARGET_V128RELAXED
|
|
943
|
+
nk_euclidean_f16_v128relaxed(a, b, n, result);
|
|
944
|
+
#elif NK_TARGET_RVVHALF
|
|
945
|
+
nk_euclidean_f16_rvvhalf(a, b, n, result);
|
|
946
|
+
#elif NK_TARGET_RVV
|
|
947
|
+
nk_euclidean_f16_rvv(a, b, n, result);
|
|
948
|
+
#elif NK_TARGET_SVEHALF
|
|
949
|
+
nk_euclidean_f16_svehalf(a, b, n, result);
|
|
950
|
+
#elif NK_TARGET_NEONHALF
|
|
951
|
+
nk_euclidean_f16_neonhalf(a, b, n, result);
|
|
952
|
+
#elif NK_TARGET_SKYLAKE
|
|
953
|
+
nk_euclidean_f16_skylake(a, b, n, result);
|
|
954
|
+
#elif NK_TARGET_HASWELL
|
|
955
|
+
nk_euclidean_f16_haswell(a, b, n, result);
|
|
956
|
+
#else
|
|
957
|
+
nk_euclidean_f16_serial(a, b, n, result);
|
|
958
|
+
#endif
|
|
959
|
+
}
|
|
960
|
+
|
|
961
|
+
NK_PUBLIC void nk_sqeuclidean_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
962
|
+
#if NK_TARGET_V128RELAXED
|
|
963
|
+
nk_sqeuclidean_f16_v128relaxed(a, b, n, result);
|
|
964
|
+
#elif NK_TARGET_RVVHALF
|
|
965
|
+
nk_sqeuclidean_f16_rvvhalf(a, b, n, result);
|
|
966
|
+
#elif NK_TARGET_RVV
|
|
967
|
+
nk_sqeuclidean_f16_rvv(a, b, n, result);
|
|
968
|
+
#elif NK_TARGET_SVEHALF
|
|
969
|
+
nk_sqeuclidean_f16_svehalf(a, b, n, result);
|
|
970
|
+
#elif NK_TARGET_NEONHALF
|
|
971
|
+
nk_sqeuclidean_f16_neonhalf(a, b, n, result);
|
|
972
|
+
#elif NK_TARGET_SKYLAKE
|
|
973
|
+
nk_sqeuclidean_f16_skylake(a, b, n, result);
|
|
974
|
+
#elif NK_TARGET_HASWELL
|
|
975
|
+
nk_sqeuclidean_f16_haswell(a, b, n, result);
|
|
976
|
+
#else
|
|
977
|
+
nk_sqeuclidean_f16_serial(a, b, n, result);
|
|
978
|
+
#endif
|
|
979
|
+
}
|
|
980
|
+
|
|
981
|
+
NK_PUBLIC void nk_angular_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
982
|
+
#if NK_TARGET_V128RELAXED
|
|
983
|
+
nk_angular_f16_v128relaxed(a, b, n, result);
|
|
984
|
+
#elif NK_TARGET_RVVHALF
|
|
985
|
+
nk_angular_f16_rvvhalf(a, b, n, result);
|
|
986
|
+
#elif NK_TARGET_RVV
|
|
987
|
+
nk_angular_f16_rvv(a, b, n, result);
|
|
988
|
+
#elif NK_TARGET_SVEHALF
|
|
989
|
+
nk_angular_f16_svehalf(a, b, n, result);
|
|
990
|
+
#elif NK_TARGET_NEONHALF
|
|
991
|
+
nk_angular_f16_neonhalf(a, b, n, result);
|
|
992
|
+
#elif NK_TARGET_SKYLAKE
|
|
993
|
+
nk_angular_f16_skylake(a, b, n, result);
|
|
994
|
+
#elif NK_TARGET_HASWELL
|
|
995
|
+
nk_angular_f16_haswell(a, b, n, result);
|
|
996
|
+
#else
|
|
997
|
+
nk_angular_f16_serial(a, b, n, result);
|
|
998
|
+
#endif
|
|
999
|
+
}
|
|
1000
|
+
|
|
1001
|
+
NK_PUBLIC void nk_euclidean_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1002
|
+
#if NK_TARGET_V128RELAXED
|
|
1003
|
+
nk_euclidean_bf16_v128relaxed(a, b, n, result);
|
|
1004
|
+
#elif NK_TARGET_RVVBF16
|
|
1005
|
+
nk_euclidean_bf16_rvvbf16(a, b, n, result);
|
|
1006
|
+
#elif NK_TARGET_RVV
|
|
1007
|
+
nk_euclidean_bf16_rvv(a, b, n, result);
|
|
1008
|
+
#elif NK_TARGET_SVEBFDOT
|
|
1009
|
+
nk_euclidean_bf16_svebfdot(a, b, n, result);
|
|
1010
|
+
#elif NK_TARGET_NEONBFDOT
|
|
1011
|
+
nk_euclidean_bf16_neonbfdot(a, b, n, result);
|
|
1012
|
+
#elif NK_TARGET_GENOA
|
|
1013
|
+
nk_euclidean_bf16_genoa(a, b, n, result);
|
|
1014
|
+
#elif NK_TARGET_HASWELL
|
|
1015
|
+
nk_euclidean_bf16_haswell(a, b, n, result);
|
|
1016
|
+
#else
|
|
1017
|
+
nk_euclidean_bf16_serial(a, b, n, result);
|
|
1018
|
+
#endif
|
|
1019
|
+
}
|
|
1020
|
+
|
|
1021
|
+
NK_PUBLIC void nk_sqeuclidean_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1022
|
+
#if NK_TARGET_V128RELAXED
|
|
1023
|
+
nk_sqeuclidean_bf16_v128relaxed(a, b, n, result);
|
|
1024
|
+
#elif NK_TARGET_RVVBF16
|
|
1025
|
+
nk_sqeuclidean_bf16_rvvbf16(a, b, n, result);
|
|
1026
|
+
#elif NK_TARGET_RVV
|
|
1027
|
+
nk_sqeuclidean_bf16_rvv(a, b, n, result);
|
|
1028
|
+
#elif NK_TARGET_SVEBFDOT
|
|
1029
|
+
nk_sqeuclidean_bf16_svebfdot(a, b, n, result);
|
|
1030
|
+
#elif NK_TARGET_NEONBFDOT
|
|
1031
|
+
nk_sqeuclidean_bf16_neonbfdot(a, b, n, result);
|
|
1032
|
+
#elif NK_TARGET_GENOA
|
|
1033
|
+
nk_sqeuclidean_bf16_genoa(a, b, n, result);
|
|
1034
|
+
#elif NK_TARGET_HASWELL
|
|
1035
|
+
nk_sqeuclidean_bf16_haswell(a, b, n, result);
|
|
1036
|
+
#else
|
|
1037
|
+
nk_sqeuclidean_bf16_serial(a, b, n, result);
|
|
1038
|
+
#endif
|
|
1039
|
+
}
|
|
1040
|
+
|
|
1041
|
+
NK_PUBLIC void nk_angular_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1042
|
+
#if NK_TARGET_V128RELAXED
|
|
1043
|
+
nk_angular_bf16_v128relaxed(a, b, n, result);
|
|
1044
|
+
#elif NK_TARGET_RVVBF16
|
|
1045
|
+
nk_angular_bf16_rvvbf16(a, b, n, result);
|
|
1046
|
+
#elif NK_TARGET_RVV
|
|
1047
|
+
nk_angular_bf16_rvv(a, b, n, result);
|
|
1048
|
+
#elif NK_TARGET_SVEBFDOT
|
|
1049
|
+
nk_angular_bf16_svebfdot(a, b, n, result);
|
|
1050
|
+
#elif NK_TARGET_NEONBFDOT
|
|
1051
|
+
nk_angular_bf16_neonbfdot(a, b, n, result);
|
|
1052
|
+
#elif NK_TARGET_GENOA
|
|
1053
|
+
nk_angular_bf16_genoa(a, b, n, result);
|
|
1054
|
+
#elif NK_TARGET_HASWELL
|
|
1055
|
+
nk_angular_bf16_haswell(a, b, n, result);
|
|
1056
|
+
#else
|
|
1057
|
+
nk_angular_bf16_serial(a, b, n, result);
|
|
1058
|
+
#endif
|
|
1059
|
+
}
|
|
1060
|
+
|
|
1061
|
+
NK_PUBLIC void nk_euclidean_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1062
|
+
#if NK_TARGET_SAPPHIRE
|
|
1063
|
+
nk_euclidean_e4m3_sapphire(a, b, n, result);
|
|
1064
|
+
#elif NK_TARGET_GENOA
|
|
1065
|
+
nk_euclidean_e4m3_genoa(a, b, n, result);
|
|
1066
|
+
#elif NK_TARGET_SKYLAKE
|
|
1067
|
+
nk_euclidean_e4m3_skylake(a, b, n, result);
|
|
1068
|
+
#elif NK_TARGET_RVV
|
|
1069
|
+
nk_euclidean_e4m3_rvv(a, b, n, result);
|
|
1070
|
+
#else
|
|
1071
|
+
nk_euclidean_e4m3_serial(a, b, n, result);
|
|
1072
|
+
#endif
|
|
1073
|
+
}
|
|
1074
|
+
|
|
1075
|
+
NK_PUBLIC void nk_sqeuclidean_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1076
|
+
#if NK_TARGET_SAPPHIRE
|
|
1077
|
+
nk_sqeuclidean_e4m3_sapphire(a, b, n, result);
|
|
1078
|
+
#elif NK_TARGET_GENOA
|
|
1079
|
+
nk_sqeuclidean_e4m3_genoa(a, b, n, result);
|
|
1080
|
+
#elif NK_TARGET_SKYLAKE
|
|
1081
|
+
nk_sqeuclidean_e4m3_skylake(a, b, n, result);
|
|
1082
|
+
#elif NK_TARGET_RVV
|
|
1083
|
+
nk_sqeuclidean_e4m3_rvv(a, b, n, result);
|
|
1084
|
+
#else
|
|
1085
|
+
nk_sqeuclidean_e4m3_serial(a, b, n, result);
|
|
1086
|
+
#endif
|
|
1087
|
+
}
|
|
1088
|
+
|
|
1089
|
+
NK_PUBLIC void nk_angular_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1090
|
+
#if NK_TARGET_GENOA
|
|
1091
|
+
nk_angular_e4m3_genoa(a, b, n, result);
|
|
1092
|
+
#elif NK_TARGET_SKYLAKE
|
|
1093
|
+
nk_angular_e4m3_skylake(a, b, n, result);
|
|
1094
|
+
#elif NK_TARGET_RVV
|
|
1095
|
+
nk_angular_e4m3_rvv(a, b, n, result);
|
|
1096
|
+
#else
|
|
1097
|
+
nk_angular_e4m3_serial(a, b, n, result);
|
|
1098
|
+
#endif
|
|
1099
|
+
}
|
|
1100
|
+
|
|
1101
|
+
NK_PUBLIC void nk_euclidean_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1102
|
+
#if NK_TARGET_GENOA
|
|
1103
|
+
nk_euclidean_e5m2_genoa(a, b, n, result);
|
|
1104
|
+
#elif NK_TARGET_SKYLAKE
|
|
1105
|
+
nk_euclidean_e5m2_skylake(a, b, n, result);
|
|
1106
|
+
#elif NK_TARGET_RVV
|
|
1107
|
+
nk_euclidean_e5m2_rvv(a, b, n, result);
|
|
1108
|
+
#else
|
|
1109
|
+
nk_euclidean_e5m2_serial(a, b, n, result);
|
|
1110
|
+
#endif
|
|
1111
|
+
}
|
|
1112
|
+
|
|
1113
|
+
NK_PUBLIC void nk_sqeuclidean_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1114
|
+
#if NK_TARGET_GENOA
|
|
1115
|
+
nk_sqeuclidean_e5m2_genoa(a, b, n, result);
|
|
1116
|
+
#elif NK_TARGET_SKYLAKE
|
|
1117
|
+
nk_sqeuclidean_e5m2_skylake(a, b, n, result);
|
|
1118
|
+
#elif NK_TARGET_RVV
|
|
1119
|
+
nk_sqeuclidean_e5m2_rvv(a, b, n, result);
|
|
1120
|
+
#else
|
|
1121
|
+
nk_sqeuclidean_e5m2_serial(a, b, n, result);
|
|
1122
|
+
#endif
|
|
1123
|
+
}
|
|
1124
|
+
|
|
1125
|
+
NK_PUBLIC void nk_angular_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1126
|
+
#if NK_TARGET_GENOA
|
|
1127
|
+
nk_angular_e5m2_genoa(a, b, n, result);
|
|
1128
|
+
#elif NK_TARGET_SKYLAKE
|
|
1129
|
+
nk_angular_e5m2_skylake(a, b, n, result);
|
|
1130
|
+
#elif NK_TARGET_RVV
|
|
1131
|
+
nk_angular_e5m2_rvv(a, b, n, result);
|
|
1132
|
+
#else
|
|
1133
|
+
nk_angular_e5m2_serial(a, b, n, result);
|
|
1134
|
+
#endif
|
|
1135
|
+
}
|
|
1136
|
+
|
|
1137
|
+
NK_PUBLIC void nk_euclidean_e2m3(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1138
|
+
#if NK_TARGET_SAPPHIRE
|
|
1139
|
+
nk_euclidean_e2m3_sapphire(a, b, n, result);
|
|
1140
|
+
#elif NK_TARGET_SKYLAKE
|
|
1141
|
+
nk_euclidean_e2m3_skylake(a, b, n, result);
|
|
1142
|
+
#elif NK_TARGET_SIERRA
|
|
1143
|
+
nk_euclidean_e2m3_sierra(a, b, n, result);
|
|
1144
|
+
#elif NK_TARGET_ALDER
|
|
1145
|
+
nk_euclidean_e2m3_alder(a, b, n, result);
|
|
1146
|
+
#elif NK_TARGET_HASWELL
|
|
1147
|
+
nk_euclidean_e2m3_haswell(a, b, n, result);
|
|
1148
|
+
#elif NK_TARGET_NEON
|
|
1149
|
+
nk_euclidean_e2m3_neon(a, b, n, result);
|
|
1150
|
+
#else
|
|
1151
|
+
nk_euclidean_e2m3_serial(a, b, n, result);
|
|
1152
|
+
#endif
|
|
1153
|
+
}
|
|
1154
|
+
|
|
1155
|
+
NK_PUBLIC void nk_sqeuclidean_e2m3(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1156
|
+
#if NK_TARGET_SAPPHIRE
|
|
1157
|
+
nk_sqeuclidean_e2m3_sapphire(a, b, n, result);
|
|
1158
|
+
#elif NK_TARGET_SKYLAKE
|
|
1159
|
+
nk_sqeuclidean_e2m3_skylake(a, b, n, result);
|
|
1160
|
+
#elif NK_TARGET_SIERRA
|
|
1161
|
+
nk_sqeuclidean_e2m3_sierra(a, b, n, result);
|
|
1162
|
+
#elif NK_TARGET_ALDER
|
|
1163
|
+
nk_sqeuclidean_e2m3_alder(a, b, n, result);
|
|
1164
|
+
#elif NK_TARGET_HASWELL
|
|
1165
|
+
nk_sqeuclidean_e2m3_haswell(a, b, n, result);
|
|
1166
|
+
#elif NK_TARGET_NEON
|
|
1167
|
+
nk_sqeuclidean_e2m3_neon(a, b, n, result);
|
|
1168
|
+
#else
|
|
1169
|
+
nk_sqeuclidean_e2m3_serial(a, b, n, result);
|
|
1170
|
+
#endif
|
|
1171
|
+
}
|
|
1172
|
+
|
|
1173
|
+
NK_PUBLIC void nk_angular_e2m3(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1174
|
+
#if NK_TARGET_SAPPHIRE
|
|
1175
|
+
nk_angular_e2m3_sapphire(a, b, n, result);
|
|
1176
|
+
#elif NK_TARGET_SKYLAKE
|
|
1177
|
+
nk_angular_e2m3_skylake(a, b, n, result);
|
|
1178
|
+
#elif NK_TARGET_SIERRA
|
|
1179
|
+
nk_angular_e2m3_sierra(a, b, n, result);
|
|
1180
|
+
#elif NK_TARGET_ALDER
|
|
1181
|
+
nk_angular_e2m3_alder(a, b, n, result);
|
|
1182
|
+
#elif NK_TARGET_HASWELL
|
|
1183
|
+
nk_angular_e2m3_haswell(a, b, n, result);
|
|
1184
|
+
#elif NK_TARGET_NEON
|
|
1185
|
+
nk_angular_e2m3_neon(a, b, n, result);
|
|
1186
|
+
#else
|
|
1187
|
+
nk_angular_e2m3_serial(a, b, n, result);
|
|
1188
|
+
#endif
|
|
1189
|
+
}
|
|
1190
|
+
|
|
1191
|
+
NK_PUBLIC void nk_euclidean_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1192
|
+
#if NK_TARGET_SAPPHIRE
|
|
1193
|
+
nk_euclidean_e3m2_sapphire(a, b, n, result);
|
|
1194
|
+
#elif NK_TARGET_SKYLAKE
|
|
1195
|
+
nk_euclidean_e3m2_skylake(a, b, n, result);
|
|
1196
|
+
#elif NK_TARGET_ALDER
|
|
1197
|
+
nk_euclidean_e3m2_alder(a, b, n, result);
|
|
1198
|
+
#elif NK_TARGET_HASWELL
|
|
1199
|
+
nk_euclidean_e3m2_haswell(a, b, n, result);
|
|
1200
|
+
#elif NK_TARGET_NEON
|
|
1201
|
+
nk_euclidean_e3m2_neon(a, b, n, result);
|
|
1202
|
+
#else
|
|
1203
|
+
nk_euclidean_e3m2_serial(a, b, n, result);
|
|
1204
|
+
#endif
|
|
1205
|
+
}
|
|
1206
|
+
|
|
1207
|
+
NK_PUBLIC void nk_sqeuclidean_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1208
|
+
#if NK_TARGET_SAPPHIRE
|
|
1209
|
+
nk_sqeuclidean_e3m2_sapphire(a, b, n, result);
|
|
1210
|
+
#elif NK_TARGET_SKYLAKE
|
|
1211
|
+
nk_sqeuclidean_e3m2_skylake(a, b, n, result);
|
|
1212
|
+
#elif NK_TARGET_ALDER
|
|
1213
|
+
nk_sqeuclidean_e3m2_alder(a, b, n, result);
|
|
1214
|
+
#elif NK_TARGET_HASWELL
|
|
1215
|
+
nk_sqeuclidean_e3m2_haswell(a, b, n, result);
|
|
1216
|
+
#elif NK_TARGET_NEON
|
|
1217
|
+
nk_sqeuclidean_e3m2_neon(a, b, n, result);
|
|
1218
|
+
#else
|
|
1219
|
+
nk_sqeuclidean_e3m2_serial(a, b, n, result);
|
|
1220
|
+
#endif
|
|
1221
|
+
}
|
|
1222
|
+
|
|
1223
|
+
NK_PUBLIC void nk_angular_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1224
|
+
#if NK_TARGET_SAPPHIRE
|
|
1225
|
+
nk_angular_e3m2_sapphire(a, b, n, result);
|
|
1226
|
+
#elif NK_TARGET_SKYLAKE
|
|
1227
|
+
nk_angular_e3m2_skylake(a, b, n, result);
|
|
1228
|
+
#elif NK_TARGET_ALDER
|
|
1229
|
+
nk_angular_e3m2_alder(a, b, n, result);
|
|
1230
|
+
#elif NK_TARGET_HASWELL
|
|
1231
|
+
nk_angular_e3m2_haswell(a, b, n, result);
|
|
1232
|
+
#elif NK_TARGET_NEON
|
|
1233
|
+
nk_angular_e3m2_neon(a, b, n, result);
|
|
1234
|
+
#else
|
|
1235
|
+
nk_angular_e3m2_serial(a, b, n, result);
|
|
1236
|
+
#endif
|
|
1237
|
+
}
|
|
1238
|
+
|
|
1239
|
+
NK_PUBLIC void nk_euclidean_i8(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1240
|
+
#if NK_TARGET_RVV
|
|
1241
|
+
nk_euclidean_i8_rvv(a, b, n, result);
|
|
1242
|
+
#elif NK_TARGET_NEONSDOT
|
|
1243
|
+
nk_euclidean_i8_neonsdot(a, b, n, result);
|
|
1244
|
+
#elif NK_TARGET_ICELAKE
|
|
1245
|
+
nk_euclidean_i8_icelake(a, b, n, result);
|
|
1246
|
+
#elif NK_TARGET_SIERRA
|
|
1247
|
+
nk_euclidean_i8_sierra(a, b, n, result);
|
|
1248
|
+
#elif NK_TARGET_ALDER
|
|
1249
|
+
nk_euclidean_i8_alder(a, b, n, result);
|
|
1250
|
+
#elif NK_TARGET_HASWELL
|
|
1251
|
+
nk_euclidean_i8_haswell(a, b, n, result);
|
|
1252
|
+
#elif NK_TARGET_V128RELAXED
|
|
1253
|
+
nk_euclidean_i8_v128relaxed(a, b, n, result);
|
|
1254
|
+
#else
|
|
1255
|
+
nk_euclidean_i8_serial(a, b, n, result);
|
|
1256
|
+
#endif
|
|
1257
|
+
}
|
|
1258
|
+
|
|
1259
|
+
NK_PUBLIC void nk_sqeuclidean_i8(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
1260
|
+
#if NK_TARGET_RVV
|
|
1261
|
+
nk_sqeuclidean_i8_rvv(a, b, n, result);
|
|
1262
|
+
#elif NK_TARGET_NEONSDOT
|
|
1263
|
+
nk_sqeuclidean_i8_neonsdot(a, b, n, result);
|
|
1264
|
+
#elif NK_TARGET_ICELAKE
|
|
1265
|
+
nk_sqeuclidean_i8_icelake(a, b, n, result);
|
|
1266
|
+
#elif NK_TARGET_SIERRA
|
|
1267
|
+
nk_sqeuclidean_i8_sierra(a, b, n, result);
|
|
1268
|
+
#elif NK_TARGET_ALDER
|
|
1269
|
+
nk_sqeuclidean_i8_alder(a, b, n, result);
|
|
1270
|
+
#elif NK_TARGET_HASWELL
|
|
1271
|
+
nk_sqeuclidean_i8_haswell(a, b, n, result);
|
|
1272
|
+
#elif NK_TARGET_V128RELAXED
|
|
1273
|
+
nk_sqeuclidean_i8_v128relaxed(a, b, n, result);
|
|
1274
|
+
#else
|
|
1275
|
+
nk_sqeuclidean_i8_serial(a, b, n, result);
|
|
1276
|
+
#endif
|
|
1277
|
+
}
|
|
1278
|
+
|
|
1279
|
+
NK_PUBLIC void nk_angular_i8(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1280
|
+
#if NK_TARGET_RVV
|
|
1281
|
+
nk_angular_i8_rvv(a, b, n, result);
|
|
1282
|
+
#elif NK_TARGET_NEONSDOT
|
|
1283
|
+
nk_angular_i8_neonsdot(a, b, n, result);
|
|
1284
|
+
#elif NK_TARGET_ICELAKE
|
|
1285
|
+
nk_angular_i8_icelake(a, b, n, result);
|
|
1286
|
+
#elif NK_TARGET_SIERRA
|
|
1287
|
+
nk_angular_i8_sierra(a, b, n, result);
|
|
1288
|
+
#elif NK_TARGET_ALDER
|
|
1289
|
+
nk_angular_i8_alder(a, b, n, result);
|
|
1290
|
+
#elif NK_TARGET_HASWELL
|
|
1291
|
+
nk_angular_i8_haswell(a, b, n, result);
|
|
1292
|
+
#elif NK_TARGET_V128RELAXED
|
|
1293
|
+
nk_angular_i8_v128relaxed(a, b, n, result);
|
|
1294
|
+
#else
|
|
1295
|
+
nk_angular_i8_serial(a, b, n, result);
|
|
1296
|
+
#endif
|
|
1297
|
+
}
|
|
1298
|
+
|
|
1299
|
+
NK_PUBLIC void nk_euclidean_u8(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1300
|
+
#if NK_TARGET_RVV
|
|
1301
|
+
nk_euclidean_u8_rvv(a, b, n, result);
|
|
1302
|
+
#elif NK_TARGET_NEONSDOT
|
|
1303
|
+
nk_euclidean_u8_neonsdot(a, b, n, result);
|
|
1304
|
+
#elif NK_TARGET_ICELAKE
|
|
1305
|
+
nk_euclidean_u8_icelake(a, b, n, result);
|
|
1306
|
+
#elif NK_TARGET_SIERRA
|
|
1307
|
+
nk_euclidean_u8_sierra(a, b, n, result);
|
|
1308
|
+
#elif NK_TARGET_ALDER
|
|
1309
|
+
nk_euclidean_u8_alder(a, b, n, result);
|
|
1310
|
+
#elif NK_TARGET_HASWELL
|
|
1311
|
+
nk_euclidean_u8_haswell(a, b, n, result);
|
|
1312
|
+
#elif NK_TARGET_V128RELAXED
|
|
1313
|
+
nk_euclidean_u8_v128relaxed(a, b, n, result);
|
|
1314
|
+
#else
|
|
1315
|
+
nk_euclidean_u8_serial(a, b, n, result);
|
|
1316
|
+
#endif
|
|
1317
|
+
}
|
|
1318
|
+
|
|
1319
|
+
NK_PUBLIC void nk_sqeuclidean_u8(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
1320
|
+
#if NK_TARGET_RVV
|
|
1321
|
+
nk_sqeuclidean_u8_rvv(a, b, n, result);
|
|
1322
|
+
#elif NK_TARGET_NEONSDOT
|
|
1323
|
+
nk_sqeuclidean_u8_neonsdot(a, b, n, result);
|
|
1324
|
+
#elif NK_TARGET_ICELAKE
|
|
1325
|
+
nk_sqeuclidean_u8_icelake(a, b, n, result);
|
|
1326
|
+
#elif NK_TARGET_SIERRA
|
|
1327
|
+
nk_sqeuclidean_u8_sierra(a, b, n, result);
|
|
1328
|
+
#elif NK_TARGET_ALDER
|
|
1329
|
+
nk_sqeuclidean_u8_alder(a, b, n, result);
|
|
1330
|
+
#elif NK_TARGET_HASWELL
|
|
1331
|
+
nk_sqeuclidean_u8_haswell(a, b, n, result);
|
|
1332
|
+
#elif NK_TARGET_V128RELAXED
|
|
1333
|
+
nk_sqeuclidean_u8_v128relaxed(a, b, n, result);
|
|
1334
|
+
#else
|
|
1335
|
+
nk_sqeuclidean_u8_serial(a, b, n, result);
|
|
1336
|
+
#endif
|
|
1337
|
+
}
|
|
1338
|
+
|
|
1339
|
+
NK_PUBLIC void nk_angular_u8(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1340
|
+
#if NK_TARGET_RVV
|
|
1341
|
+
nk_angular_u8_rvv(a, b, n, result);
|
|
1342
|
+
#elif NK_TARGET_NEONSDOT
|
|
1343
|
+
nk_angular_u8_neonsdot(a, b, n, result);
|
|
1344
|
+
#elif NK_TARGET_ICELAKE
|
|
1345
|
+
nk_angular_u8_icelake(a, b, n, result);
|
|
1346
|
+
#elif NK_TARGET_SIERRA
|
|
1347
|
+
nk_angular_u8_sierra(a, b, n, result);
|
|
1348
|
+
#elif NK_TARGET_ALDER
|
|
1349
|
+
nk_angular_u8_alder(a, b, n, result);
|
|
1350
|
+
#elif NK_TARGET_HASWELL
|
|
1351
|
+
nk_angular_u8_haswell(a, b, n, result);
|
|
1352
|
+
#elif NK_TARGET_V128RELAXED
|
|
1353
|
+
nk_angular_u8_v128relaxed(a, b, n, result);
|
|
1354
|
+
#else
|
|
1355
|
+
nk_angular_u8_serial(a, b, n, result);
|
|
1356
|
+
#endif
|
|
1357
|
+
}
|
|
1358
|
+
|
|
1359
|
+
NK_PUBLIC void nk_euclidean_i4(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1360
|
+
#if NK_TARGET_ICELAKE
|
|
1361
|
+
nk_euclidean_i4_icelake(a, b, n, result);
|
|
1362
|
+
#elif NK_TARGET_RVV
|
|
1363
|
+
nk_euclidean_i4_rvv(a, b, n, result);
|
|
1364
|
+
#else
|
|
1365
|
+
nk_euclidean_i4_serial(a, b, n, result);
|
|
1366
|
+
#endif
|
|
1367
|
+
}
|
|
1368
|
+
|
|
1369
|
+
NK_PUBLIC void nk_sqeuclidean_i4(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
1370
|
+
#if NK_TARGET_ICELAKE
|
|
1371
|
+
nk_sqeuclidean_i4_icelake(a, b, n, result);
|
|
1372
|
+
#elif NK_TARGET_RVV
|
|
1373
|
+
nk_sqeuclidean_i4_rvv(a, b, n, result);
|
|
1374
|
+
#else
|
|
1375
|
+
nk_sqeuclidean_i4_serial(a, b, n, result);
|
|
1376
|
+
#endif
|
|
1377
|
+
}
|
|
1378
|
+
|
|
1379
|
+
NK_PUBLIC void nk_angular_i4(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1380
|
+
#if NK_TARGET_ICELAKE
|
|
1381
|
+
nk_angular_i4_icelake(a, b, n, result);
|
|
1382
|
+
#elif NK_TARGET_RVV
|
|
1383
|
+
nk_angular_i4_rvv(a, b, n, result);
|
|
1384
|
+
#else
|
|
1385
|
+
nk_angular_i4_serial(a, b, n, result);
|
|
1386
|
+
#endif
|
|
1387
|
+
}
|
|
1388
|
+
|
|
1389
|
+
NK_PUBLIC void nk_euclidean_u4(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1390
|
+
#if NK_TARGET_ICELAKE
|
|
1391
|
+
nk_euclidean_u4_icelake(a, b, n, result);
|
|
1392
|
+
#elif NK_TARGET_RVV
|
|
1393
|
+
nk_euclidean_u4_rvv(a, b, n, result);
|
|
1394
|
+
#else
|
|
1395
|
+
nk_euclidean_u4_serial(a, b, n, result);
|
|
1396
|
+
#endif
|
|
1397
|
+
}
|
|
1398
|
+
|
|
1399
|
+
NK_PUBLIC void nk_sqeuclidean_u4(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
1400
|
+
#if NK_TARGET_ICELAKE
|
|
1401
|
+
nk_sqeuclidean_u4_icelake(a, b, n, result);
|
|
1402
|
+
#elif NK_TARGET_RVV
|
|
1403
|
+
nk_sqeuclidean_u4_rvv(a, b, n, result);
|
|
1404
|
+
#else
|
|
1405
|
+
nk_sqeuclidean_u4_serial(a, b, n, result);
|
|
1406
|
+
#endif
|
|
1407
|
+
}
|
|
1408
|
+
|
|
1409
|
+
NK_PUBLIC void nk_angular_u4(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
1410
|
+
#if NK_TARGET_ICELAKE
|
|
1411
|
+
nk_angular_u4_icelake(a, b, n, result);
|
|
1412
|
+
#elif NK_TARGET_RVV
|
|
1413
|
+
nk_angular_u4_rvv(a, b, n, result);
|
|
1414
|
+
#else
|
|
1415
|
+
nk_angular_u4_serial(a, b, n, result);
|
|
1416
|
+
#endif
|
|
1417
|
+
}
|
|
1418
|
+
|
|
1419
|
+
#endif // !NK_DYNAMIC_DISPATCH
|
|
1420
|
+
|
|
1421
|
+
#if defined(__cplusplus)
|
|
1422
|
+
} // extern "C"
|
|
1423
|
+
#endif
|
|
1424
|
+
|
|
1425
|
+
#endif
|