numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,1070 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Dot Products for Real and Complex Numbers.
|
|
3
|
+
* @file include/numkong/dot.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 24, 2024
|
|
6
|
+
*
|
|
7
|
+
* Contains:
|
|
8
|
+
*
|
|
9
|
+
* - Dot Product for Real and Complex vectors
|
|
10
|
+
* - Conjugate Dot Product for Complex vectors
|
|
11
|
+
*
|
|
12
|
+
* For dtypes:
|
|
13
|
+
*
|
|
14
|
+
* - f64: 64-bit IEEE floating point numbers → 64-bit floats
|
|
15
|
+
* - f32: 32-bit IEEE floating point numbers → 64-bit floats
|
|
16
|
+
* - f16: 16-bit IEEE floating point numbers → 32-bit floats
|
|
17
|
+
* - bf16: 16-bit brain floating point numbers → 32-bit floats
|
|
18
|
+
* - e4m3: 8-bit e4m3 floating point numbers → 32-bit floats
|
|
19
|
+
* - e5m2: 8-bit e5m2 floating point numbers → 32-bit floats
|
|
20
|
+
* - e2m3: 8-bit e2m3 floating point numbers (MX) → 32-bit floats
|
|
21
|
+
* - e3m2: 8-bit e3m2 floating point numbers (MX) → 32-bit floats
|
|
22
|
+
* - i8: 8-bit signed integers → 32-bit signed integers
|
|
23
|
+
* - u8: 8-bit unsigned integers → 32-bit unsigned integers
|
|
24
|
+
* - i4: 4-bit signed integers (packed nibble pairs) → 32-bit signed integers
|
|
25
|
+
* - u4: 4-bit unsigned integers (packed nibble pairs) → 32-bit unsigned integers
|
|
26
|
+
* - u1: 1-bit binary (packed octets) → 32-bit unsigned integers
|
|
27
|
+
*
|
|
28
|
+
* Complex dot product variants:
|
|
29
|
+
*
|
|
30
|
+
* - f64c: 64-bit complex pairs → 64-bit complex
|
|
31
|
+
* - f32c: 32-bit complex pairs → 64-bit complex
|
|
32
|
+
* - f16c: 16-bit complex pairs → 32-bit complex
|
|
33
|
+
* - bf16c: 16-bit brain complex pairs → 32-bit complex
|
|
34
|
+
*
|
|
35
|
+
* For hardware architectures:
|
|
36
|
+
*
|
|
37
|
+
* - Arm: NEON, NEON+I8, NEON+F16, NEON+FHM, NEON+BF16, SVE, SVE+F16
|
|
38
|
+
* - x86: Haswell, Skylake, Ice Lake, Genoa, Sapphire Rapids, Sierra Forest
|
|
39
|
+
* - RISC-V: RVV, RVV+BF16, RVV+HALF, RVV+BB
|
|
40
|
+
* - WASM: V128Relaxed
|
|
41
|
+
*
|
|
42
|
+
* @section numerical_stability Numerical Stability
|
|
43
|
+
*
|
|
44
|
+
* - f64: Dot2/Ogita-Rump-Oishi style compensated summation across serial and SIMD stateful paths.
|
|
45
|
+
* - f32: public outputs widen to f64/f64c. Arithmetic widens before the first lossy reduction step.
|
|
46
|
+
* - f16/bf16: Promoted to f32 accumulator.
|
|
47
|
+
* - e4m3/e5m2: Promoted to f32. On Sapphire, e2m3/e3m2 use f16 intermediate with periodic
|
|
48
|
+
* flush to f32 every 128 elements to avoid f16 overflow (max lane sum ~225 / ~3136).
|
|
49
|
+
* - i8: i32 accumulator. Max product |(-128)²| = 16,384. Overflows at n > 2^31/16,384 ≈ 131K.
|
|
50
|
+
* - u8: u32 accumulator. Max product 255² = 65,025. Overflows at n > 2^32/65,025 ≈ 66K.
|
|
51
|
+
* - i4: i32 accumulator. Max product 8² = 64. Safe for n ≤ ~33M.
|
|
52
|
+
* - u4: u32 accumulator. Max product 15² = 225. Safe for n ≤ ~19M.
|
|
53
|
+
* - u1: Popcount of AND into u32. Safe for n_bits ≤ 2^32.
|
|
54
|
+
* - Complex: Components accumulated independently; same guarantees as real counterpart.
|
|
55
|
+
*
|
|
56
|
+
* @section streaming_api Streaming API
|
|
57
|
+
*
|
|
58
|
+
* For compile-time dispatch and vector-at-a-time accumulation, we provide streaming helpers
|
|
59
|
+
* that accept two `nk_b512_vec_t` blocks and update a running sum for non-complex dot
|
|
60
|
+
* products. The `<count>` suffix reflects how many scalars of that type fit in a 512-bit block.
|
|
61
|
+
* The helpers are exposed per scalar type as:
|
|
62
|
+
*
|
|
63
|
+
* - nk_dot_<type>x<count>_state_<isa>_t
|
|
64
|
+
* - nk_dot_<type>x<count>_init_<isa>
|
|
65
|
+
* - nk_dot_<type>x<count>_update_<isa>
|
|
66
|
+
* - nk_dot_<type>x<count>_finalize_<isa>
|
|
67
|
+
*
|
|
68
|
+
* @section x86_instructions Relevant x86 Instructions
|
|
69
|
+
*
|
|
70
|
+
* Floating-point dot products use FMA (VFMADD231PS/PD) for sum += a[i]*b[i] accumulation.
|
|
71
|
+
* Integer i8 dot products use VPMADDUBSW (u8 × i8 → i16) + VPMADDWD (i16 × 1 → i32) on Haswell,
|
|
72
|
+
* or the newer VNNI instructions VPDPBUSD/VPDPWSSD on Ice Lake+ for direct u8 × i8 → i32.
|
|
73
|
+
* BF16 dot products (VDPBF16PS) are Genoa-only, accumulating bf16 pairs directly to f32.
|
|
74
|
+
* Genoa shows 40% faster integer multiply-add (3c vs 5c) than Ice Lake.
|
|
75
|
+
*
|
|
76
|
+
* Intrinsic Instruction Haswell Ice Genoa
|
|
77
|
+
* _mm256_fmadd_ps VFMADD231PS (YMM, YMM, YMM) 5c @ p01 4c @ p01 4c @ p01
|
|
78
|
+
* _mm256_fmadd_pd VFMADD231PD (YMM, YMM, YMM) 5c @ p01 4c @ p01 4c @ p01
|
|
79
|
+
* _mm256_maddubs_epi16 VPMADDUBSW (YMM, YMM, YMM) 5c @ p0 5c @ p01 3c @ p01
|
|
80
|
+
* _mm256_madd_epi16 VPMADDWD (YMM, YMM, YMM) 5c @ p0 5c @ p01 3c @ p01
|
|
81
|
+
* _mm256_dpbusd_epi32 VPDPBUSD (YMM, YMM, YMM) N/A 5c @ p01 4c @ p01
|
|
82
|
+
* _mm512_dpwssd_epi32 VPDPWSSD (ZMM, ZMM, ZMM) N/A 5c @ p0 4c @ p01
|
|
83
|
+
* _mm512_dpbf16_ps VDPBF16PS (ZMM, ZMM, ZMM) N/A N/A 6c @ p01
|
|
84
|
+
*
|
|
85
|
+
* @section arm_neon_instructions Relevant ARM NEON Instructions
|
|
86
|
+
*
|
|
87
|
+
* NEON integer dot products use SDOT/UDOT (ARMv8.2 dotprod) for direct i8 × i8 → i32 or u8 × u8 → u32
|
|
88
|
+
* accumulation - 4x faster than the multiply-add sequence on older cores. BFDOT (ARMv8.6 bf16)
|
|
89
|
+
* provides native bf16 dot products on Graviton 3+. Complex dot products use LD2 for deinterleaved
|
|
90
|
+
* loads of real/imag pairs, though its L01+V throughput can bottleneck on memory-bound workloads.
|
|
91
|
+
*
|
|
92
|
+
* Intrinsic Instruction M1 Firestorm Graviton 3 Graviton 4
|
|
93
|
+
* vfmaq_f32 FMLA.S (vec) 4c @ V0123 4c @ V0123 4c @ V0123
|
|
94
|
+
* vfmaq_f64 FMLA.D (vec) 4c @ V0123 4c @ V0123 4c @ V0123
|
|
95
|
+
* vdotq_s32 SDOT (vec) 3c @ V0123 3c @ V0123 3c @ V0123
|
|
96
|
+
* vdotq_u32 UDOT (vec) 3c @ V0123 3c @ V0123 3c @ V0123
|
|
97
|
+
* vbfdotq_f32 BFDOT (vec) N/A 4c @ V0123 5c @ V0123
|
|
98
|
+
* vld2q_f32 LD2 (Q-form) 5c @ L01+V 8c @ L01+V 8c @ L01+V
|
|
99
|
+
*
|
|
100
|
+
* @section arm_sve_instructions Relevant ARM SVE Instructions
|
|
101
|
+
*
|
|
102
|
+
* SVE implementations use predicated FMA (svmla_f32_x) with WHILELT for tail masking, avoiding
|
|
103
|
+
* scalar cleanup loops. FADDV performs horizontal reduction; notably 45% faster on Graviton 4
|
|
104
|
+
* (6c) than Graviton 3 (11c). SVE complex dot products use svld2 for structure loads.
|
|
105
|
+
*
|
|
106
|
+
* Intrinsic Instruction Graviton 3 Graviton 4
|
|
107
|
+
* svmla_f32_x FMLA (pred) 4c @ V0123 4c @ V0123
|
|
108
|
+
* svmls_f32_x FMLS (pred) 4c @ V0123 4c @ V0123
|
|
109
|
+
* svwhilelt_b32 WHILELT 3c @ M0 3c @ M0
|
|
110
|
+
* svld2_f32 LD2 (SVE) 8c @ L01+V 8c @ L01+V
|
|
111
|
+
* svaddv_f32 FADDV 11c @ V0123 6c @ V0123
|
|
112
|
+
*
|
|
113
|
+
* @section complex_instructions Complex Number Optimizations
|
|
114
|
+
*
|
|
115
|
+
* Standard complex multiplication involves subtraction for the real part.
|
|
116
|
+
* Instead of using subtracting variants of FMA for every element, we accumulate real
|
|
117
|
+
* and imaginary products positively and apply a single bitwise XOR to flip the sign
|
|
118
|
+
* bits before the final horizontal reduction. This delayed application of the sign
|
|
119
|
+
* flip doubles the throughput on older x86 architectures like Haswell by maximizing
|
|
120
|
+
* FMA unit utilization and reducing execution dependency chains.
|
|
121
|
+
*
|
|
122
|
+
* @section references References
|
|
123
|
+
*
|
|
124
|
+
* - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
|
|
125
|
+
* - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
|
|
126
|
+
*
|
|
127
|
+
*/
|
|
128
|
+
#ifndef NK_DOT_H
|
|
129
|
+
#define NK_DOT_H
|
|
130
|
+
|
|
131
|
+
#include "numkong/types.h"
|
|
132
|
+
|
|
133
|
+
#if defined(__cplusplus)
|
|
134
|
+
extern "C" {
|
|
135
|
+
#endif
|
|
136
|
+
|
|
137
|
+
/**
|
|
138
|
+
* @brief Dot product computing the sum of elementwise products between two vectors.
|
|
139
|
+
*
|
|
140
|
+
* @param[in] a The first vector.
|
|
141
|
+
* @param[in] b The second vector.
|
|
142
|
+
* @param[in] n The number of elements in the vectors.
|
|
143
|
+
* @param[out] result The output dot product value.
|
|
144
|
+
*
|
|
145
|
+
* @note The output value can be negative.
|
|
146
|
+
* @note The output value is zero if and only if the two vectors are orthogonal.
|
|
147
|
+
* @note Defined for floating-point, integer, and binary data types.
|
|
148
|
+
*/
|
|
149
|
+
NK_DYNAMIC void nk_dot_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
150
|
+
/** @copydoc nk_dot_f32 */
|
|
151
|
+
NK_DYNAMIC void nk_dot_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
152
|
+
/** @copydoc nk_dot_f32 */
|
|
153
|
+
NK_DYNAMIC void nk_dot_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
154
|
+
/** @copydoc nk_dot_f32 */
|
|
155
|
+
NK_DYNAMIC void nk_dot_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
156
|
+
/** @copydoc nk_dot_f32 */
|
|
157
|
+
NK_DYNAMIC void nk_dot_i8(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result);
|
|
158
|
+
/** @copydoc nk_dot_f32 */
|
|
159
|
+
NK_DYNAMIC void nk_dot_u8(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
160
|
+
/** @copydoc nk_dot_f32 */
|
|
161
|
+
NK_DYNAMIC void nk_dot_i4(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result);
|
|
162
|
+
/** @copydoc nk_dot_f32 */
|
|
163
|
+
NK_DYNAMIC void nk_dot_u4(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result);
|
|
164
|
+
/** @copydoc nk_dot_f32 */
|
|
165
|
+
NK_DYNAMIC void nk_dot_u1(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result);
|
|
166
|
+
/** @copydoc nk_dot_f32 */
|
|
167
|
+
NK_DYNAMIC void nk_dot_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
168
|
+
/** @copydoc nk_dot_f32 */
|
|
169
|
+
NK_DYNAMIC void nk_dot_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
170
|
+
/** @copydoc nk_dot_f32 */
|
|
171
|
+
NK_DYNAMIC void nk_dot_e2m3(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
172
|
+
/** @copydoc nk_dot_f32 */
|
|
173
|
+
NK_DYNAMIC void nk_dot_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
174
|
+
|
|
175
|
+
/**
|
|
176
|
+
* @brief Complex dot product computing the sum of elementwise products between two complex vectors.
|
|
177
|
+
*
|
|
178
|
+
* @param[in] a_pairs The first complex vector.
|
|
179
|
+
* @param[in] b_pairs The second complex vector.
|
|
180
|
+
* @param[in] count_pairs The number of complex pairs in the vectors.
|
|
181
|
+
* @param[out] result The output complex value as {real, imag}.
|
|
182
|
+
*/
|
|
183
|
+
NK_DYNAMIC void nk_dot_f32c(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_size_t count_pairs,
|
|
184
|
+
nk_f64c_t *result);
|
|
185
|
+
/** @copydoc nk_dot_f32c */
|
|
186
|
+
NK_DYNAMIC void nk_dot_f64c(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
|
|
187
|
+
nk_f64c_t *result);
|
|
188
|
+
/** @copydoc nk_dot_f32c */
|
|
189
|
+
NK_DYNAMIC void nk_dot_f16c(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
|
|
190
|
+
nk_f32c_t *result);
|
|
191
|
+
/** @copydoc nk_dot_f32c */
|
|
192
|
+
NK_DYNAMIC void nk_dot_bf16c(nk_bf16c_t const *a_pairs, nk_bf16c_t const *b_pairs, nk_size_t count_pairs,
|
|
193
|
+
nk_f32c_t *result);
|
|
194
|
+
|
|
195
|
+
/**
|
|
196
|
+
* @brief Complex conjugate dot product between two complex vectors.
|
|
197
|
+
*
|
|
198
|
+
* @param[in] a_pairs The first complex vector.
|
|
199
|
+
* @param[in] b_pairs The second complex vector.
|
|
200
|
+
* @param[in] count_pairs The number of complex pairs in the vectors.
|
|
201
|
+
* @param[out] result The output complex value as {real, imag}.
|
|
202
|
+
*/
|
|
203
|
+
NK_DYNAMIC void nk_vdot_f32c(nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_size_t count_pairs,
|
|
204
|
+
nk_f64c_t *result);
|
|
205
|
+
/** @copydoc nk_vdot_f32c */
|
|
206
|
+
NK_DYNAMIC void nk_vdot_f64c(nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_size_t count_pairs,
|
|
207
|
+
nk_f64c_t *result);
|
|
208
|
+
/** @copydoc nk_vdot_f32c */
|
|
209
|
+
NK_DYNAMIC void nk_vdot_f16c(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
|
|
210
|
+
nk_f32c_t *result);
|
|
211
|
+
/** @copydoc nk_vdot_f32c */
|
|
212
|
+
NK_DYNAMIC void nk_vdot_bf16c(nk_bf16c_t const *a_pairs, nk_bf16c_t const *b_pairs, nk_size_t count_pairs,
|
|
213
|
+
nk_f32c_t *result);
|
|
214
|
+
|
|
215
|
+
/** @copydoc nk_dot_f64 */
|
|
216
|
+
NK_PUBLIC void nk_dot_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
217
|
+
/** @copydoc nk_dot_f64c */
|
|
218
|
+
NK_PUBLIC void nk_dot_f64c_serial(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
219
|
+
/** @copydoc nk_vdot_f64c */
|
|
220
|
+
NK_PUBLIC void nk_vdot_f64c_serial(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
221
|
+
|
|
222
|
+
/** @copydoc nk_dot_f32 */
|
|
223
|
+
NK_PUBLIC void nk_dot_f32_serial(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
224
|
+
/** @copydoc nk_dot_f32c */
|
|
225
|
+
NK_PUBLIC void nk_dot_f32c_serial(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
226
|
+
/** @copydoc nk_vdot_f32c */
|
|
227
|
+
NK_PUBLIC void nk_vdot_f32c_serial(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
228
|
+
|
|
229
|
+
/** @copydoc nk_dot_f16 */
|
|
230
|
+
NK_PUBLIC void nk_dot_f16_serial(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
231
|
+
/** @copydoc nk_dot_f16c */
|
|
232
|
+
NK_PUBLIC void nk_dot_f16c_serial(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result);
|
|
233
|
+
/** @copydoc nk_vdot_f16c */
|
|
234
|
+
NK_PUBLIC void nk_vdot_f16c_serial(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result);
|
|
235
|
+
|
|
236
|
+
/** @copydoc nk_dot_bf16 */
|
|
237
|
+
NK_PUBLIC void nk_dot_bf16_serial(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
238
|
+
/** @copydoc nk_dot_bf16c */
|
|
239
|
+
NK_PUBLIC void nk_dot_bf16c_serial(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_size_t n, nk_f32c_t *result);
|
|
240
|
+
/** @copydoc nk_vdot_bf16c */
|
|
241
|
+
NK_PUBLIC void nk_vdot_bf16c_serial(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_size_t n, nk_f32c_t *result);
|
|
242
|
+
|
|
243
|
+
/** @copydoc nk_dot_i8 */
|
|
244
|
+
NK_PUBLIC void nk_dot_i8_serial(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result);
|
|
245
|
+
/** @copydoc nk_dot_u8 */
|
|
246
|
+
NK_PUBLIC void nk_dot_u8_serial(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
247
|
+
/** @copydoc nk_dot_i4 */
|
|
248
|
+
NK_PUBLIC void nk_dot_i4_serial(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result);
|
|
249
|
+
/** @copydoc nk_dot_u4 */
|
|
250
|
+
NK_PUBLIC void nk_dot_u4_serial(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result);
|
|
251
|
+
/** @copydoc nk_dot_u1 */
|
|
252
|
+
NK_PUBLIC void nk_dot_u1_serial(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result);
|
|
253
|
+
|
|
254
|
+
/** @copydoc nk_dot_e4m3 */
|
|
255
|
+
NK_PUBLIC void nk_dot_e4m3_serial(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
256
|
+
/** @copydoc nk_dot_e5m2 */
|
|
257
|
+
NK_PUBLIC void nk_dot_e5m2_serial(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
258
|
+
/** @copydoc nk_dot_e2m3 */
|
|
259
|
+
NK_PUBLIC void nk_dot_e2m3_serial(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
260
|
+
/** @copydoc nk_dot_e3m2 */
|
|
261
|
+
NK_PUBLIC void nk_dot_e3m2_serial(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
262
|
+
|
|
263
|
+
#if NK_TARGET_NEON
|
|
264
|
+
/** @copydoc nk_dot_f32 */
|
|
265
|
+
NK_PUBLIC void nk_dot_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
266
|
+
/** @copydoc nk_dot_f32c */
|
|
267
|
+
NK_PUBLIC void nk_dot_f32c_neon(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
268
|
+
/** @copydoc nk_vdot_f32c */
|
|
269
|
+
NK_PUBLIC void nk_vdot_f32c_neon(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
270
|
+
|
|
271
|
+
/** @copydoc nk_dot_f64 */
|
|
272
|
+
NK_PUBLIC void nk_dot_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
273
|
+
/** @copydoc nk_dot_f64c */
|
|
274
|
+
NK_PUBLIC void nk_dot_f64c_neon(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
275
|
+
/** @copydoc nk_vdot_f64c */
|
|
276
|
+
NK_PUBLIC void nk_vdot_f64c_neon(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
277
|
+
|
|
278
|
+
/** @copydoc nk_dot_bf16 */
|
|
279
|
+
NK_PUBLIC void nk_dot_bf16_neon(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
280
|
+
|
|
281
|
+
/** @copydoc nk_dot_e4m3 */
|
|
282
|
+
NK_PUBLIC void nk_dot_e4m3_neon(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
283
|
+
/** @copydoc nk_dot_e5m2 */
|
|
284
|
+
NK_PUBLIC void nk_dot_e5m2_neon(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
285
|
+
/** @copydoc nk_dot_e2m3 */
|
|
286
|
+
NK_PUBLIC void nk_dot_e2m3_neon(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
287
|
+
/** @copydoc nk_dot_e3m2 */
|
|
288
|
+
NK_PUBLIC void nk_dot_e3m2_neon(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
289
|
+
|
|
290
|
+
/** @copydoc nk_dot_u1 */
|
|
291
|
+
NK_PUBLIC void nk_dot_u1_neon(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result);
|
|
292
|
+
|
|
293
|
+
/** @copydoc nk_dot_f16 */
|
|
294
|
+
NK_PUBLIC void nk_dot_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
295
|
+
|
|
296
|
+
#endif // NK_TARGET_NEON
|
|
297
|
+
|
|
298
|
+
#if NK_TARGET_NEONHALF
|
|
299
|
+
/** @copydoc nk_dot_f16 */
|
|
300
|
+
NK_PUBLIC void nk_dot_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
301
|
+
/** @copydoc nk_dot_f16c */
|
|
302
|
+
NK_PUBLIC void nk_dot_f16c_neonhalf(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result);
|
|
303
|
+
/** @copydoc nk_vdot_f16c */
|
|
304
|
+
NK_PUBLIC void nk_vdot_f16c_neonhalf(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result);
|
|
305
|
+
#endif // NK_TARGET_NEONHALF
|
|
306
|
+
|
|
307
|
+
#if NK_TARGET_NEONFHM
|
|
308
|
+
/** @copydoc nk_dot_f16 */
|
|
309
|
+
NK_PUBLIC void nk_dot_f16_neonfhm(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
310
|
+
/** @copydoc nk_dot_e4m3 */
|
|
311
|
+
NK_PUBLIC void nk_dot_e4m3_neonfhm(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
312
|
+
/** @copydoc nk_dot_e5m2 */
|
|
313
|
+
NK_PUBLIC void nk_dot_e5m2_neonfhm(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
314
|
+
/** @copydoc nk_dot_f16c */
|
|
315
|
+
NK_PUBLIC void nk_dot_f16c_neonfhm(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result);
|
|
316
|
+
/** @copydoc nk_vdot_f16c */
|
|
317
|
+
NK_PUBLIC void nk_vdot_f16c_neonfhm(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result);
|
|
318
|
+
#endif // NK_TARGET_NEONFHM
|
|
319
|
+
|
|
320
|
+
#if NK_TARGET_NEONSDOT
|
|
321
|
+
/** @copydoc nk_dot_i8 */
|
|
322
|
+
NK_PUBLIC void nk_dot_i8_neonsdot(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result);
|
|
323
|
+
/** @copydoc nk_dot_u8 */
|
|
324
|
+
NK_PUBLIC void nk_dot_u8_neonsdot(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
325
|
+
/** @copydoc nk_dot_i4 */
|
|
326
|
+
NK_PUBLIC void nk_dot_i4_neonsdot(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result);
|
|
327
|
+
/** @copydoc nk_dot_u4 */
|
|
328
|
+
NK_PUBLIC void nk_dot_u4_neonsdot(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result);
|
|
329
|
+
/** @copydoc nk_dot_e2m3 */
|
|
330
|
+
NK_PUBLIC void nk_dot_e2m3_neonsdot(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
331
|
+
/** @copydoc nk_dot_e3m2 */
|
|
332
|
+
NK_PUBLIC void nk_dot_e3m2_neonsdot(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
333
|
+
#endif // NK_TARGET_NEONSDOT
|
|
334
|
+
|
|
335
|
+
#if NK_TARGET_NEONBFDOT
|
|
336
|
+
/** @copydoc nk_dot_bf16 */
|
|
337
|
+
NK_PUBLIC void nk_dot_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
338
|
+
/** @copydoc nk_dot_e4m3 */
|
|
339
|
+
NK_PUBLIC void nk_dot_e4m3_neonbfdot(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
340
|
+
/** @copydoc nk_dot_e5m2 */
|
|
341
|
+
NK_PUBLIC void nk_dot_e5m2_neonbfdot(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
342
|
+
/** @copydoc nk_dot_bf16c */
|
|
343
|
+
NK_PUBLIC void nk_dot_bf16c_neonbfdot(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_size_t n, nk_f32c_t *result);
|
|
344
|
+
/** @copydoc nk_vdot_bf16c */
|
|
345
|
+
NK_PUBLIC void nk_vdot_bf16c_neonbfdot(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_size_t n, nk_f32c_t *result);
|
|
346
|
+
#endif // NK_TARGET_NEONBFDOT
|
|
347
|
+
|
|
348
|
+
#if NK_TARGET_SVEBFDOT
|
|
349
|
+
/** @copydoc nk_dot_bf16 */
|
|
350
|
+
NK_PUBLIC void nk_dot_bf16_svebfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
351
|
+
#endif // NK_TARGET_SVEBFDOT
|
|
352
|
+
|
|
353
|
+
#if NK_TARGET_SVE
|
|
354
|
+
/** @copydoc nk_dot_f32 */
|
|
355
|
+
NK_PUBLIC void nk_dot_f32_sve(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
356
|
+
/** @copydoc nk_dot_f32c */
|
|
357
|
+
NK_PUBLIC void nk_dot_f32c_sve(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
358
|
+
/** @copydoc nk_vdot_f32c */
|
|
359
|
+
NK_PUBLIC void nk_vdot_f32c_sve(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
360
|
+
/** @copydoc nk_dot_f64 */
|
|
361
|
+
NK_PUBLIC void nk_dot_f64_sve(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
362
|
+
/** @copydoc nk_dot_f64c */
|
|
363
|
+
NK_PUBLIC void nk_dot_f64c_sve(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
364
|
+
/** @copydoc nk_vdot_f64c */
|
|
365
|
+
NK_PUBLIC void nk_vdot_f64c_sve(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
366
|
+
#endif // NK_TARGET_SVE
|
|
367
|
+
|
|
368
|
+
#if NK_TARGET_SVEHALF
|
|
369
|
+
/** @copydoc nk_dot_f16 */
|
|
370
|
+
NK_PUBLIC void nk_dot_f16_svehalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
371
|
+
/** @copydoc nk_dot_f16c */
|
|
372
|
+
NK_PUBLIC void nk_dot_f16c_svehalf(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result);
|
|
373
|
+
/** @copydoc nk_vdot_f16c */
|
|
374
|
+
NK_PUBLIC void nk_vdot_f16c_svehalf(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result);
|
|
375
|
+
#endif // NK_TARGET_SVEHALF
|
|
376
|
+
|
|
377
|
+
#if NK_TARGET_HASWELL
|
|
378
|
+
/** @copydoc nk_dot_f32 */
|
|
379
|
+
NK_PUBLIC void nk_dot_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
380
|
+
/** @copydoc nk_dot_f64 */
|
|
381
|
+
NK_PUBLIC void nk_dot_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
382
|
+
/** @copydoc nk_dot_f32c */
|
|
383
|
+
NK_PUBLIC void nk_dot_f32c_haswell(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
384
|
+
/** @copydoc nk_vdot_f32c */
|
|
385
|
+
NK_PUBLIC void nk_vdot_f32c_haswell(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
386
|
+
/** @copydoc nk_dot_f64c */
|
|
387
|
+
NK_PUBLIC void nk_dot_f64c_haswell(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
388
|
+
/** @copydoc nk_vdot_f64c */
|
|
389
|
+
NK_PUBLIC void nk_vdot_f64c_haswell(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
390
|
+
|
|
391
|
+
/** @copydoc nk_dot_f16 */
|
|
392
|
+
NK_PUBLIC void nk_dot_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
393
|
+
/** @copydoc nk_dot_f16c */
|
|
394
|
+
NK_PUBLIC void nk_dot_f16c_haswell(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result);
|
|
395
|
+
/** @copydoc nk_vdot_f16c */
|
|
396
|
+
NK_PUBLIC void nk_vdot_f16c_haswell(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result);
|
|
397
|
+
|
|
398
|
+
/** @copydoc nk_dot_bf16 */
|
|
399
|
+
NK_PUBLIC void nk_dot_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
400
|
+
/** @copydoc nk_dot_bf16c */
|
|
401
|
+
NK_PUBLIC void nk_dot_bf16c_haswell(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_size_t n, nk_f32c_t *result);
|
|
402
|
+
/** @copydoc nk_vdot_bf16c */
|
|
403
|
+
NK_PUBLIC void nk_vdot_bf16c_haswell(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_size_t n, nk_f32c_t *result);
|
|
404
|
+
|
|
405
|
+
/** @copydoc nk_dot_e4m3 */
|
|
406
|
+
NK_PUBLIC void nk_dot_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
407
|
+
/** @copydoc nk_dot_e5m2 */
|
|
408
|
+
NK_PUBLIC void nk_dot_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
409
|
+
/** @copydoc nk_dot_e2m3 */
|
|
410
|
+
NK_PUBLIC void nk_dot_e2m3_haswell(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
411
|
+
/** @copydoc nk_dot_e3m2 */
|
|
412
|
+
NK_PUBLIC void nk_dot_e3m2_haswell(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
413
|
+
|
|
414
|
+
/** @copydoc nk_dot_i8 */
|
|
415
|
+
NK_PUBLIC void nk_dot_i8_haswell(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result);
|
|
416
|
+
/** @copydoc nk_dot_u8 */
|
|
417
|
+
NK_PUBLIC void nk_dot_u8_haswell(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
418
|
+
/** @copydoc nk_dot_i4 */
|
|
419
|
+
NK_PUBLIC void nk_dot_i4_haswell(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result);
|
|
420
|
+
/** @copydoc nk_dot_u4 */
|
|
421
|
+
NK_PUBLIC void nk_dot_u4_haswell(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result);
|
|
422
|
+
/** @copydoc nk_dot_u1 */
|
|
423
|
+
NK_PUBLIC void nk_dot_u1_haswell(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result);
|
|
424
|
+
|
|
425
|
+
#endif // NK_TARGET_HASWELL
|
|
426
|
+
|
|
427
|
+
#if NK_TARGET_SKYLAKE
|
|
428
|
+
/** @copydoc nk_dot_f64 */
|
|
429
|
+
NK_PUBLIC void nk_dot_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
430
|
+
/** @copydoc nk_dot_f64c */
|
|
431
|
+
NK_PUBLIC void nk_dot_f64c_skylake(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
432
|
+
/** @copydoc nk_vdot_f64c */
|
|
433
|
+
NK_PUBLIC void nk_vdot_f64c_skylake(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
434
|
+
|
|
435
|
+
/** @copydoc nk_dot_f32 */
|
|
436
|
+
NK_PUBLIC void nk_dot_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
437
|
+
/** @copydoc nk_dot_f32c */
|
|
438
|
+
NK_PUBLIC void nk_dot_f32c_skylake(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
439
|
+
/** @copydoc nk_vdot_f32c */
|
|
440
|
+
NK_PUBLIC void nk_vdot_f32c_skylake(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
441
|
+
|
|
442
|
+
/** @copydoc nk_dot_f16 */
|
|
443
|
+
NK_PUBLIC void nk_dot_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
444
|
+
/** @copydoc nk_dot_bf16 */
|
|
445
|
+
NK_PUBLIC void nk_dot_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
446
|
+
|
|
447
|
+
/** @copydoc nk_dot_e4m3 */
|
|
448
|
+
NK_PUBLIC void nk_dot_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
449
|
+
/** @copydoc nk_dot_e5m2 */
|
|
450
|
+
NK_PUBLIC void nk_dot_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
451
|
+
/** @copydoc nk_dot_e2m3 */
|
|
452
|
+
NK_PUBLIC void nk_dot_e2m3_skylake(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
453
|
+
/** @copydoc nk_dot_e3m2 */
|
|
454
|
+
NK_PUBLIC void nk_dot_e3m2_skylake(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
455
|
+
|
|
456
|
+
/** @copydoc nk_dot_i8 */
|
|
457
|
+
NK_PUBLIC void nk_dot_i8_skylake(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result);
|
|
458
|
+
/** @copydoc nk_dot_u8 */
|
|
459
|
+
NK_PUBLIC void nk_dot_u8_skylake(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
460
|
+
#endif // NK_TARGET_SKYLAKE
|
|
461
|
+
|
|
462
|
+
#if NK_TARGET_ICELAKE
|
|
463
|
+
/** @copydoc nk_dot_i8 */
|
|
464
|
+
NK_PUBLIC void nk_dot_i8_icelake(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result);
|
|
465
|
+
/** @copydoc nk_dot_u8 */
|
|
466
|
+
NK_PUBLIC void nk_dot_u8_icelake(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
467
|
+
/** @copydoc nk_dot_i8 */
|
|
468
|
+
NK_PUBLIC void nk_dot_i4_icelake(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result);
|
|
469
|
+
/** @copydoc nk_dot_u8 */
|
|
470
|
+
NK_PUBLIC void nk_dot_u4_icelake(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result);
|
|
471
|
+
/** @copydoc nk_dot_e2m3 */
|
|
472
|
+
NK_PUBLIC void nk_dot_e2m3_icelake(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
473
|
+
/** @copydoc nk_dot_e3m2 */
|
|
474
|
+
NK_PUBLIC void nk_dot_e3m2_icelake(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
475
|
+
/** @copydoc nk_dot_u1 */
|
|
476
|
+
NK_PUBLIC void nk_dot_u1_icelake(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result);
|
|
477
|
+
#endif // NK_TARGET_ICELAKE
|
|
478
|
+
|
|
479
|
+
#if NK_TARGET_GENOA
|
|
480
|
+
/** @copydoc nk_dot_bf16 */
|
|
481
|
+
NK_PUBLIC void nk_dot_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
482
|
+
/** @copydoc nk_dot_bf16c */
|
|
483
|
+
NK_PUBLIC void nk_dot_bf16c_genoa(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_size_t n, nk_f32c_t *result);
|
|
484
|
+
/** @copydoc nk_vdot_bf16c */
|
|
485
|
+
NK_PUBLIC void nk_vdot_bf16c_genoa(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_size_t n, nk_f32c_t *result);
|
|
486
|
+
|
|
487
|
+
/** @copydoc nk_dot_e4m3 */
|
|
488
|
+
NK_PUBLIC void nk_dot_e4m3_genoa(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
489
|
+
/** @copydoc nk_dot_e5m2 */
|
|
490
|
+
NK_PUBLIC void nk_dot_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
491
|
+
#endif // NK_TARGET_GENOA
|
|
492
|
+
|
|
493
|
+
#if NK_TARGET_ALDER
|
|
494
|
+
/** @copydoc nk_dot_i8 */
|
|
495
|
+
NK_PUBLIC void nk_dot_i8_alder(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result);
|
|
496
|
+
/** @copydoc nk_dot_u8 */
|
|
497
|
+
NK_PUBLIC void nk_dot_u8_alder(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
498
|
+
/** @copydoc nk_dot_e2m3 */
|
|
499
|
+
NK_PUBLIC void nk_dot_e2m3_alder(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
500
|
+
#endif // NK_TARGET_ALDER
|
|
501
|
+
|
|
502
|
+
#if NK_TARGET_SIERRA
|
|
503
|
+
/** @copydoc nk_dot_i8 */
|
|
504
|
+
NK_PUBLIC void nk_dot_i8_sierra(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result);
|
|
505
|
+
/** @copydoc nk_dot_u8 */
|
|
506
|
+
NK_PUBLIC void nk_dot_u8_sierra(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
507
|
+
/** @copydoc nk_dot_e2m3 */
|
|
508
|
+
NK_PUBLIC void nk_dot_e2m3_sierra(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
509
|
+
#endif // NK_TARGET_SIERRA
|
|
510
|
+
|
|
511
|
+
#if NK_TARGET_RVV
|
|
512
|
+
/** @copydoc nk_dot_f32 */
|
|
513
|
+
NK_PUBLIC void nk_dot_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
514
|
+
/** @copydoc nk_dot_f64 */
|
|
515
|
+
NK_PUBLIC void nk_dot_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
516
|
+
/** @copydoc nk_dot_f16 */
|
|
517
|
+
NK_PUBLIC void nk_dot_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
518
|
+
/** @copydoc nk_dot_bf16 */
|
|
519
|
+
NK_PUBLIC void nk_dot_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
520
|
+
/** @copydoc nk_dot_i8 */
|
|
521
|
+
NK_PUBLIC void nk_dot_i8_rvv(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result);
|
|
522
|
+
/** @copydoc nk_dot_u8 */
|
|
523
|
+
NK_PUBLIC void nk_dot_u8_rvv(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
524
|
+
/** @copydoc nk_dot_e4m3 */
|
|
525
|
+
NK_PUBLIC void nk_dot_e4m3_rvv(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
526
|
+
/** @copydoc nk_dot_e5m2 */
|
|
527
|
+
NK_PUBLIC void nk_dot_e5m2_rvv(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
528
|
+
/** @copydoc nk_dot_e2m3 */
|
|
529
|
+
NK_PUBLIC void nk_dot_e2m3_rvv(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
530
|
+
/** @copydoc nk_dot_e3m2 */
|
|
531
|
+
NK_PUBLIC void nk_dot_e3m2_rvv(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
532
|
+
/** @copydoc nk_dot_i4 */
|
|
533
|
+
NK_PUBLIC void nk_dot_i4_rvv(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result);
|
|
534
|
+
/** @copydoc nk_dot_u4 */
|
|
535
|
+
NK_PUBLIC void nk_dot_u4_rvv(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result);
|
|
536
|
+
/** @copydoc nk_dot_u1 */
|
|
537
|
+
NK_PUBLIC void nk_dot_u1_rvv(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result);
|
|
538
|
+
/** @copydoc nk_dot_f32c */
|
|
539
|
+
NK_PUBLIC void nk_dot_f32c_rvv(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
540
|
+
/** @copydoc nk_vdot_f32c */
|
|
541
|
+
NK_PUBLIC void nk_vdot_f32c_rvv(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
542
|
+
/** @copydoc nk_dot_f64c */
|
|
543
|
+
NK_PUBLIC void nk_dot_f64c_rvv(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
544
|
+
/** @copydoc nk_vdot_f64c */
|
|
545
|
+
NK_PUBLIC void nk_vdot_f64c_rvv(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
546
|
+
#endif // NK_TARGET_RVV
|
|
547
|
+
|
|
548
|
+
#if NK_TARGET_RVVHALF
|
|
549
|
+
/** @copydoc nk_dot_f16 */
|
|
550
|
+
NK_PUBLIC void nk_dot_f16_rvvhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
551
|
+
/** @copydoc nk_dot_e4m3 */
|
|
552
|
+
NK_PUBLIC void nk_dot_e4m3_rvvhalf(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
553
|
+
/** @copydoc nk_dot_e5m2 */
|
|
554
|
+
NK_PUBLIC void nk_dot_e5m2_rvvhalf(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
555
|
+
#endif // NK_TARGET_RVVHALF
|
|
556
|
+
|
|
557
|
+
#if NK_TARGET_RVVBF16
|
|
558
|
+
/** @copydoc nk_dot_bf16 */
|
|
559
|
+
NK_PUBLIC void nk_dot_bf16_rvvbf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
560
|
+
/** @copydoc nk_dot_e4m3 */
|
|
561
|
+
NK_PUBLIC void nk_dot_e4m3_rvvbf16(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
562
|
+
/** @copydoc nk_dot_e5m2 */
|
|
563
|
+
NK_PUBLIC void nk_dot_e5m2_rvvbf16(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
564
|
+
#endif // NK_TARGET_RVVBF16
|
|
565
|
+
|
|
566
|
+
#if NK_TARGET_RVVBB
|
|
567
|
+
/** @copydoc nk_dot_u1 */
|
|
568
|
+
NK_PUBLIC void nk_dot_u1_rvvbb(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result);
|
|
569
|
+
#endif // NK_TARGET_RVVBB
|
|
570
|
+
|
|
571
|
+
#if NK_TARGET_V128RELAXED
|
|
572
|
+
/** @copydoc nk_dot_f32 */
|
|
573
|
+
NK_PUBLIC void nk_dot_f32_v128relaxed(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result);
|
|
574
|
+
/** @copydoc nk_dot_f64 */
|
|
575
|
+
NK_PUBLIC void nk_dot_f64_v128relaxed(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
576
|
+
/** @copydoc nk_dot_f16 */
|
|
577
|
+
NK_PUBLIC void nk_dot_f16_v128relaxed(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
578
|
+
/** @copydoc nk_dot_bf16 */
|
|
579
|
+
NK_PUBLIC void nk_dot_bf16_v128relaxed(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
580
|
+
/** @copydoc nk_dot_i8 */
|
|
581
|
+
NK_PUBLIC void nk_dot_i8_v128relaxed(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result);
|
|
582
|
+
/** @copydoc nk_dot_u8 */
|
|
583
|
+
NK_PUBLIC void nk_dot_u8_v128relaxed(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result);
|
|
584
|
+
/** @copydoc nk_dot_e2m3 */
|
|
585
|
+
NK_PUBLIC void nk_dot_e2m3_v128relaxed(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
586
|
+
/** @copydoc nk_dot_e3m2 */
|
|
587
|
+
NK_PUBLIC void nk_dot_e3m2_v128relaxed(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
588
|
+
/** @copydoc nk_dot_u1 */
|
|
589
|
+
NK_PUBLIC void nk_dot_u1_v128relaxed(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result);
|
|
590
|
+
/** @copydoc nk_dot_f32 */
|
|
591
|
+
NK_PUBLIC void nk_dot_e4m3_v128relaxed(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result);
|
|
592
|
+
/** @copydoc nk_dot_f32 */
|
|
593
|
+
NK_PUBLIC void nk_dot_e5m2_v128relaxed(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
594
|
+
/** @copydoc nk_dot_i4 */
|
|
595
|
+
NK_PUBLIC void nk_dot_i4_v128relaxed(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result);
|
|
596
|
+
/** @copydoc nk_dot_u4 */
|
|
597
|
+
NK_PUBLIC void nk_dot_u4_v128relaxed(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result);
|
|
598
|
+
/** @copydoc nk_dot_f32c */
|
|
599
|
+
NK_PUBLIC void nk_dot_f32c_v128relaxed(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
600
|
+
/** @copydoc nk_dot_f32c */
|
|
601
|
+
NK_PUBLIC void nk_vdot_f32c_v128relaxed(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
602
|
+
/** @copydoc nk_dot_f64c */
|
|
603
|
+
NK_PUBLIC void nk_dot_f64c_v128relaxed(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
604
|
+
/** @copydoc nk_dot_f64c */
|
|
605
|
+
NK_PUBLIC void nk_vdot_f64c_v128relaxed(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
606
|
+
#endif // NK_TARGET_V128RELAXED
|
|
607
|
+
|
|
608
|
+
/**
|
|
609
|
+
* @brief Returns the output dtype for dot products.
|
|
610
|
+
*/
|
|
611
|
+
NK_INTERNAL nk_dtype_t nk_dot_output_dtype(nk_dtype_t dtype) {
|
|
612
|
+
switch (dtype) {
|
|
613
|
+
case nk_f64_k: return nk_f64_k;
|
|
614
|
+
case nk_f32_k: return nk_f64_k;
|
|
615
|
+
case nk_f16_k: return nk_f32_k;
|
|
616
|
+
case nk_bf16_k: return nk_f32_k;
|
|
617
|
+
case nk_e4m3_k: return nk_f32_k;
|
|
618
|
+
case nk_e5m2_k: return nk_f32_k;
|
|
619
|
+
case nk_e2m3_k: return nk_f32_k;
|
|
620
|
+
case nk_e3m2_k: return nk_f32_k;
|
|
621
|
+
case nk_f64c_k: return nk_f64c_k;
|
|
622
|
+
case nk_f32c_k: return nk_f64c_k;
|
|
623
|
+
case nk_f16c_k: return nk_f32c_k;
|
|
624
|
+
case nk_bf16c_k: return nk_f32c_k;
|
|
625
|
+
case nk_i8_k: return nk_i32_k;
|
|
626
|
+
case nk_u8_k: return nk_u32_k;
|
|
627
|
+
case nk_i4_k: return nk_i32_k;
|
|
628
|
+
case nk_u4_k: return nk_u32_k;
|
|
629
|
+
case nk_u1_k: return nk_u32_k;
|
|
630
|
+
default: return nk_dtype_unknown_k;
|
|
631
|
+
}
|
|
632
|
+
}
|
|
633
|
+
|
|
634
|
+
#if defined(__cplusplus)
|
|
635
|
+
} // extern "C"
|
|
636
|
+
#endif
|
|
637
|
+
|
|
638
|
+
#include "numkong/dot/serial.h"
|
|
639
|
+
#include "numkong/dot/neon.h"
|
|
640
|
+
#include "numkong/dot/neonsdot.h"
|
|
641
|
+
#include "numkong/dot/neonhalf.h"
|
|
642
|
+
#include "numkong/dot/neonfhm.h"
|
|
643
|
+
#include "numkong/dot/neonbfdot.h"
|
|
644
|
+
#include "numkong/dot/sve.h"
|
|
645
|
+
#include "numkong/dot/svehalf.h"
|
|
646
|
+
#include "numkong/dot/svebfdot.h"
|
|
647
|
+
#include "numkong/dot/haswell.h"
|
|
648
|
+
#include "numkong/dot/skylake.h"
|
|
649
|
+
#include "numkong/dot/icelake.h"
|
|
650
|
+
#include "numkong/dot/genoa.h"
|
|
651
|
+
#include "numkong/dot/sapphire.h"
|
|
652
|
+
#include "numkong/dot/alder.h"
|
|
653
|
+
#include "numkong/dot/sierra.h"
|
|
654
|
+
#include "numkong/dot/rvv.h"
|
|
655
|
+
#include "numkong/dot/rvvbb.h"
|
|
656
|
+
#include "numkong/dot/rvvhalf.h"
|
|
657
|
+
#include "numkong/dot/rvvbf16.h"
|
|
658
|
+
#include "numkong/dot/v128relaxed.h"
|
|
659
|
+
|
|
660
|
+
#if defined(__cplusplus)
|
|
661
|
+
extern "C" {
|
|
662
|
+
#endif
|
|
663
|
+
|
|
664
|
+
#if !NK_DYNAMIC_DISPATCH
|
|
665
|
+
|
|
666
|
+
NK_PUBLIC void nk_dot_i8(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i32_t *result) {
|
|
667
|
+
#if NK_TARGET_V128RELAXED
|
|
668
|
+
nk_dot_i8_v128relaxed(a, b, n, result);
|
|
669
|
+
#elif NK_TARGET_RVV
|
|
670
|
+
nk_dot_i8_rvv(a, b, n, result);
|
|
671
|
+
#elif NK_TARGET_NEONSDOT
|
|
672
|
+
nk_dot_i8_neonsdot(a, b, n, result);
|
|
673
|
+
#elif NK_TARGET_ICELAKE
|
|
674
|
+
nk_dot_i8_icelake(a, b, n, result);
|
|
675
|
+
#elif NK_TARGET_SKYLAKE
|
|
676
|
+
nk_dot_i8_skylake(a, b, n, result);
|
|
677
|
+
#elif NK_TARGET_SIERRA
|
|
678
|
+
nk_dot_i8_sierra(a, b, n, result);
|
|
679
|
+
#elif NK_TARGET_ALDER
|
|
680
|
+
nk_dot_i8_alder(a, b, n, result);
|
|
681
|
+
#elif NK_TARGET_HASWELL
|
|
682
|
+
nk_dot_i8_haswell(a, b, n, result);
|
|
683
|
+
#else
|
|
684
|
+
nk_dot_i8_serial(a, b, n, result);
|
|
685
|
+
#endif
|
|
686
|
+
}
|
|
687
|
+
|
|
688
|
+
NK_PUBLIC void nk_dot_u8(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
689
|
+
#if NK_TARGET_V128RELAXED
|
|
690
|
+
nk_dot_u8_v128relaxed(a, b, n, result);
|
|
691
|
+
#elif NK_TARGET_RVV
|
|
692
|
+
nk_dot_u8_rvv(a, b, n, result);
|
|
693
|
+
#elif NK_TARGET_NEONSDOT
|
|
694
|
+
nk_dot_u8_neonsdot(a, b, n, result);
|
|
695
|
+
#elif NK_TARGET_ICELAKE
|
|
696
|
+
nk_dot_u8_icelake(a, b, n, result);
|
|
697
|
+
#elif NK_TARGET_SKYLAKE
|
|
698
|
+
nk_dot_u8_skylake(a, b, n, result);
|
|
699
|
+
#elif NK_TARGET_SIERRA
|
|
700
|
+
nk_dot_u8_sierra(a, b, n, result);
|
|
701
|
+
#elif NK_TARGET_ALDER
|
|
702
|
+
nk_dot_u8_alder(a, b, n, result);
|
|
703
|
+
#elif NK_TARGET_HASWELL
|
|
704
|
+
nk_dot_u8_haswell(a, b, n, result);
|
|
705
|
+
#else
|
|
706
|
+
nk_dot_u8_serial(a, b, n, result);
|
|
707
|
+
#endif
|
|
708
|
+
}
|
|
709
|
+
|
|
710
|
+
NK_PUBLIC void nk_dot_i4(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_i32_t *result) {
|
|
711
|
+
#if NK_TARGET_ICELAKE
|
|
712
|
+
nk_dot_i4_icelake(a, b, n, result);
|
|
713
|
+
#elif NK_TARGET_NEONSDOT
|
|
714
|
+
nk_dot_i4_neonsdot(a, b, n, result);
|
|
715
|
+
#elif NK_TARGET_RVV
|
|
716
|
+
nk_dot_i4_rvv(a, b, n, result);
|
|
717
|
+
#elif NK_TARGET_HASWELL
|
|
718
|
+
nk_dot_i4_haswell(a, b, n, result);
|
|
719
|
+
#elif NK_TARGET_V128RELAXED
|
|
720
|
+
nk_dot_i4_v128relaxed(a, b, n, result);
|
|
721
|
+
#else
|
|
722
|
+
nk_dot_i4_serial(a, b, n, result);
|
|
723
|
+
#endif
|
|
724
|
+
}
|
|
725
|
+
|
|
726
|
+
NK_PUBLIC void nk_dot_u4(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result) {
|
|
727
|
+
#if NK_TARGET_ICELAKE
|
|
728
|
+
nk_dot_u4_icelake(a, b, n, result);
|
|
729
|
+
#elif NK_TARGET_NEONSDOT
|
|
730
|
+
nk_dot_u4_neonsdot(a, b, n, result);
|
|
731
|
+
#elif NK_TARGET_RVV
|
|
732
|
+
nk_dot_u4_rvv(a, b, n, result);
|
|
733
|
+
#elif NK_TARGET_HASWELL
|
|
734
|
+
nk_dot_u4_haswell(a, b, n, result);
|
|
735
|
+
#elif NK_TARGET_V128RELAXED
|
|
736
|
+
nk_dot_u4_v128relaxed(a, b, n, result);
|
|
737
|
+
#else
|
|
738
|
+
nk_dot_u4_serial(a, b, n, result);
|
|
739
|
+
#endif
|
|
740
|
+
}
|
|
741
|
+
|
|
742
|
+
NK_PUBLIC void nk_dot_u1(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
|
|
743
|
+
#if NK_TARGET_ICELAKE
|
|
744
|
+
nk_dot_u1_icelake(a, b, n_bits, result);
|
|
745
|
+
#elif NK_TARGET_HASWELL
|
|
746
|
+
nk_dot_u1_haswell(a, b, n_bits, result);
|
|
747
|
+
#elif NK_TARGET_V128RELAXED
|
|
748
|
+
nk_dot_u1_v128relaxed(a, b, n_bits, result);
|
|
749
|
+
#elif NK_TARGET_RVVBB
|
|
750
|
+
nk_dot_u1_rvvbb(a, b, n_bits, result);
|
|
751
|
+
#elif NK_TARGET_RVV
|
|
752
|
+
nk_dot_u1_rvv(a, b, n_bits, result);
|
|
753
|
+
#elif NK_TARGET_NEON
|
|
754
|
+
nk_dot_u1_neon(a, b, n_bits, result);
|
|
755
|
+
#else
|
|
756
|
+
nk_dot_u1_serial(a, b, n_bits, result);
|
|
757
|
+
#endif
|
|
758
|
+
}
|
|
759
|
+
|
|
760
|
+
NK_PUBLIC void nk_dot_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
761
|
+
#if NK_TARGET_V128RELAXED
|
|
762
|
+
nk_dot_f16_v128relaxed(a, b, n, result);
|
|
763
|
+
#elif NK_TARGET_RVVHALF
|
|
764
|
+
nk_dot_f16_rvvhalf(a, b, n, result);
|
|
765
|
+
#elif NK_TARGET_RVV
|
|
766
|
+
nk_dot_f16_rvv(a, b, n, result);
|
|
767
|
+
#elif NK_TARGET_SVEHALF
|
|
768
|
+
nk_dot_f16_svehalf(a, b, n, result);
|
|
769
|
+
#elif NK_TARGET_NEONFHM
|
|
770
|
+
nk_dot_f16_neonfhm(a, b, n, result);
|
|
771
|
+
#elif NK_TARGET_NEONHALF
|
|
772
|
+
nk_dot_f16_neonhalf(a, b, n, result);
|
|
773
|
+
#elif NK_TARGET_NEON
|
|
774
|
+
nk_dot_f16_neon(a, b, n, result);
|
|
775
|
+
#elif NK_TARGET_SKYLAKE
|
|
776
|
+
nk_dot_f16_skylake(a, b, n, result);
|
|
777
|
+
#elif NK_TARGET_HASWELL
|
|
778
|
+
nk_dot_f16_haswell(a, b, n, result);
|
|
779
|
+
#else
|
|
780
|
+
nk_dot_f16_serial(a, b, n, result);
|
|
781
|
+
#endif
|
|
782
|
+
}
|
|
783
|
+
|
|
784
|
+
NK_PUBLIC void nk_dot_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
785
|
+
#if NK_TARGET_V128RELAXED
|
|
786
|
+
nk_dot_bf16_v128relaxed(a, b, n, result);
|
|
787
|
+
#elif NK_TARGET_GENOA
|
|
788
|
+
nk_dot_bf16_genoa(a, b, n, result);
|
|
789
|
+
#elif NK_TARGET_RVVBF16
|
|
790
|
+
nk_dot_bf16_rvvbf16(a, b, n, result);
|
|
791
|
+
#elif NK_TARGET_RVV
|
|
792
|
+
nk_dot_bf16_rvv(a, b, n, result);
|
|
793
|
+
#elif NK_TARGET_SKYLAKE
|
|
794
|
+
nk_dot_bf16_skylake(a, b, n, result);
|
|
795
|
+
#elif NK_TARGET_HASWELL
|
|
796
|
+
nk_dot_bf16_haswell(a, b, n, result);
|
|
797
|
+
#elif NK_TARGET_SVEBFDOT
|
|
798
|
+
nk_dot_bf16_svebfdot(a, b, n, result);
|
|
799
|
+
#elif NK_TARGET_NEONBFDOT
|
|
800
|
+
nk_dot_bf16_neonbfdot(a, b, n, result);
|
|
801
|
+
#elif NK_TARGET_NEON
|
|
802
|
+
nk_dot_bf16_neon(a, b, n, result);
|
|
803
|
+
#else
|
|
804
|
+
nk_dot_bf16_serial(a, b, n, result);
|
|
805
|
+
#endif
|
|
806
|
+
}
|
|
807
|
+
|
|
808
|
+
NK_PUBLIC void nk_dot_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
809
|
+
#if NK_TARGET_GENOA
|
|
810
|
+
nk_dot_e4m3_genoa(a, b, n, result);
|
|
811
|
+
#elif NK_TARGET_NEONBFDOT
|
|
812
|
+
nk_dot_e4m3_neonbfdot(a, b, n, result);
|
|
813
|
+
#elif NK_TARGET_NEONFHM
|
|
814
|
+
nk_dot_e4m3_neonfhm(a, b, n, result);
|
|
815
|
+
#elif NK_TARGET_RVVHALF
|
|
816
|
+
nk_dot_e4m3_rvvhalf(a, b, n, result);
|
|
817
|
+
#elif NK_TARGET_RVVBF16
|
|
818
|
+
nk_dot_e4m3_rvvbf16(a, b, n, result);
|
|
819
|
+
#elif NK_TARGET_RVV
|
|
820
|
+
nk_dot_e4m3_rvv(a, b, n, result);
|
|
821
|
+
#elif NK_TARGET_V128RELAXED
|
|
822
|
+
nk_dot_e4m3_v128relaxed(a, b, n, result);
|
|
823
|
+
#elif NK_TARGET_SKYLAKE
|
|
824
|
+
nk_dot_e4m3_skylake(a, b, n, result);
|
|
825
|
+
#elif NK_TARGET_HASWELL
|
|
826
|
+
nk_dot_e4m3_haswell(a, b, n, result);
|
|
827
|
+
#elif NK_TARGET_NEON
|
|
828
|
+
nk_dot_e4m3_neon(a, b, n, result);
|
|
829
|
+
#else
|
|
830
|
+
nk_dot_e4m3_serial(a, b, n, result);
|
|
831
|
+
#endif
|
|
832
|
+
}
|
|
833
|
+
|
|
834
|
+
NK_PUBLIC void nk_dot_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
835
|
+
#if NK_TARGET_GENOA
|
|
836
|
+
nk_dot_e5m2_genoa(a, b, n, result);
|
|
837
|
+
#elif NK_TARGET_NEONBFDOT
|
|
838
|
+
nk_dot_e5m2_neonbfdot(a, b, n, result);
|
|
839
|
+
#elif NK_TARGET_NEONFHM
|
|
840
|
+
nk_dot_e5m2_neonfhm(a, b, n, result);
|
|
841
|
+
#elif NK_TARGET_RVVHALF
|
|
842
|
+
nk_dot_e5m2_rvvhalf(a, b, n, result);
|
|
843
|
+
#elif NK_TARGET_RVVBF16
|
|
844
|
+
nk_dot_e5m2_rvvbf16(a, b, n, result);
|
|
845
|
+
#elif NK_TARGET_RVV
|
|
846
|
+
nk_dot_e5m2_rvv(a, b, n, result);
|
|
847
|
+
#elif NK_TARGET_V128RELAXED
|
|
848
|
+
nk_dot_e5m2_v128relaxed(a, b, n, result);
|
|
849
|
+
#elif NK_TARGET_SKYLAKE
|
|
850
|
+
nk_dot_e5m2_skylake(a, b, n, result);
|
|
851
|
+
#elif NK_TARGET_HASWELL
|
|
852
|
+
nk_dot_e5m2_haswell(a, b, n, result);
|
|
853
|
+
#elif NK_TARGET_NEON
|
|
854
|
+
nk_dot_e5m2_neon(a, b, n, result);
|
|
855
|
+
#else
|
|
856
|
+
nk_dot_e5m2_serial(a, b, n, result);
|
|
857
|
+
#endif
|
|
858
|
+
}
|
|
859
|
+
|
|
860
|
+
NK_PUBLIC void nk_dot_e2m3(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
861
|
+
#if NK_TARGET_ICELAKE
|
|
862
|
+
nk_dot_e2m3_icelake(a, b, n, result);
|
|
863
|
+
#elif NK_TARGET_SKYLAKE
|
|
864
|
+
nk_dot_e2m3_skylake(a, b, n, result);
|
|
865
|
+
#elif NK_TARGET_SIERRA
|
|
866
|
+
nk_dot_e2m3_sierra(a, b, n, result);
|
|
867
|
+
#elif NK_TARGET_ALDER
|
|
868
|
+
nk_dot_e2m3_alder(a, b, n, result);
|
|
869
|
+
#elif NK_TARGET_RVV
|
|
870
|
+
nk_dot_e2m3_rvv(a, b, n, result);
|
|
871
|
+
#elif NK_TARGET_HASWELL
|
|
872
|
+
nk_dot_e2m3_haswell(a, b, n, result);
|
|
873
|
+
#elif NK_TARGET_NEONSDOT
|
|
874
|
+
nk_dot_e2m3_neonsdot(a, b, n, result);
|
|
875
|
+
#elif NK_TARGET_NEON
|
|
876
|
+
nk_dot_e2m3_neon(a, b, n, result);
|
|
877
|
+
#elif NK_TARGET_V128RELAXED
|
|
878
|
+
nk_dot_e2m3_v128relaxed(a, b, n, result);
|
|
879
|
+
#else
|
|
880
|
+
nk_dot_e2m3_serial(a, b, n, result);
|
|
881
|
+
#endif
|
|
882
|
+
}
|
|
883
|
+
|
|
884
|
+
NK_PUBLIC void nk_dot_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
885
|
+
#if NK_TARGET_ICELAKE
|
|
886
|
+
nk_dot_e3m2_icelake(a, b, n, result);
|
|
887
|
+
#elif NK_TARGET_NEONSDOT
|
|
888
|
+
nk_dot_e3m2_neonsdot(a, b, n, result);
|
|
889
|
+
#elif NK_TARGET_V128RELAXED
|
|
890
|
+
nk_dot_e3m2_v128relaxed(a, b, n, result);
|
|
891
|
+
#elif NK_TARGET_RVV
|
|
892
|
+
nk_dot_e3m2_rvv(a, b, n, result);
|
|
893
|
+
#elif NK_TARGET_SKYLAKE
|
|
894
|
+
nk_dot_e3m2_skylake(a, b, n, result);
|
|
895
|
+
#elif NK_TARGET_HASWELL
|
|
896
|
+
nk_dot_e3m2_haswell(a, b, n, result);
|
|
897
|
+
#elif NK_TARGET_NEON
|
|
898
|
+
nk_dot_e3m2_neon(a, b, n, result);
|
|
899
|
+
#else
|
|
900
|
+
nk_dot_e3m2_serial(a, b, n, result);
|
|
901
|
+
#endif
|
|
902
|
+
}
|
|
903
|
+
|
|
904
|
+
NK_PUBLIC void nk_dot_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
905
|
+
#if NK_TARGET_V128RELAXED
|
|
906
|
+
nk_dot_f32_v128relaxed(a, b, n, result);
|
|
907
|
+
#elif NK_TARGET_RVV
|
|
908
|
+
nk_dot_f32_rvv(a, b, n, result);
|
|
909
|
+
#elif NK_TARGET_SVE
|
|
910
|
+
nk_dot_f32_sve(a, b, n, result);
|
|
911
|
+
#elif NK_TARGET_NEON
|
|
912
|
+
nk_dot_f32_neon(a, b, n, result);
|
|
913
|
+
#elif NK_TARGET_SKYLAKE
|
|
914
|
+
nk_dot_f32_skylake(a, b, n, result);
|
|
915
|
+
#elif NK_TARGET_HASWELL
|
|
916
|
+
nk_dot_f32_haswell(a, b, n, result);
|
|
917
|
+
#else
|
|
918
|
+
nk_dot_f32_serial(a, b, n, result);
|
|
919
|
+
#endif
|
|
920
|
+
}
|
|
921
|
+
|
|
922
|
+
NK_PUBLIC void nk_dot_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
923
|
+
#if NK_TARGET_V128RELAXED
|
|
924
|
+
nk_dot_f64_v128relaxed(a, b, n, result);
|
|
925
|
+
#elif NK_TARGET_RVV
|
|
926
|
+
nk_dot_f64_rvv(a, b, n, result);
|
|
927
|
+
#elif NK_TARGET_SVE
|
|
928
|
+
nk_dot_f64_sve(a, b, n, result);
|
|
929
|
+
#elif NK_TARGET_NEON
|
|
930
|
+
nk_dot_f64_neon(a, b, n, result);
|
|
931
|
+
#elif NK_TARGET_SKYLAKE
|
|
932
|
+
nk_dot_f64_skylake(a, b, n, result);
|
|
933
|
+
#elif NK_TARGET_HASWELL
|
|
934
|
+
nk_dot_f64_haswell(a, b, n, result);
|
|
935
|
+
#else
|
|
936
|
+
nk_dot_f64_serial(a, b, n, result);
|
|
937
|
+
#endif
|
|
938
|
+
}
|
|
939
|
+
|
|
940
|
+
NK_PUBLIC void nk_dot_f16c(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result) {
|
|
941
|
+
#if NK_TARGET_SVEHALF
|
|
942
|
+
nk_dot_f16c_svehalf(a, b, n, result);
|
|
943
|
+
#elif NK_TARGET_NEONFHM
|
|
944
|
+
nk_dot_f16c_neonfhm(a, b, n, result);
|
|
945
|
+
#elif NK_TARGET_NEONHALF
|
|
946
|
+
nk_dot_f16c_neonhalf(a, b, n, result);
|
|
947
|
+
#elif NK_TARGET_HASWELL
|
|
948
|
+
nk_dot_f16c_haswell(a, b, n, result);
|
|
949
|
+
#else
|
|
950
|
+
nk_dot_f16c_serial(a, b, n, result);
|
|
951
|
+
#endif
|
|
952
|
+
}
|
|
953
|
+
|
|
954
|
+
NK_PUBLIC void nk_dot_bf16c(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_size_t n, nk_f32c_t *result) {
|
|
955
|
+
#if NK_TARGET_GENOA
|
|
956
|
+
nk_dot_bf16c_genoa(a, b, n, result);
|
|
957
|
+
#elif NK_TARGET_NEONBFDOT
|
|
958
|
+
nk_dot_bf16c_neonbfdot(a, b, n, result);
|
|
959
|
+
#elif NK_TARGET_HASWELL
|
|
960
|
+
nk_dot_bf16c_haswell(a, b, n, result);
|
|
961
|
+
#else
|
|
962
|
+
nk_dot_bf16c_serial(a, b, n, result);
|
|
963
|
+
#endif
|
|
964
|
+
}
|
|
965
|
+
|
|
966
|
+
NK_PUBLIC void nk_dot_f32c(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result) {
|
|
967
|
+
#if NK_TARGET_SVE
|
|
968
|
+
nk_dot_f32c_sve(a, b, n, result);
|
|
969
|
+
#elif NK_TARGET_NEON
|
|
970
|
+
nk_dot_f32c_neon(a, b, n, result);
|
|
971
|
+
#elif NK_TARGET_RVV
|
|
972
|
+
nk_dot_f32c_rvv(a, b, n, result);
|
|
973
|
+
#elif NK_TARGET_SKYLAKE
|
|
974
|
+
nk_dot_f32c_skylake(a, b, n, result);
|
|
975
|
+
#elif NK_TARGET_HASWELL
|
|
976
|
+
nk_dot_f32c_haswell(a, b, n, result);
|
|
977
|
+
#elif NK_TARGET_V128RELAXED
|
|
978
|
+
nk_dot_f32c_v128relaxed(a, b, n, result);
|
|
979
|
+
#else
|
|
980
|
+
nk_dot_f32c_serial(a, b, n, result);
|
|
981
|
+
#endif
|
|
982
|
+
}
|
|
983
|
+
|
|
984
|
+
NK_PUBLIC void nk_dot_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result) {
|
|
985
|
+
#if NK_TARGET_SVE
|
|
986
|
+
nk_dot_f64c_sve(a, b, n, result);
|
|
987
|
+
#elif NK_TARGET_NEON
|
|
988
|
+
nk_dot_f64c_neon(a, b, n, result);
|
|
989
|
+
#elif NK_TARGET_RVV
|
|
990
|
+
nk_dot_f64c_rvv(a, b, n, result);
|
|
991
|
+
#elif NK_TARGET_SKYLAKE
|
|
992
|
+
nk_dot_f64c_skylake(a, b, n, result);
|
|
993
|
+
#elif NK_TARGET_HASWELL
|
|
994
|
+
nk_dot_f64c_haswell(a, b, n, result);
|
|
995
|
+
#elif NK_TARGET_V128RELAXED
|
|
996
|
+
nk_dot_f64c_v128relaxed(a, b, n, result);
|
|
997
|
+
#else
|
|
998
|
+
nk_dot_f64c_serial(a, b, n, result);
|
|
999
|
+
#endif
|
|
1000
|
+
}
|
|
1001
|
+
|
|
1002
|
+
NK_PUBLIC void nk_vdot_f16c(nk_f16c_t const *a, nk_f16c_t const *b, nk_size_t n, nk_f32c_t *result) {
|
|
1003
|
+
#if NK_TARGET_SVEHALF
|
|
1004
|
+
nk_vdot_f16c_svehalf(a, b, n, result);
|
|
1005
|
+
#elif NK_TARGET_NEONFHM
|
|
1006
|
+
nk_vdot_f16c_neonfhm(a, b, n, result);
|
|
1007
|
+
#elif NK_TARGET_NEONHALF
|
|
1008
|
+
nk_vdot_f16c_neonhalf(a, b, n, result);
|
|
1009
|
+
#elif NK_TARGET_HASWELL
|
|
1010
|
+
nk_vdot_f16c_haswell(a, b, n, result);
|
|
1011
|
+
#else
|
|
1012
|
+
nk_vdot_f16c_serial(a, b, n, result);
|
|
1013
|
+
#endif
|
|
1014
|
+
}
|
|
1015
|
+
|
|
1016
|
+
NK_PUBLIC void nk_vdot_bf16c(nk_bf16c_t const *a, nk_bf16c_t const *b, nk_size_t n, nk_f32c_t *result) {
|
|
1017
|
+
#if NK_TARGET_GENOA
|
|
1018
|
+
nk_vdot_bf16c_genoa(a, b, n, result);
|
|
1019
|
+
#elif NK_TARGET_NEONBFDOT
|
|
1020
|
+
nk_vdot_bf16c_neonbfdot(a, b, n, result);
|
|
1021
|
+
#elif NK_TARGET_HASWELL
|
|
1022
|
+
nk_vdot_bf16c_haswell(a, b, n, result);
|
|
1023
|
+
#else
|
|
1024
|
+
nk_vdot_bf16c_serial(a, b, n, result);
|
|
1025
|
+
#endif
|
|
1026
|
+
}
|
|
1027
|
+
|
|
1028
|
+
NK_PUBLIC void nk_vdot_f32c(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f64c_t *result) {
|
|
1029
|
+
#if NK_TARGET_SVE
|
|
1030
|
+
nk_vdot_f32c_sve(a, b, n, result);
|
|
1031
|
+
#elif NK_TARGET_NEON
|
|
1032
|
+
nk_vdot_f32c_neon(a, b, n, result);
|
|
1033
|
+
#elif NK_TARGET_RVV
|
|
1034
|
+
nk_vdot_f32c_rvv(a, b, n, result);
|
|
1035
|
+
#elif NK_TARGET_SKYLAKE
|
|
1036
|
+
nk_vdot_f32c_skylake(a, b, n, result);
|
|
1037
|
+
#elif NK_TARGET_HASWELL
|
|
1038
|
+
nk_vdot_f32c_haswell(a, b, n, result);
|
|
1039
|
+
#elif NK_TARGET_V128RELAXED
|
|
1040
|
+
nk_vdot_f32c_v128relaxed(a, b, n, result);
|
|
1041
|
+
#else
|
|
1042
|
+
nk_vdot_f32c_serial(a, b, n, result);
|
|
1043
|
+
#endif
|
|
1044
|
+
}
|
|
1045
|
+
|
|
1046
|
+
NK_PUBLIC void nk_vdot_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result) {
|
|
1047
|
+
#if NK_TARGET_SVE
|
|
1048
|
+
nk_vdot_f64c_sve(a, b, n, result);
|
|
1049
|
+
#elif NK_TARGET_NEON
|
|
1050
|
+
nk_vdot_f64c_neon(a, b, n, result);
|
|
1051
|
+
#elif NK_TARGET_RVV
|
|
1052
|
+
nk_vdot_f64c_rvv(a, b, n, result);
|
|
1053
|
+
#elif NK_TARGET_SKYLAKE
|
|
1054
|
+
nk_vdot_f64c_skylake(a, b, n, result);
|
|
1055
|
+
#elif NK_TARGET_HASWELL
|
|
1056
|
+
nk_vdot_f64c_haswell(a, b, n, result);
|
|
1057
|
+
#elif NK_TARGET_V128RELAXED
|
|
1058
|
+
nk_vdot_f64c_v128relaxed(a, b, n, result);
|
|
1059
|
+
#else
|
|
1060
|
+
nk_vdot_f64c_serial(a, b, n, result);
|
|
1061
|
+
#endif
|
|
1062
|
+
}
|
|
1063
|
+
|
|
1064
|
+
#endif // !NK_DYNAMIC_DISPATCH
|
|
1065
|
+
|
|
1066
|
+
#if defined(__cplusplus)
|
|
1067
|
+
} // extern "C"
|
|
1068
|
+
#endif
|
|
1069
|
+
|
|
1070
|
+
#endif
|