numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,2804 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Batched Dot Products.
|
|
3
|
+
* @file include/numkong/dots.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date September 14, 2024
|
|
6
|
+
*
|
|
7
|
+
* Implements batch dot-product kernels computing C[m × n] = A[m × k] × B[n × k]ᵀ
|
|
8
|
+
* with row-major A and arbitrary B, optimized for ML inference and similarity workloads.
|
|
9
|
+
*
|
|
10
|
+
* Primary Use Cases (1-to-N focus):
|
|
11
|
+
*
|
|
12
|
+
* - k-NN search: ‖a-b‖² = ‖a‖² + ‖b‖² - 2(a × b)
|
|
13
|
+
* - Cosine similarity: (a × b) / (‖a‖ × ‖b‖)
|
|
14
|
+
* - Sparse attention patterns
|
|
15
|
+
* - Embedding similarity matrices
|
|
16
|
+
* - k-means clustering, DBSCAN, hierarchical clustering
|
|
17
|
+
*
|
|
18
|
+
* It implements several operations:
|
|
19
|
+
*
|
|
20
|
+
* - "dots_packed" - computing dot-products where the B matrix is pre-packed into optimal form
|
|
21
|
+
* - "dots_packed_size" - which estimates the memory requirements for external `malloc`
|
|
22
|
+
* - "dots_pack" - to perform the pre-processing
|
|
23
|
+
* - "dots_compact" - optional helpers to normalize or downcast into original precision
|
|
24
|
+
* - "dots_symmetric" - for A × Aᵀ Gram matrix multiplication
|
|
25
|
+
*
|
|
26
|
+
* If the original "dots_packed" is analogous to "GEMM" (General Matrix Multiplication) in BLAS,
|
|
27
|
+
* the "dots_symmetric" is similar to the "SYRK" (the Symmetric rank-k update of a matrix).
|
|
28
|
+
*
|
|
29
|
+
* For dtypes:
|
|
30
|
+
*
|
|
31
|
+
* - f64: 64-bit IEEE floating point numbers → 64-bit floats
|
|
32
|
+
* - f32: 32-bit IEEE floating point numbers → 64-bit floats
|
|
33
|
+
* - f16: 16-bit IEEE floating point numbers → 32-bit floats
|
|
34
|
+
* - bf16: 16-bit brain floating point numbers → 32-bit floats
|
|
35
|
+
* - e4m3: 8-bit e4m3 floating point numbers → 32-bit floats
|
|
36
|
+
* - e5m2: 8-bit e5m2 floating point numbers → 32-bit floats
|
|
37
|
+
* - e2m3: 8-bit e2m3 floating point numbers (MX) → 32-bit floats
|
|
38
|
+
* - e3m2: 8-bit e3m2 floating point numbers (MX) → 32-bit floats
|
|
39
|
+
* - i8: 8-bit signed integers → 32-bit signed integers
|
|
40
|
+
* - u8: 8-bit unsigned integers → 32-bit unsigned integers
|
|
41
|
+
* - i4: 4-bit signed integers (packed pairs) → 32-bit signed integers
|
|
42
|
+
* - u4: 4-bit unsigned integers (packed pairs) → 32-bit unsigned integers
|
|
43
|
+
* - u1: 1-bit binary (packed octets) → 32-bit unsigned integers
|
|
44
|
+
*
|
|
45
|
+
* For hardware architectures:
|
|
46
|
+
*
|
|
47
|
+
* - Arm: NEON, NEON+HALF, NEON+FHM, NEON+BF16, NEON+SDOT, SVE, SME, SME+F64, SME+BI32
|
|
48
|
+
* - x86: Haswell, Skylake, Ice Lake, Genoa, Sapphire Rapids (AMX), Sierra Forest
|
|
49
|
+
* - RISC-V: RVV
|
|
50
|
+
*
|
|
51
|
+
* @section numerical_stability Numerical Stability
|
|
52
|
+
*
|
|
53
|
+
* - f64: Dot2 (Ogita-Rump-Oishi) on the accurate backends, otherwise native f64 FMA accumulation.
|
|
54
|
+
* - f32: public outputs widen to f64. Packed and symmetric kernels keep payloads narrow but widen accumulation.
|
|
55
|
+
* - bf16/f16: f32 accumulation. VDPBF16PS on Genoa does bf16×bf16→f32 natively.
|
|
56
|
+
* - e2m3/e3m2: f16 intermediate with flush to f32 every 128 elements (Sapphire).
|
|
57
|
+
* - i8: i32 accumulation. AMX TDPBSSD gives i8×i8→i32 tiles. Overflows at k > ~131K.
|
|
58
|
+
* - u1: Popcount, exact.
|
|
59
|
+
*
|
|
60
|
+
* @section memory_layout Memory Layout and Transpose Semantics
|
|
61
|
+
*
|
|
62
|
+
* All matrices use row-major storage. Column-major is NOT supported.
|
|
63
|
+
* The kernel computes C = A × Bᵀ where:
|
|
64
|
+
*
|
|
65
|
+
* - A is (m × k): m rows, k columns, stride = a_stride bytes between rows
|
|
66
|
+
* - B is (n × k): n rows, k columns, stride = b_stride bytes between rows
|
|
67
|
+
* - C is (m × n): m rows, n columns, stride = c_stride bytes between rows
|
|
68
|
+
*
|
|
69
|
+
* This means C[i,j] = dot(row i of A, row j of B) = Σₗ A[i,l] × B[j,l].
|
|
70
|
+
*
|
|
71
|
+
* All strides are in bytes.
|
|
72
|
+
*
|
|
73
|
+
* To compute standard A × B (where B is k × n), pass Bᵀ to the packing function:
|
|
74
|
+
*
|
|
75
|
+
* @code{.c}
|
|
76
|
+
* // Standard matmul: C[m × n] = A[m × k] × B[k × n]
|
|
77
|
+
* // B is stored row-major as k rows of n elements
|
|
78
|
+
* // Treat it as Bᵀ: n rows of k elements with stride = sizeof(element)
|
|
79
|
+
* nk_dots_pack_bf16(b, width, depth, sizeof(nk_bf16_t), b_packed);
|
|
80
|
+
* nk_dots_packed_bf16(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
81
|
+
* // Result: C = A × (Bᵀ)ᵀ = A × B
|
|
82
|
+
* @endcode
|
|
83
|
+
*
|
|
84
|
+
* @section two_phase_api Two-Phase API for Static Weights
|
|
85
|
+
*
|
|
86
|
+
* Matrix multiplication hardware (AMX, SME) requires specific data layouts that differ
|
|
87
|
+
* from standard row-major ordering. Since one matrix (typically weights in neural networks)
|
|
88
|
+
* is often static, we provide a two-phase API: pack once, multiply many times.
|
|
89
|
+
*
|
|
90
|
+
* @code{.c}
|
|
91
|
+
* // Similarity search: C[m × n] = queries[m × k] × database[n × k]ᵀ
|
|
92
|
+
* // Both matrices stored row-major, each row is one vector of dimension k
|
|
93
|
+
* nk_size_t packed_bytes = nk_dots_packed_size_bf16(width, depth);
|
|
94
|
+
* void *b_packed = malloc(packed_bytes);
|
|
95
|
+
* nk_dots_pack_bf16(database, width, depth, depth * sizeof(nk_bf16_t), b_packed);
|
|
96
|
+
* nk_dots_packed_bf16(queries, b_packed, c, height, width, depth, ...);
|
|
97
|
+
* // Result: C[i,j] = dot(query i, database vector j)
|
|
98
|
+
* @endcode
|
|
99
|
+
*
|
|
100
|
+
* The packed format is opaque and backend-specific. AMX expects (16 × 32) tiles with interleaved
|
|
101
|
+
* pairs, while NEON/SVE use arrangements optimized for their vector lengths.
|
|
102
|
+
*
|
|
103
|
+
* @section why_int8 Why INT8 and Not UINT8?
|
|
104
|
+
*
|
|
105
|
+
* Unsigned 8-bit integers were considered but deprioritized. The industry has converged on
|
|
106
|
+
* signed INT8 as the standard for quantized inference:
|
|
107
|
+
*
|
|
108
|
+
* Framework Default Notes
|
|
109
|
+
* PyTorch qint8 New X86 backend uses INT8 via oneDNN
|
|
110
|
+
* TensorFlow Lite int8 Actively removing UINT8 support
|
|
111
|
+
* ONNX Runtime S8S8 "Should be the first choice"
|
|
112
|
+
* TensorRT INT8 Symmetric [-128,127], no UINT8 option
|
|
113
|
+
* ARM CMSIS-NN int8 Follows TFLite INT8 spec exactly
|
|
114
|
+
*
|
|
115
|
+
* @section why_no_scaling Why No Alpha/Beta Scaling?
|
|
116
|
+
*
|
|
117
|
+
* BLAS-style `C = α × A × B + β × C` scaling was considered but omitted. While useful for scientific
|
|
118
|
+
* computing (iterative solvers, matrix factorizations), it's rarely used in ML inference where
|
|
119
|
+
* frameworks handle such operations via graph fusion. More importantly, on chips with separate
|
|
120
|
+
* physical registers for vector and matrix operations (like AMX), moving scalars between register
|
|
121
|
+
* files adds transfer latency that negates any benefit.
|
|
122
|
+
*
|
|
123
|
+
* @section why_no_pad Why Not Pad N Dimension to Eliminate Edge Handling?
|
|
124
|
+
*
|
|
125
|
+
* Padding N to a tile-aligned boundary (multiple of 16) during packing was considered to eliminate
|
|
126
|
+
* the separate AVX-512 edge kernel for N remainder rows. While this sounds simpler ("pure AMX"),
|
|
127
|
+
* it actually increases code size by ~125 lines because:
|
|
128
|
+
*
|
|
129
|
+
* - The AVX-512 edge fallback is compact (~40 lines) and handles both full-M × N-edge and
|
|
130
|
+
* M-edge × N-edge cases through a single reusable function
|
|
131
|
+
* - Replacing it with "AMX + masked stores" requires verbose tile handling code duplicated
|
|
132
|
+
* across all 4 multiply functions (aligned/misaligned × BF16/I8)
|
|
133
|
+
* - Each function needs a new "trailing N tile for full M blocks" section (~50 lines each)
|
|
134
|
+
*
|
|
135
|
+
* The current hybrid layout (AMX for full tiles, AVX-512 for edges) is more maintainable despite
|
|
136
|
+
* being conceptually less uniform. Memory overhead of the edge region is negligible (<2% worst case).
|
|
137
|
+
*
|
|
138
|
+
* @section x86_instructions Relevant x86 Instructions
|
|
139
|
+
*
|
|
140
|
+
* Low-precision matmul relies on VPMADD* (AVX2), VNNI dot-products, and BF16 dot-products
|
|
141
|
+
* on AVX-512. Zen4 improves throughput by dual-issuing many integer ops on FP ports.
|
|
142
|
+
*
|
|
143
|
+
* Intrinsic Instruction Haswell Genoa
|
|
144
|
+
* _mm256_maddubs_epi16 VPMADDUBSW (YMM, YMM, YMM) 5c @ p0 3c @ p01
|
|
145
|
+
* _mm256_madd_epi16 VPMADDWD (YMM, YMM, YMM) 5c @ p0 3c @ p01
|
|
146
|
+
* _mm256_dpbusd_epi32 VPDPBUSD (YMM, K, YMM, YMM) n/a 4c @ p01
|
|
147
|
+
* _mm256_dpwssds_epi32 VPDPWSSDS (YMM, K, YMM, YMM) n/a 4c @ p01
|
|
148
|
+
* _mm256_dpbf16_ps VDPBF16PS (YMM, YMM, YMM) n/a 6c @ p01
|
|
149
|
+
*
|
|
150
|
+
* AMX tile ops (TDPBF16PS/TDPBUSD/TDPBSSD) are not covered by the uops.info 2022 dataset.
|
|
151
|
+
*
|
|
152
|
+
* @section references References
|
|
153
|
+
*
|
|
154
|
+
* - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
|
|
155
|
+
* - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
|
|
156
|
+
* - uops.info: https://uops.info/
|
|
157
|
+
* - Matrix Multiplication in 40 lines: https://en.algorithmica.org/hpc/algorithms/matmul/
|
|
158
|
+
* - LLaMA CPU optimization: https://justine.lol/matmul/
|
|
159
|
+
* - SME outer-product notes: https://github.com/tzakharko/m4-sme-exploration
|
|
160
|
+
*
|
|
161
|
+
*/
|
|
162
|
+
#ifndef NK_DOTS_H
|
|
163
|
+
#define NK_DOTS_H
|
|
164
|
+
|
|
165
|
+
#include "numkong/types.h"
|
|
166
|
+
|
|
167
|
+
#if defined(__cplusplus)
|
|
168
|
+
extern "C" {
|
|
169
|
+
#endif
|
|
170
|
+
|
|
171
|
+
/**
|
|
172
|
+
* @brief Returns packed buffer size in bytes for second multiplier matrix (B).
|
|
173
|
+
* @param[in] width The number of rows in B (output columns).
|
|
174
|
+
* @param[in] depth The number of columns in B.
|
|
175
|
+
* @note The packed layout is backend-specific and must be produced by the matching pack function.
|
|
176
|
+
*/
|
|
177
|
+
NK_DYNAMIC nk_size_t nk_dots_packed_size_bf16(nk_size_t width, nk_size_t depth);
|
|
178
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
179
|
+
NK_DYNAMIC nk_size_t nk_dots_packed_size_f16(nk_size_t width, nk_size_t depth);
|
|
180
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
181
|
+
NK_DYNAMIC nk_size_t nk_dots_packed_size_e4m3(nk_size_t width, nk_size_t depth);
|
|
182
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
183
|
+
NK_DYNAMIC nk_size_t nk_dots_packed_size_e5m2(nk_size_t width, nk_size_t depth);
|
|
184
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
185
|
+
NK_DYNAMIC nk_size_t nk_dots_packed_size_e2m3(nk_size_t width, nk_size_t depth);
|
|
186
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
187
|
+
NK_DYNAMIC nk_size_t nk_dots_packed_size_e3m2(nk_size_t width, nk_size_t depth);
|
|
188
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
189
|
+
NK_DYNAMIC nk_size_t nk_dots_packed_size_f32(nk_size_t width, nk_size_t depth);
|
|
190
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
191
|
+
NK_DYNAMIC nk_size_t nk_dots_packed_size_f64(nk_size_t width, nk_size_t depth);
|
|
192
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
193
|
+
NK_DYNAMIC nk_size_t nk_dots_packed_size_i8(nk_size_t width, nk_size_t depth);
|
|
194
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
195
|
+
NK_DYNAMIC nk_size_t nk_dots_packed_size_u8(nk_size_t width, nk_size_t depth);
|
|
196
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
197
|
+
NK_DYNAMIC nk_size_t nk_dots_packed_size_i4(nk_size_t width, nk_size_t depth);
|
|
198
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
199
|
+
NK_DYNAMIC nk_size_t nk_dots_packed_size_u4(nk_size_t width, nk_size_t depth);
|
|
200
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
201
|
+
NK_DYNAMIC nk_size_t nk_dots_packed_size_u1(nk_size_t width, nk_size_t depth);
|
|
202
|
+
|
|
203
|
+
/**
|
|
204
|
+
* @brief Packs the second multiplier (B) matrix into a backend-specific layout.
|
|
205
|
+
* @param[in] b The input B matrix in row-major order.
|
|
206
|
+
* @param[in] width The number of rows in B (output columns).
|
|
207
|
+
* @param[in] depth The number of columns in B.
|
|
208
|
+
* @param[in] b_stride The row stride in bytes for B.
|
|
209
|
+
* @param[out] b_packed The output packed buffer from nk_dots_packed_size_bf16.
|
|
210
|
+
*/
|
|
211
|
+
NK_DYNAMIC void nk_dots_pack_bf16(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
212
|
+
void *b_packed);
|
|
213
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
214
|
+
NK_DYNAMIC void nk_dots_pack_f16(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
215
|
+
void *b_packed);
|
|
216
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
217
|
+
NK_DYNAMIC void nk_dots_pack_e4m3(nk_e4m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
218
|
+
void *b_packed);
|
|
219
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
220
|
+
NK_DYNAMIC void nk_dots_pack_e5m2(nk_e5m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
221
|
+
void *b_packed);
|
|
222
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
223
|
+
NK_DYNAMIC void nk_dots_pack_e2m3(nk_e2m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
224
|
+
void *b_packed);
|
|
225
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
226
|
+
NK_DYNAMIC void nk_dots_pack_e3m2(nk_e3m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
227
|
+
void *b_packed);
|
|
228
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
229
|
+
NK_DYNAMIC void nk_dots_pack_f32(nk_f32_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
230
|
+
void *b_packed);
|
|
231
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
232
|
+
NK_DYNAMIC void nk_dots_pack_f64(nk_f64_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
233
|
+
void *b_packed);
|
|
234
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
235
|
+
NK_DYNAMIC void nk_dots_pack_i8(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride, void *b_packed);
|
|
236
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
237
|
+
NK_DYNAMIC void nk_dots_pack_u8(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride, void *b_packed);
|
|
238
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
239
|
+
NK_DYNAMIC void nk_dots_pack_i4(nk_i4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
240
|
+
void *b_packed);
|
|
241
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
242
|
+
NK_DYNAMIC void nk_dots_pack_u4(nk_u4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
243
|
+
void *b_packed);
|
|
244
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
245
|
+
NK_DYNAMIC void nk_dots_pack_u1(nk_u1x8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
246
|
+
void *b_packed);
|
|
247
|
+
|
|
248
|
+
/**
|
|
249
|
+
* @brief Computes C = A × Bᵀ using packed second multiplier matrix (B), accumulating into C.
|
|
250
|
+
* @param[in] a The input A matrix in row-major order.
|
|
251
|
+
* @param[in] b_packed The packed B matrix produced.
|
|
252
|
+
* @param[out] c The output C matrix in row-major order.
|
|
253
|
+
* @param[in] height The number of rows in A.
|
|
254
|
+
* @param[in] width The number of rows in B (output columns).
|
|
255
|
+
* @param[in] depth The shared inner dimension.
|
|
256
|
+
* @param[in] a_stride The row stride in bytes for A.
|
|
257
|
+
* @param[in] c_stride The row stride in bytes for C.
|
|
258
|
+
*/
|
|
259
|
+
NK_DYNAMIC void nk_dots_packed_bf16(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
260
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
261
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
262
|
+
NK_DYNAMIC void nk_dots_packed_f16(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
263
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
264
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
265
|
+
NK_DYNAMIC void nk_dots_packed_e4m3(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
266
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
267
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
268
|
+
NK_DYNAMIC void nk_dots_packed_e5m2(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
269
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
270
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
271
|
+
NK_DYNAMIC void nk_dots_packed_e2m3(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
272
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
273
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
274
|
+
NK_DYNAMIC void nk_dots_packed_e3m2(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
275
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
276
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
277
|
+
NK_DYNAMIC void nk_dots_packed_f32(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
|
|
278
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
279
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
280
|
+
NK_DYNAMIC void nk_dots_packed_f64(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
|
|
281
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
282
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
283
|
+
NK_DYNAMIC void nk_dots_packed_i8(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
|
|
284
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
285
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
286
|
+
NK_DYNAMIC void nk_dots_packed_u8(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
287
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
288
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
289
|
+
NK_DYNAMIC void nk_dots_packed_i4(nk_i4x2_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
|
|
290
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
291
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
292
|
+
NK_DYNAMIC void nk_dots_packed_u4(nk_u4x2_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
293
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
294
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
295
|
+
NK_DYNAMIC void nk_dots_packed_u1(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
296
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
297
|
+
|
|
298
|
+
/**
|
|
299
|
+
* @brief Computes C = A × Aᵀ symmetric Gram matrix.
|
|
300
|
+
* @param[in] vectors Input matrix of row vectors in row-major order.
|
|
301
|
+
* @param[in] n_vectors Number of vectors (rows) in the input matrix.
|
|
302
|
+
* @param[in] depth Dimension of each vector (columns).
|
|
303
|
+
* @param[in] stride Row stride in bytes for the input matrix.
|
|
304
|
+
* @param[out] result Output symmetric matrix (n_vectors × n_vectors).
|
|
305
|
+
* @param[in] result_stride Row stride in bytes for the result matrix.
|
|
306
|
+
* @param[in] row_start Starting row offset of results to compute (needed for parallelism).
|
|
307
|
+
* @param[in] row_count Number of rows of results to compute (needed for parallelism).
|
|
308
|
+
*/
|
|
309
|
+
NK_DYNAMIC void nk_dots_symmetric_bf16(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
310
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
311
|
+
nk_size_t row_count);
|
|
312
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
313
|
+
NK_DYNAMIC void nk_dots_symmetric_f16(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
314
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
315
|
+
nk_size_t row_count);
|
|
316
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
317
|
+
NK_DYNAMIC void nk_dots_symmetric_e4m3(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
318
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
319
|
+
nk_size_t row_count);
|
|
320
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
321
|
+
NK_DYNAMIC void nk_dots_symmetric_e5m2(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
322
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
323
|
+
nk_size_t row_count);
|
|
324
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
325
|
+
NK_DYNAMIC void nk_dots_symmetric_e2m3(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
326
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
327
|
+
nk_size_t row_count);
|
|
328
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
329
|
+
NK_DYNAMIC void nk_dots_symmetric_e3m2(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
330
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
331
|
+
nk_size_t row_count);
|
|
332
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
333
|
+
NK_DYNAMIC void nk_dots_symmetric_f32(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
334
|
+
nk_f64_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
335
|
+
nk_size_t row_count);
|
|
336
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
337
|
+
NK_DYNAMIC void nk_dots_symmetric_f64(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
338
|
+
nk_f64_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
339
|
+
nk_size_t row_count);
|
|
340
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
341
|
+
NK_DYNAMIC void nk_dots_symmetric_i8(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
342
|
+
nk_i32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
343
|
+
nk_size_t row_count);
|
|
344
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
345
|
+
NK_DYNAMIC void nk_dots_symmetric_u8(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
346
|
+
nk_u32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
347
|
+
nk_size_t row_count);
|
|
348
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
349
|
+
NK_DYNAMIC void nk_dots_symmetric_i4(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
350
|
+
nk_i32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
351
|
+
nk_size_t row_count);
|
|
352
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
353
|
+
NK_DYNAMIC void nk_dots_symmetric_u4(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
354
|
+
nk_u32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
355
|
+
nk_size_t row_count);
|
|
356
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
357
|
+
NK_DYNAMIC void nk_dots_symmetric_u1(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
358
|
+
nk_u32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
359
|
+
nk_size_t row_count);
|
|
360
|
+
|
|
361
|
+
/** @copydoc nk_dots_packed_size_f32 */
|
|
362
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f32_serial(nk_size_t width, nk_size_t depth);
|
|
363
|
+
/** @copydoc nk_dots_pack_f32 */
|
|
364
|
+
NK_PUBLIC void nk_dots_pack_f32_serial(nk_f32_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
365
|
+
void *b_packed);
|
|
366
|
+
/** @copydoc nk_dots_packed_f32 */
|
|
367
|
+
NK_PUBLIC void nk_dots_packed_f32_serial(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
|
|
368
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
369
|
+
/** @copydoc nk_dots_symmetric_f32 */
|
|
370
|
+
NK_PUBLIC void nk_dots_symmetric_f32_serial(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
371
|
+
nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
|
|
372
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
373
|
+
|
|
374
|
+
/** @copydoc nk_dots_packed_size_f64 */
|
|
375
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f64_serial(nk_size_t width, nk_size_t depth);
|
|
376
|
+
/** @copydoc nk_dots_pack_f64 */
|
|
377
|
+
NK_PUBLIC void nk_dots_pack_f64_serial(nk_f64_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
378
|
+
void *b_packed);
|
|
379
|
+
/** @copydoc nk_dots_packed_f64 */
|
|
380
|
+
NK_PUBLIC void nk_dots_packed_f64_serial(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
|
|
381
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
382
|
+
/** @copydoc nk_dots_symmetric_f64 */
|
|
383
|
+
NK_PUBLIC void nk_dots_symmetric_f64_serial(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
384
|
+
nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
|
|
385
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
386
|
+
|
|
387
|
+
/** @copydoc nk_dots_packed_size_f16 */
|
|
388
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f16_serial(nk_size_t width, nk_size_t depth);
|
|
389
|
+
/** @copydoc nk_dots_pack_f16 */
|
|
390
|
+
NK_PUBLIC void nk_dots_pack_f16_serial(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
391
|
+
void *b_packed);
|
|
392
|
+
/** @copydoc nk_dots_packed_f16 */
|
|
393
|
+
NK_PUBLIC void nk_dots_packed_f16_serial(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
394
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
395
|
+
/** @copydoc nk_dots_symmetric_f16 */
|
|
396
|
+
NK_PUBLIC void nk_dots_symmetric_f16_serial(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
397
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
398
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
399
|
+
|
|
400
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
401
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_serial(nk_size_t width, nk_size_t depth);
|
|
402
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
403
|
+
NK_PUBLIC void nk_dots_pack_bf16_serial(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
404
|
+
void *b_packed);
|
|
405
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
406
|
+
NK_PUBLIC void nk_dots_packed_bf16_serial(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
407
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
408
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
409
|
+
NK_PUBLIC void nk_dots_symmetric_bf16_serial(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
410
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
411
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
412
|
+
|
|
413
|
+
/** @copydoc nk_dots_packed_size_i8 */
|
|
414
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_i8_serial(nk_size_t width, nk_size_t depth);
|
|
415
|
+
/** @copydoc nk_dots_pack_i8 */
|
|
416
|
+
NK_PUBLIC void nk_dots_pack_i8_serial(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
417
|
+
void *b_packed);
|
|
418
|
+
/** @copydoc nk_dots_packed_i8 */
|
|
419
|
+
NK_PUBLIC void nk_dots_packed_i8_serial(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
|
|
420
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
421
|
+
/** @copydoc nk_dots_symmetric_i8 */
|
|
422
|
+
NK_PUBLIC void nk_dots_symmetric_i8_serial(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
423
|
+
nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
|
|
424
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
425
|
+
|
|
426
|
+
/** @copydoc nk_dots_packed_size_u8 */
|
|
427
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u8_serial(nk_size_t width, nk_size_t depth);
|
|
428
|
+
/** @copydoc nk_dots_pack_u8 */
|
|
429
|
+
NK_PUBLIC void nk_dots_pack_u8_serial(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
430
|
+
void *b_packed);
|
|
431
|
+
/** @copydoc nk_dots_packed_u8 */
|
|
432
|
+
NK_PUBLIC void nk_dots_packed_u8_serial(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
433
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
434
|
+
/** @copydoc nk_dots_symmetric_u8 */
|
|
435
|
+
NK_PUBLIC void nk_dots_symmetric_u8_serial(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
436
|
+
nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
|
|
437
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
438
|
+
|
|
439
|
+
/** @copydoc nk_dots_packed_size_u4 */
|
|
440
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u4_serial(nk_size_t width, nk_size_t depth);
|
|
441
|
+
/** @copydoc nk_dots_pack_u4 */
|
|
442
|
+
NK_PUBLIC void nk_dots_pack_u4_serial(nk_u4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
443
|
+
void *b_packed);
|
|
444
|
+
/** @copydoc nk_dots_packed_u4 */
|
|
445
|
+
NK_PUBLIC void nk_dots_packed_u4_serial(nk_u4x2_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
446
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
447
|
+
/** @copydoc nk_dots_symmetric_u4 */
|
|
448
|
+
NK_PUBLIC void nk_dots_symmetric_u4_serial(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
449
|
+
nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
|
|
450
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
451
|
+
|
|
452
|
+
/** @copydoc nk_dots_packed_size_u1 */
|
|
453
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u1_serial(nk_size_t width, nk_size_t depth);
|
|
454
|
+
/** @copydoc nk_dots_pack_u1 */
|
|
455
|
+
NK_PUBLIC void nk_dots_pack_u1_serial(nk_u1x8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
456
|
+
void *b_packed);
|
|
457
|
+
/** @copydoc nk_dots_packed_u1 */
|
|
458
|
+
NK_PUBLIC void nk_dots_packed_u1_serial(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
459
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
460
|
+
/** @copydoc nk_dots_symmetric_u1 */
|
|
461
|
+
NK_PUBLIC void nk_dots_symmetric_u1_serial(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
462
|
+
nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
|
|
463
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
464
|
+
|
|
465
|
+
/** @copydoc nk_dots_packed_size_i4 */
|
|
466
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_i4_serial(nk_size_t width, nk_size_t depth);
|
|
467
|
+
/** @copydoc nk_dots_pack_i4 */
|
|
468
|
+
NK_PUBLIC void nk_dots_pack_i4_serial(nk_i4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
469
|
+
void *b_packed);
|
|
470
|
+
/** @copydoc nk_dots_packed_i4 */
|
|
471
|
+
NK_PUBLIC void nk_dots_packed_i4_serial(nk_i4x2_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
|
|
472
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
473
|
+
/** @copydoc nk_dots_symmetric_i4 */
|
|
474
|
+
NK_PUBLIC void nk_dots_symmetric_i4_serial(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
475
|
+
nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
|
|
476
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
477
|
+
/** @copydoc nk_dots_symmetric_e4m3 */
|
|
478
|
+
NK_PUBLIC void nk_dots_symmetric_e4m3_serial(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
479
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
480
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
481
|
+
/** @copydoc nk_dots_symmetric_e5m2 */
|
|
482
|
+
NK_PUBLIC void nk_dots_symmetric_e5m2_serial(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
483
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
484
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
485
|
+
/** @copydoc nk_dots_symmetric_e2m3 */
|
|
486
|
+
NK_PUBLIC void nk_dots_symmetric_e2m3_serial(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
487
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
488
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
489
|
+
/** @copydoc nk_dots_symmetric_e3m2 */
|
|
490
|
+
NK_PUBLIC void nk_dots_symmetric_e3m2_serial(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
491
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
492
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
493
|
+
/** @copydoc nk_dots_packed_size_e2m3 */
|
|
494
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_serial(nk_size_t width, nk_size_t depth);
|
|
495
|
+
/** @copydoc nk_dots_pack_e2m3 */
|
|
496
|
+
NK_PUBLIC void nk_dots_pack_e2m3_serial(nk_e2m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
497
|
+
void *b_packed);
|
|
498
|
+
/** @copydoc nk_dots_packed_e2m3 */
|
|
499
|
+
NK_PUBLIC void nk_dots_packed_e2m3_serial(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
500
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
501
|
+
/** @copydoc nk_dots_packed_size_e3m2 */
|
|
502
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_serial(nk_size_t width, nk_size_t depth);
|
|
503
|
+
/** @copydoc nk_dots_pack_e3m2 */
|
|
504
|
+
NK_PUBLIC void nk_dots_pack_e3m2_serial(nk_e3m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
505
|
+
void *b_packed);
|
|
506
|
+
/** @copydoc nk_dots_packed_e3m2 */
|
|
507
|
+
NK_PUBLIC void nk_dots_packed_e3m2_serial(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
508
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
509
|
+
|
|
510
|
+
/* Genoa backends using AVX-512 with BF16 extensions.
|
|
511
|
+
* These use VDPBF16PS for BF16 dot products.
|
|
512
|
+
* Packing interleaves elements for SIMD broadcast patterns.
|
|
513
|
+
*/
|
|
514
|
+
#if NK_TARGET_GENOA
|
|
515
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
516
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_genoa(nk_size_t width, nk_size_t depth);
|
|
517
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
518
|
+
NK_PUBLIC void nk_dots_pack_bf16_genoa(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
519
|
+
void *b_packed);
|
|
520
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
521
|
+
NK_PUBLIC void nk_dots_packed_bf16_genoa(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
522
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
523
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
524
|
+
NK_PUBLIC void nk_dots_symmetric_bf16_genoa(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
525
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
526
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
527
|
+
|
|
528
|
+
/** @copydoc nk_dots_packed_size_e4m3 */
|
|
529
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_genoa(nk_size_t width, nk_size_t depth);
|
|
530
|
+
/** @copydoc nk_dots_pack_e4m3 */
|
|
531
|
+
NK_PUBLIC void nk_dots_pack_e4m3_genoa(nk_e4m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
532
|
+
void *b_packed);
|
|
533
|
+
/** @copydoc nk_dots_packed_e4m3 */
|
|
534
|
+
NK_PUBLIC void nk_dots_packed_e4m3_genoa(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
535
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
536
|
+
/** @copydoc nk_dots_packed_size_e5m2 */
|
|
537
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_genoa(nk_size_t width, nk_size_t depth);
|
|
538
|
+
/** @copydoc nk_dots_pack_e5m2 */
|
|
539
|
+
NK_PUBLIC void nk_dots_pack_e5m2_genoa(nk_e5m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
540
|
+
void *b_packed);
|
|
541
|
+
/** @copydoc nk_dots_packed_e5m2 */
|
|
542
|
+
NK_PUBLIC void nk_dots_packed_e5m2_genoa(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
543
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
544
|
+
/** @copydoc nk_dots_symmetric_e4m3 */
|
|
545
|
+
NK_PUBLIC void nk_dots_symmetric_e4m3_genoa(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
546
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
547
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
548
|
+
/** @copydoc nk_dots_symmetric_e5m2 */
|
|
549
|
+
NK_PUBLIC void nk_dots_symmetric_e5m2_genoa(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
550
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
551
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
552
|
+
#endif // NK_TARGET_GENOA
|
|
553
|
+
|
|
554
|
+
/* Sapphire Rapids backends using Intel AMX (Advanced Matrix Extensions).
|
|
555
|
+
* AMX provides 8 tile registers (TMM0-TMM7), each holding up to 1KB of data.
|
|
556
|
+
* Tiles are configured as 16 rows × 64 bytes, enabling (16 × 32) BF16 or (16 × 64) INT8 tiles.
|
|
557
|
+
* Packing arranges data into AMX-native tile layout with pair interleaving for TDPBF16PS.
|
|
558
|
+
*/
|
|
559
|
+
#if NK_TARGET_SAPPHIREAMX
|
|
560
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
561
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_sapphireamx(nk_size_t width, nk_size_t depth);
|
|
562
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
563
|
+
NK_PUBLIC void nk_dots_pack_bf16_sapphireamx(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
564
|
+
void *b_packed);
|
|
565
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
566
|
+
NK_PUBLIC void nk_dots_packed_bf16_sapphireamx(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
567
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride,
|
|
568
|
+
nk_size_t c_stride);
|
|
569
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
570
|
+
NK_PUBLIC void nk_dots_symmetric_bf16_sapphireamx(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
571
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
572
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
573
|
+
|
|
574
|
+
/** @copydoc nk_dots_packed_size_i8 */
|
|
575
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_i8_sapphireamx(nk_size_t width, nk_size_t depth);
|
|
576
|
+
/** @copydoc nk_dots_pack_i8 */
|
|
577
|
+
NK_PUBLIC void nk_dots_pack_i8_sapphireamx(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
578
|
+
void *b_packed);
|
|
579
|
+
/** @copydoc nk_dots_packed_i8 */
|
|
580
|
+
NK_PUBLIC void nk_dots_packed_i8_sapphireamx(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
|
|
581
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
582
|
+
/** @copydoc nk_dots_symmetric_i8 */
|
|
583
|
+
NK_PUBLIC void nk_dots_symmetric_i8_sapphireamx(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
584
|
+
nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
|
|
585
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
586
|
+
|
|
587
|
+
/** @copydoc nk_dots_packed_size_e4m3 */
|
|
588
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_sapphireamx(nk_size_t width, nk_size_t depth);
|
|
589
|
+
/** @copydoc nk_dots_pack_e4m3 */
|
|
590
|
+
NK_PUBLIC void nk_dots_pack_e4m3_sapphireamx(nk_e4m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
591
|
+
void *b_packed);
|
|
592
|
+
/** @copydoc nk_dots_packed_e4m3 */
|
|
593
|
+
NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
594
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride,
|
|
595
|
+
nk_size_t c_stride);
|
|
596
|
+
|
|
597
|
+
/** @copydoc nk_dots_symmetric_e4m3 */
|
|
598
|
+
NK_PUBLIC void nk_dots_symmetric_e4m3_sapphireamx(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
599
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
600
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
601
|
+
|
|
602
|
+
/** @copydoc nk_dots_packed_size_e5m2 */
|
|
603
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_sapphireamx(nk_size_t width, nk_size_t depth);
|
|
604
|
+
/** @copydoc nk_dots_pack_e5m2 */
|
|
605
|
+
NK_PUBLIC void nk_dots_pack_e5m2_sapphireamx(nk_e5m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
606
|
+
void *b_packed);
|
|
607
|
+
/** @copydoc nk_dots_packed_e5m2 */
|
|
608
|
+
NK_PUBLIC void nk_dots_packed_e5m2_sapphireamx(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
609
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride,
|
|
610
|
+
nk_size_t c_stride);
|
|
611
|
+
/** @copydoc nk_dots_symmetric_e5m2 */
|
|
612
|
+
NK_PUBLIC void nk_dots_symmetric_e5m2_sapphireamx(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
613
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
614
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
615
|
+
/** @copydoc nk_dots_packed_size_e2m3 */
|
|
616
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_sapphireamx(nk_size_t width, nk_size_t depth);
|
|
617
|
+
/** @copydoc nk_dots_pack_e2m3 */
|
|
618
|
+
NK_PUBLIC void nk_dots_pack_e2m3_sapphireamx(nk_e2m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
619
|
+
void *b_packed);
|
|
620
|
+
/** @copydoc nk_dots_packed_e2m3 */
|
|
621
|
+
NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
622
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride,
|
|
623
|
+
nk_size_t c_stride);
|
|
624
|
+
/** @copydoc nk_dots_symmetric_e2m3 */
|
|
625
|
+
NK_PUBLIC void nk_dots_symmetric_e2m3_sapphireamx(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
626
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
627
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
628
|
+
|
|
629
|
+
/** @copydoc nk_dots_packed_size_e3m2 */
|
|
630
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_sapphireamx(nk_size_t width, nk_size_t depth);
|
|
631
|
+
/** @copydoc nk_dots_pack_e3m2 */
|
|
632
|
+
NK_PUBLIC void nk_dots_pack_e3m2_sapphireamx(nk_e3m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
633
|
+
void *b_packed);
|
|
634
|
+
/** @copydoc nk_dots_packed_e3m2 */
|
|
635
|
+
NK_PUBLIC void nk_dots_packed_e3m2_sapphireamx(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
636
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride,
|
|
637
|
+
nk_size_t c_stride);
|
|
638
|
+
/** @copydoc nk_dots_symmetric_e3m2 */
|
|
639
|
+
NK_PUBLIC void nk_dots_symmetric_e3m2_sapphireamx(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
640
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
641
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
642
|
+
|
|
643
|
+
/** @copydoc nk_dots_packed_size_u8 */
|
|
644
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u8_sapphireamx(nk_size_t width, nk_size_t depth);
|
|
645
|
+
/** @copydoc nk_dots_pack_u8 */
|
|
646
|
+
NK_PUBLIC void nk_dots_pack_u8_sapphireamx(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
647
|
+
void *b_packed);
|
|
648
|
+
/** @copydoc nk_dots_packed_u8 */
|
|
649
|
+
NK_PUBLIC void nk_dots_packed_u8_sapphireamx(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
650
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
651
|
+
/** @copydoc nk_dots_symmetric_u8 */
|
|
652
|
+
NK_PUBLIC void nk_dots_symmetric_u8_sapphireamx(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
653
|
+
nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
|
|
654
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
655
|
+
#endif // NK_TARGET_SAPPHIREAMX
|
|
656
|
+
|
|
657
|
+
/* ARM SME backends using Scalable Matrix Extension.
|
|
658
|
+
* SME provides ZA tile registers for outer product operations.
|
|
659
|
+
* F16/BF16/I8/U8/E4M3 use ZA32 tiles, F32/F64 use ZA64 tiles (FEAT_SME_F64F64).
|
|
660
|
+
*/
|
|
661
|
+
#if NK_TARGET_SME
|
|
662
|
+
/** @copydoc nk_dots_packed_size_f16 */
|
|
663
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f16_sme(nk_size_t width, nk_size_t depth);
|
|
664
|
+
/** @copydoc nk_dots_pack_f16 */
|
|
665
|
+
NK_PUBLIC void nk_dots_pack_f16_sme(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
666
|
+
void *b_packed);
|
|
667
|
+
/** @copydoc nk_dots_packed_f16 */
|
|
668
|
+
NK_PUBLIC void nk_dots_packed_f16_sme(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
669
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
670
|
+
/** @copydoc nk_dots_symmetric_f16 */
|
|
671
|
+
NK_PUBLIC void nk_dots_symmetric_f16_sme(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
672
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
673
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
674
|
+
|
|
675
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
676
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_sme(nk_size_t width, nk_size_t depth);
|
|
677
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
678
|
+
NK_PUBLIC void nk_dots_pack_bf16_sme(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
679
|
+
void *b_packed);
|
|
680
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
681
|
+
NK_PUBLIC void nk_dots_packed_bf16_sme(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
682
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
683
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
684
|
+
NK_PUBLIC void nk_dots_symmetric_bf16_sme(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
685
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
686
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
687
|
+
|
|
688
|
+
/** @copydoc nk_dots_packed_size_i8 */
|
|
689
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_i8_sme(nk_size_t width, nk_size_t depth);
|
|
690
|
+
/** @copydoc nk_dots_pack_i8 */
|
|
691
|
+
NK_PUBLIC void nk_dots_pack_i8_sme(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
692
|
+
void *b_packed);
|
|
693
|
+
/** @copydoc nk_dots_packed_i8 */
|
|
694
|
+
NK_PUBLIC void nk_dots_packed_i8_sme(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
|
|
695
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
696
|
+
/** @copydoc nk_dots_symmetric_i8 */
|
|
697
|
+
NK_PUBLIC void nk_dots_symmetric_i8_sme(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
698
|
+
nk_i32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
699
|
+
nk_size_t row_count);
|
|
700
|
+
|
|
701
|
+
/** @copydoc nk_dots_packed_size_u8 */
|
|
702
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u8_sme(nk_size_t width, nk_size_t depth);
|
|
703
|
+
/** @copydoc nk_dots_pack_u8 */
|
|
704
|
+
NK_PUBLIC void nk_dots_pack_u8_sme(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
705
|
+
void *b_packed);
|
|
706
|
+
/** @copydoc nk_dots_packed_u8 */
|
|
707
|
+
NK_PUBLIC void nk_dots_packed_u8_sme(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
708
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
709
|
+
/** @copydoc nk_dots_symmetric_u8 */
|
|
710
|
+
NK_PUBLIC void nk_dots_symmetric_u8_sme(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
711
|
+
nk_u32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
712
|
+
nk_size_t row_count);
|
|
713
|
+
|
|
714
|
+
/** @copydoc nk_dots_packed_size_e4m3 */
|
|
715
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_sme(nk_size_t width, nk_size_t depth);
|
|
716
|
+
/** @copydoc nk_dots_pack_e4m3 */
|
|
717
|
+
NK_PUBLIC void nk_dots_pack_e4m3_sme(nk_e4m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
718
|
+
void *b_packed);
|
|
719
|
+
/** @copydoc nk_dots_packed_e4m3 */
|
|
720
|
+
NK_PUBLIC void nk_dots_packed_e4m3_sme(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
721
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
722
|
+
/** @copydoc nk_dots_symmetric_e4m3 */
|
|
723
|
+
NK_PUBLIC void nk_dots_symmetric_e4m3_sme(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
724
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
725
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
726
|
+
|
|
727
|
+
/** @copydoc nk_dots_packed_size_e5m2 */
|
|
728
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_sme(nk_size_t width, nk_size_t depth);
|
|
729
|
+
/** @copydoc nk_dots_pack_e5m2 */
|
|
730
|
+
NK_PUBLIC void nk_dots_pack_e5m2_sme(nk_e5m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
731
|
+
void *b_packed);
|
|
732
|
+
/** @copydoc nk_dots_packed_e5m2 */
|
|
733
|
+
NK_PUBLIC void nk_dots_packed_e5m2_sme(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
734
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
735
|
+
/** @copydoc nk_dots_symmetric_e5m2 */
|
|
736
|
+
NK_PUBLIC void nk_dots_symmetric_e5m2_sme(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
737
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
738
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
739
|
+
|
|
740
|
+
/** @copydoc nk_dots_packed_size_u4 */
|
|
741
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u4_sme(nk_size_t width, nk_size_t depth);
|
|
742
|
+
/** @copydoc nk_dots_pack_u4 */
|
|
743
|
+
NK_PUBLIC void nk_dots_pack_u4_sme(nk_u4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
744
|
+
void *b_packed);
|
|
745
|
+
/** @copydoc nk_dots_packed_u4 */
|
|
746
|
+
NK_PUBLIC void nk_dots_packed_u4_sme(nk_u4x2_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
747
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
748
|
+
/** @copydoc nk_dots_symmetric_u4 */
|
|
749
|
+
NK_PUBLIC void nk_dots_symmetric_u4_sme(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
750
|
+
nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
|
|
751
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
752
|
+
|
|
753
|
+
/** @copydoc nk_dots_packed_size_i4 */
|
|
754
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_i4_sme(nk_size_t width, nk_size_t depth);
|
|
755
|
+
/** @copydoc nk_dots_pack_i4 */
|
|
756
|
+
NK_PUBLIC void nk_dots_pack_i4_sme(nk_i4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
757
|
+
void *b_packed);
|
|
758
|
+
/** @copydoc nk_dots_packed_i4 */
|
|
759
|
+
NK_PUBLIC void nk_dots_packed_i4_sme(nk_i4x2_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
|
|
760
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
761
|
+
/** @copydoc nk_dots_symmetric_i4 */
|
|
762
|
+
NK_PUBLIC void nk_dots_symmetric_i4_sme(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
763
|
+
nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
|
|
764
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
765
|
+
|
|
766
|
+
/** @copydoc nk_dots_packed_size_e2m3 */
|
|
767
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_sme(nk_size_t width, nk_size_t depth);
|
|
768
|
+
/** @copydoc nk_dots_pack_e2m3 */
|
|
769
|
+
NK_PUBLIC void nk_dots_pack_e2m3_sme(nk_e2m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
770
|
+
void *b_packed);
|
|
771
|
+
/** @copydoc nk_dots_packed_e2m3 */
|
|
772
|
+
NK_PUBLIC void nk_dots_packed_e2m3_sme(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
773
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
774
|
+
/** @copydoc nk_dots_symmetric_e2m3 */
|
|
775
|
+
NK_PUBLIC void nk_dots_symmetric_e2m3_sme(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
776
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
777
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
778
|
+
|
|
779
|
+
/** @copydoc nk_dots_packed_size_e3m2 */
|
|
780
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_sme(nk_size_t width, nk_size_t depth);
|
|
781
|
+
/** @copydoc nk_dots_pack_e3m2 */
|
|
782
|
+
NK_PUBLIC void nk_dots_pack_e3m2_sme(nk_e3m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
783
|
+
void *b_packed);
|
|
784
|
+
/** @copydoc nk_dots_packed_e3m2 */
|
|
785
|
+
NK_PUBLIC void nk_dots_packed_e3m2_sme(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
786
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
787
|
+
/** @copydoc nk_dots_symmetric_e3m2 */
|
|
788
|
+
NK_PUBLIC void nk_dots_symmetric_e3m2_sme(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
789
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
790
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
791
|
+
#endif // NK_TARGET_SME
|
|
792
|
+
|
|
793
|
+
/* ARM SME with integer-accumulating binary outer products.
|
|
794
|
+
* Used for packed 1-bit dot products backed by ZA32.
|
|
795
|
+
*/
|
|
796
|
+
#if NK_TARGET_SMEBI32
|
|
797
|
+
/** @copydoc nk_dots_packed_size_u1 */
|
|
798
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u1_smebi32(nk_size_t width, nk_size_t depth);
|
|
799
|
+
/** @copydoc nk_dots_pack_u1 */
|
|
800
|
+
NK_PUBLIC void nk_dots_pack_u1_smebi32(nk_u1x8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
801
|
+
void *b_packed);
|
|
802
|
+
/** @copydoc nk_dots_packed_u1 */
|
|
803
|
+
NK_PUBLIC void nk_dots_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
804
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
805
|
+
/** @copydoc nk_dots_symmetric_u1 */
|
|
806
|
+
NK_PUBLIC void nk_dots_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
807
|
+
nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
|
|
808
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
809
|
+
#endif // NK_TARGET_SMEBI32
|
|
810
|
+
|
|
811
|
+
/* ARM SME with FEAT_SME_F64F64 (F32/F64 with F64 accumulators).
|
|
812
|
+
* Requires Apple M4 or equivalent with F64 outer product support.
|
|
813
|
+
*/
|
|
814
|
+
#if NK_TARGET_SMEF64
|
|
815
|
+
/** @copydoc nk_dots_packed_size_f32 */
|
|
816
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f32_smef64(nk_size_t width, nk_size_t depth);
|
|
817
|
+
/** @copydoc nk_dots_pack_f32 */
|
|
818
|
+
NK_PUBLIC void nk_dots_pack_f32_smef64(nk_f32_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
819
|
+
void *b_packed);
|
|
820
|
+
/** @copydoc nk_dots_packed_f32 */
|
|
821
|
+
NK_PUBLIC void nk_dots_packed_f32_smef64(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
|
|
822
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
823
|
+
/** @copydoc nk_dots_symmetric_f32 */
|
|
824
|
+
NK_PUBLIC void nk_dots_symmetric_f32_smef64(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
825
|
+
nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
|
|
826
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
827
|
+
|
|
828
|
+
/** @copydoc nk_dots_packed_size_f64 */
|
|
829
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f64_smef64(nk_size_t width, nk_size_t depth);
|
|
830
|
+
/** @copydoc nk_dots_pack_f64 */
|
|
831
|
+
NK_PUBLIC void nk_dots_pack_f64_smef64(nk_f64_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
832
|
+
void *b_packed);
|
|
833
|
+
/** @copydoc nk_dots_packed_f64 */
|
|
834
|
+
NK_PUBLIC void nk_dots_packed_f64_smef64(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
|
|
835
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
836
|
+
/** @copydoc nk_dots_symmetric_f64 */
|
|
837
|
+
NK_PUBLIC void nk_dots_symmetric_f64_smef64(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
838
|
+
nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
|
|
839
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
840
|
+
#endif // NK_TARGET_SMEF64
|
|
841
|
+
|
|
842
|
+
/* Haswell backends using AVX2 (Intel Core 4th gen).
|
|
843
|
+
* Supports F32/F64 via FMA, F16/BF16/FP8 via software emulation, I8/U8 via VPMADDUBSW+VPADDD.
|
|
844
|
+
*/
|
|
845
|
+
#if NK_TARGET_HASWELL
|
|
846
|
+
/** @copydoc nk_dots_packed_size_f32 */
|
|
847
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f32_haswell(nk_size_t width, nk_size_t depth);
|
|
848
|
+
/** @copydoc nk_dots_pack_f32 */
|
|
849
|
+
NK_PUBLIC void nk_dots_pack_f32_haswell(nk_f32_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
850
|
+
void *b_packed);
|
|
851
|
+
/** @copydoc nk_dots_packed_f32 */
|
|
852
|
+
NK_PUBLIC void nk_dots_packed_f32_haswell(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
|
|
853
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
854
|
+
/** @copydoc nk_dots_symmetric_f32 */
|
|
855
|
+
NK_PUBLIC void nk_dots_symmetric_f32_haswell(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
856
|
+
nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
|
|
857
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
858
|
+
/** @copydoc nk_dots_packed_size_f64 */
|
|
859
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f64_haswell(nk_size_t width, nk_size_t depth);
|
|
860
|
+
/** @copydoc nk_dots_pack_f64 */
|
|
861
|
+
NK_PUBLIC void nk_dots_pack_f64_haswell(nk_f64_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
862
|
+
void *b_packed);
|
|
863
|
+
/** @copydoc nk_dots_packed_f64 */
|
|
864
|
+
NK_PUBLIC void nk_dots_packed_f64_haswell(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
|
|
865
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
866
|
+
/** @copydoc nk_dots_symmetric_f64 */
|
|
867
|
+
NK_PUBLIC void nk_dots_symmetric_f64_haswell(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
868
|
+
nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
|
|
869
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
870
|
+
/** @copydoc nk_dots_packed_size_f16 */
|
|
871
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f16_haswell(nk_size_t width, nk_size_t depth);
|
|
872
|
+
/** @copydoc nk_dots_pack_f16 */
|
|
873
|
+
NK_PUBLIC void nk_dots_pack_f16_haswell(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
874
|
+
void *b_packed);
|
|
875
|
+
/** @copydoc nk_dots_packed_f16 */
|
|
876
|
+
NK_PUBLIC void nk_dots_packed_f16_haswell(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
877
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
878
|
+
/** @copydoc nk_dots_symmetric_f16 */
|
|
879
|
+
NK_PUBLIC void nk_dots_symmetric_f16_haswell(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
880
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
881
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
882
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
883
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_haswell(nk_size_t width, nk_size_t depth);
|
|
884
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
885
|
+
NK_PUBLIC void nk_dots_pack_bf16_haswell(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
886
|
+
void *b_packed);
|
|
887
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
888
|
+
NK_PUBLIC void nk_dots_packed_bf16_haswell(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
889
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
890
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
891
|
+
NK_PUBLIC void nk_dots_symmetric_bf16_haswell(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
892
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
893
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
894
|
+
/** @copydoc nk_dots_packed_size_e4m3 */
|
|
895
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_haswell(nk_size_t width, nk_size_t depth);
|
|
896
|
+
/** @copydoc nk_dots_pack_e4m3 */
|
|
897
|
+
NK_PUBLIC void nk_dots_pack_e4m3_haswell(nk_e4m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
898
|
+
void *b_packed);
|
|
899
|
+
/** @copydoc nk_dots_packed_e4m3 */
|
|
900
|
+
NK_PUBLIC void nk_dots_packed_e4m3_haswell(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
901
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
902
|
+
/** @copydoc nk_dots_symmetric_e4m3 */
|
|
903
|
+
NK_PUBLIC void nk_dots_symmetric_e4m3_haswell(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
904
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
905
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
906
|
+
/** @copydoc nk_dots_packed_size_e5m2 */
|
|
907
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_haswell(nk_size_t width, nk_size_t depth);
|
|
908
|
+
/** @copydoc nk_dots_pack_e5m2 */
|
|
909
|
+
NK_PUBLIC void nk_dots_pack_e5m2_haswell(nk_e5m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
910
|
+
void *b_packed);
|
|
911
|
+
/** @copydoc nk_dots_packed_e5m2 */
|
|
912
|
+
NK_PUBLIC void nk_dots_packed_e5m2_haswell(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
913
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
914
|
+
/** @copydoc nk_dots_symmetric_e5m2 */
|
|
915
|
+
NK_PUBLIC void nk_dots_symmetric_e5m2_haswell(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
916
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
917
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
918
|
+
/** @copydoc nk_dots_packed_size_e2m3 */
|
|
919
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_haswell(nk_size_t width, nk_size_t depth);
|
|
920
|
+
/** @copydoc nk_dots_pack_e2m3 */
|
|
921
|
+
NK_PUBLIC void nk_dots_pack_e2m3_haswell(nk_e2m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
922
|
+
void *b_packed);
|
|
923
|
+
/** @copydoc nk_dots_packed_e2m3 */
|
|
924
|
+
NK_PUBLIC void nk_dots_packed_e2m3_haswell(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
925
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
926
|
+
/** @copydoc nk_dots_symmetric_e2m3 */
|
|
927
|
+
NK_PUBLIC void nk_dots_symmetric_e2m3_haswell(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
928
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
929
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
930
|
+
/** @copydoc nk_dots_packed_size_e3m2 */
|
|
931
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_haswell(nk_size_t width, nk_size_t depth);
|
|
932
|
+
/** @copydoc nk_dots_pack_e3m2 */
|
|
933
|
+
NK_PUBLIC void nk_dots_pack_e3m2_haswell(nk_e3m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
934
|
+
void *b_packed);
|
|
935
|
+
/** @copydoc nk_dots_packed_e3m2 */
|
|
936
|
+
NK_PUBLIC void nk_dots_packed_e3m2_haswell(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
937
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
938
|
+
/** @copydoc nk_dots_symmetric_e3m2 */
|
|
939
|
+
NK_PUBLIC void nk_dots_symmetric_e3m2_haswell(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
940
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
941
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
942
|
+
/** @copydoc nk_dots_packed_size_i8 */
|
|
943
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_i8_haswell(nk_size_t width, nk_size_t depth);
|
|
944
|
+
/** @copydoc nk_dots_pack_i8 */
|
|
945
|
+
NK_PUBLIC void nk_dots_pack_i8_haswell(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
946
|
+
void *b_packed);
|
|
947
|
+
/** @copydoc nk_dots_packed_i8 */
|
|
948
|
+
NK_PUBLIC void nk_dots_packed_i8_haswell(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
|
|
949
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
950
|
+
/** @copydoc nk_dots_symmetric_i8 */
|
|
951
|
+
NK_PUBLIC void nk_dots_symmetric_i8_haswell(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
952
|
+
nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
|
|
953
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
954
|
+
/** @copydoc nk_dots_packed_size_u8 */
|
|
955
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u8_haswell(nk_size_t width, nk_size_t depth);
|
|
956
|
+
/** @copydoc nk_dots_pack_u8 */
|
|
957
|
+
NK_PUBLIC void nk_dots_pack_u8_haswell(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
958
|
+
void *b_packed);
|
|
959
|
+
/** @copydoc nk_dots_packed_u8 */
|
|
960
|
+
NK_PUBLIC void nk_dots_packed_u8_haswell(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
961
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
962
|
+
/** @copydoc nk_dots_symmetric_u8 */
|
|
963
|
+
NK_PUBLIC void nk_dots_symmetric_u8_haswell(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
964
|
+
nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
|
|
965
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
966
|
+
/** @copydoc nk_dots_packed_size_u1 */
|
|
967
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u1_haswell(nk_size_t width, nk_size_t depth);
|
|
968
|
+
/** @copydoc nk_dots_pack_u1 */
|
|
969
|
+
NK_PUBLIC void nk_dots_pack_u1_haswell(nk_u1x8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
970
|
+
void *b_packed);
|
|
971
|
+
/** @copydoc nk_dots_packed_u1 */
|
|
972
|
+
NK_PUBLIC void nk_dots_packed_u1_haswell(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
973
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
974
|
+
/** @copydoc nk_dots_symmetric_u1 */
|
|
975
|
+
NK_PUBLIC void nk_dots_symmetric_u1_haswell(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
976
|
+
nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
|
|
977
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
978
|
+
/** @copydoc nk_dots_packed_size_i4 */
|
|
979
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_i4_haswell(nk_size_t width, nk_size_t depth);
|
|
980
|
+
/** @copydoc nk_dots_pack_i4 */
|
|
981
|
+
NK_PUBLIC void nk_dots_pack_i4_haswell(nk_i4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
982
|
+
void *b_packed);
|
|
983
|
+
/** @copydoc nk_dots_packed_i4 */
|
|
984
|
+
NK_PUBLIC void nk_dots_packed_i4_haswell(nk_i4x2_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
|
|
985
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
986
|
+
/** @copydoc nk_dots_symmetric_i4 */
|
|
987
|
+
NK_PUBLIC void nk_dots_symmetric_i4_haswell(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
988
|
+
nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
|
|
989
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
990
|
+
/** @copydoc nk_dots_packed_size_u4 */
|
|
991
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u4_haswell(nk_size_t width, nk_size_t depth);
|
|
992
|
+
/** @copydoc nk_dots_pack_u4 */
|
|
993
|
+
NK_PUBLIC void nk_dots_pack_u4_haswell(nk_u4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
994
|
+
void *b_packed);
|
|
995
|
+
/** @copydoc nk_dots_packed_u4 */
|
|
996
|
+
NK_PUBLIC void nk_dots_packed_u4_haswell(nk_u4x2_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
997
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
998
|
+
/** @copydoc nk_dots_symmetric_u4 */
|
|
999
|
+
NK_PUBLIC void nk_dots_symmetric_u4_haswell(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1000
|
+
nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
|
|
1001
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1002
|
+
#endif // NK_TARGET_HASWELL
|
|
1003
|
+
|
|
1004
|
+
/* Skylake backends using AVX-512 (Intel Core 6th gen+).
|
|
1005
|
+
* Provides 512-bit vectors (16× f32, 8× f64), supporting F32/F64/F16/BF16/FP8 with FMA.
|
|
1006
|
+
*/
|
|
1007
|
+
#if NK_TARGET_SKYLAKE
|
|
1008
|
+
/** @copydoc nk_dots_packed_size_f64 */
|
|
1009
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f64_skylake(nk_size_t width, nk_size_t depth);
|
|
1010
|
+
/** @copydoc nk_dots_pack_f64 */
|
|
1011
|
+
NK_PUBLIC void nk_dots_pack_f64_skylake(nk_f64_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1012
|
+
void *b_packed);
|
|
1013
|
+
/** @copydoc nk_dots_packed_f64 */
|
|
1014
|
+
NK_PUBLIC void nk_dots_packed_f64_skylake(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
|
|
1015
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1016
|
+
/** @copydoc nk_dots_symmetric_f64 */
|
|
1017
|
+
NK_PUBLIC void nk_dots_symmetric_f64_skylake(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1018
|
+
nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
|
|
1019
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1020
|
+
/** @copydoc nk_dots_packed_size_f32 */
|
|
1021
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f32_skylake(nk_size_t width, nk_size_t depth);
|
|
1022
|
+
/** @copydoc nk_dots_pack_f32 */
|
|
1023
|
+
NK_PUBLIC void nk_dots_pack_f32_skylake(nk_f32_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1024
|
+
void *b_packed);
|
|
1025
|
+
/** @copydoc nk_dots_packed_f32 */
|
|
1026
|
+
NK_PUBLIC void nk_dots_packed_f32_skylake(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
|
|
1027
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1028
|
+
/** @copydoc nk_dots_symmetric_f32 */
|
|
1029
|
+
NK_PUBLIC void nk_dots_symmetric_f32_skylake(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1030
|
+
nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
|
|
1031
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1032
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
1033
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_skylake(nk_size_t width, nk_size_t depth);
|
|
1034
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
1035
|
+
NK_PUBLIC void nk_dots_pack_bf16_skylake(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1036
|
+
void *b_packed);
|
|
1037
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
1038
|
+
NK_PUBLIC void nk_dots_packed_bf16_skylake(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1039
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1040
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
1041
|
+
NK_PUBLIC void nk_dots_symmetric_bf16_skylake(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1042
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1043
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1044
|
+
/** @copydoc nk_dots_packed_size_f16 */
|
|
1045
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f16_skylake(nk_size_t width, nk_size_t depth);
|
|
1046
|
+
/** @copydoc nk_dots_pack_f16 */
|
|
1047
|
+
NK_PUBLIC void nk_dots_pack_f16_skylake(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1048
|
+
void *b_packed);
|
|
1049
|
+
/** @copydoc nk_dots_packed_f16 */
|
|
1050
|
+
NK_PUBLIC void nk_dots_packed_f16_skylake(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1051
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1052
|
+
/** @copydoc nk_dots_symmetric_f16 */
|
|
1053
|
+
NK_PUBLIC void nk_dots_symmetric_f16_skylake(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1054
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1055
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1056
|
+
/** @copydoc nk_dots_packed_size_e4m3 */
|
|
1057
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_skylake(nk_size_t width, nk_size_t depth);
|
|
1058
|
+
/** @copydoc nk_dots_pack_e4m3 */
|
|
1059
|
+
NK_PUBLIC void nk_dots_pack_e4m3_skylake(nk_e4m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1060
|
+
void *b_packed);
|
|
1061
|
+
/** @copydoc nk_dots_packed_e4m3 */
|
|
1062
|
+
NK_PUBLIC void nk_dots_packed_e4m3_skylake(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1063
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1064
|
+
/** @copydoc nk_dots_symmetric_e4m3 */
|
|
1065
|
+
NK_PUBLIC void nk_dots_symmetric_e4m3_skylake(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1066
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1067
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1068
|
+
/** @copydoc nk_dots_packed_size_e5m2 */
|
|
1069
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_skylake(nk_size_t width, nk_size_t depth);
|
|
1070
|
+
/** @copydoc nk_dots_pack_e5m2 */
|
|
1071
|
+
NK_PUBLIC void nk_dots_pack_e5m2_skylake(nk_e5m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1072
|
+
void *b_packed);
|
|
1073
|
+
/** @copydoc nk_dots_packed_e5m2 */
|
|
1074
|
+
NK_PUBLIC void nk_dots_packed_e5m2_skylake(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1075
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1076
|
+
/** @copydoc nk_dots_symmetric_e5m2 */
|
|
1077
|
+
NK_PUBLIC void nk_dots_symmetric_e5m2_skylake(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1078
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1079
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1080
|
+
/** @copydoc nk_dots_packed_size_e2m3 */
|
|
1081
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_skylake(nk_size_t width, nk_size_t depth);
|
|
1082
|
+
/** @copydoc nk_dots_pack_e2m3 */
|
|
1083
|
+
NK_PUBLIC void nk_dots_pack_e2m3_skylake(nk_e2m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1084
|
+
void *b_packed);
|
|
1085
|
+
/** @copydoc nk_dots_packed_e2m3 */
|
|
1086
|
+
NK_PUBLIC void nk_dots_packed_e2m3_skylake(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1087
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1088
|
+
/** @copydoc nk_dots_symmetric_e2m3 */
|
|
1089
|
+
NK_PUBLIC void nk_dots_symmetric_e2m3_skylake(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1090
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1091
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1092
|
+
/** @copydoc nk_dots_packed_size_e3m2 */
|
|
1093
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_skylake(nk_size_t width, nk_size_t depth);
|
|
1094
|
+
/** @copydoc nk_dots_pack_e3m2 */
|
|
1095
|
+
NK_PUBLIC void nk_dots_pack_e3m2_skylake(nk_e3m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1096
|
+
void *b_packed);
|
|
1097
|
+
/** @copydoc nk_dots_packed_e3m2 */
|
|
1098
|
+
NK_PUBLIC void nk_dots_packed_e3m2_skylake(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1099
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1100
|
+
/** @copydoc nk_dots_symmetric_e3m2 */
|
|
1101
|
+
NK_PUBLIC void nk_dots_symmetric_e3m2_skylake(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1102
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1103
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1104
|
+
#endif // NK_TARGET_SKYLAKE
|
|
1105
|
+
|
|
1106
|
+
/* Ice Lake backends using AVX-512 with VNNI (Vector Neural Network Instructions).
|
|
1107
|
+
* Adds VPDPBUSD for I8/U8, VPDPWSSD for I4/U4 with efficient dot products.
|
|
1108
|
+
*/
|
|
1109
|
+
#if NK_TARGET_ICELAKE
|
|
1110
|
+
/** @copydoc nk_dots_packed_size_i8 */
|
|
1111
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_i8_icelake(nk_size_t width, nk_size_t depth);
|
|
1112
|
+
/** @copydoc nk_dots_pack_i8 */
|
|
1113
|
+
NK_PUBLIC void nk_dots_pack_i8_icelake(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1114
|
+
void *b_packed);
|
|
1115
|
+
/** @copydoc nk_dots_packed_i8 */
|
|
1116
|
+
NK_PUBLIC void nk_dots_packed_i8_icelake(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
|
|
1117
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1118
|
+
/** @copydoc nk_dots_symmetric_i8 */
|
|
1119
|
+
NK_PUBLIC void nk_dots_symmetric_i8_icelake(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1120
|
+
nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
|
|
1121
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1122
|
+
/** @copydoc nk_dots_packed_size_u8 */
|
|
1123
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u8_icelake(nk_size_t width, nk_size_t depth);
|
|
1124
|
+
/** @copydoc nk_dots_pack_u8 */
|
|
1125
|
+
NK_PUBLIC void nk_dots_pack_u8_icelake(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1126
|
+
void *b_packed);
|
|
1127
|
+
/** @copydoc nk_dots_packed_u8 */
|
|
1128
|
+
NK_PUBLIC void nk_dots_packed_u8_icelake(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
1129
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1130
|
+
/** @copydoc nk_dots_symmetric_u8 */
|
|
1131
|
+
NK_PUBLIC void nk_dots_symmetric_u8_icelake(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1132
|
+
nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
|
|
1133
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1134
|
+
/** @copydoc nk_dots_packed_size_i4 */
|
|
1135
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_i4_icelake(nk_size_t width, nk_size_t depth);
|
|
1136
|
+
/** @copydoc nk_dots_pack_i4 */
|
|
1137
|
+
NK_PUBLIC void nk_dots_pack_i4_icelake(nk_i4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1138
|
+
void *b_packed);
|
|
1139
|
+
/** @copydoc nk_dots_packed_i4 */
|
|
1140
|
+
NK_PUBLIC void nk_dots_packed_i4_icelake(nk_i4x2_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
|
|
1141
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1142
|
+
/** @copydoc nk_dots_symmetric_i4 */
|
|
1143
|
+
NK_PUBLIC void nk_dots_symmetric_i4_icelake(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1144
|
+
nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
|
|
1145
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1146
|
+
/** @copydoc nk_dots_packed_size_u4 */
|
|
1147
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u4_icelake(nk_size_t width, nk_size_t depth);
|
|
1148
|
+
/** @copydoc nk_dots_pack_u4 */
|
|
1149
|
+
NK_PUBLIC void nk_dots_pack_u4_icelake(nk_u4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1150
|
+
void *b_packed);
|
|
1151
|
+
/** @copydoc nk_dots_packed_u4 */
|
|
1152
|
+
NK_PUBLIC void nk_dots_packed_u4_icelake(nk_u4x2_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
1153
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1154
|
+
/** @copydoc nk_dots_symmetric_u4 */
|
|
1155
|
+
NK_PUBLIC void nk_dots_symmetric_u4_icelake(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1156
|
+
nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
|
|
1157
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1158
|
+
/** @copydoc nk_dots_packed_size_u1 */
|
|
1159
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u1_icelake(nk_size_t width, nk_size_t depth);
|
|
1160
|
+
/** @copydoc nk_dots_pack_u1 */
|
|
1161
|
+
NK_PUBLIC void nk_dots_pack_u1_icelake(nk_u1x8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1162
|
+
void *b_packed);
|
|
1163
|
+
/** @copydoc nk_dots_packed_u1 */
|
|
1164
|
+
NK_PUBLIC void nk_dots_packed_u1_icelake(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
1165
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1166
|
+
/** @copydoc nk_dots_symmetric_u1 */
|
|
1167
|
+
NK_PUBLIC void nk_dots_symmetric_u1_icelake(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1168
|
+
nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
|
|
1169
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1170
|
+
#endif // NK_TARGET_ICELAKE
|
|
1171
|
+
|
|
1172
|
+
/* Alder backends using AMX with TDPB[SU]SD / TDPBF16PS.
|
|
1173
|
+
* Optimized for I8/U8 via AMX integer tiles, E2M3 via AMX BF16 tiles.
|
|
1174
|
+
*/
|
|
1175
|
+
#if NK_TARGET_ALDER
|
|
1176
|
+
/** @copydoc nk_dots_packed_size_i8 */
|
|
1177
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_i8_alder(nk_size_t width, nk_size_t depth);
|
|
1178
|
+
/** @copydoc nk_dots_pack_i8 */
|
|
1179
|
+
NK_PUBLIC void nk_dots_pack_i8_alder(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1180
|
+
void *b_packed);
|
|
1181
|
+
/** @copydoc nk_dots_packed_i8 */
|
|
1182
|
+
NK_PUBLIC void nk_dots_packed_i8_alder(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
|
|
1183
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1184
|
+
/** @copydoc nk_dots_symmetric_i8 */
|
|
1185
|
+
NK_PUBLIC void nk_dots_symmetric_i8_alder(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1186
|
+
nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
|
|
1187
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1188
|
+
/** @copydoc nk_dots_packed_size_u8 */
|
|
1189
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u8_alder(nk_size_t width, nk_size_t depth);
|
|
1190
|
+
/** @copydoc nk_dots_pack_u8 */
|
|
1191
|
+
NK_PUBLIC void nk_dots_pack_u8_alder(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1192
|
+
void *b_packed);
|
|
1193
|
+
/** @copydoc nk_dots_packed_u8 */
|
|
1194
|
+
NK_PUBLIC void nk_dots_packed_u8_alder(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
1195
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1196
|
+
/** @copydoc nk_dots_symmetric_u8 */
|
|
1197
|
+
NK_PUBLIC void nk_dots_symmetric_u8_alder(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1198
|
+
nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
|
|
1199
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1200
|
+
/** @copydoc nk_dots_packed_size_e2m3 */
|
|
1201
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_alder(nk_size_t width, nk_size_t depth);
|
|
1202
|
+
/** @copydoc nk_dots_pack_e2m3 */
|
|
1203
|
+
NK_PUBLIC void nk_dots_pack_e2m3_alder(nk_e2m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1204
|
+
void *b_packed);
|
|
1205
|
+
/** @copydoc nk_dots_packed_e2m3 */
|
|
1206
|
+
NK_PUBLIC void nk_dots_packed_e2m3_alder(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1207
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1208
|
+
/** @copydoc nk_dots_symmetric_e2m3 */
|
|
1209
|
+
NK_PUBLIC void nk_dots_symmetric_e2m3_alder(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1210
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1211
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1212
|
+
#endif // NK_TARGET_ALDER
|
|
1213
|
+
|
|
1214
|
+
/* Sierra backends using AVX10.2 with VMPSADBW.
|
|
1215
|
+
* Optimized for I8/U8 via VMPSADBW (vector multiply-sum of absolute differences).
|
|
1216
|
+
*/
|
|
1217
|
+
#if NK_TARGET_SIERRA
|
|
1218
|
+
/** @copydoc nk_dots_packed_size_i8 */
|
|
1219
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_i8_sierra(nk_size_t width, nk_size_t depth);
|
|
1220
|
+
/** @copydoc nk_dots_pack_i8 */
|
|
1221
|
+
NK_PUBLIC void nk_dots_pack_i8_sierra(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1222
|
+
void *b_packed);
|
|
1223
|
+
/** @copydoc nk_dots_packed_i8 */
|
|
1224
|
+
NK_PUBLIC void nk_dots_packed_i8_sierra(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
|
|
1225
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1226
|
+
/** @copydoc nk_dots_symmetric_i8 */
|
|
1227
|
+
NK_PUBLIC void nk_dots_symmetric_i8_sierra(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1228
|
+
nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
|
|
1229
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1230
|
+
/** @copydoc nk_dots_packed_size_u8 */
|
|
1231
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u8_sierra(nk_size_t width, nk_size_t depth);
|
|
1232
|
+
/** @copydoc nk_dots_pack_u8 */
|
|
1233
|
+
NK_PUBLIC void nk_dots_pack_u8_sierra(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1234
|
+
void *b_packed);
|
|
1235
|
+
/** @copydoc nk_dots_packed_u8 */
|
|
1236
|
+
NK_PUBLIC void nk_dots_packed_u8_sierra(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
1237
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1238
|
+
/** @copydoc nk_dots_symmetric_u8 */
|
|
1239
|
+
NK_PUBLIC void nk_dots_symmetric_u8_sierra(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1240
|
+
nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
|
|
1241
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1242
|
+
/** @copydoc nk_dots_packed_size_e2m3 */
|
|
1243
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_sierra(nk_size_t width, nk_size_t depth);
|
|
1244
|
+
/** @copydoc nk_dots_pack_e2m3 */
|
|
1245
|
+
NK_PUBLIC void nk_dots_pack_e2m3_sierra(nk_e2m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1246
|
+
void *b_packed);
|
|
1247
|
+
/** @copydoc nk_dots_packed_e2m3 */
|
|
1248
|
+
NK_PUBLIC void nk_dots_packed_e2m3_sierra(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1249
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1250
|
+
/** @copydoc nk_dots_symmetric_e2m3 */
|
|
1251
|
+
NK_PUBLIC void nk_dots_symmetric_e2m3_sierra(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1252
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1253
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1254
|
+
#endif // NK_TARGET_SIERRA
|
|
1255
|
+
|
|
1256
|
+
/* WASM Relaxed SIMD backends using wasm_i32x4_relaxed_dot_i8x16_i7x16_add.
|
|
1257
|
+
* Covers I8/U8/E2M3 (depth_simd_dimensions=16), BF16/F32 (4), F64 (2).
|
|
1258
|
+
*/
|
|
1259
|
+
#if NK_TARGET_V128RELAXED
|
|
1260
|
+
/** @copydoc nk_dots_packed_size_i8 */
|
|
1261
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_i8_v128relaxed(nk_size_t width, nk_size_t depth);
|
|
1262
|
+
/** @copydoc nk_dots_pack_i8 */
|
|
1263
|
+
NK_PUBLIC void nk_dots_pack_i8_v128relaxed(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1264
|
+
void *b_packed);
|
|
1265
|
+
/** @copydoc nk_dots_packed_i8 */
|
|
1266
|
+
NK_PUBLIC void nk_dots_packed_i8_v128relaxed(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
|
|
1267
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1268
|
+
/** @copydoc nk_dots_symmetric_i8 */
|
|
1269
|
+
NK_PUBLIC void nk_dots_symmetric_i8_v128relaxed(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1270
|
+
nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
|
|
1271
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1272
|
+
/** @copydoc nk_dots_packed_size_u8 */
|
|
1273
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u8_v128relaxed(nk_size_t width, nk_size_t depth);
|
|
1274
|
+
/** @copydoc nk_dots_pack_u8 */
|
|
1275
|
+
NK_PUBLIC void nk_dots_pack_u8_v128relaxed(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1276
|
+
void *b_packed);
|
|
1277
|
+
/** @copydoc nk_dots_packed_u8 */
|
|
1278
|
+
NK_PUBLIC void nk_dots_packed_u8_v128relaxed(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
1279
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1280
|
+
/** @copydoc nk_dots_symmetric_u8 */
|
|
1281
|
+
NK_PUBLIC void nk_dots_symmetric_u8_v128relaxed(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1282
|
+
nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
|
|
1283
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1284
|
+
/** @copydoc nk_dots_packed_size_e2m3 */
|
|
1285
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_v128relaxed(nk_size_t width, nk_size_t depth);
|
|
1286
|
+
/** @copydoc nk_dots_pack_e2m3 */
|
|
1287
|
+
NK_PUBLIC void nk_dots_pack_e2m3_v128relaxed(nk_e2m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1288
|
+
void *b_packed);
|
|
1289
|
+
/** @copydoc nk_dots_packed_e2m3 */
|
|
1290
|
+
NK_PUBLIC void nk_dots_packed_e2m3_v128relaxed(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1291
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride,
|
|
1292
|
+
nk_size_t c_stride);
|
|
1293
|
+
/** @copydoc nk_dots_symmetric_e2m3 */
|
|
1294
|
+
NK_PUBLIC void nk_dots_symmetric_e2m3_v128relaxed(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1295
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1296
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1297
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
1298
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_v128relaxed(nk_size_t width, nk_size_t depth);
|
|
1299
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
1300
|
+
NK_PUBLIC void nk_dots_pack_bf16_v128relaxed(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1301
|
+
void *b_packed);
|
|
1302
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
1303
|
+
NK_PUBLIC void nk_dots_packed_bf16_v128relaxed(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1304
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride,
|
|
1305
|
+
nk_size_t c_stride);
|
|
1306
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
1307
|
+
NK_PUBLIC void nk_dots_symmetric_bf16_v128relaxed(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1308
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1309
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1310
|
+
/** @copydoc nk_dots_packed_size_f32 */
|
|
1311
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f32_v128relaxed(nk_size_t width, nk_size_t depth);
|
|
1312
|
+
/** @copydoc nk_dots_pack_f32 */
|
|
1313
|
+
NK_PUBLIC void nk_dots_pack_f32_v128relaxed(nk_f32_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1314
|
+
void *b_packed);
|
|
1315
|
+
/** @copydoc nk_dots_packed_f32 */
|
|
1316
|
+
NK_PUBLIC void nk_dots_packed_f32_v128relaxed(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
|
|
1317
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1318
|
+
/** @copydoc nk_dots_symmetric_f32 */
|
|
1319
|
+
NK_PUBLIC void nk_dots_symmetric_f32_v128relaxed(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1320
|
+
nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
|
|
1321
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1322
|
+
/** @copydoc nk_dots_packed_size_f64 */
|
|
1323
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f64_v128relaxed(nk_size_t width, nk_size_t depth);
|
|
1324
|
+
/** @copydoc nk_dots_pack_f64 */
|
|
1325
|
+
NK_PUBLIC void nk_dots_pack_f64_v128relaxed(nk_f64_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1326
|
+
void *b_packed);
|
|
1327
|
+
/** @copydoc nk_dots_packed_f64 */
|
|
1328
|
+
NK_PUBLIC void nk_dots_packed_f64_v128relaxed(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
|
|
1329
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1330
|
+
/** @copydoc nk_dots_symmetric_f64 */
|
|
1331
|
+
NK_PUBLIC void nk_dots_symmetric_f64_v128relaxed(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1332
|
+
nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
|
|
1333
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1334
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
1335
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_v128relaxed(nk_size_t width, nk_size_t depth);
|
|
1336
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
1337
|
+
NK_PUBLIC void nk_dots_pack_e4m3_v128relaxed(nk_e4m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1338
|
+
void *b_packed);
|
|
1339
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
1340
|
+
NK_PUBLIC void nk_dots_packed_e4m3_v128relaxed(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1341
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride,
|
|
1342
|
+
nk_size_t c_stride);
|
|
1343
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
1344
|
+
NK_PUBLIC void nk_dots_symmetric_e4m3_v128relaxed(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1345
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1346
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1347
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
1348
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_v128relaxed(nk_size_t width, nk_size_t depth);
|
|
1349
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
1350
|
+
NK_PUBLIC void nk_dots_pack_e5m2_v128relaxed(nk_e5m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1351
|
+
void *b_packed);
|
|
1352
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
1353
|
+
NK_PUBLIC void nk_dots_packed_e5m2_v128relaxed(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1354
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride,
|
|
1355
|
+
nk_size_t c_stride);
|
|
1356
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
1357
|
+
NK_PUBLIC void nk_dots_symmetric_e5m2_v128relaxed(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1358
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1359
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1360
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
1361
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u4_v128relaxed(nk_size_t width, nk_size_t depth);
|
|
1362
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
1363
|
+
NK_PUBLIC void nk_dots_pack_u4_v128relaxed(nk_u4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1364
|
+
void *b_packed);
|
|
1365
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
1366
|
+
NK_PUBLIC void nk_dots_packed_u4_v128relaxed(nk_u4x2_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
1367
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1368
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
1369
|
+
NK_PUBLIC void nk_dots_symmetric_u4_v128relaxed(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1370
|
+
nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
|
|
1371
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1372
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
1373
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_i4_v128relaxed(nk_size_t width, nk_size_t depth);
|
|
1374
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
1375
|
+
NK_PUBLIC void nk_dots_pack_i4_v128relaxed(nk_i4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1376
|
+
void *b_packed);
|
|
1377
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
1378
|
+
NK_PUBLIC void nk_dots_packed_i4_v128relaxed(nk_i4x2_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
|
|
1379
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1380
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
1381
|
+
NK_PUBLIC void nk_dots_symmetric_i4_v128relaxed(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1382
|
+
nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
|
|
1383
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1384
|
+
/** @copydoc nk_dots_packed_size_u1 */
|
|
1385
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u1_v128relaxed(nk_size_t width, nk_size_t depth);
|
|
1386
|
+
/** @copydoc nk_dots_pack_u1 */
|
|
1387
|
+
NK_PUBLIC void nk_dots_pack_u1_v128relaxed(nk_u1x8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1388
|
+
void *b_packed);
|
|
1389
|
+
/** @copydoc nk_dots_packed_u1 */
|
|
1390
|
+
NK_PUBLIC void nk_dots_packed_u1_v128relaxed(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
1391
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1392
|
+
/** @copydoc nk_dots_symmetric_u1 */
|
|
1393
|
+
NK_PUBLIC void nk_dots_symmetric_u1_v128relaxed(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1394
|
+
nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
|
|
1395
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1396
|
+
#endif // NK_TARGET_V128RELAXED
|
|
1397
|
+
|
|
1398
|
+
/* ARM NEON backends (base NEON with F32/F64 support).
|
|
1399
|
+
* Uses FMLA for F32 dots, FMLA (scalar) for F64.
|
|
1400
|
+
*/
|
|
1401
|
+
#if NK_TARGET_NEON
|
|
1402
|
+
/** @copydoc nk_dots_packed_size_f32 */
|
|
1403
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f32_neon(nk_size_t width, nk_size_t depth);
|
|
1404
|
+
/** @copydoc nk_dots_pack_f32 */
|
|
1405
|
+
NK_PUBLIC void nk_dots_pack_f32_neon(nk_f32_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1406
|
+
void *b_packed);
|
|
1407
|
+
/** @copydoc nk_dots_packed_f32 */
|
|
1408
|
+
NK_PUBLIC void nk_dots_packed_f32_neon(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
|
|
1409
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1410
|
+
/** @copydoc nk_dots_symmetric_f32 */
|
|
1411
|
+
NK_PUBLIC void nk_dots_symmetric_f32_neon(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1412
|
+
nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
|
|
1413
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1414
|
+
/** @copydoc nk_dots_packed_size_f64 */
|
|
1415
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f64_neon(nk_size_t width, nk_size_t depth);
|
|
1416
|
+
/** @copydoc nk_dots_pack_f64 */
|
|
1417
|
+
NK_PUBLIC void nk_dots_pack_f64_neon(nk_f64_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1418
|
+
void *b_packed);
|
|
1419
|
+
/** @copydoc nk_dots_packed_f64 */
|
|
1420
|
+
NK_PUBLIC void nk_dots_packed_f64_neon(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
|
|
1421
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1422
|
+
/** @copydoc nk_dots_symmetric_f64 */
|
|
1423
|
+
NK_PUBLIC void nk_dots_symmetric_f64_neon(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1424
|
+
nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
|
|
1425
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1426
|
+
/** @copydoc nk_dots_packed_size_u1 */
|
|
1427
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u1_neon(nk_size_t width, nk_size_t depth);
|
|
1428
|
+
/** @copydoc nk_dots_pack_u1 */
|
|
1429
|
+
NK_PUBLIC void nk_dots_pack_u1_neon(nk_u1x8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1430
|
+
void *b_packed);
|
|
1431
|
+
/** @copydoc nk_dots_packed_u1 */
|
|
1432
|
+
NK_PUBLIC void nk_dots_packed_u1_neon(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
1433
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1434
|
+
/** @copydoc nk_dots_symmetric_u1 */
|
|
1435
|
+
NK_PUBLIC void nk_dots_symmetric_u1_neon(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1436
|
+
nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
|
|
1437
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1438
|
+
/** @copydoc nk_dots_packed_size_f16 */
|
|
1439
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f16_neon(nk_size_t width, nk_size_t depth);
|
|
1440
|
+
/** @copydoc nk_dots_pack_f16 */
|
|
1441
|
+
NK_PUBLIC void nk_dots_pack_f16_neon(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1442
|
+
void *b_packed);
|
|
1443
|
+
/** @copydoc nk_dots_packed_f16 */
|
|
1444
|
+
NK_PUBLIC void nk_dots_packed_f16_neon(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1445
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1446
|
+
/** @copydoc nk_dots_symmetric_f16 */
|
|
1447
|
+
NK_PUBLIC void nk_dots_symmetric_f16_neon(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1448
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1449
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1450
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
1451
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_neon(nk_size_t width, nk_size_t depth);
|
|
1452
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
1453
|
+
NK_PUBLIC void nk_dots_pack_bf16_neon(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1454
|
+
void *b_packed);
|
|
1455
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
1456
|
+
NK_PUBLIC void nk_dots_packed_bf16_neon(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1457
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1458
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
1459
|
+
NK_PUBLIC void nk_dots_symmetric_bf16_neon(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1460
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1461
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1462
|
+
#endif // NK_TARGET_NEON
|
|
1463
|
+
|
|
1464
|
+
/* ARM NEON with F16 arithmetic (ARMv8.2-A FP16).
|
|
1465
|
+
* Provides native F16 FMLA for half-precision dot products.
|
|
1466
|
+
*/
|
|
1467
|
+
#if NK_TARGET_NEONHALF
|
|
1468
|
+
/** @copydoc nk_dots_packed_size_f16 */
|
|
1469
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f16_neonhalf(nk_size_t width, nk_size_t depth);
|
|
1470
|
+
/** @copydoc nk_dots_pack_f16 */
|
|
1471
|
+
NK_PUBLIC void nk_dots_pack_f16_neonhalf(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1472
|
+
void *b_packed);
|
|
1473
|
+
/** @copydoc nk_dots_packed_f16 */
|
|
1474
|
+
NK_PUBLIC void nk_dots_packed_f16_neonhalf(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1475
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1476
|
+
/** @copydoc nk_dots_symmetric_f16 */
|
|
1477
|
+
NK_PUBLIC void nk_dots_symmetric_f16_neonhalf(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1478
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1479
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1480
|
+
#endif // NK_TARGET_NEONHALF
|
|
1481
|
+
|
|
1482
|
+
/* ARM NEON with BF16 dot product (ARMv8.6-A BF16).
|
|
1483
|
+
* Uses BFDOT/BFMMLA for efficient BF16 matrix operations.
|
|
1484
|
+
*/
|
|
1485
|
+
#if NK_TARGET_NEONBFDOT
|
|
1486
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
1487
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_neonbfdot(nk_size_t width, nk_size_t depth);
|
|
1488
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
1489
|
+
NK_PUBLIC void nk_dots_pack_bf16_neonbfdot(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1490
|
+
void *b_packed);
|
|
1491
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
1492
|
+
NK_PUBLIC void nk_dots_packed_bf16_neonbfdot(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1493
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1494
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
1495
|
+
NK_PUBLIC void nk_dots_symmetric_bf16_neonbfdot(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1496
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1497
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1498
|
+
#endif // NK_TARGET_NEONBFDOT
|
|
1499
|
+
|
|
1500
|
+
/* ARM NEON with signed/unsigned dot product (ARMv8.2-A DotProd).
|
|
1501
|
+
* Provides SDOT/UDOT for I8/U8 vector dot products.
|
|
1502
|
+
*/
|
|
1503
|
+
#if NK_TARGET_NEONSDOT
|
|
1504
|
+
/** @copydoc nk_dots_packed_size_i8 */
|
|
1505
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_i8_neonsdot(nk_size_t width, nk_size_t depth);
|
|
1506
|
+
/** @copydoc nk_dots_pack_i8 */
|
|
1507
|
+
NK_PUBLIC void nk_dots_pack_i8_neonsdot(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1508
|
+
void *b_packed);
|
|
1509
|
+
/** @copydoc nk_dots_packed_i8 */
|
|
1510
|
+
NK_PUBLIC void nk_dots_packed_i8_neonsdot(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
|
|
1511
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1512
|
+
/** @copydoc nk_dots_symmetric_i8 */
|
|
1513
|
+
NK_PUBLIC void nk_dots_symmetric_i8_neonsdot(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1514
|
+
nk_size_t stride, nk_i32_t *result, nk_size_t result_stride,
|
|
1515
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1516
|
+
/** @copydoc nk_dots_packed_size_u8 */
|
|
1517
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u8_neonsdot(nk_size_t width, nk_size_t depth);
|
|
1518
|
+
/** @copydoc nk_dots_pack_u8 */
|
|
1519
|
+
NK_PUBLIC void nk_dots_pack_u8_neonsdot(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1520
|
+
void *b_packed);
|
|
1521
|
+
/** @copydoc nk_dots_packed_u8 */
|
|
1522
|
+
NK_PUBLIC void nk_dots_packed_u8_neonsdot(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
1523
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1524
|
+
/** @copydoc nk_dots_symmetric_u8 */
|
|
1525
|
+
NK_PUBLIC void nk_dots_symmetric_u8_neonsdot(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1526
|
+
nk_size_t stride, nk_u32_t *result, nk_size_t result_stride,
|
|
1527
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1528
|
+
#endif // NK_TARGET_NEONSDOT
|
|
1529
|
+
|
|
1530
|
+
/* ARM NEON with FP16 FML (fused multiply-long, ARMv8.2-A FP16FML).
|
|
1531
|
+
* Uses FMLAL/FMLSL for F16 and custom FP8 (E2M3/E3M2) operations.
|
|
1532
|
+
*/
|
|
1533
|
+
#if NK_TARGET_NEONFHM
|
|
1534
|
+
/** @copydoc nk_dots_packed_size_f16 */
|
|
1535
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f16_neonfhm(nk_size_t width, nk_size_t depth);
|
|
1536
|
+
/** @copydoc nk_dots_pack_f16 */
|
|
1537
|
+
NK_PUBLIC void nk_dots_pack_f16_neonfhm(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1538
|
+
void *b_packed);
|
|
1539
|
+
/** @copydoc nk_dots_packed_f16 */
|
|
1540
|
+
NK_PUBLIC void nk_dots_packed_f16_neonfhm(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1541
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1542
|
+
/** @copydoc nk_dots_symmetric_f16 */
|
|
1543
|
+
NK_PUBLIC void nk_dots_symmetric_f16_neonfhm(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1544
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1545
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1546
|
+
/** @copydoc nk_dots_packed_size_e4m3 */
|
|
1547
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_neonfhm(nk_size_t width, nk_size_t depth);
|
|
1548
|
+
/** @copydoc nk_dots_pack_e4m3 */
|
|
1549
|
+
NK_PUBLIC void nk_dots_pack_e4m3_neonfhm(nk_e4m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1550
|
+
void *b_packed);
|
|
1551
|
+
/** @copydoc nk_dots_packed_e4m3 */
|
|
1552
|
+
NK_PUBLIC void nk_dots_packed_e4m3_neonfhm(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1553
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1554
|
+
/** @copydoc nk_dots_symmetric_e4m3 */
|
|
1555
|
+
NK_PUBLIC void nk_dots_symmetric_e4m3_neonfhm(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1556
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1557
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1558
|
+
/** @copydoc nk_dots_packed_size_e5m2 */
|
|
1559
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_neonfhm(nk_size_t width, nk_size_t depth);
|
|
1560
|
+
/** @copydoc nk_dots_pack_e5m2 */
|
|
1561
|
+
NK_PUBLIC void nk_dots_pack_e5m2_neonfhm(nk_e5m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1562
|
+
void *b_packed);
|
|
1563
|
+
/** @copydoc nk_dots_packed_e5m2 */
|
|
1564
|
+
NK_PUBLIC void nk_dots_packed_e5m2_neonfhm(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1565
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1566
|
+
/** @copydoc nk_dots_symmetric_e5m2 */
|
|
1567
|
+
NK_PUBLIC void nk_dots_symmetric_e5m2_neonfhm(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1568
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1569
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1570
|
+
#endif // NK_TARGET_NEONFHM
|
|
1571
|
+
|
|
1572
|
+
#if NK_TARGET_RVV
|
|
1573
|
+
/** @copydoc nk_dots_packed_size_e2m3 */
|
|
1574
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_rvv(nk_size_t width, nk_size_t depth);
|
|
1575
|
+
/** @copydoc nk_dots_pack_e2m3 */
|
|
1576
|
+
NK_PUBLIC void nk_dots_pack_e2m3_rvv(nk_e2m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1577
|
+
void *b_packed);
|
|
1578
|
+
/** @copydoc nk_dots_packed_e2m3 */
|
|
1579
|
+
NK_PUBLIC void nk_dots_packed_e2m3_rvv(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1580
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1581
|
+
/** @copydoc nk_dots_symmetric_e2m3 */
|
|
1582
|
+
NK_PUBLIC void nk_dots_symmetric_e2m3_rvv(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1583
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1584
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1585
|
+
/** @copydoc nk_dots_packed_size_e3m2 */
|
|
1586
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_rvv(nk_size_t width, nk_size_t depth);
|
|
1587
|
+
/** @copydoc nk_dots_pack_e3m2 */
|
|
1588
|
+
NK_PUBLIC void nk_dots_pack_e3m2_rvv(nk_e3m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1589
|
+
void *b_packed);
|
|
1590
|
+
/** @copydoc nk_dots_packed_e3m2 */
|
|
1591
|
+
NK_PUBLIC void nk_dots_packed_e3m2_rvv(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1592
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1593
|
+
/** @copydoc nk_dots_symmetric_e3m2 */
|
|
1594
|
+
NK_PUBLIC void nk_dots_symmetric_e3m2_rvv(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1595
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1596
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1597
|
+
/** @copydoc nk_dots_packed_size_f32 */
|
|
1598
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f32_rvv(nk_size_t width, nk_size_t depth);
|
|
1599
|
+
/** @copydoc nk_dots_pack_f32 */
|
|
1600
|
+
NK_PUBLIC void nk_dots_pack_f32_rvv(nk_f32_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1601
|
+
void *b_packed);
|
|
1602
|
+
/** @copydoc nk_dots_packed_f32 */
|
|
1603
|
+
NK_PUBLIC void nk_dots_packed_f32_rvv(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
|
|
1604
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1605
|
+
/** @copydoc nk_dots_symmetric_f32 */
|
|
1606
|
+
NK_PUBLIC void nk_dots_symmetric_f32_rvv(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1607
|
+
nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
|
|
1608
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1609
|
+
/** @copydoc nk_dots_packed_size_f64 */
|
|
1610
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f64_rvv(nk_size_t width, nk_size_t depth);
|
|
1611
|
+
/** @copydoc nk_dots_pack_f64 */
|
|
1612
|
+
NK_PUBLIC void nk_dots_pack_f64_rvv(nk_f64_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1613
|
+
void *b_packed);
|
|
1614
|
+
/** @copydoc nk_dots_packed_f64 */
|
|
1615
|
+
NK_PUBLIC void nk_dots_packed_f64_rvv(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
|
|
1616
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1617
|
+
/** @copydoc nk_dots_symmetric_f64 */
|
|
1618
|
+
NK_PUBLIC void nk_dots_symmetric_f64_rvv(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1619
|
+
nk_size_t stride, nk_f64_t *result, nk_size_t result_stride,
|
|
1620
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1621
|
+
/** @copydoc nk_dots_packed_size_bf16 */
|
|
1622
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_rvv(nk_size_t width, nk_size_t depth);
|
|
1623
|
+
/** @copydoc nk_dots_pack_bf16 */
|
|
1624
|
+
NK_PUBLIC void nk_dots_pack_bf16_rvv(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1625
|
+
void *b_packed);
|
|
1626
|
+
/** @copydoc nk_dots_packed_bf16 */
|
|
1627
|
+
NK_PUBLIC void nk_dots_packed_bf16_rvv(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1628
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1629
|
+
/** @copydoc nk_dots_symmetric_bf16 */
|
|
1630
|
+
NK_PUBLIC void nk_dots_symmetric_bf16_rvv(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1631
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1632
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1633
|
+
/** @copydoc nk_dots_packed_size_f16 */
|
|
1634
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f16_rvv(nk_size_t width, nk_size_t depth);
|
|
1635
|
+
/** @copydoc nk_dots_pack_f16 */
|
|
1636
|
+
NK_PUBLIC void nk_dots_pack_f16_rvv(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1637
|
+
void *b_packed);
|
|
1638
|
+
/** @copydoc nk_dots_packed_f16 */
|
|
1639
|
+
NK_PUBLIC void nk_dots_packed_f16_rvv(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1640
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1641
|
+
/** @copydoc nk_dots_symmetric_f16 */
|
|
1642
|
+
NK_PUBLIC void nk_dots_symmetric_f16_rvv(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1643
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1644
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1645
|
+
/** @copydoc nk_dots_packed_size_i8 */
|
|
1646
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_i8_rvv(nk_size_t width, nk_size_t depth);
|
|
1647
|
+
/** @copydoc nk_dots_pack_i8 */
|
|
1648
|
+
NK_PUBLIC void nk_dots_pack_i8_rvv(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1649
|
+
void *b_packed);
|
|
1650
|
+
/** @copydoc nk_dots_packed_i8 */
|
|
1651
|
+
NK_PUBLIC void nk_dots_packed_i8_rvv(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
|
|
1652
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1653
|
+
/** @copydoc nk_dots_symmetric_i8 */
|
|
1654
|
+
NK_PUBLIC void nk_dots_symmetric_i8_rvv(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
1655
|
+
nk_i32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
1656
|
+
nk_size_t row_count);
|
|
1657
|
+
/** @copydoc nk_dots_packed_size_u8 */
|
|
1658
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u8_rvv(nk_size_t width, nk_size_t depth);
|
|
1659
|
+
/** @copydoc nk_dots_pack_u8 */
|
|
1660
|
+
NK_PUBLIC void nk_dots_pack_u8_rvv(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1661
|
+
void *b_packed);
|
|
1662
|
+
/** @copydoc nk_dots_packed_u8 */
|
|
1663
|
+
NK_PUBLIC void nk_dots_packed_u8_rvv(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
1664
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1665
|
+
/** @copydoc nk_dots_symmetric_u8 */
|
|
1666
|
+
NK_PUBLIC void nk_dots_symmetric_u8_rvv(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
1667
|
+
nk_u32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
1668
|
+
nk_size_t row_count);
|
|
1669
|
+
/** @copydoc nk_dots_packed_size_e4m3 */
|
|
1670
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_rvv(nk_size_t width, nk_size_t depth);
|
|
1671
|
+
/** @copydoc nk_dots_pack_e4m3 */
|
|
1672
|
+
NK_PUBLIC void nk_dots_pack_e4m3_rvv(nk_e4m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1673
|
+
void *b_packed);
|
|
1674
|
+
/** @copydoc nk_dots_packed_e4m3 */
|
|
1675
|
+
NK_PUBLIC void nk_dots_packed_e4m3_rvv(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1676
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1677
|
+
/** @copydoc nk_dots_symmetric_e4m3 */
|
|
1678
|
+
NK_PUBLIC void nk_dots_symmetric_e4m3_rvv(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1679
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1680
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1681
|
+
/** @copydoc nk_dots_packed_size_e5m2 */
|
|
1682
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_rvv(nk_size_t width, nk_size_t depth);
|
|
1683
|
+
/** @copydoc nk_dots_pack_e5m2 */
|
|
1684
|
+
NK_PUBLIC void nk_dots_pack_e5m2_rvv(nk_e5m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1685
|
+
void *b_packed);
|
|
1686
|
+
/** @copydoc nk_dots_packed_e5m2 */
|
|
1687
|
+
NK_PUBLIC void nk_dots_packed_e5m2_rvv(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1688
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
1689
|
+
/** @copydoc nk_dots_symmetric_e5m2 */
|
|
1690
|
+
NK_PUBLIC void nk_dots_symmetric_e5m2_rvv(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth,
|
|
1691
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
1692
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
1693
|
+
#endif // NK_TARGET_RVV
|
|
1694
|
+
|
|
1695
|
+
#if defined(__cplusplus)
|
|
1696
|
+
} // extern "C"
|
|
1697
|
+
#endif
|
|
1698
|
+
|
|
1699
|
+
#include "numkong/dots/serial.h"
|
|
1700
|
+
#include "numkong/dots/haswell.h"
|
|
1701
|
+
#include "numkong/dots/skylake.h"
|
|
1702
|
+
#include "numkong/dots/icelake.h"
|
|
1703
|
+
#include "numkong/dots/alder.h"
|
|
1704
|
+
#include "numkong/dots/sierra.h"
|
|
1705
|
+
#include "numkong/dots/genoa.h"
|
|
1706
|
+
#include "numkong/dots/sapphireamx.h"
|
|
1707
|
+
#include "numkong/dots/neon.h"
|
|
1708
|
+
#include "numkong/dots/neonsdot.h"
|
|
1709
|
+
#include "numkong/dots/neonhalf.h"
|
|
1710
|
+
#include "numkong/dots/neonfhm.h"
|
|
1711
|
+
#include "numkong/dots/neonbfdot.h"
|
|
1712
|
+
#include "numkong/dots/sme.h"
|
|
1713
|
+
#include "numkong/dots/smef64.h"
|
|
1714
|
+
#include "numkong/dots/smebi32.h"
|
|
1715
|
+
#include "numkong/dots/rvv.h"
|
|
1716
|
+
#include "numkong/dots/v128relaxed.h"
|
|
1717
|
+
|
|
1718
|
+
#if defined(__cplusplus)
|
|
1719
|
+
extern "C" {
|
|
1720
|
+
#endif
|
|
1721
|
+
|
|
1722
|
+
#if !NK_DYNAMIC_DISPATCH
|
|
1723
|
+
|
|
1724
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f32(nk_size_t width, nk_size_t depth) {
|
|
1725
|
+
#if NK_TARGET_SMEF64
|
|
1726
|
+
return nk_dots_packed_size_f32_smef64(width, depth);
|
|
1727
|
+
#elif NK_TARGET_SKYLAKE
|
|
1728
|
+
return nk_dots_packed_size_f32_skylake(width, depth);
|
|
1729
|
+
#elif NK_TARGET_HASWELL
|
|
1730
|
+
return nk_dots_packed_size_f32_haswell(width, depth);
|
|
1731
|
+
#elif NK_TARGET_NEON
|
|
1732
|
+
return nk_dots_packed_size_f32_neon(width, depth);
|
|
1733
|
+
#elif NK_TARGET_RVV
|
|
1734
|
+
return nk_dots_packed_size_f32_rvv(width, depth);
|
|
1735
|
+
#elif NK_TARGET_V128RELAXED
|
|
1736
|
+
return nk_dots_packed_size_f32_v128relaxed(width, depth);
|
|
1737
|
+
#else
|
|
1738
|
+
return nk_dots_packed_size_f32_serial(width, depth);
|
|
1739
|
+
#endif
|
|
1740
|
+
}
|
|
1741
|
+
|
|
1742
|
+
NK_PUBLIC void nk_dots_pack_f32(nk_f32_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1743
|
+
void *b_packed) {
|
|
1744
|
+
#if NK_TARGET_SMEF64
|
|
1745
|
+
nk_dots_pack_f32_smef64(b, width, depth, b_stride, b_packed);
|
|
1746
|
+
#elif NK_TARGET_SKYLAKE
|
|
1747
|
+
nk_dots_pack_f32_skylake(b, width, depth, b_stride, b_packed);
|
|
1748
|
+
#elif NK_TARGET_HASWELL
|
|
1749
|
+
nk_dots_pack_f32_haswell(b, width, depth, b_stride, b_packed);
|
|
1750
|
+
#elif NK_TARGET_NEON
|
|
1751
|
+
nk_dots_pack_f32_neon(b, width, depth, b_stride, b_packed);
|
|
1752
|
+
#elif NK_TARGET_RVV
|
|
1753
|
+
nk_dots_pack_f32_rvv(b, width, depth, b_stride, b_packed);
|
|
1754
|
+
#elif NK_TARGET_V128RELAXED
|
|
1755
|
+
nk_dots_pack_f32_v128relaxed(b, width, depth, b_stride, b_packed);
|
|
1756
|
+
#else
|
|
1757
|
+
nk_dots_pack_f32_serial(b, width, depth, b_stride, b_packed);
|
|
1758
|
+
#endif
|
|
1759
|
+
}
|
|
1760
|
+
|
|
1761
|
+
NK_PUBLIC void nk_dots_packed_f32(nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
|
|
1762
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
|
|
1763
|
+
#if NK_TARGET_SMEF64
|
|
1764
|
+
nk_dots_packed_f32_smef64(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1765
|
+
#elif NK_TARGET_SKYLAKE
|
|
1766
|
+
nk_dots_packed_f32_skylake(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1767
|
+
#elif NK_TARGET_HASWELL
|
|
1768
|
+
nk_dots_packed_f32_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1769
|
+
#elif NK_TARGET_NEON
|
|
1770
|
+
nk_dots_packed_f32_neon(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1771
|
+
#elif NK_TARGET_RVV
|
|
1772
|
+
nk_dots_packed_f32_rvv(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1773
|
+
#elif NK_TARGET_V128RELAXED
|
|
1774
|
+
nk_dots_packed_f32_v128relaxed(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1775
|
+
#else
|
|
1776
|
+
nk_dots_packed_f32_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1777
|
+
#endif
|
|
1778
|
+
}
|
|
1779
|
+
|
|
1780
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f64(nk_size_t width, nk_size_t depth) {
|
|
1781
|
+
#if NK_TARGET_SMEF64
|
|
1782
|
+
return nk_dots_packed_size_f64_smef64(width, depth);
|
|
1783
|
+
#elif NK_TARGET_SKYLAKE
|
|
1784
|
+
return nk_dots_packed_size_f64_skylake(width, depth);
|
|
1785
|
+
#elif NK_TARGET_HASWELL
|
|
1786
|
+
return nk_dots_packed_size_f64_haswell(width, depth);
|
|
1787
|
+
#elif NK_TARGET_NEON
|
|
1788
|
+
return nk_dots_packed_size_f64_neon(width, depth);
|
|
1789
|
+
#elif NK_TARGET_RVV
|
|
1790
|
+
return nk_dots_packed_size_f64_rvv(width, depth);
|
|
1791
|
+
#elif NK_TARGET_V128RELAXED
|
|
1792
|
+
return nk_dots_packed_size_f64_v128relaxed(width, depth);
|
|
1793
|
+
#else
|
|
1794
|
+
return nk_dots_packed_size_f64_serial(width, depth);
|
|
1795
|
+
#endif
|
|
1796
|
+
}
|
|
1797
|
+
|
|
1798
|
+
NK_PUBLIC void nk_dots_pack_f64(nk_f64_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1799
|
+
void *b_packed) {
|
|
1800
|
+
#if NK_TARGET_SMEF64
|
|
1801
|
+
nk_dots_pack_f64_smef64(b, width, depth, b_stride, b_packed);
|
|
1802
|
+
#elif NK_TARGET_SKYLAKE
|
|
1803
|
+
nk_dots_pack_f64_skylake(b, width, depth, b_stride, b_packed);
|
|
1804
|
+
#elif NK_TARGET_HASWELL
|
|
1805
|
+
nk_dots_pack_f64_haswell(b, width, depth, b_stride, b_packed);
|
|
1806
|
+
#elif NK_TARGET_NEON
|
|
1807
|
+
nk_dots_pack_f64_neon(b, width, depth, b_stride, b_packed);
|
|
1808
|
+
#elif NK_TARGET_RVV
|
|
1809
|
+
nk_dots_pack_f64_rvv(b, width, depth, b_stride, b_packed);
|
|
1810
|
+
#elif NK_TARGET_V128RELAXED
|
|
1811
|
+
nk_dots_pack_f64_v128relaxed(b, width, depth, b_stride, b_packed);
|
|
1812
|
+
#else
|
|
1813
|
+
nk_dots_pack_f64_serial(b, width, depth, b_stride, b_packed);
|
|
1814
|
+
#endif
|
|
1815
|
+
}
|
|
1816
|
+
|
|
1817
|
+
NK_PUBLIC void nk_dots_packed_f64(nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t height,
|
|
1818
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
|
|
1819
|
+
#if NK_TARGET_SMEF64
|
|
1820
|
+
nk_dots_packed_f64_smef64(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1821
|
+
#elif NK_TARGET_SKYLAKE
|
|
1822
|
+
nk_dots_packed_f64_skylake(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1823
|
+
#elif NK_TARGET_HASWELL
|
|
1824
|
+
nk_dots_packed_f64_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1825
|
+
#elif NK_TARGET_NEON
|
|
1826
|
+
nk_dots_packed_f64_neon(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1827
|
+
#elif NK_TARGET_RVV
|
|
1828
|
+
nk_dots_packed_f64_rvv(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1829
|
+
#elif NK_TARGET_V128RELAXED
|
|
1830
|
+
nk_dots_packed_f64_v128relaxed(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1831
|
+
#else
|
|
1832
|
+
nk_dots_packed_f64_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1833
|
+
#endif
|
|
1834
|
+
}
|
|
1835
|
+
|
|
1836
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f16(nk_size_t width, nk_size_t depth) {
|
|
1837
|
+
#if NK_TARGET_SME
|
|
1838
|
+
return nk_dots_packed_size_f16_sme(width, depth);
|
|
1839
|
+
#elif NK_TARGET_NEONFHM
|
|
1840
|
+
return nk_dots_packed_size_f16_neonfhm(width, depth);
|
|
1841
|
+
#elif NK_TARGET_NEONHALF
|
|
1842
|
+
return nk_dots_packed_size_f16_neonhalf(width, depth);
|
|
1843
|
+
#elif NK_TARGET_NEON
|
|
1844
|
+
return nk_dots_packed_size_f16_neon(width, depth);
|
|
1845
|
+
#elif NK_TARGET_SKYLAKE
|
|
1846
|
+
return nk_dots_packed_size_f16_skylake(width, depth);
|
|
1847
|
+
#elif NK_TARGET_HASWELL
|
|
1848
|
+
return nk_dots_packed_size_f16_haswell(width, depth);
|
|
1849
|
+
#elif NK_TARGET_RVV
|
|
1850
|
+
return nk_dots_packed_size_f16_rvv(width, depth);
|
|
1851
|
+
#else
|
|
1852
|
+
return nk_dots_packed_size_f16_serial(width, depth);
|
|
1853
|
+
#endif
|
|
1854
|
+
}
|
|
1855
|
+
|
|
1856
|
+
NK_PUBLIC void nk_dots_pack_f16(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1857
|
+
void *b_packed) {
|
|
1858
|
+
#if NK_TARGET_SME
|
|
1859
|
+
nk_dots_pack_f16_sme(b, width, depth, b_stride, b_packed);
|
|
1860
|
+
#elif NK_TARGET_NEONFHM
|
|
1861
|
+
nk_dots_pack_f16_neonfhm(b, width, depth, b_stride, b_packed);
|
|
1862
|
+
#elif NK_TARGET_NEONHALF
|
|
1863
|
+
nk_dots_pack_f16_neonhalf(b, width, depth, b_stride, b_packed);
|
|
1864
|
+
#elif NK_TARGET_NEON
|
|
1865
|
+
nk_dots_pack_f16_neon(b, width, depth, b_stride, b_packed);
|
|
1866
|
+
#elif NK_TARGET_SKYLAKE
|
|
1867
|
+
nk_dots_pack_f16_skylake(b, width, depth, b_stride, b_packed);
|
|
1868
|
+
#elif NK_TARGET_HASWELL
|
|
1869
|
+
nk_dots_pack_f16_haswell(b, width, depth, b_stride, b_packed);
|
|
1870
|
+
#elif NK_TARGET_RVV
|
|
1871
|
+
nk_dots_pack_f16_rvv(b, width, depth, b_stride, b_packed);
|
|
1872
|
+
#else
|
|
1873
|
+
nk_dots_pack_f16_serial(b, width, depth, b_stride, b_packed);
|
|
1874
|
+
#endif
|
|
1875
|
+
}
|
|
1876
|
+
|
|
1877
|
+
NK_PUBLIC void nk_dots_packed_f16(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1878
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
|
|
1879
|
+
#if NK_TARGET_SME
|
|
1880
|
+
nk_dots_packed_f16_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1881
|
+
#elif NK_TARGET_NEONFHM
|
|
1882
|
+
nk_dots_packed_f16_neonfhm(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1883
|
+
#elif NK_TARGET_NEONHALF
|
|
1884
|
+
nk_dots_packed_f16_neonhalf(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1885
|
+
#elif NK_TARGET_NEON
|
|
1886
|
+
nk_dots_packed_f16_neon(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1887
|
+
#elif NK_TARGET_SKYLAKE
|
|
1888
|
+
nk_dots_packed_f16_skylake(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1889
|
+
#elif NK_TARGET_HASWELL
|
|
1890
|
+
nk_dots_packed_f16_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1891
|
+
#elif NK_TARGET_RVV
|
|
1892
|
+
nk_dots_packed_f16_rvv(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1893
|
+
#else
|
|
1894
|
+
nk_dots_packed_f16_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1895
|
+
#endif
|
|
1896
|
+
}
|
|
1897
|
+
|
|
1898
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_bf16(nk_size_t width, nk_size_t depth) {
|
|
1899
|
+
#if NK_TARGET_SME
|
|
1900
|
+
return nk_dots_packed_size_bf16_sme(width, depth);
|
|
1901
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
1902
|
+
return nk_dots_packed_size_bf16_sapphireamx(width, depth);
|
|
1903
|
+
#elif NK_TARGET_NEONBFDOT
|
|
1904
|
+
return nk_dots_packed_size_bf16_neonbfdot(width, depth);
|
|
1905
|
+
#elif NK_TARGET_GENOA
|
|
1906
|
+
return nk_dots_packed_size_bf16_genoa(width, depth);
|
|
1907
|
+
#elif NK_TARGET_SKYLAKE
|
|
1908
|
+
return nk_dots_packed_size_bf16_skylake(width, depth);
|
|
1909
|
+
#elif NK_TARGET_HASWELL
|
|
1910
|
+
return nk_dots_packed_size_bf16_haswell(width, depth);
|
|
1911
|
+
#elif NK_TARGET_RVV
|
|
1912
|
+
return nk_dots_packed_size_bf16_rvv(width, depth);
|
|
1913
|
+
#elif NK_TARGET_V128RELAXED
|
|
1914
|
+
return nk_dots_packed_size_bf16_v128relaxed(width, depth);
|
|
1915
|
+
#else
|
|
1916
|
+
return nk_dots_packed_size_bf16_serial(width, depth);
|
|
1917
|
+
#endif
|
|
1918
|
+
}
|
|
1919
|
+
|
|
1920
|
+
NK_PUBLIC void nk_dots_pack_bf16(nk_bf16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
1921
|
+
void *b_packed) {
|
|
1922
|
+
#if NK_TARGET_SME
|
|
1923
|
+
nk_dots_pack_bf16_sme(b, width, depth, b_stride, b_packed);
|
|
1924
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
1925
|
+
nk_dots_pack_bf16_sapphireamx(b, width, depth, b_stride, b_packed);
|
|
1926
|
+
#elif NK_TARGET_NEONBFDOT
|
|
1927
|
+
nk_dots_pack_bf16_neonbfdot(b, width, depth, b_stride, b_packed);
|
|
1928
|
+
#elif NK_TARGET_GENOA
|
|
1929
|
+
nk_dots_pack_bf16_genoa(b, width, depth, b_stride, b_packed);
|
|
1930
|
+
#elif NK_TARGET_SKYLAKE
|
|
1931
|
+
nk_dots_pack_bf16_skylake(b, width, depth, b_stride, b_packed);
|
|
1932
|
+
#elif NK_TARGET_HASWELL
|
|
1933
|
+
nk_dots_pack_bf16_haswell(b, width, depth, b_stride, b_packed);
|
|
1934
|
+
#elif NK_TARGET_RVV
|
|
1935
|
+
nk_dots_pack_bf16_rvv(b, width, depth, b_stride, b_packed);
|
|
1936
|
+
#elif NK_TARGET_V128RELAXED
|
|
1937
|
+
nk_dots_pack_bf16_v128relaxed(b, width, depth, b_stride, b_packed);
|
|
1938
|
+
#else
|
|
1939
|
+
nk_dots_pack_bf16_serial(b, width, depth, b_stride, b_packed);
|
|
1940
|
+
#endif
|
|
1941
|
+
}
|
|
1942
|
+
|
|
1943
|
+
NK_PUBLIC void nk_dots_packed_bf16(nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
1944
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
|
|
1945
|
+
#if NK_TARGET_SME
|
|
1946
|
+
nk_dots_packed_bf16_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1947
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
1948
|
+
nk_dots_packed_bf16_sapphireamx(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1949
|
+
#elif NK_TARGET_NEONBFDOT
|
|
1950
|
+
nk_dots_packed_bf16_neonbfdot(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1951
|
+
#elif NK_TARGET_GENOA
|
|
1952
|
+
nk_dots_packed_bf16_genoa(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1953
|
+
#elif NK_TARGET_SKYLAKE
|
|
1954
|
+
nk_dots_packed_bf16_skylake(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1955
|
+
#elif NK_TARGET_HASWELL
|
|
1956
|
+
nk_dots_packed_bf16_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1957
|
+
#elif NK_TARGET_RVV
|
|
1958
|
+
nk_dots_packed_bf16_rvv(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1959
|
+
#elif NK_TARGET_V128RELAXED
|
|
1960
|
+
nk_dots_packed_bf16_v128relaxed(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1961
|
+
#else
|
|
1962
|
+
nk_dots_packed_bf16_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
1963
|
+
#endif
|
|
1964
|
+
}
|
|
1965
|
+
|
|
1966
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_i8(nk_size_t width, nk_size_t depth) {
|
|
1967
|
+
#if NK_TARGET_SME
|
|
1968
|
+
return nk_dots_packed_size_i8_sme(width, depth);
|
|
1969
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
1970
|
+
return nk_dots_packed_size_i8_sapphireamx(width, depth);
|
|
1971
|
+
#elif NK_TARGET_NEONSDOT
|
|
1972
|
+
return nk_dots_packed_size_i8_neonsdot(width, depth);
|
|
1973
|
+
#elif NK_TARGET_ICELAKE
|
|
1974
|
+
return nk_dots_packed_size_i8_icelake(width, depth);
|
|
1975
|
+
#elif NK_TARGET_SIERRA
|
|
1976
|
+
return nk_dots_packed_size_i8_sierra(width, depth);
|
|
1977
|
+
#elif NK_TARGET_ALDER
|
|
1978
|
+
return nk_dots_packed_size_i8_alder(width, depth);
|
|
1979
|
+
#elif NK_TARGET_HASWELL
|
|
1980
|
+
return nk_dots_packed_size_i8_haswell(width, depth);
|
|
1981
|
+
#elif NK_TARGET_RVV
|
|
1982
|
+
return nk_dots_packed_size_i8_rvv(width, depth);
|
|
1983
|
+
#elif NK_TARGET_V128RELAXED
|
|
1984
|
+
return nk_dots_packed_size_i8_v128relaxed(width, depth);
|
|
1985
|
+
#else
|
|
1986
|
+
return nk_dots_packed_size_i8_serial(width, depth);
|
|
1987
|
+
#endif
|
|
1988
|
+
}
|
|
1989
|
+
|
|
1990
|
+
NK_PUBLIC void nk_dots_pack_i8(nk_i8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride, void *b_packed) {
|
|
1991
|
+
#if NK_TARGET_SME
|
|
1992
|
+
nk_dots_pack_i8_sme(b, width, depth, b_stride, b_packed);
|
|
1993
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
1994
|
+
nk_dots_pack_i8_sapphireamx(b, width, depth, b_stride, b_packed);
|
|
1995
|
+
#elif NK_TARGET_NEONSDOT
|
|
1996
|
+
nk_dots_pack_i8_neonsdot(b, width, depth, b_stride, b_packed);
|
|
1997
|
+
#elif NK_TARGET_ICELAKE
|
|
1998
|
+
nk_dots_pack_i8_icelake(b, width, depth, b_stride, b_packed);
|
|
1999
|
+
#elif NK_TARGET_SIERRA
|
|
2000
|
+
nk_dots_pack_i8_sierra(b, width, depth, b_stride, b_packed);
|
|
2001
|
+
#elif NK_TARGET_ALDER
|
|
2002
|
+
nk_dots_pack_i8_alder(b, width, depth, b_stride, b_packed);
|
|
2003
|
+
#elif NK_TARGET_HASWELL
|
|
2004
|
+
nk_dots_pack_i8_haswell(b, width, depth, b_stride, b_packed);
|
|
2005
|
+
#elif NK_TARGET_RVV
|
|
2006
|
+
nk_dots_pack_i8_rvv(b, width, depth, b_stride, b_packed);
|
|
2007
|
+
#elif NK_TARGET_V128RELAXED
|
|
2008
|
+
nk_dots_pack_i8_v128relaxed(b, width, depth, b_stride, b_packed);
|
|
2009
|
+
#else
|
|
2010
|
+
nk_dots_pack_i8_serial(b, width, depth, b_stride, b_packed);
|
|
2011
|
+
#endif
|
|
2012
|
+
}
|
|
2013
|
+
|
|
2014
|
+
NK_PUBLIC void nk_dots_packed_i8(nk_i8_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height, nk_size_t width,
|
|
2015
|
+
nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
|
|
2016
|
+
#if NK_TARGET_SME
|
|
2017
|
+
nk_dots_packed_i8_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2018
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2019
|
+
nk_dots_packed_i8_sapphireamx(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2020
|
+
#elif NK_TARGET_NEONSDOT
|
|
2021
|
+
nk_dots_packed_i8_neonsdot(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2022
|
+
#elif NK_TARGET_ICELAKE
|
|
2023
|
+
nk_dots_packed_i8_icelake(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2024
|
+
#elif NK_TARGET_SIERRA
|
|
2025
|
+
nk_dots_packed_i8_sierra(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2026
|
+
#elif NK_TARGET_ALDER
|
|
2027
|
+
nk_dots_packed_i8_alder(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2028
|
+
#elif NK_TARGET_HASWELL
|
|
2029
|
+
nk_dots_packed_i8_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2030
|
+
#elif NK_TARGET_RVV
|
|
2031
|
+
nk_dots_packed_i8_rvv(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2032
|
+
#elif NK_TARGET_V128RELAXED
|
|
2033
|
+
nk_dots_packed_i8_v128relaxed(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2034
|
+
#else
|
|
2035
|
+
nk_dots_packed_i8_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2036
|
+
#endif
|
|
2037
|
+
}
|
|
2038
|
+
|
|
2039
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u8(nk_size_t width, nk_size_t depth) {
|
|
2040
|
+
#if NK_TARGET_SME
|
|
2041
|
+
return nk_dots_packed_size_u8_sme(width, depth);
|
|
2042
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2043
|
+
return nk_dots_packed_size_u8_sapphireamx(width, depth);
|
|
2044
|
+
#elif NK_TARGET_NEONSDOT
|
|
2045
|
+
return nk_dots_packed_size_u8_neonsdot(width, depth);
|
|
2046
|
+
#elif NK_TARGET_ICELAKE
|
|
2047
|
+
return nk_dots_packed_size_u8_icelake(width, depth);
|
|
2048
|
+
#elif NK_TARGET_SIERRA
|
|
2049
|
+
return nk_dots_packed_size_u8_sierra(width, depth);
|
|
2050
|
+
#elif NK_TARGET_ALDER
|
|
2051
|
+
return nk_dots_packed_size_u8_alder(width, depth);
|
|
2052
|
+
#elif NK_TARGET_HASWELL
|
|
2053
|
+
return nk_dots_packed_size_u8_haswell(width, depth);
|
|
2054
|
+
#elif NK_TARGET_RVV
|
|
2055
|
+
return nk_dots_packed_size_u8_rvv(width, depth);
|
|
2056
|
+
#elif NK_TARGET_V128RELAXED
|
|
2057
|
+
return nk_dots_packed_size_u8_v128relaxed(width, depth);
|
|
2058
|
+
#else
|
|
2059
|
+
return nk_dots_packed_size_u8_serial(width, depth);
|
|
2060
|
+
#endif
|
|
2061
|
+
}
|
|
2062
|
+
|
|
2063
|
+
NK_PUBLIC void nk_dots_pack_u8(nk_u8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride, void *b_packed) {
|
|
2064
|
+
#if NK_TARGET_SME
|
|
2065
|
+
nk_dots_pack_u8_sme(b, width, depth, b_stride, b_packed);
|
|
2066
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2067
|
+
nk_dots_pack_u8_sapphireamx(b, width, depth, b_stride, b_packed);
|
|
2068
|
+
#elif NK_TARGET_NEONSDOT
|
|
2069
|
+
nk_dots_pack_u8_neonsdot(b, width, depth, b_stride, b_packed);
|
|
2070
|
+
#elif NK_TARGET_ICELAKE
|
|
2071
|
+
nk_dots_pack_u8_icelake(b, width, depth, b_stride, b_packed);
|
|
2072
|
+
#elif NK_TARGET_SIERRA
|
|
2073
|
+
nk_dots_pack_u8_sierra(b, width, depth, b_stride, b_packed);
|
|
2074
|
+
#elif NK_TARGET_ALDER
|
|
2075
|
+
nk_dots_pack_u8_alder(b, width, depth, b_stride, b_packed);
|
|
2076
|
+
#elif NK_TARGET_HASWELL
|
|
2077
|
+
nk_dots_pack_u8_haswell(b, width, depth, b_stride, b_packed);
|
|
2078
|
+
#elif NK_TARGET_RVV
|
|
2079
|
+
nk_dots_pack_u8_rvv(b, width, depth, b_stride, b_packed);
|
|
2080
|
+
#elif NK_TARGET_V128RELAXED
|
|
2081
|
+
nk_dots_pack_u8_v128relaxed(b, width, depth, b_stride, b_packed);
|
|
2082
|
+
#else
|
|
2083
|
+
nk_dots_pack_u8_serial(b, width, depth, b_stride, b_packed);
|
|
2084
|
+
#endif
|
|
2085
|
+
}
|
|
2086
|
+
|
|
2087
|
+
NK_PUBLIC void nk_dots_packed_u8(nk_u8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height, nk_size_t width,
|
|
2088
|
+
nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
|
|
2089
|
+
#if NK_TARGET_SME
|
|
2090
|
+
nk_dots_packed_u8_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2091
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2092
|
+
nk_dots_packed_u8_sapphireamx(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2093
|
+
#elif NK_TARGET_NEONSDOT
|
|
2094
|
+
nk_dots_packed_u8_neonsdot(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2095
|
+
#elif NK_TARGET_ICELAKE
|
|
2096
|
+
nk_dots_packed_u8_icelake(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2097
|
+
#elif NK_TARGET_SIERRA
|
|
2098
|
+
nk_dots_packed_u8_sierra(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2099
|
+
#elif NK_TARGET_ALDER
|
|
2100
|
+
nk_dots_packed_u8_alder(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2101
|
+
#elif NK_TARGET_HASWELL
|
|
2102
|
+
nk_dots_packed_u8_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2103
|
+
#elif NK_TARGET_RVV
|
|
2104
|
+
nk_dots_packed_u8_rvv(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2105
|
+
#elif NK_TARGET_V128RELAXED
|
|
2106
|
+
nk_dots_packed_u8_v128relaxed(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2107
|
+
#else
|
|
2108
|
+
nk_dots_packed_u8_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2109
|
+
#endif
|
|
2110
|
+
}
|
|
2111
|
+
|
|
2112
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3(nk_size_t width, nk_size_t depth) {
|
|
2113
|
+
#if NK_TARGET_SME
|
|
2114
|
+
return nk_dots_packed_size_e4m3_sme(width, depth);
|
|
2115
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2116
|
+
return nk_dots_packed_size_e4m3_sapphireamx(width, depth);
|
|
2117
|
+
#elif NK_TARGET_NEONFHM
|
|
2118
|
+
return nk_dots_packed_size_e4m3_neonfhm(width, depth);
|
|
2119
|
+
#elif NK_TARGET_GENOA
|
|
2120
|
+
return nk_dots_packed_size_e4m3_genoa(width, depth);
|
|
2121
|
+
#elif NK_TARGET_SKYLAKE
|
|
2122
|
+
return nk_dots_packed_size_e4m3_skylake(width, depth);
|
|
2123
|
+
#elif NK_TARGET_HASWELL
|
|
2124
|
+
return nk_dots_packed_size_e4m3_haswell(width, depth);
|
|
2125
|
+
#elif NK_TARGET_RVV
|
|
2126
|
+
return nk_dots_packed_size_e4m3_rvv(width, depth);
|
|
2127
|
+
#elif NK_TARGET_V128RELAXED
|
|
2128
|
+
return nk_dots_packed_size_e4m3_v128relaxed(width, depth);
|
|
2129
|
+
#else
|
|
2130
|
+
return nk_dots_packed_size_e4m3_serial(width, depth);
|
|
2131
|
+
#endif
|
|
2132
|
+
}
|
|
2133
|
+
|
|
2134
|
+
NK_PUBLIC void nk_dots_pack_e4m3(nk_e4m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
2135
|
+
void *b_packed) {
|
|
2136
|
+
#if NK_TARGET_SME
|
|
2137
|
+
nk_dots_pack_e4m3_sme(b, width, depth, b_stride, b_packed);
|
|
2138
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2139
|
+
nk_dots_pack_e4m3_sapphireamx(b, width, depth, b_stride, b_packed);
|
|
2140
|
+
#elif NK_TARGET_NEONFHM
|
|
2141
|
+
nk_dots_pack_e4m3_neonfhm(b, width, depth, b_stride, b_packed);
|
|
2142
|
+
#elif NK_TARGET_GENOA
|
|
2143
|
+
nk_dots_pack_e4m3_genoa(b, width, depth, b_stride, b_packed);
|
|
2144
|
+
#elif NK_TARGET_SKYLAKE
|
|
2145
|
+
nk_dots_pack_e4m3_skylake(b, width, depth, b_stride, b_packed);
|
|
2146
|
+
#elif NK_TARGET_HASWELL
|
|
2147
|
+
nk_dots_pack_e4m3_haswell(b, width, depth, b_stride, b_packed);
|
|
2148
|
+
#elif NK_TARGET_RVV
|
|
2149
|
+
nk_dots_pack_e4m3_rvv(b, width, depth, b_stride, b_packed);
|
|
2150
|
+
#elif NK_TARGET_V128RELAXED
|
|
2151
|
+
nk_dots_pack_e4m3_v128relaxed(b, width, depth, b_stride, b_packed);
|
|
2152
|
+
#else
|
|
2153
|
+
nk_dots_pack_e4m3_serial(b, width, depth, b_stride, b_packed);
|
|
2154
|
+
#endif
|
|
2155
|
+
}
|
|
2156
|
+
|
|
2157
|
+
NK_PUBLIC void nk_dots_packed_e4m3(nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
2158
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
|
|
2159
|
+
#if NK_TARGET_SME
|
|
2160
|
+
nk_dots_packed_e4m3_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2161
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2162
|
+
nk_dots_packed_e4m3_sapphireamx(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2163
|
+
#elif NK_TARGET_NEONFHM
|
|
2164
|
+
nk_dots_packed_e4m3_neonfhm(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2165
|
+
#elif NK_TARGET_GENOA
|
|
2166
|
+
nk_dots_packed_e4m3_genoa(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2167
|
+
#elif NK_TARGET_SKYLAKE
|
|
2168
|
+
nk_dots_packed_e4m3_skylake(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2169
|
+
#elif NK_TARGET_HASWELL
|
|
2170
|
+
nk_dots_packed_e4m3_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2171
|
+
#elif NK_TARGET_RVV
|
|
2172
|
+
nk_dots_packed_e4m3_rvv(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2173
|
+
#elif NK_TARGET_V128RELAXED
|
|
2174
|
+
nk_dots_packed_e4m3_v128relaxed(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2175
|
+
#else
|
|
2176
|
+
nk_dots_packed_e4m3_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2177
|
+
#endif
|
|
2178
|
+
}
|
|
2179
|
+
|
|
2180
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2(nk_size_t width, nk_size_t depth) {
|
|
2181
|
+
#if NK_TARGET_SME
|
|
2182
|
+
return nk_dots_packed_size_e5m2_sme(width, depth);
|
|
2183
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2184
|
+
return nk_dots_packed_size_e5m2_sapphireamx(width, depth);
|
|
2185
|
+
#elif NK_TARGET_NEONFHM
|
|
2186
|
+
return nk_dots_packed_size_e5m2_neonfhm(width, depth);
|
|
2187
|
+
#elif NK_TARGET_GENOA
|
|
2188
|
+
return nk_dots_packed_size_e5m2_genoa(width, depth);
|
|
2189
|
+
#elif NK_TARGET_SKYLAKE
|
|
2190
|
+
return nk_dots_packed_size_e5m2_skylake(width, depth);
|
|
2191
|
+
#elif NK_TARGET_HASWELL
|
|
2192
|
+
return nk_dots_packed_size_e5m2_haswell(width, depth);
|
|
2193
|
+
#elif NK_TARGET_RVV
|
|
2194
|
+
return nk_dots_packed_size_e5m2_rvv(width, depth);
|
|
2195
|
+
#elif NK_TARGET_V128RELAXED
|
|
2196
|
+
return nk_dots_packed_size_e5m2_v128relaxed(width, depth);
|
|
2197
|
+
#else
|
|
2198
|
+
return nk_dots_packed_size_e5m2_serial(width, depth);
|
|
2199
|
+
#endif
|
|
2200
|
+
}
|
|
2201
|
+
|
|
2202
|
+
NK_PUBLIC void nk_dots_pack_e5m2(nk_e5m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
2203
|
+
void *b_packed) {
|
|
2204
|
+
#if NK_TARGET_SME
|
|
2205
|
+
nk_dots_pack_e5m2_sme(b, width, depth, b_stride, b_packed);
|
|
2206
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2207
|
+
nk_dots_pack_e5m2_sapphireamx(b, width, depth, b_stride, b_packed);
|
|
2208
|
+
#elif NK_TARGET_NEONFHM
|
|
2209
|
+
nk_dots_pack_e5m2_neonfhm(b, width, depth, b_stride, b_packed);
|
|
2210
|
+
#elif NK_TARGET_GENOA
|
|
2211
|
+
nk_dots_pack_e5m2_genoa(b, width, depth, b_stride, b_packed);
|
|
2212
|
+
#elif NK_TARGET_SKYLAKE
|
|
2213
|
+
nk_dots_pack_e5m2_skylake(b, width, depth, b_stride, b_packed);
|
|
2214
|
+
#elif NK_TARGET_HASWELL
|
|
2215
|
+
nk_dots_pack_e5m2_haswell(b, width, depth, b_stride, b_packed);
|
|
2216
|
+
#elif NK_TARGET_RVV
|
|
2217
|
+
nk_dots_pack_e5m2_rvv(b, width, depth, b_stride, b_packed);
|
|
2218
|
+
#elif NK_TARGET_V128RELAXED
|
|
2219
|
+
nk_dots_pack_e5m2_v128relaxed(b, width, depth, b_stride, b_packed);
|
|
2220
|
+
#else
|
|
2221
|
+
nk_dots_pack_e5m2_serial(b, width, depth, b_stride, b_packed);
|
|
2222
|
+
#endif
|
|
2223
|
+
}
|
|
2224
|
+
|
|
2225
|
+
NK_PUBLIC void nk_dots_packed_e5m2(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
2226
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
|
|
2227
|
+
#if NK_TARGET_SME
|
|
2228
|
+
nk_dots_packed_e5m2_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2229
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2230
|
+
nk_dots_packed_e5m2_sapphireamx(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2231
|
+
#elif NK_TARGET_NEONFHM
|
|
2232
|
+
nk_dots_packed_e5m2_neonfhm(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2233
|
+
#elif NK_TARGET_GENOA
|
|
2234
|
+
nk_dots_packed_e5m2_genoa(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2235
|
+
#elif NK_TARGET_SKYLAKE
|
|
2236
|
+
nk_dots_packed_e5m2_skylake(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2237
|
+
#elif NK_TARGET_HASWELL
|
|
2238
|
+
nk_dots_packed_e5m2_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2239
|
+
#elif NK_TARGET_RVV
|
|
2240
|
+
nk_dots_packed_e5m2_rvv(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2241
|
+
#elif NK_TARGET_V128RELAXED
|
|
2242
|
+
nk_dots_packed_e5m2_v128relaxed(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2243
|
+
#else
|
|
2244
|
+
nk_dots_packed_e5m2_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2245
|
+
#endif
|
|
2246
|
+
}
|
|
2247
|
+
|
|
2248
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3(nk_size_t width, nk_size_t depth) {
|
|
2249
|
+
#if NK_TARGET_SME
|
|
2250
|
+
return nk_dots_packed_size_e2m3_sme(width, depth);
|
|
2251
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2252
|
+
return nk_dots_packed_size_e2m3_sapphireamx(width, depth);
|
|
2253
|
+
#elif NK_TARGET_SKYLAKE
|
|
2254
|
+
return nk_dots_packed_size_e2m3_skylake(width, depth);
|
|
2255
|
+
#elif NK_TARGET_SIERRA
|
|
2256
|
+
return nk_dots_packed_size_e2m3_sierra(width, depth);
|
|
2257
|
+
#elif NK_TARGET_ALDER
|
|
2258
|
+
return nk_dots_packed_size_e2m3_alder(width, depth);
|
|
2259
|
+
#elif NK_TARGET_HASWELL
|
|
2260
|
+
return nk_dots_packed_size_e2m3_haswell(width, depth);
|
|
2261
|
+
#elif NK_TARGET_RVV
|
|
2262
|
+
return nk_dots_packed_size_e2m3_rvv(width, depth);
|
|
2263
|
+
#elif NK_TARGET_V128RELAXED
|
|
2264
|
+
return nk_dots_packed_size_e2m3_v128relaxed(width, depth);
|
|
2265
|
+
#else
|
|
2266
|
+
return nk_dots_packed_size_e2m3_serial(width, depth);
|
|
2267
|
+
#endif
|
|
2268
|
+
}
|
|
2269
|
+
|
|
2270
|
+
NK_PUBLIC void nk_dots_pack_e2m3(nk_e2m3_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
2271
|
+
void *b_packed) {
|
|
2272
|
+
#if NK_TARGET_SME
|
|
2273
|
+
nk_dots_pack_e2m3_sme(b, width, depth, b_stride, b_packed);
|
|
2274
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2275
|
+
nk_dots_pack_e2m3_sapphireamx(b, width, depth, b_stride, b_packed);
|
|
2276
|
+
#elif NK_TARGET_SKYLAKE
|
|
2277
|
+
nk_dots_pack_e2m3_skylake(b, width, depth, b_stride, b_packed);
|
|
2278
|
+
#elif NK_TARGET_SIERRA
|
|
2279
|
+
nk_dots_pack_e2m3_sierra(b, width, depth, b_stride, b_packed);
|
|
2280
|
+
#elif NK_TARGET_ALDER
|
|
2281
|
+
nk_dots_pack_e2m3_alder(b, width, depth, b_stride, b_packed);
|
|
2282
|
+
#elif NK_TARGET_HASWELL
|
|
2283
|
+
nk_dots_pack_e2m3_haswell(b, width, depth, b_stride, b_packed);
|
|
2284
|
+
#elif NK_TARGET_RVV
|
|
2285
|
+
nk_dots_pack_e2m3_rvv(b, width, depth, b_stride, b_packed);
|
|
2286
|
+
#elif NK_TARGET_V128RELAXED
|
|
2287
|
+
nk_dots_pack_e2m3_v128relaxed(b, width, depth, b_stride, b_packed);
|
|
2288
|
+
#else
|
|
2289
|
+
nk_dots_pack_e2m3_serial(b, width, depth, b_stride, b_packed);
|
|
2290
|
+
#endif
|
|
2291
|
+
}
|
|
2292
|
+
|
|
2293
|
+
NK_PUBLIC void nk_dots_packed_e2m3(nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
2294
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
|
|
2295
|
+
#if NK_TARGET_SME
|
|
2296
|
+
nk_dots_packed_e2m3_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2297
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2298
|
+
nk_dots_packed_e2m3_sapphireamx(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2299
|
+
#elif NK_TARGET_SKYLAKE
|
|
2300
|
+
nk_dots_packed_e2m3_skylake(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2301
|
+
#elif NK_TARGET_SIERRA
|
|
2302
|
+
nk_dots_packed_e2m3_sierra(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2303
|
+
#elif NK_TARGET_ALDER
|
|
2304
|
+
nk_dots_packed_e2m3_alder(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2305
|
+
#elif NK_TARGET_HASWELL
|
|
2306
|
+
nk_dots_packed_e2m3_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2307
|
+
#elif NK_TARGET_RVV
|
|
2308
|
+
nk_dots_packed_e2m3_rvv(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2309
|
+
#elif NK_TARGET_V128RELAXED
|
|
2310
|
+
nk_dots_packed_e2m3_v128relaxed(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2311
|
+
#else
|
|
2312
|
+
nk_dots_packed_e2m3_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2313
|
+
#endif
|
|
2314
|
+
}
|
|
2315
|
+
|
|
2316
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2(nk_size_t width, nk_size_t depth) {
|
|
2317
|
+
#if NK_TARGET_SME
|
|
2318
|
+
return nk_dots_packed_size_e3m2_sme(width, depth);
|
|
2319
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2320
|
+
return nk_dots_packed_size_e3m2_sapphireamx(width, depth);
|
|
2321
|
+
#elif NK_TARGET_SKYLAKE
|
|
2322
|
+
return nk_dots_packed_size_e3m2_skylake(width, depth);
|
|
2323
|
+
#elif NK_TARGET_HASWELL
|
|
2324
|
+
return nk_dots_packed_size_e3m2_haswell(width, depth);
|
|
2325
|
+
#elif NK_TARGET_RVV
|
|
2326
|
+
return nk_dots_packed_size_e3m2_rvv(width, depth);
|
|
2327
|
+
#else
|
|
2328
|
+
return nk_dots_packed_size_e3m2_serial(width, depth);
|
|
2329
|
+
#endif
|
|
2330
|
+
}
|
|
2331
|
+
|
|
2332
|
+
NK_PUBLIC void nk_dots_pack_e3m2(nk_e3m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
2333
|
+
void *b_packed) {
|
|
2334
|
+
#if NK_TARGET_SME
|
|
2335
|
+
nk_dots_pack_e3m2_sme(b, width, depth, b_stride, b_packed);
|
|
2336
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2337
|
+
nk_dots_pack_e3m2_sapphireamx(b, width, depth, b_stride, b_packed);
|
|
2338
|
+
#elif NK_TARGET_SKYLAKE
|
|
2339
|
+
nk_dots_pack_e3m2_skylake(b, width, depth, b_stride, b_packed);
|
|
2340
|
+
#elif NK_TARGET_HASWELL
|
|
2341
|
+
nk_dots_pack_e3m2_haswell(b, width, depth, b_stride, b_packed);
|
|
2342
|
+
#elif NK_TARGET_RVV
|
|
2343
|
+
nk_dots_pack_e3m2_rvv(b, width, depth, b_stride, b_packed);
|
|
2344
|
+
#else
|
|
2345
|
+
nk_dots_pack_e3m2_serial(b, width, depth, b_stride, b_packed);
|
|
2346
|
+
#endif
|
|
2347
|
+
}
|
|
2348
|
+
|
|
2349
|
+
NK_PUBLIC void nk_dots_packed_e3m2(nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
2350
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
|
|
2351
|
+
#if NK_TARGET_SME
|
|
2352
|
+
nk_dots_packed_e3m2_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2353
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2354
|
+
nk_dots_packed_e3m2_sapphireamx(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2355
|
+
#elif NK_TARGET_SKYLAKE
|
|
2356
|
+
nk_dots_packed_e3m2_skylake(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2357
|
+
#elif NK_TARGET_HASWELL
|
|
2358
|
+
nk_dots_packed_e3m2_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2359
|
+
#elif NK_TARGET_RVV
|
|
2360
|
+
nk_dots_packed_e3m2_rvv(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2361
|
+
#else
|
|
2362
|
+
nk_dots_packed_e3m2_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2363
|
+
#endif
|
|
2364
|
+
}
|
|
2365
|
+
|
|
2366
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u4(nk_size_t width, nk_size_t depth) {
|
|
2367
|
+
#if NK_TARGET_SME
|
|
2368
|
+
return nk_dots_packed_size_u4_sme(width, depth);
|
|
2369
|
+
#elif NK_TARGET_ICELAKE
|
|
2370
|
+
return nk_dots_packed_size_u4_icelake(width, depth);
|
|
2371
|
+
#elif NK_TARGET_NEONSDOT
|
|
2372
|
+
return nk_dots_packed_size_u4_neonsdot(width, depth);
|
|
2373
|
+
#elif NK_TARGET_HASWELL
|
|
2374
|
+
return nk_dots_packed_size_u4_haswell(width, depth);
|
|
2375
|
+
#elif NK_TARGET_V128RELAXED
|
|
2376
|
+
return nk_dots_packed_size_u4_v128relaxed(width, depth);
|
|
2377
|
+
#else
|
|
2378
|
+
return nk_dots_packed_size_u4_serial(width, depth);
|
|
2379
|
+
#endif
|
|
2380
|
+
}
|
|
2381
|
+
|
|
2382
|
+
NK_PUBLIC void nk_dots_pack_u4(nk_u4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
2383
|
+
void *b_packed) {
|
|
2384
|
+
#if NK_TARGET_SME
|
|
2385
|
+
nk_dots_pack_u4_sme(b, width, depth, b_stride, b_packed);
|
|
2386
|
+
#elif NK_TARGET_ICELAKE
|
|
2387
|
+
nk_dots_pack_u4_icelake(b, width, depth, b_stride, b_packed);
|
|
2388
|
+
#elif NK_TARGET_NEONSDOT
|
|
2389
|
+
nk_dots_pack_u4_neonsdot(b, width, depth, b_stride, b_packed);
|
|
2390
|
+
#elif NK_TARGET_HASWELL
|
|
2391
|
+
nk_dots_pack_u4_haswell(b, width, depth, b_stride, b_packed);
|
|
2392
|
+
#elif NK_TARGET_V128RELAXED
|
|
2393
|
+
nk_dots_pack_u4_v128relaxed(b, width, depth, b_stride, b_packed);
|
|
2394
|
+
#else
|
|
2395
|
+
nk_dots_pack_u4_serial(b, width, depth, b_stride, b_packed);
|
|
2396
|
+
#endif
|
|
2397
|
+
}
|
|
2398
|
+
|
|
2399
|
+
NK_PUBLIC void nk_dots_packed_u4(nk_u4x2_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
2400
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
|
|
2401
|
+
#if NK_TARGET_SME
|
|
2402
|
+
nk_dots_packed_u4_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2403
|
+
#elif NK_TARGET_ICELAKE
|
|
2404
|
+
nk_dots_packed_u4_icelake(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2405
|
+
#elif NK_TARGET_NEONSDOT
|
|
2406
|
+
nk_dots_packed_u4_neonsdot(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2407
|
+
#elif NK_TARGET_HASWELL
|
|
2408
|
+
nk_dots_packed_u4_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2409
|
+
#elif NK_TARGET_V128RELAXED
|
|
2410
|
+
nk_dots_packed_u4_v128relaxed(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2411
|
+
#else
|
|
2412
|
+
nk_dots_packed_u4_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2413
|
+
#endif
|
|
2414
|
+
}
|
|
2415
|
+
|
|
2416
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u1(nk_size_t width, nk_size_t depth) {
|
|
2417
|
+
#if NK_TARGET_SMEBI32
|
|
2418
|
+
return nk_dots_packed_size_u1_smebi32(width, depth);
|
|
2419
|
+
#elif NK_TARGET_ICELAKE
|
|
2420
|
+
return nk_dots_packed_size_u1_icelake(width, depth);
|
|
2421
|
+
#elif NK_TARGET_HASWELL
|
|
2422
|
+
return nk_dots_packed_size_u1_haswell(width, depth);
|
|
2423
|
+
#elif NK_TARGET_NEON
|
|
2424
|
+
return nk_dots_packed_size_u1_neon(width, depth);
|
|
2425
|
+
#elif NK_TARGET_V128RELAXED
|
|
2426
|
+
return nk_dots_packed_size_u1_v128relaxed(width, depth);
|
|
2427
|
+
#else
|
|
2428
|
+
return nk_dots_packed_size_u1_serial(width, depth);
|
|
2429
|
+
#endif
|
|
2430
|
+
}
|
|
2431
|
+
|
|
2432
|
+
NK_PUBLIC void nk_dots_pack_u1(nk_u1x8_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
2433
|
+
void *b_packed) {
|
|
2434
|
+
#if NK_TARGET_SMEBI32
|
|
2435
|
+
nk_dots_pack_u1_smebi32(b, width, depth, b_stride, b_packed);
|
|
2436
|
+
#elif NK_TARGET_ICELAKE
|
|
2437
|
+
nk_dots_pack_u1_icelake(b, width, depth, b_stride, b_packed);
|
|
2438
|
+
#elif NK_TARGET_HASWELL
|
|
2439
|
+
nk_dots_pack_u1_haswell(b, width, depth, b_stride, b_packed);
|
|
2440
|
+
#elif NK_TARGET_NEON
|
|
2441
|
+
nk_dots_pack_u1_neon(b, width, depth, b_stride, b_packed);
|
|
2442
|
+
#elif NK_TARGET_V128RELAXED
|
|
2443
|
+
nk_dots_pack_u1_v128relaxed(b, width, depth, b_stride, b_packed);
|
|
2444
|
+
#else
|
|
2445
|
+
nk_dots_pack_u1_serial(b, width, depth, b_stride, b_packed);
|
|
2446
|
+
#endif
|
|
2447
|
+
}
|
|
2448
|
+
|
|
2449
|
+
NK_PUBLIC void nk_dots_packed_u1(nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t height,
|
|
2450
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
|
|
2451
|
+
#if NK_TARGET_SMEBI32
|
|
2452
|
+
nk_dots_packed_u1_smebi32(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2453
|
+
#elif NK_TARGET_ICELAKE
|
|
2454
|
+
nk_dots_packed_u1_icelake(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2455
|
+
#elif NK_TARGET_HASWELL
|
|
2456
|
+
nk_dots_packed_u1_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2457
|
+
#elif NK_TARGET_NEON
|
|
2458
|
+
nk_dots_packed_u1_neon(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2459
|
+
#elif NK_TARGET_V128RELAXED
|
|
2460
|
+
nk_dots_packed_u1_v128relaxed(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2461
|
+
#else
|
|
2462
|
+
nk_dots_packed_u1_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2463
|
+
#endif
|
|
2464
|
+
}
|
|
2465
|
+
|
|
2466
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_i4(nk_size_t width, nk_size_t depth) {
|
|
2467
|
+
#if NK_TARGET_SME
|
|
2468
|
+
return nk_dots_packed_size_i4_sme(width, depth);
|
|
2469
|
+
#elif NK_TARGET_ICELAKE
|
|
2470
|
+
return nk_dots_packed_size_i4_icelake(width, depth);
|
|
2471
|
+
#elif NK_TARGET_NEONSDOT
|
|
2472
|
+
return nk_dots_packed_size_i4_neonsdot(width, depth);
|
|
2473
|
+
#elif NK_TARGET_HASWELL
|
|
2474
|
+
return nk_dots_packed_size_i4_haswell(width, depth);
|
|
2475
|
+
#elif NK_TARGET_V128RELAXED
|
|
2476
|
+
return nk_dots_packed_size_i4_v128relaxed(width, depth);
|
|
2477
|
+
#else
|
|
2478
|
+
return nk_dots_packed_size_i4_serial(width, depth);
|
|
2479
|
+
#endif
|
|
2480
|
+
}
|
|
2481
|
+
|
|
2482
|
+
NK_PUBLIC void nk_dots_pack_i4(nk_i4x2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
2483
|
+
void *b_packed) {
|
|
2484
|
+
#if NK_TARGET_SME
|
|
2485
|
+
nk_dots_pack_i4_sme(b, width, depth, b_stride, b_packed);
|
|
2486
|
+
#elif NK_TARGET_ICELAKE
|
|
2487
|
+
nk_dots_pack_i4_icelake(b, width, depth, b_stride, b_packed);
|
|
2488
|
+
#elif NK_TARGET_NEONSDOT
|
|
2489
|
+
nk_dots_pack_i4_neonsdot(b, width, depth, b_stride, b_packed);
|
|
2490
|
+
#elif NK_TARGET_HASWELL
|
|
2491
|
+
nk_dots_pack_i4_haswell(b, width, depth, b_stride, b_packed);
|
|
2492
|
+
#elif NK_TARGET_V128RELAXED
|
|
2493
|
+
nk_dots_pack_i4_v128relaxed(b, width, depth, b_stride, b_packed);
|
|
2494
|
+
#else
|
|
2495
|
+
nk_dots_pack_i4_serial(b, width, depth, b_stride, b_packed);
|
|
2496
|
+
#endif
|
|
2497
|
+
}
|
|
2498
|
+
|
|
2499
|
+
NK_PUBLIC void nk_dots_packed_i4(nk_i4x2_t const *a, void const *b_packed, nk_i32_t *c, nk_size_t height,
|
|
2500
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
|
|
2501
|
+
#if NK_TARGET_SME
|
|
2502
|
+
nk_dots_packed_i4_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2503
|
+
#elif NK_TARGET_ICELAKE
|
|
2504
|
+
nk_dots_packed_i4_icelake(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2505
|
+
#elif NK_TARGET_NEONSDOT
|
|
2506
|
+
nk_dots_packed_i4_neonsdot(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2507
|
+
#elif NK_TARGET_HASWELL
|
|
2508
|
+
nk_dots_packed_i4_haswell(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2509
|
+
#elif NK_TARGET_V128RELAXED
|
|
2510
|
+
nk_dots_packed_i4_v128relaxed(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2511
|
+
#else
|
|
2512
|
+
nk_dots_packed_i4_serial(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2513
|
+
#endif
|
|
2514
|
+
}
|
|
2515
|
+
|
|
2516
|
+
NK_PUBLIC void nk_dots_symmetric_f16(nk_f16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
2517
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
2518
|
+
nk_size_t row_count) {
|
|
2519
|
+
#if NK_TARGET_SME
|
|
2520
|
+
nk_dots_symmetric_f16_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2521
|
+
#elif NK_TARGET_NEONHALF
|
|
2522
|
+
nk_dots_symmetric_f16_neonhalf(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2523
|
+
#elif NK_TARGET_NEON
|
|
2524
|
+
nk_dots_symmetric_f16_neon(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2525
|
+
#elif NK_TARGET_NEONFHM
|
|
2526
|
+
nk_dots_symmetric_f16_neonfhm(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2527
|
+
#elif NK_TARGET_SKYLAKE
|
|
2528
|
+
nk_dots_symmetric_f16_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2529
|
+
#elif NK_TARGET_HASWELL
|
|
2530
|
+
nk_dots_symmetric_f16_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2531
|
+
#elif NK_TARGET_RVV
|
|
2532
|
+
nk_dots_symmetric_f16_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2533
|
+
#else
|
|
2534
|
+
nk_dots_symmetric_f16_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2535
|
+
#endif
|
|
2536
|
+
}
|
|
2537
|
+
|
|
2538
|
+
NK_PUBLIC void nk_dots_symmetric_bf16(nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
2539
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
2540
|
+
nk_size_t row_count) {
|
|
2541
|
+
#if NK_TARGET_SME
|
|
2542
|
+
nk_dots_symmetric_bf16_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2543
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2544
|
+
nk_dots_symmetric_bf16_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2545
|
+
#elif NK_TARGET_NEONBFDOT
|
|
2546
|
+
nk_dots_symmetric_bf16_neonbfdot(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2547
|
+
#elif NK_TARGET_GENOA
|
|
2548
|
+
nk_dots_symmetric_bf16_genoa(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2549
|
+
#elif NK_TARGET_SKYLAKE
|
|
2550
|
+
nk_dots_symmetric_bf16_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2551
|
+
#elif NK_TARGET_HASWELL
|
|
2552
|
+
nk_dots_symmetric_bf16_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2553
|
+
#elif NK_TARGET_RVV
|
|
2554
|
+
nk_dots_symmetric_bf16_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2555
|
+
#elif NK_TARGET_V128RELAXED
|
|
2556
|
+
nk_dots_symmetric_bf16_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2557
|
+
#else
|
|
2558
|
+
nk_dots_symmetric_bf16_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2559
|
+
#endif
|
|
2560
|
+
}
|
|
2561
|
+
|
|
2562
|
+
NK_PUBLIC void nk_dots_symmetric_i8(nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
2563
|
+
nk_i32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
2564
|
+
nk_size_t row_count) {
|
|
2565
|
+
#if NK_TARGET_SME
|
|
2566
|
+
nk_dots_symmetric_i8_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2567
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2568
|
+
nk_dots_symmetric_i8_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2569
|
+
#elif NK_TARGET_NEONSDOT
|
|
2570
|
+
nk_dots_symmetric_i8_neonsdot(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2571
|
+
#elif NK_TARGET_ICELAKE
|
|
2572
|
+
nk_dots_symmetric_i8_icelake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2573
|
+
#elif NK_TARGET_SIERRA
|
|
2574
|
+
nk_dots_symmetric_i8_sierra(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2575
|
+
#elif NK_TARGET_ALDER
|
|
2576
|
+
nk_dots_symmetric_i8_alder(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2577
|
+
#elif NK_TARGET_HASWELL
|
|
2578
|
+
nk_dots_symmetric_i8_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2579
|
+
#elif NK_TARGET_RVV
|
|
2580
|
+
nk_dots_symmetric_i8_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2581
|
+
#elif NK_TARGET_V128RELAXED
|
|
2582
|
+
nk_dots_symmetric_i8_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2583
|
+
#else
|
|
2584
|
+
nk_dots_symmetric_i8_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2585
|
+
#endif
|
|
2586
|
+
}
|
|
2587
|
+
|
|
2588
|
+
NK_PUBLIC void nk_dots_symmetric_u8(nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
2589
|
+
nk_u32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
2590
|
+
nk_size_t row_count) {
|
|
2591
|
+
#if NK_TARGET_SME
|
|
2592
|
+
nk_dots_symmetric_u8_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2593
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2594
|
+
nk_dots_symmetric_u8_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2595
|
+
#elif NK_TARGET_ICELAKE
|
|
2596
|
+
nk_dots_symmetric_u8_icelake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2597
|
+
#elif NK_TARGET_SIERRA
|
|
2598
|
+
nk_dots_symmetric_u8_sierra(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2599
|
+
#elif NK_TARGET_ALDER
|
|
2600
|
+
nk_dots_symmetric_u8_alder(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2601
|
+
#elif NK_TARGET_NEONSDOT
|
|
2602
|
+
nk_dots_symmetric_u8_neonsdot(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2603
|
+
#elif NK_TARGET_HASWELL
|
|
2604
|
+
nk_dots_symmetric_u8_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2605
|
+
#elif NK_TARGET_RVV
|
|
2606
|
+
nk_dots_symmetric_u8_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2607
|
+
#elif NK_TARGET_V128RELAXED
|
|
2608
|
+
nk_dots_symmetric_u8_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2609
|
+
#else
|
|
2610
|
+
nk_dots_symmetric_u8_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2611
|
+
#endif
|
|
2612
|
+
}
|
|
2613
|
+
|
|
2614
|
+
NK_PUBLIC void nk_dots_symmetric_e4m3(nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
2615
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
2616
|
+
nk_size_t row_count) {
|
|
2617
|
+
#if NK_TARGET_SME
|
|
2618
|
+
nk_dots_symmetric_e4m3_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2619
|
+
#elif NK_TARGET_NEONFHM
|
|
2620
|
+
nk_dots_symmetric_e4m3_neonfhm(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2621
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2622
|
+
nk_dots_symmetric_e4m3_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2623
|
+
#elif NK_TARGET_GENOA
|
|
2624
|
+
nk_dots_symmetric_e4m3_genoa(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2625
|
+
#elif NK_TARGET_SKYLAKE
|
|
2626
|
+
nk_dots_symmetric_e4m3_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2627
|
+
#elif NK_TARGET_HASWELL
|
|
2628
|
+
nk_dots_symmetric_e4m3_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2629
|
+
#elif NK_TARGET_RVV
|
|
2630
|
+
nk_dots_symmetric_e4m3_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2631
|
+
#elif NK_TARGET_V128RELAXED
|
|
2632
|
+
nk_dots_symmetric_e4m3_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2633
|
+
#else
|
|
2634
|
+
nk_dots_symmetric_e4m3_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2635
|
+
#endif
|
|
2636
|
+
}
|
|
2637
|
+
|
|
2638
|
+
NK_PUBLIC void nk_dots_symmetric_e5m2(nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
2639
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
2640
|
+
nk_size_t row_count) {
|
|
2641
|
+
#if NK_TARGET_SME
|
|
2642
|
+
nk_dots_symmetric_e5m2_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2643
|
+
#elif NK_TARGET_NEONFHM
|
|
2644
|
+
nk_dots_symmetric_e5m2_neonfhm(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2645
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2646
|
+
nk_dots_symmetric_e5m2_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2647
|
+
#elif NK_TARGET_GENOA
|
|
2648
|
+
nk_dots_symmetric_e5m2_genoa(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2649
|
+
#elif NK_TARGET_SKYLAKE
|
|
2650
|
+
nk_dots_symmetric_e5m2_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2651
|
+
#elif NK_TARGET_HASWELL
|
|
2652
|
+
nk_dots_symmetric_e5m2_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2653
|
+
#elif NK_TARGET_RVV
|
|
2654
|
+
nk_dots_symmetric_e5m2_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2655
|
+
#elif NK_TARGET_V128RELAXED
|
|
2656
|
+
nk_dots_symmetric_e5m2_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2657
|
+
#else
|
|
2658
|
+
nk_dots_symmetric_e5m2_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2659
|
+
#endif
|
|
2660
|
+
}
|
|
2661
|
+
|
|
2662
|
+
NK_PUBLIC void nk_dots_symmetric_e2m3(nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
2663
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
2664
|
+
nk_size_t row_count) {
|
|
2665
|
+
#if NK_TARGET_SME
|
|
2666
|
+
nk_dots_symmetric_e2m3_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2667
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2668
|
+
nk_dots_symmetric_e2m3_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2669
|
+
#elif NK_TARGET_SKYLAKE
|
|
2670
|
+
nk_dots_symmetric_e2m3_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2671
|
+
#elif NK_TARGET_SIERRA
|
|
2672
|
+
nk_dots_symmetric_e2m3_sierra(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2673
|
+
#elif NK_TARGET_ALDER
|
|
2674
|
+
nk_dots_symmetric_e2m3_alder(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2675
|
+
#elif NK_TARGET_HASWELL
|
|
2676
|
+
nk_dots_symmetric_e2m3_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2677
|
+
#elif NK_TARGET_RVV
|
|
2678
|
+
nk_dots_symmetric_e2m3_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2679
|
+
#elif NK_TARGET_V128RELAXED
|
|
2680
|
+
nk_dots_symmetric_e2m3_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2681
|
+
#else
|
|
2682
|
+
nk_dots_symmetric_e2m3_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2683
|
+
#endif
|
|
2684
|
+
}
|
|
2685
|
+
|
|
2686
|
+
NK_PUBLIC void nk_dots_symmetric_e3m2(nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
2687
|
+
nk_f32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
2688
|
+
nk_size_t row_count) {
|
|
2689
|
+
#if NK_TARGET_SME
|
|
2690
|
+
nk_dots_symmetric_e3m2_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2691
|
+
#elif NK_TARGET_SAPPHIREAMX
|
|
2692
|
+
nk_dots_symmetric_e3m2_sapphireamx(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2693
|
+
#elif NK_TARGET_SKYLAKE
|
|
2694
|
+
nk_dots_symmetric_e3m2_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2695
|
+
#elif NK_TARGET_HASWELL
|
|
2696
|
+
nk_dots_symmetric_e3m2_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2697
|
+
#elif NK_TARGET_RVV
|
|
2698
|
+
nk_dots_symmetric_e3m2_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2699
|
+
#else
|
|
2700
|
+
nk_dots_symmetric_e3m2_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2701
|
+
#endif
|
|
2702
|
+
}
|
|
2703
|
+
|
|
2704
|
+
NK_PUBLIC void nk_dots_symmetric_u4(nk_u4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
2705
|
+
nk_u32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
2706
|
+
nk_size_t row_count) {
|
|
2707
|
+
#if NK_TARGET_SME
|
|
2708
|
+
nk_dots_symmetric_u4_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2709
|
+
#elif NK_TARGET_ICELAKE
|
|
2710
|
+
nk_dots_symmetric_u4_icelake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2711
|
+
#elif NK_TARGET_NEONSDOT
|
|
2712
|
+
nk_dots_symmetric_u4_neonsdot(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2713
|
+
#elif NK_TARGET_HASWELL
|
|
2714
|
+
nk_dots_symmetric_u4_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2715
|
+
#elif NK_TARGET_V128RELAXED
|
|
2716
|
+
nk_dots_symmetric_u4_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2717
|
+
#else
|
|
2718
|
+
nk_dots_symmetric_u4_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2719
|
+
#endif
|
|
2720
|
+
}
|
|
2721
|
+
|
|
2722
|
+
NK_PUBLIC void nk_dots_symmetric_u1(nk_u1x8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
2723
|
+
nk_u32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
2724
|
+
nk_size_t row_count) {
|
|
2725
|
+
#if NK_TARGET_SMEBI32
|
|
2726
|
+
nk_dots_symmetric_u1_smebi32(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2727
|
+
#elif NK_TARGET_ICELAKE
|
|
2728
|
+
nk_dots_symmetric_u1_icelake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2729
|
+
#elif NK_TARGET_HASWELL
|
|
2730
|
+
nk_dots_symmetric_u1_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2731
|
+
#elif NK_TARGET_NEON
|
|
2732
|
+
nk_dots_symmetric_u1_neon(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2733
|
+
#elif NK_TARGET_V128RELAXED
|
|
2734
|
+
nk_dots_symmetric_u1_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2735
|
+
#else
|
|
2736
|
+
nk_dots_symmetric_u1_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2737
|
+
#endif
|
|
2738
|
+
}
|
|
2739
|
+
|
|
2740
|
+
NK_PUBLIC void nk_dots_symmetric_i4(nk_i4x2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
2741
|
+
nk_i32_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
2742
|
+
nk_size_t row_count) {
|
|
2743
|
+
#if NK_TARGET_SME
|
|
2744
|
+
nk_dots_symmetric_i4_sme(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2745
|
+
#elif NK_TARGET_ICELAKE
|
|
2746
|
+
nk_dots_symmetric_i4_icelake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2747
|
+
#elif NK_TARGET_NEONSDOT
|
|
2748
|
+
nk_dots_symmetric_i4_neonsdot(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2749
|
+
#elif NK_TARGET_HASWELL
|
|
2750
|
+
nk_dots_symmetric_i4_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2751
|
+
#elif NK_TARGET_V128RELAXED
|
|
2752
|
+
nk_dots_symmetric_i4_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2753
|
+
#else
|
|
2754
|
+
nk_dots_symmetric_i4_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2755
|
+
#endif
|
|
2756
|
+
}
|
|
2757
|
+
|
|
2758
|
+
NK_PUBLIC void nk_dots_symmetric_f32(nk_f32_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
2759
|
+
nk_f64_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
2760
|
+
nk_size_t row_count) {
|
|
2761
|
+
#if NK_TARGET_SMEF64
|
|
2762
|
+
nk_dots_symmetric_f32_smef64(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2763
|
+
#elif NK_TARGET_SKYLAKE
|
|
2764
|
+
nk_dots_symmetric_f32_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2765
|
+
#elif NK_TARGET_HASWELL
|
|
2766
|
+
nk_dots_symmetric_f32_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2767
|
+
#elif NK_TARGET_NEON
|
|
2768
|
+
nk_dots_symmetric_f32_neon(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2769
|
+
#elif NK_TARGET_RVV
|
|
2770
|
+
nk_dots_symmetric_f32_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2771
|
+
#elif NK_TARGET_V128RELAXED
|
|
2772
|
+
nk_dots_symmetric_f32_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2773
|
+
#else
|
|
2774
|
+
nk_dots_symmetric_f32_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2775
|
+
#endif
|
|
2776
|
+
}
|
|
2777
|
+
|
|
2778
|
+
NK_PUBLIC void nk_dots_symmetric_f64(nk_f64_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride,
|
|
2779
|
+
nk_f64_t *result, nk_size_t result_stride, nk_size_t row_start,
|
|
2780
|
+
nk_size_t row_count) {
|
|
2781
|
+
#if NK_TARGET_SMEF64
|
|
2782
|
+
nk_dots_symmetric_f64_smef64(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2783
|
+
#elif NK_TARGET_SKYLAKE
|
|
2784
|
+
nk_dots_symmetric_f64_skylake(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2785
|
+
#elif NK_TARGET_HASWELL
|
|
2786
|
+
nk_dots_symmetric_f64_haswell(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2787
|
+
#elif NK_TARGET_NEON
|
|
2788
|
+
nk_dots_symmetric_f64_neon(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2789
|
+
#elif NK_TARGET_RVV
|
|
2790
|
+
nk_dots_symmetric_f64_rvv(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2791
|
+
#elif NK_TARGET_V128RELAXED
|
|
2792
|
+
nk_dots_symmetric_f64_v128relaxed(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2793
|
+
#else
|
|
2794
|
+
nk_dots_symmetric_f64_serial(vectors, n_vectors, depth, stride, result, result_stride, row_start, row_count);
|
|
2795
|
+
#endif
|
|
2796
|
+
}
|
|
2797
|
+
|
|
2798
|
+
#endif // !NK_DYNAMIC_DISPATCH
|
|
2799
|
+
|
|
2800
|
+
#if defined(__cplusplus)
|
|
2801
|
+
} // extern "C"
|
|
2802
|
+
#endif
|
|
2803
|
+
|
|
2804
|
+
#endif
|