numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,2146 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Elementwise Arithmetic.
|
|
3
|
+
* @file include/numkong/each.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date October 16, 2024
|
|
6
|
+
*
|
|
7
|
+
* Contains following element-wise operations:
|
|
8
|
+
*
|
|
9
|
+
* - Scale (Multiply) with shift: result[i] = alpha * a[i] + beta
|
|
10
|
+
* - Sum (Add): result[i] = a[i] + b[i]
|
|
11
|
+
* - Blend: result[i] = alpha * a[i] + beta * b[i]
|
|
12
|
+
* - FMA (Fused Multiply-Add): result[i] = alpha * a[i] * b[i] + beta * c[i]
|
|
13
|
+
*
|
|
14
|
+
* Beyond their obvious usecases, those can be reused for vector-scalar math and other operations:
|
|
15
|
+
*
|
|
16
|
+
* - Scale with beta = 0 for a pure multiply.
|
|
17
|
+
* - Sum is equivalent to WSum with alpha = beta = 1.
|
|
18
|
+
* - Average is WSum with alpha = beta = 0.5.
|
|
19
|
+
* - Elementwise multiply is FMA with beta = 0.
|
|
20
|
+
*
|
|
21
|
+
* For dtypes:
|
|
22
|
+
*
|
|
23
|
+
* - f64: 64-bit IEEE floating point numbers × 64-bit scales
|
|
24
|
+
* - f32: 32-bit IEEE floating point numbers × 32-bit scales
|
|
25
|
+
* - f16: 16-bit IEEE floating point numbers × 32-bit scales
|
|
26
|
+
* - bf16: 16-bit brain floating point numbers × 32-bit scales
|
|
27
|
+
* - e4m3: 8-bit e4m3 floating point numbers × 32-bit scales
|
|
28
|
+
* - e5m2: 8-bit e5m2 floating point numbers × 32-bit scales
|
|
29
|
+
* - e2m3: 8-bit e2m3 floating point numbers (MX) × 32-bit scales
|
|
30
|
+
* - e3m2: 8-bit e3m2 floating point numbers (MX) × 32-bit scales
|
|
31
|
+
* - i8/u8: 8-bit signed and unsigned integers × 32-bit scales
|
|
32
|
+
* - i16/u16: 16-bit signed and unsigned integers × 32-bit scales
|
|
33
|
+
* - i32/u32: 32-bit signed and unsigned integers × 64-bit scales
|
|
34
|
+
* - i64/u64: 64-bit signed and unsigned integers × 64-bit scales
|
|
35
|
+
*
|
|
36
|
+
* For hardware architectures:
|
|
37
|
+
*
|
|
38
|
+
* - Arm: NEON, NEON+F16, NEON+BF16
|
|
39
|
+
* - x86: Haswell, Skylake, Ice Lake, Sapphire Rapids
|
|
40
|
+
* - RISC-V: RVV
|
|
41
|
+
*
|
|
42
|
+
*
|
|
43
|
+
* @section numerical_stability Numerical Stability
|
|
44
|
+
*
|
|
45
|
+
* Integer sum is elementwise a[i]+b[i] clamped to the type's range. Serial widens to
|
|
46
|
+
* i64 then clamps on store. NEON uses hardware saturating adds (SQADD/UQADD).
|
|
47
|
+
* f16/bf16/FP8 sum: promoted to f32, added, truncated back — double rounding possible.
|
|
48
|
+
* Scale/blend/fma: float alpha/beta arithmetic, result rounded to nearest, ties to even, then clamped.
|
|
49
|
+
* f32/f64 operations are native precision with no widening.
|
|
50
|
+
*
|
|
51
|
+
* @section x86_instructions Relevant x86 Instructions
|
|
52
|
+
*
|
|
53
|
+
* FP16 conversions (VCVTPH2PS/VCVTPS2PH) are used for f16 scale/sum/blend/fma operations, converting
|
|
54
|
+
* to f32 for arithmetic then back. The 6-7 cycle latency is amortized over vector-width elements.
|
|
55
|
+
* Saturating integer adds (VPADDSW/VPADDUSW) provide overflow protection for i16/u16 sums without
|
|
56
|
+
* branching. FMA (VFMADD231PS) is the workhorse for scale (alpha*x+beta) and blend (alpha*a+beta*b).
|
|
57
|
+
*
|
|
58
|
+
* Intrinsic Instruction Ice Genoa
|
|
59
|
+
* _mm512_cvtph_ps VCVTPH2PS (ZMM, YMM) 7c @ p0+p5 6c @ p12+p23
|
|
60
|
+
* _mm512_cvtps_ph VCVTPS2PH (YMM, ZMM, I8) 7c @ p0+p5 7c @ p12+p23
|
|
61
|
+
* _mm256_adds_epi16 VPADDSW (YMM, YMM, YMM) 1c @ p01 N/A
|
|
62
|
+
* _mm256_adds_epu16 VPADDUSW (YMM, YMM, YMM) 1c @ p01 N/A
|
|
63
|
+
* _mm512_fpclass_ps_mask VFPCLASSPS (K, ZMM, I8) 3c @ p5 5c @ p01
|
|
64
|
+
* _mm256_fmadd_ps VFMADD231PS (YMM, YMM, YMM) 4c @ p01 4c @ p01
|
|
65
|
+
*
|
|
66
|
+
* @section arm_instructions Relevant ARM NEON/SVE Instructions
|
|
67
|
+
*
|
|
68
|
+
* On ARM, i8/u8 elementwise operations convert to f16 intermediates using FCVT to maintain high
|
|
69
|
+
* vector throughput (8 elements per 128-bit register vs 4 for f32). Saturating adds (SQADD/UQADD)
|
|
70
|
+
* handle integer overflow. FMLA provides fused multiply-add for floating-point scale/blend/fma.
|
|
71
|
+
*
|
|
72
|
+
* Intrinsic Instruction M1 Firestorm Graviton 3 Graviton 4
|
|
73
|
+
* vfmaq_f32 FMLA.S (vec) 4c @ V0123 4c @ V0123 4c @ V0123
|
|
74
|
+
* vqaddq_s16 SQADD (vec) 3c @ V0123 2c @ V0123 2c @ V0123
|
|
75
|
+
* vqaddq_u16 UQADD (vec) 3c @ V0123 2c @ V0123 2c @ V0123
|
|
76
|
+
* vcvtq_f32_s32 SCVTF (vec) 3c @ V0123 3c @ V01 3c @ V01
|
|
77
|
+
* vcvtnq_s32_f32 FCVTNS (vec) 3c @ V0123 3c @ V01 3c @ V01
|
|
78
|
+
*
|
|
79
|
+
* @section references References
|
|
80
|
+
*
|
|
81
|
+
* - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
|
|
82
|
+
* - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
|
|
83
|
+
*
|
|
84
|
+
*/
|
|
85
|
+
#ifndef NK_EACH_H
|
|
86
|
+
#define NK_EACH_H
|
|
87
|
+
|
|
88
|
+
#include "numkong/types.h"
|
|
89
|
+
|
|
90
|
+
#if defined(__cplusplus)
|
|
91
|
+
extern "C" {
|
|
92
|
+
#endif
|
|
93
|
+
|
|
94
|
+
/**
|
|
95
|
+
* @brief Element-wise scale with shift: result[i] = alpha * a[i] + beta.
|
|
96
|
+
*
|
|
97
|
+
* @param[in] a The input vector.
|
|
98
|
+
* @param[in] n The number of elements in the vector.
|
|
99
|
+
* @param[in] alpha Pointer to the scaling factor (type depends on input precision).
|
|
100
|
+
* @param[in] beta Pointer to the shift (bias) value (type depends on input precision).
|
|
101
|
+
* @param[out] result The output vector.
|
|
102
|
+
*/
|
|
103
|
+
NK_DYNAMIC void nk_each_scale_f64(nk_f64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
104
|
+
nk_f64_t *result);
|
|
105
|
+
/** @copydoc nk_each_scale_f64 */
|
|
106
|
+
NK_DYNAMIC void nk_each_scale_f32(nk_f32_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
107
|
+
nk_f32_t *result);
|
|
108
|
+
/** @copydoc nk_each_scale_f64 */
|
|
109
|
+
NK_DYNAMIC void nk_each_scale_f16(nk_f16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
110
|
+
nk_f16_t *result);
|
|
111
|
+
/** @copydoc nk_each_scale_f64 */
|
|
112
|
+
NK_DYNAMIC void nk_each_scale_bf16(nk_bf16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
113
|
+
nk_bf16_t *result);
|
|
114
|
+
/** @copydoc nk_each_scale_f64 */
|
|
115
|
+
NK_DYNAMIC void nk_each_scale_i8(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
116
|
+
nk_i8_t *result);
|
|
117
|
+
/** @copydoc nk_each_scale_f64 */
|
|
118
|
+
NK_DYNAMIC void nk_each_scale_u8(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
119
|
+
nk_u8_t *result);
|
|
120
|
+
/** @copydoc nk_each_scale_f64 */
|
|
121
|
+
NK_DYNAMIC void nk_each_scale_i16(nk_i16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
122
|
+
nk_i16_t *result);
|
|
123
|
+
/** @copydoc nk_each_scale_f64 */
|
|
124
|
+
NK_DYNAMIC void nk_each_scale_u16(nk_u16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
125
|
+
nk_u16_t *result);
|
|
126
|
+
/** @copydoc nk_each_scale_f64 */
|
|
127
|
+
NK_DYNAMIC void nk_each_scale_i32(nk_i32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
128
|
+
nk_i32_t *result);
|
|
129
|
+
/** @copydoc nk_each_scale_f64 */
|
|
130
|
+
NK_DYNAMIC void nk_each_scale_u32(nk_u32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
131
|
+
nk_u32_t *result);
|
|
132
|
+
/** @copydoc nk_each_scale_f64 */
|
|
133
|
+
NK_DYNAMIC void nk_each_scale_i64(nk_i64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
134
|
+
nk_i64_t *result);
|
|
135
|
+
/** @copydoc nk_each_scale_f64 */
|
|
136
|
+
NK_DYNAMIC void nk_each_scale_u64(nk_u64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
137
|
+
nk_u64_t *result);
|
|
138
|
+
|
|
139
|
+
/**
|
|
140
|
+
* @brief Element-wise sum: result[i] = a[i] + b[i].
|
|
141
|
+
*
|
|
142
|
+
* @param[in] a The first input vector.
|
|
143
|
+
* @param[in] b The second input vector.
|
|
144
|
+
* @param[in] n The number of elements in the vectors.
|
|
145
|
+
* @param[out] result The output vector.
|
|
146
|
+
*/
|
|
147
|
+
NK_DYNAMIC void nk_each_sum_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
148
|
+
/** @copydoc nk_each_sum_f64 */
|
|
149
|
+
NK_DYNAMIC void nk_each_sum_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *result);
|
|
150
|
+
/** @copydoc nk_each_sum_f64 */
|
|
151
|
+
NK_DYNAMIC void nk_each_sum_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f16_t *result);
|
|
152
|
+
/** @copydoc nk_each_sum_f64 */
|
|
153
|
+
NK_DYNAMIC void nk_each_sum_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_bf16_t *result);
|
|
154
|
+
/** @copydoc nk_each_sum_f64 */
|
|
155
|
+
NK_DYNAMIC void nk_each_sum_i8(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i8_t *result);
|
|
156
|
+
/** @copydoc nk_each_sum_f64 */
|
|
157
|
+
NK_DYNAMIC void nk_each_sum_u8(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u8_t *result);
|
|
158
|
+
/** @copydoc nk_each_sum_f64 */
|
|
159
|
+
NK_DYNAMIC void nk_each_sum_i16(nk_i16_t const *a, nk_i16_t const *b, nk_size_t n, nk_i16_t *result);
|
|
160
|
+
/** @copydoc nk_each_sum_f64 */
|
|
161
|
+
NK_DYNAMIC void nk_each_sum_u16(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_u16_t *result);
|
|
162
|
+
/** @copydoc nk_each_sum_f64 */
|
|
163
|
+
NK_DYNAMIC void nk_each_sum_i32(nk_i32_t const *a, nk_i32_t const *b, nk_size_t n, nk_i32_t *result);
|
|
164
|
+
/** @copydoc nk_each_sum_f64 */
|
|
165
|
+
NK_DYNAMIC void nk_each_sum_u32(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_u32_t *result);
|
|
166
|
+
/** @copydoc nk_each_sum_f64 */
|
|
167
|
+
NK_DYNAMIC void nk_each_sum_i64(nk_i64_t const *a, nk_i64_t const *b, nk_size_t n, nk_i64_t *result);
|
|
168
|
+
/** @copydoc nk_each_sum_f64 */
|
|
169
|
+
NK_DYNAMIC void nk_each_sum_u64(nk_u64_t const *a, nk_u64_t const *b, nk_size_t n, nk_u64_t *result);
|
|
170
|
+
|
|
171
|
+
/**
|
|
172
|
+
* @brief Weighted sum: result[i] = alpha * a[i] + beta * b[i].
|
|
173
|
+
*
|
|
174
|
+
* @param[in] a The first input vector.
|
|
175
|
+
* @param[in] b The second input vector.
|
|
176
|
+
* @param[in] n The number of elements in the vectors.
|
|
177
|
+
* @param[in] alpha Pointer to the first weight (type depends on input precision).
|
|
178
|
+
* @param[in] beta Pointer to the second weight (type depends on input precision).
|
|
179
|
+
* @param[out] result The output vector.
|
|
180
|
+
*/
|
|
181
|
+
NK_DYNAMIC void nk_each_blend_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t const *alpha,
|
|
182
|
+
nk_f64_t const *beta, nk_f64_t *result);
|
|
183
|
+
/** @copydoc nk_each_blend_f64 */
|
|
184
|
+
NK_DYNAMIC void nk_each_blend_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
185
|
+
nk_f32_t const *beta, nk_f32_t *result);
|
|
186
|
+
/** @copydoc nk_each_blend_f64 */
|
|
187
|
+
NK_DYNAMIC void nk_each_blend_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
188
|
+
nk_f32_t const *beta, nk_f16_t *result);
|
|
189
|
+
/** @copydoc nk_each_blend_f64 */
|
|
190
|
+
NK_DYNAMIC void nk_each_blend_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
191
|
+
nk_f32_t const *beta, nk_bf16_t *result);
|
|
192
|
+
/** @copydoc nk_each_blend_f64 */
|
|
193
|
+
NK_DYNAMIC void nk_each_blend_i8(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
194
|
+
nk_f32_t const *beta, nk_i8_t *result);
|
|
195
|
+
/** @copydoc nk_each_blend_f64 */
|
|
196
|
+
NK_DYNAMIC void nk_each_blend_u8(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
197
|
+
nk_f32_t const *beta, nk_u8_t *result);
|
|
198
|
+
/** @copydoc nk_each_blend_f64 */
|
|
199
|
+
NK_DYNAMIC void nk_each_blend_i16(nk_i16_t const *a, nk_i16_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
200
|
+
nk_f32_t const *beta, nk_i16_t *result);
|
|
201
|
+
/** @copydoc nk_each_blend_f64 */
|
|
202
|
+
NK_DYNAMIC void nk_each_blend_u16(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
203
|
+
nk_f32_t const *beta, nk_u16_t *result);
|
|
204
|
+
/** @copydoc nk_each_blend_f64 */
|
|
205
|
+
NK_DYNAMIC void nk_each_blend_i32(nk_i32_t const *a, nk_i32_t const *b, nk_size_t n, nk_f64_t const *alpha,
|
|
206
|
+
nk_f64_t const *beta, nk_i32_t *result);
|
|
207
|
+
/** @copydoc nk_each_blend_f64 */
|
|
208
|
+
NK_DYNAMIC void nk_each_blend_u32(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_f64_t const *alpha,
|
|
209
|
+
nk_f64_t const *beta, nk_u32_t *result);
|
|
210
|
+
/** @copydoc nk_each_blend_f64 */
|
|
211
|
+
NK_DYNAMIC void nk_each_blend_i64(nk_i64_t const *a, nk_i64_t const *b, nk_size_t n, nk_f64_t const *alpha,
|
|
212
|
+
nk_f64_t const *beta, nk_i64_t *result);
|
|
213
|
+
/** @copydoc nk_each_blend_f64 */
|
|
214
|
+
NK_DYNAMIC void nk_each_blend_u64(nk_u64_t const *a, nk_u64_t const *b, nk_size_t n, nk_f64_t const *alpha,
|
|
215
|
+
nk_f64_t const *beta, nk_u64_t *result);
|
|
216
|
+
|
|
217
|
+
/**
|
|
218
|
+
* @brief Fused multiply-add: result[i] = alpha * a[i] * b[i] + beta * c[i].
|
|
219
|
+
*
|
|
220
|
+
* @param[in] a The first input vector.
|
|
221
|
+
* @param[in] b The second input vector.
|
|
222
|
+
* @param[in] c The third input vector.
|
|
223
|
+
* @param[in] n The number of elements in the vectors.
|
|
224
|
+
* @param[in] alpha Pointer to the scaling factor for a[i] * b[i] (type depends on input precision).
|
|
225
|
+
* @param[in] beta Pointer to the scaling factor for c[i] (type depends on input precision).
|
|
226
|
+
* @param[out] result The output vector.
|
|
227
|
+
*/
|
|
228
|
+
NK_DYNAMIC void nk_each_fma_f64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
229
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_f64_t *result);
|
|
230
|
+
/** @copydoc nk_each_fma_f64 */
|
|
231
|
+
NK_DYNAMIC void nk_each_fma_f32(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
232
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_f32_t *result);
|
|
233
|
+
/** @copydoc nk_each_fma_f64 */
|
|
234
|
+
NK_DYNAMIC void nk_each_fma_f16(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
235
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_f16_t *result);
|
|
236
|
+
/** @copydoc nk_each_fma_f64 */
|
|
237
|
+
NK_DYNAMIC void nk_each_fma_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
238
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_bf16_t *result);
|
|
239
|
+
/** @copydoc nk_each_fma_f64 */
|
|
240
|
+
NK_DYNAMIC void nk_each_fma_i8(nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, nk_size_t n, nk_f32_t const *alpha,
|
|
241
|
+
nk_f32_t const *beta, nk_i8_t *result);
|
|
242
|
+
/** @copydoc nk_each_fma_f64 */
|
|
243
|
+
NK_DYNAMIC void nk_each_fma_u8(nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, nk_size_t n, nk_f32_t const *alpha,
|
|
244
|
+
nk_f32_t const *beta, nk_u8_t *result);
|
|
245
|
+
/** @copydoc nk_each_fma_f64 */
|
|
246
|
+
NK_DYNAMIC void nk_each_fma_i16(nk_i16_t const *a, nk_i16_t const *b, nk_i16_t const *c, nk_size_t n,
|
|
247
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_i16_t *result);
|
|
248
|
+
|
|
249
|
+
/** @copydoc nk_each_sum_f64 */
|
|
250
|
+
NK_DYNAMIC void nk_each_sum_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result);
|
|
251
|
+
/** @copydoc nk_each_sum_f64 */
|
|
252
|
+
NK_DYNAMIC void nk_each_sum_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_e5m2_t *result);
|
|
253
|
+
/** @copydoc nk_each_scale_f64 */
|
|
254
|
+
NK_DYNAMIC void nk_each_scale_e4m3(nk_e4m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
255
|
+
nk_e4m3_t *result);
|
|
256
|
+
/** @copydoc nk_each_scale_f64 */
|
|
257
|
+
NK_DYNAMIC void nk_each_scale_e5m2(nk_e5m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
258
|
+
nk_e5m2_t *result);
|
|
259
|
+
/** @copydoc nk_each_blend_f64 */
|
|
260
|
+
NK_DYNAMIC void nk_each_blend_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
261
|
+
nk_f32_t const *beta, nk_e4m3_t *result);
|
|
262
|
+
/** @copydoc nk_each_blend_f64 */
|
|
263
|
+
NK_DYNAMIC void nk_each_blend_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
264
|
+
nk_f32_t const *beta, nk_e5m2_t *result);
|
|
265
|
+
/** @copydoc nk_each_fma_f64 */
|
|
266
|
+
NK_DYNAMIC void nk_each_fma_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_e4m3_t const *c, nk_size_t n,
|
|
267
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e4m3_t *result);
|
|
268
|
+
/** @copydoc nk_each_fma_f64 */
|
|
269
|
+
NK_DYNAMIC void nk_each_fma_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_e5m2_t const *c, nk_size_t n,
|
|
270
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e5m2_t *result);
|
|
271
|
+
/** @copydoc nk_each_sum_f64 */
|
|
272
|
+
NK_DYNAMIC void nk_each_sum_e2m3(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_e2m3_t *result);
|
|
273
|
+
/** @copydoc nk_each_sum_f64 */
|
|
274
|
+
NK_DYNAMIC void nk_each_sum_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_e3m2_t *result);
|
|
275
|
+
/** @copydoc nk_each_scale_f64 */
|
|
276
|
+
NK_DYNAMIC void nk_each_scale_e2m3(nk_e2m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
277
|
+
nk_e2m3_t *result);
|
|
278
|
+
/** @copydoc nk_each_scale_f64 */
|
|
279
|
+
NK_DYNAMIC void nk_each_scale_e3m2(nk_e3m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
280
|
+
nk_e3m2_t *result);
|
|
281
|
+
/** @copydoc nk_each_blend_f64 */
|
|
282
|
+
NK_DYNAMIC void nk_each_blend_e2m3(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
283
|
+
nk_f32_t const *beta, nk_e2m3_t *result);
|
|
284
|
+
/** @copydoc nk_each_blend_f64 */
|
|
285
|
+
NK_DYNAMIC void nk_each_blend_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
286
|
+
nk_f32_t const *beta, nk_e3m2_t *result);
|
|
287
|
+
/** @copydoc nk_each_fma_f64 */
|
|
288
|
+
NK_DYNAMIC void nk_each_fma_e2m3(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_e2m3_t const *c, nk_size_t n,
|
|
289
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e2m3_t *result);
|
|
290
|
+
/** @copydoc nk_each_fma_f64 */
|
|
291
|
+
NK_DYNAMIC void nk_each_fma_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_e3m2_t const *c, nk_size_t n,
|
|
292
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e3m2_t *result);
|
|
293
|
+
/** @copydoc nk_each_fma_f64 */
|
|
294
|
+
NK_DYNAMIC void nk_each_fma_u16(nk_u16_t const *a, nk_u16_t const *b, nk_u16_t const *c, nk_size_t n,
|
|
295
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_u16_t *result);
|
|
296
|
+
/** @copydoc nk_each_fma_f64 */
|
|
297
|
+
NK_DYNAMIC void nk_each_fma_i32(nk_i32_t const *a, nk_i32_t const *b, nk_i32_t const *c, nk_size_t n,
|
|
298
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_i32_t *result);
|
|
299
|
+
/** @copydoc nk_each_fma_f64 */
|
|
300
|
+
NK_DYNAMIC void nk_each_fma_u32(nk_u32_t const *a, nk_u32_t const *b, nk_u32_t const *c, nk_size_t n,
|
|
301
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_u32_t *result);
|
|
302
|
+
/** @copydoc nk_each_fma_f64 */
|
|
303
|
+
NK_DYNAMIC void nk_each_fma_i64(nk_i64_t const *a, nk_i64_t const *b, nk_i64_t const *c, nk_size_t n,
|
|
304
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_i64_t *result);
|
|
305
|
+
/** @copydoc nk_each_fma_f64 */
|
|
306
|
+
NK_DYNAMIC void nk_each_fma_u64(nk_u64_t const *a, nk_u64_t const *b, nk_u64_t const *c, nk_size_t n,
|
|
307
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_u64_t *result);
|
|
308
|
+
|
|
309
|
+
/** @copydoc nk_each_sum_f64 */
|
|
310
|
+
NK_DYNAMIC void nk_each_sum_f32c(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t *result);
|
|
311
|
+
/** @copydoc nk_each_sum_f64 */
|
|
312
|
+
NK_DYNAMIC void nk_each_sum_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
313
|
+
/** @copydoc nk_each_scale_f64 */
|
|
314
|
+
NK_DYNAMIC void nk_each_scale_f32c(nk_f32c_t const *a, nk_size_t n, nk_f32c_t const *alpha, nk_f32c_t const *beta,
|
|
315
|
+
nk_f32c_t *result);
|
|
316
|
+
/** @copydoc nk_each_scale_f64 */
|
|
317
|
+
NK_DYNAMIC void nk_each_scale_f64c(nk_f64c_t const *a, nk_size_t n, nk_f64c_t const *alpha, nk_f64c_t const *beta,
|
|
318
|
+
nk_f64c_t *result);
|
|
319
|
+
/** @copydoc nk_each_blend_f64 */
|
|
320
|
+
NK_DYNAMIC void nk_each_blend_f32c(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t const *alpha,
|
|
321
|
+
nk_f32c_t const *beta, nk_f32c_t *result);
|
|
322
|
+
/** @copydoc nk_each_blend_f64 */
|
|
323
|
+
NK_DYNAMIC void nk_each_blend_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t const *alpha,
|
|
324
|
+
nk_f64c_t const *beta, nk_f64c_t *result);
|
|
325
|
+
/** @copydoc nk_each_fma_f64 */
|
|
326
|
+
NK_DYNAMIC void nk_each_fma_f32c(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
|
|
327
|
+
nk_f32c_t const *alpha, nk_f32c_t const *beta, nk_f32c_t *result);
|
|
328
|
+
/** @copydoc nk_each_fma_f64 */
|
|
329
|
+
NK_DYNAMIC void nk_each_fma_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
|
|
330
|
+
nk_f64c_t const *alpha, nk_f64c_t const *beta, nk_f64c_t *result);
|
|
331
|
+
|
|
332
|
+
/** @copydoc nk_each_scale_f64 */
|
|
333
|
+
NK_PUBLIC void nk_each_scale_f64_serial(nk_f64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
334
|
+
nk_f64_t *result);
|
|
335
|
+
/** @copydoc nk_each_scale_f64 */
|
|
336
|
+
NK_PUBLIC void nk_each_scale_f32_serial(nk_f32_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
337
|
+
nk_f32_t *result);
|
|
338
|
+
/** @copydoc nk_each_scale_f64 */
|
|
339
|
+
NK_PUBLIC void nk_each_scale_f16_serial(nk_f16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
340
|
+
nk_f16_t *result);
|
|
341
|
+
/** @copydoc nk_each_scale_f64 */
|
|
342
|
+
NK_PUBLIC void nk_each_scale_bf16_serial(nk_bf16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
343
|
+
nk_bf16_t *result);
|
|
344
|
+
/** @copydoc nk_each_scale_f64 */
|
|
345
|
+
NK_PUBLIC void nk_each_scale_i8_serial(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
346
|
+
nk_i8_t *result);
|
|
347
|
+
/** @copydoc nk_each_scale_f64 */
|
|
348
|
+
NK_PUBLIC void nk_each_scale_u8_serial(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
349
|
+
nk_u8_t *result);
|
|
350
|
+
/** @copydoc nk_each_scale_f64 */
|
|
351
|
+
NK_PUBLIC void nk_each_scale_i16_serial(nk_i16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
352
|
+
nk_i16_t *result);
|
|
353
|
+
/** @copydoc nk_each_scale_f64 */
|
|
354
|
+
NK_PUBLIC void nk_each_scale_u16_serial(nk_u16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
355
|
+
nk_u16_t *result);
|
|
356
|
+
/** @copydoc nk_each_scale_f64 */
|
|
357
|
+
NK_PUBLIC void nk_each_scale_i32_serial(nk_i32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
358
|
+
nk_i32_t *result);
|
|
359
|
+
/** @copydoc nk_each_scale_f64 */
|
|
360
|
+
NK_PUBLIC void nk_each_scale_u32_serial(nk_u32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
361
|
+
nk_u32_t *result);
|
|
362
|
+
/** @copydoc nk_each_scale_f64 */
|
|
363
|
+
NK_PUBLIC void nk_each_scale_i64_serial(nk_i64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
364
|
+
nk_i64_t *result);
|
|
365
|
+
/** @copydoc nk_each_scale_f64 */
|
|
366
|
+
NK_PUBLIC void nk_each_scale_u64_serial(nk_u64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
367
|
+
nk_u64_t *result);
|
|
368
|
+
|
|
369
|
+
/** @copydoc nk_each_sum_f64 */
|
|
370
|
+
NK_PUBLIC void nk_each_sum_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
371
|
+
/** @copydoc nk_each_sum_f64 */
|
|
372
|
+
NK_PUBLIC void nk_each_sum_f32_serial(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *result);
|
|
373
|
+
/** @copydoc nk_each_sum_f64 */
|
|
374
|
+
NK_PUBLIC void nk_each_sum_f16_serial(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f16_t *result);
|
|
375
|
+
/** @copydoc nk_each_sum_f64 */
|
|
376
|
+
NK_PUBLIC void nk_each_sum_bf16_serial(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_bf16_t *result);
|
|
377
|
+
/** @copydoc nk_each_sum_f64 */
|
|
378
|
+
NK_PUBLIC void nk_each_sum_i8_serial(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i8_t *result);
|
|
379
|
+
/** @copydoc nk_each_sum_f64 */
|
|
380
|
+
NK_PUBLIC void nk_each_sum_u8_serial(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u8_t *result);
|
|
381
|
+
/** @copydoc nk_each_sum_f64 */
|
|
382
|
+
NK_PUBLIC void nk_each_sum_i16_serial(nk_i16_t const *a, nk_i16_t const *b, nk_size_t n, nk_i16_t *result);
|
|
383
|
+
/** @copydoc nk_each_sum_f64 */
|
|
384
|
+
NK_PUBLIC void nk_each_sum_u16_serial(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_u16_t *result);
|
|
385
|
+
/** @copydoc nk_each_sum_f64 */
|
|
386
|
+
NK_PUBLIC void nk_each_sum_i32_serial(nk_i32_t const *a, nk_i32_t const *b, nk_size_t n, nk_i32_t *result);
|
|
387
|
+
/** @copydoc nk_each_sum_f64 */
|
|
388
|
+
NK_PUBLIC void nk_each_sum_u32_serial(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_u32_t *result);
|
|
389
|
+
/** @copydoc nk_each_sum_f64 */
|
|
390
|
+
NK_PUBLIC void nk_each_sum_i64_serial(nk_i64_t const *a, nk_i64_t const *b, nk_size_t n, nk_i64_t *result);
|
|
391
|
+
/** @copydoc nk_each_sum_f64 */
|
|
392
|
+
NK_PUBLIC void nk_each_sum_u64_serial(nk_u64_t const *a, nk_u64_t const *b, nk_size_t n, nk_u64_t *result);
|
|
393
|
+
|
|
394
|
+
/** @copydoc nk_each_blend_f64 */
|
|
395
|
+
NK_PUBLIC void nk_each_blend_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t const *alpha,
|
|
396
|
+
nk_f64_t const *beta, nk_f64_t *result);
|
|
397
|
+
/** @copydoc nk_each_blend_f64 */
|
|
398
|
+
NK_PUBLIC void nk_each_blend_f32_serial(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
399
|
+
nk_f32_t const *beta, nk_f32_t *result);
|
|
400
|
+
/** @copydoc nk_each_blend_f64 */
|
|
401
|
+
NK_PUBLIC void nk_each_blend_f16_serial(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
402
|
+
nk_f32_t const *beta, nk_f16_t *result);
|
|
403
|
+
/** @copydoc nk_each_blend_f64 */
|
|
404
|
+
NK_PUBLIC void nk_each_blend_bf16_serial(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
405
|
+
nk_f32_t const *beta, nk_bf16_t *result);
|
|
406
|
+
/** @copydoc nk_each_blend_f64 */
|
|
407
|
+
NK_PUBLIC void nk_each_blend_i8_serial(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
408
|
+
nk_f32_t const *beta, nk_i8_t *result);
|
|
409
|
+
/** @copydoc nk_each_blend_f64 */
|
|
410
|
+
NK_PUBLIC void nk_each_blend_u8_serial(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
411
|
+
nk_f32_t const *beta, nk_u8_t *result);
|
|
412
|
+
/** @copydoc nk_each_blend_f64 */
|
|
413
|
+
NK_PUBLIC void nk_each_blend_i16_serial(nk_i16_t const *a, nk_i16_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
414
|
+
nk_f32_t const *beta, nk_i16_t *result);
|
|
415
|
+
/** @copydoc nk_each_blend_f64 */
|
|
416
|
+
NK_PUBLIC void nk_each_blend_u16_serial(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
417
|
+
nk_f32_t const *beta, nk_u16_t *result);
|
|
418
|
+
/** @copydoc nk_each_blend_f64 */
|
|
419
|
+
NK_PUBLIC void nk_each_blend_i32_serial(nk_i32_t const *a, nk_i32_t const *b, nk_size_t n, nk_f64_t const *alpha,
|
|
420
|
+
nk_f64_t const *beta, nk_i32_t *result);
|
|
421
|
+
/** @copydoc nk_each_blend_f64 */
|
|
422
|
+
NK_PUBLIC void nk_each_blend_u32_serial(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_f64_t const *alpha,
|
|
423
|
+
nk_f64_t const *beta, nk_u32_t *result);
|
|
424
|
+
/** @copydoc nk_each_blend_f64 */
|
|
425
|
+
NK_PUBLIC void nk_each_blend_i64_serial(nk_i64_t const *a, nk_i64_t const *b, nk_size_t n, nk_f64_t const *alpha,
|
|
426
|
+
nk_f64_t const *beta, nk_i64_t *result);
|
|
427
|
+
/** @copydoc nk_each_blend_f64 */
|
|
428
|
+
NK_PUBLIC void nk_each_blend_u64_serial(nk_u64_t const *a, nk_u64_t const *b, nk_size_t n, nk_f64_t const *alpha,
|
|
429
|
+
nk_f64_t const *beta, nk_u64_t *result);
|
|
430
|
+
|
|
431
|
+
/** @copydoc nk_each_fma_f64 */
|
|
432
|
+
NK_PUBLIC void nk_each_fma_f64_serial(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
433
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_f64_t *result);
|
|
434
|
+
/** @copydoc nk_each_fma_f64 */
|
|
435
|
+
NK_PUBLIC void nk_each_fma_f32_serial(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
436
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_f32_t *result);
|
|
437
|
+
/** @copydoc nk_each_fma_f64 */
|
|
438
|
+
NK_PUBLIC void nk_each_fma_f16_serial(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
439
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_f16_t *result);
|
|
440
|
+
/** @copydoc nk_each_fma_f64 */
|
|
441
|
+
NK_PUBLIC void nk_each_fma_bf16_serial(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
442
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_bf16_t *result);
|
|
443
|
+
/** @copydoc nk_each_fma_f64 */
|
|
444
|
+
NK_PUBLIC void nk_each_fma_i8_serial(nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, nk_size_t n,
|
|
445
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result);
|
|
446
|
+
/** @copydoc nk_each_fma_f64 */
|
|
447
|
+
NK_PUBLIC void nk_each_fma_u8_serial(nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, nk_size_t n,
|
|
448
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result);
|
|
449
|
+
/** @copydoc nk_each_fma_f64 */
|
|
450
|
+
NK_PUBLIC void nk_each_fma_i16_serial(nk_i16_t const *a, nk_i16_t const *b, nk_i16_t const *c, nk_size_t n,
|
|
451
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_i16_t *result);
|
|
452
|
+
/** @copydoc nk_each_fma_f64 */
|
|
453
|
+
NK_PUBLIC void nk_each_fma_u16_serial(nk_u16_t const *a, nk_u16_t const *b, nk_u16_t const *c, nk_size_t n,
|
|
454
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_u16_t *result);
|
|
455
|
+
/** @copydoc nk_each_fma_f64 */
|
|
456
|
+
NK_PUBLIC void nk_each_fma_i32_serial(nk_i32_t const *a, nk_i32_t const *b, nk_i32_t const *c, nk_size_t n,
|
|
457
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_i32_t *result);
|
|
458
|
+
/** @copydoc nk_each_fma_f64 */
|
|
459
|
+
NK_PUBLIC void nk_each_fma_u32_serial(nk_u32_t const *a, nk_u32_t const *b, nk_u32_t const *c, nk_size_t n,
|
|
460
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_u32_t *result);
|
|
461
|
+
/** @copydoc nk_each_fma_f64 */
|
|
462
|
+
NK_PUBLIC void nk_each_fma_i64_serial(nk_i64_t const *a, nk_i64_t const *b, nk_i64_t const *c, nk_size_t n,
|
|
463
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_i64_t *result);
|
|
464
|
+
/** @copydoc nk_each_fma_f64 */
|
|
465
|
+
NK_PUBLIC void nk_each_fma_u64_serial(nk_u64_t const *a, nk_u64_t const *b, nk_u64_t const *c, nk_size_t n,
|
|
466
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_u64_t *result);
|
|
467
|
+
|
|
468
|
+
/** @copydoc nk_each_sum_e4m3 */
|
|
469
|
+
NK_PUBLIC void nk_each_sum_e4m3_serial(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result);
|
|
470
|
+
/** @copydoc nk_each_sum_e5m2 */
|
|
471
|
+
NK_PUBLIC void nk_each_sum_e5m2_serial(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_e5m2_t *result);
|
|
472
|
+
/** @copydoc nk_each_scale_e4m3 */
|
|
473
|
+
NK_PUBLIC void nk_each_scale_e4m3_serial(nk_e4m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
474
|
+
nk_e4m3_t *result);
|
|
475
|
+
/** @copydoc nk_each_scale_e5m2 */
|
|
476
|
+
NK_PUBLIC void nk_each_scale_e5m2_serial(nk_e5m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
477
|
+
nk_e5m2_t *result);
|
|
478
|
+
/** @copydoc nk_each_blend_e4m3 */
|
|
479
|
+
NK_PUBLIC void nk_each_blend_e4m3_serial(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
480
|
+
nk_f32_t const *beta, nk_e4m3_t *result);
|
|
481
|
+
/** @copydoc nk_each_blend_e5m2 */
|
|
482
|
+
NK_PUBLIC void nk_each_blend_e5m2_serial(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
483
|
+
nk_f32_t const *beta, nk_e5m2_t *result);
|
|
484
|
+
/** @copydoc nk_each_fma_e4m3 */
|
|
485
|
+
NK_PUBLIC void nk_each_fma_e4m3_serial(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_e4m3_t const *c, nk_size_t n,
|
|
486
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e4m3_t *result);
|
|
487
|
+
/** @copydoc nk_each_fma_e5m2 */
|
|
488
|
+
NK_PUBLIC void nk_each_fma_e5m2_serial(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_e5m2_t const *c, nk_size_t n,
|
|
489
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e5m2_t *result);
|
|
490
|
+
|
|
491
|
+
/** @copydoc nk_each_sum_e2m3 */
|
|
492
|
+
NK_PUBLIC void nk_each_sum_e2m3_serial(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_e2m3_t *result);
|
|
493
|
+
/** @copydoc nk_each_sum_e3m2 */
|
|
494
|
+
NK_PUBLIC void nk_each_sum_e3m2_serial(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_e3m2_t *result);
|
|
495
|
+
/** @copydoc nk_each_scale_e2m3 */
|
|
496
|
+
NK_PUBLIC void nk_each_scale_e2m3_serial(nk_e2m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
497
|
+
nk_e2m3_t *result);
|
|
498
|
+
/** @copydoc nk_each_scale_e3m2 */
|
|
499
|
+
NK_PUBLIC void nk_each_scale_e3m2_serial(nk_e3m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
500
|
+
nk_e3m2_t *result);
|
|
501
|
+
/** @copydoc nk_each_blend_e2m3 */
|
|
502
|
+
NK_PUBLIC void nk_each_blend_e2m3_serial(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
503
|
+
nk_f32_t const *beta, nk_e2m3_t *result);
|
|
504
|
+
/** @copydoc nk_each_blend_e3m2 */
|
|
505
|
+
NK_PUBLIC void nk_each_blend_e3m2_serial(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
506
|
+
nk_f32_t const *beta, nk_e3m2_t *result);
|
|
507
|
+
/** @copydoc nk_each_fma_e2m3 */
|
|
508
|
+
NK_PUBLIC void nk_each_fma_e2m3_serial(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_e2m3_t const *c, nk_size_t n,
|
|
509
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e2m3_t *result);
|
|
510
|
+
/** @copydoc nk_each_fma_e3m2 */
|
|
511
|
+
NK_PUBLIC void nk_each_fma_e3m2_serial(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_e3m2_t const *c, nk_size_t n,
|
|
512
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e3m2_t *result);
|
|
513
|
+
|
|
514
|
+
/** @copydoc nk_each_sum_f64 */
|
|
515
|
+
NK_PUBLIC void nk_each_sum_f32c_serial(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t *result);
|
|
516
|
+
/** @copydoc nk_each_sum_f64 */
|
|
517
|
+
NK_PUBLIC void nk_each_sum_f64c_serial(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *result);
|
|
518
|
+
/** @copydoc nk_each_scale_f64 */
|
|
519
|
+
NK_PUBLIC void nk_each_scale_f32c_serial(nk_f32c_t const *a, nk_size_t n, nk_f32c_t const *alpha, nk_f32c_t const *beta,
|
|
520
|
+
nk_f32c_t *result);
|
|
521
|
+
/** @copydoc nk_each_scale_f64 */
|
|
522
|
+
NK_PUBLIC void nk_each_scale_f64c_serial(nk_f64c_t const *a, nk_size_t n, nk_f64c_t const *alpha, nk_f64c_t const *beta,
|
|
523
|
+
nk_f64c_t *result);
|
|
524
|
+
/** @copydoc nk_each_blend_f64 */
|
|
525
|
+
NK_PUBLIC void nk_each_blend_f32c_serial(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t const *alpha,
|
|
526
|
+
nk_f32c_t const *beta, nk_f32c_t *result);
|
|
527
|
+
/** @copydoc nk_each_blend_f64 */
|
|
528
|
+
NK_PUBLIC void nk_each_blend_f64c_serial(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t const *alpha,
|
|
529
|
+
nk_f64c_t const *beta, nk_f64c_t *result);
|
|
530
|
+
/** @copydoc nk_each_fma_f64 */
|
|
531
|
+
NK_PUBLIC void nk_each_fma_f32c_serial(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
|
|
532
|
+
nk_f32c_t const *alpha, nk_f32c_t const *beta, nk_f32c_t *result);
|
|
533
|
+
/** @copydoc nk_each_fma_f64 */
|
|
534
|
+
NK_PUBLIC void nk_each_fma_f64c_serial(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
|
|
535
|
+
nk_f64c_t const *alpha, nk_f64c_t const *beta, nk_f64c_t *result);
|
|
536
|
+
|
|
537
|
+
#if NK_TARGET_NEON
|
|
538
|
+
/** @copydoc nk_each_scale_f32 */
|
|
539
|
+
NK_PUBLIC void nk_each_scale_f32_neon(nk_f32_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
540
|
+
nk_f32_t *result);
|
|
541
|
+
/** @copydoc nk_each_scale_i16 */
|
|
542
|
+
NK_PUBLIC void nk_each_scale_i16_neon(nk_i16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
543
|
+
nk_i16_t *result);
|
|
544
|
+
/** @copydoc nk_each_scale_u16 */
|
|
545
|
+
NK_PUBLIC void nk_each_scale_u16_neon(nk_u16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
546
|
+
nk_u16_t *result);
|
|
547
|
+
/** @copydoc nk_each_scale_i32 */
|
|
548
|
+
NK_PUBLIC void nk_each_scale_i32_neon(nk_i32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
549
|
+
nk_i32_t *result);
|
|
550
|
+
/** @copydoc nk_each_scale_u32 */
|
|
551
|
+
NK_PUBLIC void nk_each_scale_u32_neon(nk_u32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
552
|
+
nk_u32_t *result);
|
|
553
|
+
/** @copydoc nk_each_scale_i64 */
|
|
554
|
+
NK_PUBLIC void nk_each_scale_i64_neon(nk_i64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
555
|
+
nk_i64_t *result);
|
|
556
|
+
/** @copydoc nk_each_scale_u64 */
|
|
557
|
+
NK_PUBLIC void nk_each_scale_u64_neon(nk_u64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
558
|
+
nk_u64_t *result);
|
|
559
|
+
|
|
560
|
+
/** @copydoc nk_each_sum_f32 */
|
|
561
|
+
NK_PUBLIC void nk_each_sum_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *result);
|
|
562
|
+
/** @copydoc nk_each_sum_i16 */
|
|
563
|
+
NK_PUBLIC void nk_each_sum_i16_neon(nk_i16_t const *a, nk_i16_t const *b, nk_size_t n, nk_i16_t *result);
|
|
564
|
+
/** @copydoc nk_each_sum_u16 */
|
|
565
|
+
NK_PUBLIC void nk_each_sum_u16_neon(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_u16_t *result);
|
|
566
|
+
/** @copydoc nk_each_sum_i32 */
|
|
567
|
+
NK_PUBLIC void nk_each_sum_i32_neon(nk_i32_t const *a, nk_i32_t const *b, nk_size_t n, nk_i32_t *result);
|
|
568
|
+
/** @copydoc nk_each_sum_u32 */
|
|
569
|
+
NK_PUBLIC void nk_each_sum_u32_neon(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_u32_t *result);
|
|
570
|
+
/** @copydoc nk_each_sum_i64 */
|
|
571
|
+
NK_PUBLIC void nk_each_sum_i64_neon(nk_i64_t const *a, nk_i64_t const *b, nk_size_t n, nk_i64_t *result);
|
|
572
|
+
/** @copydoc nk_each_sum_u64 */
|
|
573
|
+
NK_PUBLIC void nk_each_sum_u64_neon(nk_u64_t const *a, nk_u64_t const *b, nk_size_t n, nk_u64_t *result);
|
|
574
|
+
|
|
575
|
+
/** @copydoc nk_each_blend_f32 */
|
|
576
|
+
NK_PUBLIC void nk_each_blend_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
577
|
+
nk_f32_t const *beta, nk_f32_t *result);
|
|
578
|
+
|
|
579
|
+
/** @copydoc nk_each_fma_f32 */
|
|
580
|
+
NK_PUBLIC void nk_each_fma_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
581
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_f32_t *result);
|
|
582
|
+
/** @copydoc nk_each_fma_i16 */
|
|
583
|
+
NK_PUBLIC void nk_each_fma_i16_neon(nk_i16_t const *a, nk_i16_t const *b, nk_i16_t const *c, nk_size_t n,
|
|
584
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_i16_t *result);
|
|
585
|
+
/** @copydoc nk_each_fma_u16 */
|
|
586
|
+
NK_PUBLIC void nk_each_fma_u16_neon(nk_u16_t const *a, nk_u16_t const *b, nk_u16_t const *c, nk_size_t n,
|
|
587
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_u16_t *result);
|
|
588
|
+
/** @copydoc nk_each_fma_i32 */
|
|
589
|
+
NK_PUBLIC void nk_each_fma_i32_neon(nk_i32_t const *a, nk_i32_t const *b, nk_i32_t const *c, nk_size_t n,
|
|
590
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_i32_t *result);
|
|
591
|
+
/** @copydoc nk_each_fma_u32 */
|
|
592
|
+
NK_PUBLIC void nk_each_fma_u32_neon(nk_u32_t const *a, nk_u32_t const *b, nk_u32_t const *c, nk_size_t n,
|
|
593
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_u32_t *result);
|
|
594
|
+
/** @copydoc nk_each_fma_i64 */
|
|
595
|
+
NK_PUBLIC void nk_each_fma_i64_neon(nk_i64_t const *a, nk_i64_t const *b, nk_i64_t const *c, nk_size_t n,
|
|
596
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_i64_t *result);
|
|
597
|
+
/** @copydoc nk_each_fma_u64 */
|
|
598
|
+
NK_PUBLIC void nk_each_fma_u64_neon(nk_u64_t const *a, nk_u64_t const *b, nk_u64_t const *c, nk_size_t n,
|
|
599
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_u64_t *result);
|
|
600
|
+
|
|
601
|
+
/** @copydoc nk_each_sum_f64 */
|
|
602
|
+
NK_PUBLIC void nk_each_sum_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
603
|
+
/** @copydoc nk_each_scale_f64 */
|
|
604
|
+
NK_PUBLIC void nk_each_scale_f64_neon(nk_f64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
605
|
+
nk_f64_t *result);
|
|
606
|
+
/** @copydoc nk_each_blend_f64 */
|
|
607
|
+
NK_PUBLIC void nk_each_blend_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t const *alpha,
|
|
608
|
+
nk_f64_t const *beta, nk_f64_t *result);
|
|
609
|
+
/** @copydoc nk_each_fma_f64 */
|
|
610
|
+
NK_PUBLIC void nk_each_fma_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
611
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_f64_t *result);
|
|
612
|
+
|
|
613
|
+
/** @copydoc nk_each_sum_e4m3 */
|
|
614
|
+
NK_PUBLIC void nk_each_sum_e4m3_neon(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result);
|
|
615
|
+
/** @copydoc nk_each_sum_e5m2 */
|
|
616
|
+
NK_PUBLIC void nk_each_sum_e5m2_neon(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_e5m2_t *result);
|
|
617
|
+
/** @copydoc nk_each_scale_e4m3 */
|
|
618
|
+
NK_PUBLIC void nk_each_scale_e4m3_neon(nk_e4m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
619
|
+
nk_e4m3_t *result);
|
|
620
|
+
/** @copydoc nk_each_scale_e5m2 */
|
|
621
|
+
NK_PUBLIC void nk_each_scale_e5m2_neon(nk_e5m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
622
|
+
nk_e5m2_t *result);
|
|
623
|
+
/** @copydoc nk_each_blend_e4m3 */
|
|
624
|
+
NK_PUBLIC void nk_each_blend_e4m3_neon(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
625
|
+
nk_f32_t const *beta, nk_e4m3_t *result);
|
|
626
|
+
/** @copydoc nk_each_blend_e5m2 */
|
|
627
|
+
NK_PUBLIC void nk_each_blend_e5m2_neon(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
628
|
+
nk_f32_t const *beta, nk_e5m2_t *result);
|
|
629
|
+
/** @copydoc nk_each_fma_e4m3 */
|
|
630
|
+
NK_PUBLIC void nk_each_fma_e4m3_neon(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_e4m3_t const *c, nk_size_t n,
|
|
631
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e4m3_t *result);
|
|
632
|
+
/** @copydoc nk_each_fma_e5m2 */
|
|
633
|
+
NK_PUBLIC void nk_each_fma_e5m2_neon(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_e5m2_t const *c, nk_size_t n,
|
|
634
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e5m2_t *result);
|
|
635
|
+
|
|
636
|
+
/** @copydoc nk_each_scale_f64 */
|
|
637
|
+
NK_PUBLIC void nk_each_scale_f32c_neon(nk_f32c_t const *a, nk_size_t n, nk_f32c_t const *alpha, nk_f32c_t const *beta,
|
|
638
|
+
nk_f32c_t *result);
|
|
639
|
+
/** @copydoc nk_each_scale_f64 */
|
|
640
|
+
NK_PUBLIC void nk_each_scale_f64c_neon(nk_f64c_t const *a, nk_size_t n, nk_f64c_t const *alpha, nk_f64c_t const *beta,
|
|
641
|
+
nk_f64c_t *result);
|
|
642
|
+
/** @copydoc nk_each_blend_f64 */
|
|
643
|
+
NK_PUBLIC void nk_each_blend_f32c_neon(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t const *alpha,
|
|
644
|
+
nk_f32c_t const *beta, nk_f32c_t *result);
|
|
645
|
+
/** @copydoc nk_each_blend_f64 */
|
|
646
|
+
NK_PUBLIC void nk_each_blend_f64c_neon(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t const *alpha,
|
|
647
|
+
nk_f64c_t const *beta, nk_f64c_t *result);
|
|
648
|
+
/** @copydoc nk_each_fma_f64 */
|
|
649
|
+
NK_PUBLIC void nk_each_fma_f32c_neon(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
|
|
650
|
+
nk_f32c_t const *alpha, nk_f32c_t const *beta, nk_f32c_t *result);
|
|
651
|
+
/** @copydoc nk_each_fma_f64 */
|
|
652
|
+
NK_PUBLIC void nk_each_fma_f64c_neon(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
|
|
653
|
+
nk_f64c_t const *alpha, nk_f64c_t const *beta, nk_f64c_t *result);
|
|
654
|
+
#endif // NK_TARGET_NEON
|
|
655
|
+
|
|
656
|
+
#if NK_TARGET_NEONBFDOT
|
|
657
|
+
/** @copydoc nk_each_sum_bf16 */
|
|
658
|
+
NK_PUBLIC void nk_each_sum_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_bf16_t *result);
|
|
659
|
+
/** @copydoc nk_each_scale_bf16 */
|
|
660
|
+
NK_PUBLIC void nk_each_scale_bf16_neonbfdot(nk_bf16_t const *a, nk_size_t n, nk_f32_t const *alpha,
|
|
661
|
+
nk_f32_t const *beta, nk_bf16_t *result);
|
|
662
|
+
/** @copydoc nk_each_blend_bf16 */
|
|
663
|
+
NK_PUBLIC void nk_each_blend_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
664
|
+
nk_f32_t const *beta, nk_bf16_t *result);
|
|
665
|
+
/** @copydoc nk_each_fma_bf16 */
|
|
666
|
+
NK_PUBLIC void nk_each_fma_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
667
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_bf16_t *result);
|
|
668
|
+
#endif // NK_TARGET_NEONBFDOT
|
|
669
|
+
|
|
670
|
+
#if NK_TARGET_NEONHALF
|
|
671
|
+
/** @copydoc nk_each_sum_f16 */
|
|
672
|
+
NK_PUBLIC void nk_each_sum_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f16_t *result);
|
|
673
|
+
/** @copydoc nk_each_scale_f16 */
|
|
674
|
+
NK_PUBLIC void nk_each_scale_f16_neonhalf(nk_f16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
675
|
+
nk_f16_t *result);
|
|
676
|
+
/** @copydoc nk_each_blend_f16 */
|
|
677
|
+
NK_PUBLIC void nk_each_blend_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
678
|
+
nk_f32_t const *beta, nk_f16_t *result);
|
|
679
|
+
/** @copydoc nk_each_fma_f16 */
|
|
680
|
+
NK_PUBLIC void nk_each_fma_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
681
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_f16_t *result);
|
|
682
|
+
|
|
683
|
+
/** @copydoc nk_each_sum_i8 */
|
|
684
|
+
NK_PUBLIC void nk_each_sum_i8_neonhalf(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i8_t *result);
|
|
685
|
+
/** @copydoc nk_each_sum_u8 */
|
|
686
|
+
NK_PUBLIC void nk_each_sum_u8_neonhalf(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u8_t *result);
|
|
687
|
+
/** @copydoc nk_each_scale_i8 */
|
|
688
|
+
NK_PUBLIC void nk_each_scale_i8_neonhalf(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
689
|
+
nk_i8_t *result);
|
|
690
|
+
/** @copydoc nk_each_scale_u8 */
|
|
691
|
+
NK_PUBLIC void nk_each_scale_u8_neonhalf(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
692
|
+
nk_u8_t *result);
|
|
693
|
+
/** @copydoc nk_each_blend_i8 */
|
|
694
|
+
NK_PUBLIC void nk_each_blend_i8_neonhalf(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
695
|
+
nk_f32_t const *beta, nk_i8_t *result);
|
|
696
|
+
/** @copydoc nk_each_blend_u8 */
|
|
697
|
+
NK_PUBLIC void nk_each_blend_u8_neonhalf(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
698
|
+
nk_f32_t const *beta, nk_u8_t *result);
|
|
699
|
+
/** @copydoc nk_each_fma_i8 */
|
|
700
|
+
NK_PUBLIC void nk_each_fma_i8_neonhalf(nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, nk_size_t n,
|
|
701
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result);
|
|
702
|
+
/** @copydoc nk_each_fma_u8 */
|
|
703
|
+
NK_PUBLIC void nk_each_fma_u8_neonhalf(nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, nk_size_t n,
|
|
704
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result);
|
|
705
|
+
#endif // NK_TARGET_NEONHALF
|
|
706
|
+
|
|
707
|
+
#if NK_TARGET_HASWELL
|
|
708
|
+
/** @copydoc nk_each_scale_f64 */
|
|
709
|
+
NK_PUBLIC void nk_each_scale_f64_haswell(nk_f64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
710
|
+
nk_f64_t *result);
|
|
711
|
+
/** @copydoc nk_each_scale_f32 */
|
|
712
|
+
NK_PUBLIC void nk_each_scale_f32_haswell(nk_f32_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
713
|
+
nk_f32_t *result);
|
|
714
|
+
/** @copydoc nk_each_scale_f16 */
|
|
715
|
+
NK_PUBLIC void nk_each_scale_f16_haswell(nk_f16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
716
|
+
nk_f16_t *result);
|
|
717
|
+
/** @copydoc nk_each_scale_bf16 */
|
|
718
|
+
NK_PUBLIC void nk_each_scale_bf16_haswell(nk_bf16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
719
|
+
nk_bf16_t *result);
|
|
720
|
+
/** @copydoc nk_each_scale_i8 */
|
|
721
|
+
NK_PUBLIC void nk_each_scale_i8_haswell(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
722
|
+
nk_i8_t *result);
|
|
723
|
+
/** @copydoc nk_each_scale_u8 */
|
|
724
|
+
NK_PUBLIC void nk_each_scale_u8_haswell(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
725
|
+
nk_u8_t *result);
|
|
726
|
+
/** @copydoc nk_each_scale_i16 */
|
|
727
|
+
NK_PUBLIC void nk_each_scale_i16_haswell(nk_i16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
728
|
+
nk_i16_t *result);
|
|
729
|
+
/** @copydoc nk_each_scale_u16 */
|
|
730
|
+
NK_PUBLIC void nk_each_scale_u16_haswell(nk_u16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
731
|
+
nk_u16_t *result);
|
|
732
|
+
/** @copydoc nk_each_scale_i32 */
|
|
733
|
+
NK_PUBLIC void nk_each_scale_i32_haswell(nk_i32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
734
|
+
nk_i32_t *result);
|
|
735
|
+
/** @copydoc nk_each_scale_u32 */
|
|
736
|
+
NK_PUBLIC void nk_each_scale_u32_haswell(nk_u32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
737
|
+
nk_u32_t *result);
|
|
738
|
+
|
|
739
|
+
/** @copydoc nk_each_sum_f64 */
|
|
740
|
+
NK_PUBLIC void nk_each_sum_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
741
|
+
/** @copydoc nk_each_sum_f32 */
|
|
742
|
+
NK_PUBLIC void nk_each_sum_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *result);
|
|
743
|
+
/** @copydoc nk_each_sum_f16 */
|
|
744
|
+
NK_PUBLIC void nk_each_sum_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f16_t *result);
|
|
745
|
+
/** @copydoc nk_each_sum_bf16 */
|
|
746
|
+
NK_PUBLIC void nk_each_sum_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_bf16_t *result);
|
|
747
|
+
/** @copydoc nk_each_sum_i8 */
|
|
748
|
+
NK_PUBLIC void nk_each_sum_i8_haswell(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i8_t *result);
|
|
749
|
+
/** @copydoc nk_each_sum_u8 */
|
|
750
|
+
NK_PUBLIC void nk_each_sum_u8_haswell(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u8_t *result);
|
|
751
|
+
/** @copydoc nk_each_sum_i16 */
|
|
752
|
+
NK_PUBLIC void nk_each_sum_i16_haswell(nk_i16_t const *a, nk_i16_t const *b, nk_size_t n, nk_i16_t *result);
|
|
753
|
+
/** @copydoc nk_each_sum_u16 */
|
|
754
|
+
NK_PUBLIC void nk_each_sum_u16_haswell(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_u16_t *result);
|
|
755
|
+
/** @copydoc nk_each_sum_i32 */
|
|
756
|
+
NK_PUBLIC void nk_each_sum_i32_haswell(nk_i32_t const *a, nk_i32_t const *b, nk_size_t n, nk_i32_t *result);
|
|
757
|
+
/** @copydoc nk_each_sum_u32 */
|
|
758
|
+
NK_PUBLIC void nk_each_sum_u32_haswell(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_u32_t *result);
|
|
759
|
+
|
|
760
|
+
/** @copydoc nk_each_blend_f64 */
|
|
761
|
+
NK_PUBLIC void nk_each_blend_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t const *alpha,
|
|
762
|
+
nk_f64_t const *beta, nk_f64_t *result);
|
|
763
|
+
/** @copydoc nk_each_blend_f32 */
|
|
764
|
+
NK_PUBLIC void nk_each_blend_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
765
|
+
nk_f32_t const *beta, nk_f32_t *result);
|
|
766
|
+
/** @copydoc nk_each_blend_f16 */
|
|
767
|
+
NK_PUBLIC void nk_each_blend_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
768
|
+
nk_f32_t const *beta, nk_f16_t *result);
|
|
769
|
+
/** @copydoc nk_each_blend_bf16 */
|
|
770
|
+
NK_PUBLIC void nk_each_blend_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
771
|
+
nk_f32_t const *beta, nk_bf16_t *result);
|
|
772
|
+
/** @copydoc nk_each_blend_i8 */
|
|
773
|
+
NK_PUBLIC void nk_each_blend_i8_haswell(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
774
|
+
nk_f32_t const *beta, nk_i8_t *result);
|
|
775
|
+
/** @copydoc nk_each_blend_u8 */
|
|
776
|
+
NK_PUBLIC void nk_each_blend_u8_haswell(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
777
|
+
nk_f32_t const *beta, nk_u8_t *result);
|
|
778
|
+
|
|
779
|
+
/** @copydoc nk_each_fma_f64 */
|
|
780
|
+
NK_PUBLIC void nk_each_fma_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
781
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_f64_t *result);
|
|
782
|
+
/** @copydoc nk_each_fma_f32 */
|
|
783
|
+
NK_PUBLIC void nk_each_fma_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
784
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_f32_t *result);
|
|
785
|
+
/** @copydoc nk_each_fma_f16 */
|
|
786
|
+
NK_PUBLIC void nk_each_fma_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
787
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_f16_t *result);
|
|
788
|
+
/** @copydoc nk_each_fma_bf16 */
|
|
789
|
+
NK_PUBLIC void nk_each_fma_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
790
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_bf16_t *result);
|
|
791
|
+
/** @copydoc nk_each_fma_i8 */
|
|
792
|
+
NK_PUBLIC void nk_each_fma_i8_haswell(nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, nk_size_t n,
|
|
793
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result);
|
|
794
|
+
/** @copydoc nk_each_fma_u8 */
|
|
795
|
+
NK_PUBLIC void nk_each_fma_u8_haswell(nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, nk_size_t n,
|
|
796
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result);
|
|
797
|
+
/** @copydoc nk_each_fma_i16 */
|
|
798
|
+
NK_PUBLIC void nk_each_fma_i16_haswell(nk_i16_t const *a, nk_i16_t const *b, nk_i16_t const *c, nk_size_t n,
|
|
799
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_i16_t *result);
|
|
800
|
+
/** @copydoc nk_each_fma_u16 */
|
|
801
|
+
NK_PUBLIC void nk_each_fma_u16_haswell(nk_u16_t const *a, nk_u16_t const *b, nk_u16_t const *c, nk_size_t n,
|
|
802
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_u16_t *result);
|
|
803
|
+
/** @copydoc nk_each_fma_i32 */
|
|
804
|
+
NK_PUBLIC void nk_each_fma_i32_haswell(nk_i32_t const *a, nk_i32_t const *b, nk_i32_t const *c, nk_size_t n,
|
|
805
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_i32_t *result);
|
|
806
|
+
/** @copydoc nk_each_fma_u32 */
|
|
807
|
+
NK_PUBLIC void nk_each_fma_u32_haswell(nk_u32_t const *a, nk_u32_t const *b, nk_u32_t const *c, nk_size_t n,
|
|
808
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_u32_t *result);
|
|
809
|
+
|
|
810
|
+
/** @copydoc nk_each_sum_e4m3 */
|
|
811
|
+
NK_PUBLIC void nk_each_sum_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result);
|
|
812
|
+
/** @copydoc nk_each_sum_e5m2 */
|
|
813
|
+
NK_PUBLIC void nk_each_sum_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_e5m2_t *result);
|
|
814
|
+
/** @copydoc nk_each_scale_e4m3 */
|
|
815
|
+
NK_PUBLIC void nk_each_scale_e4m3_haswell(nk_e4m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
816
|
+
nk_e4m3_t *result);
|
|
817
|
+
/** @copydoc nk_each_scale_e5m2 */
|
|
818
|
+
NK_PUBLIC void nk_each_scale_e5m2_haswell(nk_e5m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
819
|
+
nk_e5m2_t *result);
|
|
820
|
+
/** @copydoc nk_each_blend_e4m3 */
|
|
821
|
+
NK_PUBLIC void nk_each_blend_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
822
|
+
nk_f32_t const *beta, nk_e4m3_t *result);
|
|
823
|
+
/** @copydoc nk_each_blend_e5m2 */
|
|
824
|
+
NK_PUBLIC void nk_each_blend_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
825
|
+
nk_f32_t const *beta, nk_e5m2_t *result);
|
|
826
|
+
/** @copydoc nk_each_fma_e4m3 */
|
|
827
|
+
NK_PUBLIC void nk_each_fma_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_e4m3_t const *c, nk_size_t n,
|
|
828
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e4m3_t *result);
|
|
829
|
+
/** @copydoc nk_each_fma_e5m2 */
|
|
830
|
+
NK_PUBLIC void nk_each_fma_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_e5m2_t const *c, nk_size_t n,
|
|
831
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e5m2_t *result);
|
|
832
|
+
|
|
833
|
+
/** @copydoc nk_each_scale_f64 */
|
|
834
|
+
NK_PUBLIC void nk_each_scale_f32c_haswell(nk_f32c_t const *a, nk_size_t n, nk_f32c_t const *alpha,
|
|
835
|
+
nk_f32c_t const *beta, nk_f32c_t *result);
|
|
836
|
+
/** @copydoc nk_each_scale_f64 */
|
|
837
|
+
NK_PUBLIC void nk_each_scale_f64c_haswell(nk_f64c_t const *a, nk_size_t n, nk_f64c_t const *alpha,
|
|
838
|
+
nk_f64c_t const *beta, nk_f64c_t *result);
|
|
839
|
+
/** @copydoc nk_each_blend_f64 */
|
|
840
|
+
NK_PUBLIC void nk_each_blend_f32c_haswell(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t const *alpha,
|
|
841
|
+
nk_f32c_t const *beta, nk_f32c_t *result);
|
|
842
|
+
/** @copydoc nk_each_blend_f64 */
|
|
843
|
+
NK_PUBLIC void nk_each_blend_f64c_haswell(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t const *alpha,
|
|
844
|
+
nk_f64c_t const *beta, nk_f64c_t *result);
|
|
845
|
+
/** @copydoc nk_each_fma_f64 */
|
|
846
|
+
NK_PUBLIC void nk_each_fma_f32c_haswell(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
|
|
847
|
+
nk_f32c_t const *alpha, nk_f32c_t const *beta, nk_f32c_t *result);
|
|
848
|
+
/** @copydoc nk_each_fma_f64 */
|
|
849
|
+
NK_PUBLIC void nk_each_fma_f64c_haswell(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
|
|
850
|
+
nk_f64c_t const *alpha, nk_f64c_t const *beta, nk_f64c_t *result);
|
|
851
|
+
#endif // NK_TARGET_HASWELL
|
|
852
|
+
|
|
853
|
+
#if NK_TARGET_SKYLAKE
|
|
854
|
+
/** @copydoc nk_each_scale_f64 */
|
|
855
|
+
NK_PUBLIC void nk_each_scale_f64_skylake(nk_f64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
856
|
+
nk_f64_t *result);
|
|
857
|
+
/** @copydoc nk_each_scale_f32 */
|
|
858
|
+
NK_PUBLIC void nk_each_scale_f32_skylake(nk_f32_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
859
|
+
nk_f32_t *result);
|
|
860
|
+
/** @copydoc nk_each_scale_f16 */
|
|
861
|
+
NK_PUBLIC void nk_each_scale_f16_skylake(nk_f16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
862
|
+
nk_f16_t *result);
|
|
863
|
+
/** @copydoc nk_each_scale_bf16 */
|
|
864
|
+
NK_PUBLIC void nk_each_scale_bf16_skylake(nk_bf16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
865
|
+
nk_bf16_t *result);
|
|
866
|
+
/** @copydoc nk_each_scale_i8 */
|
|
867
|
+
NK_PUBLIC void nk_each_scale_i8_skylake(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
868
|
+
nk_i8_t *result);
|
|
869
|
+
/** @copydoc nk_each_scale_u8 */
|
|
870
|
+
NK_PUBLIC void nk_each_scale_u8_skylake(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
871
|
+
nk_u8_t *result);
|
|
872
|
+
/** @copydoc nk_each_scale_i16 */
|
|
873
|
+
NK_PUBLIC void nk_each_scale_i16_skylake(nk_i16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
874
|
+
nk_i16_t *result);
|
|
875
|
+
/** @copydoc nk_each_scale_u16 */
|
|
876
|
+
NK_PUBLIC void nk_each_scale_u16_skylake(nk_u16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
877
|
+
nk_u16_t *result);
|
|
878
|
+
/** @copydoc nk_each_scale_i32 */
|
|
879
|
+
NK_PUBLIC void nk_each_scale_i32_skylake(nk_i32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
880
|
+
nk_i32_t *result);
|
|
881
|
+
/** @copydoc nk_each_scale_u32 */
|
|
882
|
+
NK_PUBLIC void nk_each_scale_u32_skylake(nk_u32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
883
|
+
nk_u32_t *result);
|
|
884
|
+
/** @copydoc nk_each_scale_i64 */
|
|
885
|
+
NK_PUBLIC void nk_each_scale_i64_skylake(nk_i64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
886
|
+
nk_i64_t *result);
|
|
887
|
+
/** @copydoc nk_each_scale_u64 */
|
|
888
|
+
NK_PUBLIC void nk_each_scale_u64_skylake(nk_u64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
889
|
+
nk_u64_t *result);
|
|
890
|
+
|
|
891
|
+
/** @copydoc nk_each_sum_f64 */
|
|
892
|
+
NK_PUBLIC void nk_each_sum_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
893
|
+
/** @copydoc nk_each_sum_f32 */
|
|
894
|
+
NK_PUBLIC void nk_each_sum_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *result);
|
|
895
|
+
/** @copydoc nk_each_sum_bf16 */
|
|
896
|
+
NK_PUBLIC void nk_each_sum_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_bf16_t *result);
|
|
897
|
+
|
|
898
|
+
/** @copydoc nk_each_blend_f64 */
|
|
899
|
+
NK_PUBLIC void nk_each_blend_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t const *alpha,
|
|
900
|
+
nk_f64_t const *beta, nk_f64_t *result);
|
|
901
|
+
/** @copydoc nk_each_blend_f32 */
|
|
902
|
+
NK_PUBLIC void nk_each_blend_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
903
|
+
nk_f32_t const *beta, nk_f32_t *result);
|
|
904
|
+
/** @copydoc nk_each_blend_f16 */
|
|
905
|
+
NK_PUBLIC void nk_each_blend_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
906
|
+
nk_f32_t const *beta, nk_f16_t *result);
|
|
907
|
+
/** @copydoc nk_each_blend_bf16 */
|
|
908
|
+
NK_PUBLIC void nk_each_blend_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
909
|
+
nk_f32_t const *beta, nk_bf16_t *result);
|
|
910
|
+
|
|
911
|
+
/** @copydoc nk_each_fma_f64 */
|
|
912
|
+
NK_PUBLIC void nk_each_fma_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
913
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_f64_t *result);
|
|
914
|
+
/** @copydoc nk_each_fma_f32 */
|
|
915
|
+
NK_PUBLIC void nk_each_fma_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
916
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_f32_t *result);
|
|
917
|
+
/** @copydoc nk_each_fma_f16 */
|
|
918
|
+
NK_PUBLIC void nk_each_fma_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
919
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_f16_t *result);
|
|
920
|
+
/** @copydoc nk_each_fma_bf16 */
|
|
921
|
+
NK_PUBLIC void nk_each_fma_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
922
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_bf16_t *result);
|
|
923
|
+
/** @copydoc nk_each_fma_i8 */
|
|
924
|
+
NK_PUBLIC void nk_each_fma_i8_skylake(nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, nk_size_t n,
|
|
925
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result);
|
|
926
|
+
/** @copydoc nk_each_fma_u8 */
|
|
927
|
+
NK_PUBLIC void nk_each_fma_u8_skylake(nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, nk_size_t n,
|
|
928
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result);
|
|
929
|
+
/** @copydoc nk_each_fma_i16 */
|
|
930
|
+
NK_PUBLIC void nk_each_fma_i16_skylake(nk_i16_t const *a, nk_i16_t const *b, nk_i16_t const *c, nk_size_t n,
|
|
931
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_i16_t *result);
|
|
932
|
+
/** @copydoc nk_each_fma_u16 */
|
|
933
|
+
NK_PUBLIC void nk_each_fma_u16_skylake(nk_u16_t const *a, nk_u16_t const *b, nk_u16_t const *c, nk_size_t n,
|
|
934
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_u16_t *result);
|
|
935
|
+
/** @copydoc nk_each_fma_i32 */
|
|
936
|
+
NK_PUBLIC void nk_each_fma_i32_skylake(nk_i32_t const *a, nk_i32_t const *b, nk_i32_t const *c, nk_size_t n,
|
|
937
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_i32_t *result);
|
|
938
|
+
/** @copydoc nk_each_fma_u32 */
|
|
939
|
+
NK_PUBLIC void nk_each_fma_u32_skylake(nk_u32_t const *a, nk_u32_t const *b, nk_u32_t const *c, nk_size_t n,
|
|
940
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_u32_t *result);
|
|
941
|
+
/** @copydoc nk_each_fma_i64 */
|
|
942
|
+
NK_PUBLIC void nk_each_fma_i64_skylake(nk_i64_t const *a, nk_i64_t const *b, nk_i64_t const *c, nk_size_t n,
|
|
943
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_i64_t *result);
|
|
944
|
+
/** @copydoc nk_each_fma_u64 */
|
|
945
|
+
NK_PUBLIC void nk_each_fma_u64_skylake(nk_u64_t const *a, nk_u64_t const *b, nk_u64_t const *c, nk_size_t n,
|
|
946
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_u64_t *result);
|
|
947
|
+
/** @copydoc nk_each_sum_e4m3 */
|
|
948
|
+
NK_PUBLIC void nk_each_sum_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result);
|
|
949
|
+
/** @copydoc nk_each_sum_e5m2 */
|
|
950
|
+
NK_PUBLIC void nk_each_sum_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_e5m2_t *result);
|
|
951
|
+
/** @copydoc nk_each_scale_e4m3 */
|
|
952
|
+
NK_PUBLIC void nk_each_scale_e4m3_skylake(nk_e4m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
953
|
+
nk_e4m3_t *result);
|
|
954
|
+
/** @copydoc nk_each_scale_e5m2 */
|
|
955
|
+
NK_PUBLIC void nk_each_scale_e5m2_skylake(nk_e5m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
956
|
+
nk_e5m2_t *result);
|
|
957
|
+
/** @copydoc nk_each_blend_e4m3 */
|
|
958
|
+
NK_PUBLIC void nk_each_blend_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
959
|
+
nk_f32_t const *beta, nk_e4m3_t *result);
|
|
960
|
+
/** @copydoc nk_each_blend_e5m2 */
|
|
961
|
+
NK_PUBLIC void nk_each_blend_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
962
|
+
nk_f32_t const *beta, nk_e5m2_t *result);
|
|
963
|
+
/** @copydoc nk_each_fma_e4m3 */
|
|
964
|
+
NK_PUBLIC void nk_each_fma_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_e4m3_t const *c, nk_size_t n,
|
|
965
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e4m3_t *result);
|
|
966
|
+
/** @copydoc nk_each_fma_e5m2 */
|
|
967
|
+
NK_PUBLIC void nk_each_fma_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_e5m2_t const *c, nk_size_t n,
|
|
968
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e5m2_t *result);
|
|
969
|
+
|
|
970
|
+
/** @copydoc nk_each_scale_f64 */
|
|
971
|
+
NK_PUBLIC void nk_each_scale_f32c_skylake(nk_f32c_t const *a, nk_size_t n, nk_f32c_t const *alpha,
|
|
972
|
+
nk_f32c_t const *beta, nk_f32c_t *result);
|
|
973
|
+
/** @copydoc nk_each_scale_f64 */
|
|
974
|
+
NK_PUBLIC void nk_each_scale_f64c_skylake(nk_f64c_t const *a, nk_size_t n, nk_f64c_t const *alpha,
|
|
975
|
+
nk_f64c_t const *beta, nk_f64c_t *result);
|
|
976
|
+
/** @copydoc nk_each_blend_f64 */
|
|
977
|
+
NK_PUBLIC void nk_each_blend_f32c_skylake(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t const *alpha,
|
|
978
|
+
nk_f32c_t const *beta, nk_f32c_t *result);
|
|
979
|
+
/** @copydoc nk_each_blend_f64 */
|
|
980
|
+
NK_PUBLIC void nk_each_blend_f64c_skylake(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t const *alpha,
|
|
981
|
+
nk_f64c_t const *beta, nk_f64c_t *result);
|
|
982
|
+
/** @copydoc nk_each_fma_f64 */
|
|
983
|
+
NK_PUBLIC void nk_each_fma_f32c_skylake(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
|
|
984
|
+
nk_f32c_t const *alpha, nk_f32c_t const *beta, nk_f32c_t *result);
|
|
985
|
+
/** @copydoc nk_each_fma_f64 */
|
|
986
|
+
NK_PUBLIC void nk_each_fma_f64c_skylake(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
|
|
987
|
+
nk_f64c_t const *alpha, nk_f64c_t const *beta, nk_f64c_t *result);
|
|
988
|
+
#endif // NK_TARGET_SKYLAKE
|
|
989
|
+
|
|
990
|
+
#if NK_TARGET_ICELAKE
|
|
991
|
+
/** @copydoc nk_each_sum_i8 */
|
|
992
|
+
NK_PUBLIC void nk_each_sum_i8_icelake(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i8_t *result);
|
|
993
|
+
/** @copydoc nk_each_sum_u8 */
|
|
994
|
+
NK_PUBLIC void nk_each_sum_u8_icelake(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u8_t *result);
|
|
995
|
+
/** @copydoc nk_each_sum_i16 */
|
|
996
|
+
NK_PUBLIC void nk_each_sum_i16_icelake(nk_i16_t const *a, nk_i16_t const *b, nk_size_t n, nk_i16_t *result);
|
|
997
|
+
/** @copydoc nk_each_sum_u16 */
|
|
998
|
+
NK_PUBLIC void nk_each_sum_u16_icelake(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_u16_t *result);
|
|
999
|
+
/** @copydoc nk_each_sum_i32 */
|
|
1000
|
+
NK_PUBLIC void nk_each_sum_i32_icelake(nk_i32_t const *a, nk_i32_t const *b, nk_size_t n, nk_i32_t *result);
|
|
1001
|
+
/** @copydoc nk_each_sum_u32 */
|
|
1002
|
+
NK_PUBLIC void nk_each_sum_u32_icelake(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_u32_t *result);
|
|
1003
|
+
/** @copydoc nk_each_sum_i64 */
|
|
1004
|
+
NK_PUBLIC void nk_each_sum_i64_icelake(nk_i64_t const *a, nk_i64_t const *b, nk_size_t n, nk_i64_t *result);
|
|
1005
|
+
/** @copydoc nk_each_sum_u64 */
|
|
1006
|
+
NK_PUBLIC void nk_each_sum_u64_icelake(nk_u64_t const *a, nk_u64_t const *b, nk_size_t n, nk_u64_t *result);
|
|
1007
|
+
#endif // NK_TARGET_ICELAKE
|
|
1008
|
+
|
|
1009
|
+
#if NK_TARGET_SAPPHIRE
|
|
1010
|
+
/** @copydoc nk_each_scale_i8 */
|
|
1011
|
+
NK_PUBLIC void nk_each_scale_i8_sapphire(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1012
|
+
nk_i8_t *result);
|
|
1013
|
+
/** @copydoc nk_each_scale_u8 */
|
|
1014
|
+
NK_PUBLIC void nk_each_scale_u8_sapphire(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1015
|
+
nk_u8_t *result);
|
|
1016
|
+
|
|
1017
|
+
/** @copydoc nk_each_sum_f16 */
|
|
1018
|
+
NK_PUBLIC void nk_each_sum_f16_sapphire(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f16_t *result);
|
|
1019
|
+
/** @copydoc nk_each_sum_e4m3 */
|
|
1020
|
+
NK_PUBLIC void nk_each_sum_e4m3_sapphire(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result);
|
|
1021
|
+
|
|
1022
|
+
/** @copydoc nk_each_blend_i8 */
|
|
1023
|
+
NK_PUBLIC void nk_each_blend_i8_sapphire(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
1024
|
+
nk_f32_t const *beta, nk_i8_t *result);
|
|
1025
|
+
/** @copydoc nk_each_blend_u8 */
|
|
1026
|
+
NK_PUBLIC void nk_each_blend_u8_sapphire(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
1027
|
+
nk_f32_t const *beta, nk_u8_t *result);
|
|
1028
|
+
|
|
1029
|
+
/** @copydoc nk_each_fma_i8 */
|
|
1030
|
+
NK_PUBLIC void nk_each_fma_i8_sapphire(nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, nk_size_t n,
|
|
1031
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result);
|
|
1032
|
+
/** @copydoc nk_each_fma_u8 */
|
|
1033
|
+
NK_PUBLIC void nk_each_fma_u8_sapphire(nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, nk_size_t n,
|
|
1034
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result);
|
|
1035
|
+
#endif // NK_TARGET_SAPPHIRE
|
|
1036
|
+
|
|
1037
|
+
#if NK_TARGET_RVV
|
|
1038
|
+
/** @copydoc nk_each_sum_f64 */
|
|
1039
|
+
NK_PUBLIC void nk_each_sum_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result);
|
|
1040
|
+
/** @copydoc nk_each_sum_f32 */
|
|
1041
|
+
NK_PUBLIC void nk_each_sum_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *result);
|
|
1042
|
+
/** @copydoc nk_each_sum_f16 */
|
|
1043
|
+
NK_PUBLIC void nk_each_sum_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f16_t *result);
|
|
1044
|
+
/** @copydoc nk_each_sum_bf16 */
|
|
1045
|
+
NK_PUBLIC void nk_each_sum_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_bf16_t *result);
|
|
1046
|
+
/** @copydoc nk_each_sum_i8 */
|
|
1047
|
+
NK_PUBLIC void nk_each_sum_i8_rvv(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i8_t *result);
|
|
1048
|
+
/** @copydoc nk_each_sum_u8 */
|
|
1049
|
+
NK_PUBLIC void nk_each_sum_u8_rvv(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u8_t *result);
|
|
1050
|
+
/** @copydoc nk_each_sum_i16 */
|
|
1051
|
+
NK_PUBLIC void nk_each_sum_i16_rvv(nk_i16_t const *a, nk_i16_t const *b, nk_size_t n, nk_i16_t *result);
|
|
1052
|
+
/** @copydoc nk_each_sum_u16 */
|
|
1053
|
+
NK_PUBLIC void nk_each_sum_u16_rvv(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_u16_t *result);
|
|
1054
|
+
/** @copydoc nk_each_sum_i32 */
|
|
1055
|
+
NK_PUBLIC void nk_each_sum_i32_rvv(nk_i32_t const *a, nk_i32_t const *b, nk_size_t n, nk_i32_t *result);
|
|
1056
|
+
/** @copydoc nk_each_sum_u32 */
|
|
1057
|
+
NK_PUBLIC void nk_each_sum_u32_rvv(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_u32_t *result);
|
|
1058
|
+
/** @copydoc nk_each_sum_i64 */
|
|
1059
|
+
NK_PUBLIC void nk_each_sum_i64_rvv(nk_i64_t const *a, nk_i64_t const *b, nk_size_t n, nk_i64_t *result);
|
|
1060
|
+
/** @copydoc nk_each_sum_u64 */
|
|
1061
|
+
NK_PUBLIC void nk_each_sum_u64_rvv(nk_u64_t const *a, nk_u64_t const *b, nk_size_t n, nk_u64_t *result);
|
|
1062
|
+
/** @copydoc nk_each_sum_e4m3 */
|
|
1063
|
+
NK_PUBLIC void nk_each_sum_e4m3_rvv(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result);
|
|
1064
|
+
/** @copydoc nk_each_sum_e5m2 */
|
|
1065
|
+
NK_PUBLIC void nk_each_sum_e5m2_rvv(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_e5m2_t *result);
|
|
1066
|
+
|
|
1067
|
+
/** @copydoc nk_each_scale_f64 */
|
|
1068
|
+
NK_PUBLIC void nk_each_scale_f64_rvv(nk_f64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
1069
|
+
nk_f64_t *result);
|
|
1070
|
+
/** @copydoc nk_each_scale_f32 */
|
|
1071
|
+
NK_PUBLIC void nk_each_scale_f32_rvv(nk_f32_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1072
|
+
nk_f32_t *result);
|
|
1073
|
+
/** @copydoc nk_each_scale_f16 */
|
|
1074
|
+
NK_PUBLIC void nk_each_scale_f16_rvv(nk_f16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1075
|
+
nk_f16_t *result);
|
|
1076
|
+
/** @copydoc nk_each_scale_bf16 */
|
|
1077
|
+
NK_PUBLIC void nk_each_scale_bf16_rvv(nk_bf16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1078
|
+
nk_bf16_t *result);
|
|
1079
|
+
/** @copydoc nk_each_scale_i8 */
|
|
1080
|
+
NK_PUBLIC void nk_each_scale_i8_rvv(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1081
|
+
nk_i8_t *result);
|
|
1082
|
+
/** @copydoc nk_each_scale_u8 */
|
|
1083
|
+
NK_PUBLIC void nk_each_scale_u8_rvv(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1084
|
+
nk_u8_t *result);
|
|
1085
|
+
/** @copydoc nk_each_scale_i16 */
|
|
1086
|
+
NK_PUBLIC void nk_each_scale_i16_rvv(nk_i16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1087
|
+
nk_i16_t *result);
|
|
1088
|
+
/** @copydoc nk_each_scale_u16 */
|
|
1089
|
+
NK_PUBLIC void nk_each_scale_u16_rvv(nk_u16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1090
|
+
nk_u16_t *result);
|
|
1091
|
+
/** @copydoc nk_each_scale_i32 */
|
|
1092
|
+
NK_PUBLIC void nk_each_scale_i32_rvv(nk_i32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
1093
|
+
nk_i32_t *result);
|
|
1094
|
+
/** @copydoc nk_each_scale_u32 */
|
|
1095
|
+
NK_PUBLIC void nk_each_scale_u32_rvv(nk_u32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
1096
|
+
nk_u32_t *result);
|
|
1097
|
+
/** @copydoc nk_each_scale_i64 */
|
|
1098
|
+
NK_PUBLIC void nk_each_scale_i64_rvv(nk_i64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
1099
|
+
nk_i64_t *result);
|
|
1100
|
+
/** @copydoc nk_each_scale_u64 */
|
|
1101
|
+
NK_PUBLIC void nk_each_scale_u64_rvv(nk_u64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
1102
|
+
nk_u64_t *result);
|
|
1103
|
+
/** @copydoc nk_each_scale_e4m3 */
|
|
1104
|
+
NK_PUBLIC void nk_each_scale_e4m3_rvv(nk_e4m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1105
|
+
nk_e4m3_t *result);
|
|
1106
|
+
/** @copydoc nk_each_scale_e5m2 */
|
|
1107
|
+
NK_PUBLIC void nk_each_scale_e5m2_rvv(nk_e5m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1108
|
+
nk_e5m2_t *result);
|
|
1109
|
+
|
|
1110
|
+
/** @copydoc nk_each_blend_f64 */
|
|
1111
|
+
NK_PUBLIC void nk_each_blend_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t const *alpha,
|
|
1112
|
+
nk_f64_t const *beta, nk_f64_t *result);
|
|
1113
|
+
/** @copydoc nk_each_blend_f32 */
|
|
1114
|
+
NK_PUBLIC void nk_each_blend_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
1115
|
+
nk_f32_t const *beta, nk_f32_t *result);
|
|
1116
|
+
/** @copydoc nk_each_blend_f16 */
|
|
1117
|
+
NK_PUBLIC void nk_each_blend_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
1118
|
+
nk_f32_t const *beta, nk_f16_t *result);
|
|
1119
|
+
/** @copydoc nk_each_blend_bf16 */
|
|
1120
|
+
NK_PUBLIC void nk_each_blend_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
1121
|
+
nk_f32_t const *beta, nk_bf16_t *result);
|
|
1122
|
+
/** @copydoc nk_each_blend_i8 */
|
|
1123
|
+
NK_PUBLIC void nk_each_blend_i8_rvv(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
1124
|
+
nk_f32_t const *beta, nk_i8_t *result);
|
|
1125
|
+
/** @copydoc nk_each_blend_u8 */
|
|
1126
|
+
NK_PUBLIC void nk_each_blend_u8_rvv(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
1127
|
+
nk_f32_t const *beta, nk_u8_t *result);
|
|
1128
|
+
/** @copydoc nk_each_blend_e4m3 */
|
|
1129
|
+
NK_PUBLIC void nk_each_blend_e4m3_rvv(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
1130
|
+
nk_f32_t const *beta, nk_e4m3_t *result);
|
|
1131
|
+
/** @copydoc nk_each_blend_e5m2 */
|
|
1132
|
+
NK_PUBLIC void nk_each_blend_e5m2_rvv(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
1133
|
+
nk_f32_t const *beta, nk_e5m2_t *result);
|
|
1134
|
+
|
|
1135
|
+
/** @copydoc nk_each_fma_f64 */
|
|
1136
|
+
NK_PUBLIC void nk_each_fma_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
1137
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_f64_t *result);
|
|
1138
|
+
/** @copydoc nk_each_fma_f32 */
|
|
1139
|
+
NK_PUBLIC void nk_each_fma_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
1140
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_f32_t *result);
|
|
1141
|
+
/** @copydoc nk_each_fma_f16 */
|
|
1142
|
+
NK_PUBLIC void nk_each_fma_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
1143
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_f16_t *result);
|
|
1144
|
+
/** @copydoc nk_each_fma_bf16 */
|
|
1145
|
+
NK_PUBLIC void nk_each_fma_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
1146
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_bf16_t *result);
|
|
1147
|
+
/** @copydoc nk_each_fma_i8 */
|
|
1148
|
+
NK_PUBLIC void nk_each_fma_i8_rvv(nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, nk_size_t n,
|
|
1149
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_i8_t *result);
|
|
1150
|
+
/** @copydoc nk_each_fma_u8 */
|
|
1151
|
+
NK_PUBLIC void nk_each_fma_u8_rvv(nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, nk_size_t n,
|
|
1152
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_u8_t *result);
|
|
1153
|
+
/** @copydoc nk_each_fma_i16 */
|
|
1154
|
+
NK_PUBLIC void nk_each_fma_i16_rvv(nk_i16_t const *a, nk_i16_t const *b, nk_i16_t const *c, nk_size_t n,
|
|
1155
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_i16_t *result);
|
|
1156
|
+
/** @copydoc nk_each_fma_u16 */
|
|
1157
|
+
NK_PUBLIC void nk_each_fma_u16_rvv(nk_u16_t const *a, nk_u16_t const *b, nk_u16_t const *c, nk_size_t n,
|
|
1158
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_u16_t *result);
|
|
1159
|
+
/** @copydoc nk_each_fma_i32 */
|
|
1160
|
+
NK_PUBLIC void nk_each_fma_i32_rvv(nk_i32_t const *a, nk_i32_t const *b, nk_i32_t const *c, nk_size_t n,
|
|
1161
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_i32_t *result);
|
|
1162
|
+
/** @copydoc nk_each_fma_u32 */
|
|
1163
|
+
NK_PUBLIC void nk_each_fma_u32_rvv(nk_u32_t const *a, nk_u32_t const *b, nk_u32_t const *c, nk_size_t n,
|
|
1164
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_u32_t *result);
|
|
1165
|
+
/** @copydoc nk_each_fma_i64 */
|
|
1166
|
+
NK_PUBLIC void nk_each_fma_i64_rvv(nk_i64_t const *a, nk_i64_t const *b, nk_i64_t const *c, nk_size_t n,
|
|
1167
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_i64_t *result);
|
|
1168
|
+
/** @copydoc nk_each_fma_u64 */
|
|
1169
|
+
NK_PUBLIC void nk_each_fma_u64_rvv(nk_u64_t const *a, nk_u64_t const *b, nk_u64_t const *c, nk_size_t n,
|
|
1170
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_u64_t *result);
|
|
1171
|
+
/** @copydoc nk_each_fma_e4m3 */
|
|
1172
|
+
NK_PUBLIC void nk_each_fma_e4m3_rvv(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_e4m3_t const *c, nk_size_t n,
|
|
1173
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e4m3_t *result);
|
|
1174
|
+
/** @copydoc nk_each_fma_e5m2 */
|
|
1175
|
+
NK_PUBLIC void nk_each_fma_e5m2_rvv(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_e5m2_t const *c, nk_size_t n,
|
|
1176
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e5m2_t *result);
|
|
1177
|
+
/** @copydoc nk_each_scale_f32c */
|
|
1178
|
+
NK_PUBLIC void nk_each_scale_f32c_rvv(nk_f32c_t const *a, nk_size_t n, nk_f32c_t const *alpha, nk_f32c_t const *beta,
|
|
1179
|
+
nk_f32c_t *result);
|
|
1180
|
+
/** @copydoc nk_each_scale_f64c */
|
|
1181
|
+
NK_PUBLIC void nk_each_scale_f64c_rvv(nk_f64c_t const *a, nk_size_t n, nk_f64c_t const *alpha, nk_f64c_t const *beta,
|
|
1182
|
+
nk_f64c_t *result);
|
|
1183
|
+
/** @copydoc nk_each_blend_f32c */
|
|
1184
|
+
NK_PUBLIC void nk_each_blend_f32c_rvv(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t const *alpha,
|
|
1185
|
+
nk_f32c_t const *beta, nk_f32c_t *result);
|
|
1186
|
+
/** @copydoc nk_each_blend_f64c */
|
|
1187
|
+
NK_PUBLIC void nk_each_blend_f64c_rvv(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t const *alpha,
|
|
1188
|
+
nk_f64c_t const *beta, nk_f64c_t *result);
|
|
1189
|
+
/** @copydoc nk_each_fma_f32c */
|
|
1190
|
+
NK_PUBLIC void nk_each_fma_f32c_rvv(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
|
|
1191
|
+
nk_f32c_t const *alpha, nk_f32c_t const *beta, nk_f32c_t *result);
|
|
1192
|
+
/** @copydoc nk_each_fma_f64c */
|
|
1193
|
+
NK_PUBLIC void nk_each_fma_f64c_rvv(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
|
|
1194
|
+
nk_f64c_t const *alpha, nk_f64c_t const *beta, nk_f64c_t *result);
|
|
1195
|
+
#endif // NK_TARGET_RVV
|
|
1196
|
+
|
|
1197
|
+
/**
|
|
1198
|
+
* @brief Returns the scalar parameter dtype for elementwise scale/blend/fma operations.
|
|
1199
|
+
*/
|
|
1200
|
+
NK_INTERNAL nk_dtype_t nk_each_scale_input_dtype(nk_dtype_t dtype) {
|
|
1201
|
+
switch (dtype) {
|
|
1202
|
+
case nk_f64c_k: return nk_f64c_k;
|
|
1203
|
+
case nk_f32c_k: return nk_f32c_k;
|
|
1204
|
+
case nk_f64_k: return nk_f64_k;
|
|
1205
|
+
case nk_f32_k: return nk_f32_k;
|
|
1206
|
+
case nk_f16_k: return nk_f32_k;
|
|
1207
|
+
case nk_bf16_k: return nk_f32_k;
|
|
1208
|
+
case nk_i64_k: return nk_f64_k;
|
|
1209
|
+
case nk_u64_k: return nk_f64_k;
|
|
1210
|
+
case nk_i32_k: return nk_f64_k;
|
|
1211
|
+
case nk_u32_k: return nk_f64_k;
|
|
1212
|
+
case nk_i16_k: return nk_f32_k;
|
|
1213
|
+
case nk_u16_k: return nk_f32_k;
|
|
1214
|
+
case nk_i8_k: return nk_f32_k;
|
|
1215
|
+
case nk_u8_k: return nk_f32_k;
|
|
1216
|
+
default: return nk_dtype_unknown_k;
|
|
1217
|
+
}
|
|
1218
|
+
}
|
|
1219
|
+
|
|
1220
|
+
#if defined(__cplusplus)
|
|
1221
|
+
} // extern "C"
|
|
1222
|
+
#endif
|
|
1223
|
+
|
|
1224
|
+
#include "numkong/each/serial.h"
|
|
1225
|
+
#include "numkong/each/neon.h"
|
|
1226
|
+
#include "numkong/each/neonhalf.h"
|
|
1227
|
+
#include "numkong/each/neonbfdot.h"
|
|
1228
|
+
#include "numkong/each/haswell.h"
|
|
1229
|
+
#include "numkong/each/skylake.h"
|
|
1230
|
+
#include "numkong/each/icelake.h"
|
|
1231
|
+
#include "numkong/each/sapphire.h"
|
|
1232
|
+
#include "numkong/each/rvv.h"
|
|
1233
|
+
|
|
1234
|
+
#if defined(__cplusplus)
|
|
1235
|
+
extern "C" {
|
|
1236
|
+
#endif
|
|
1237
|
+
|
|
1238
|
+
#if !NK_DYNAMIC_DISPATCH
|
|
1239
|
+
|
|
1240
|
+
NK_PUBLIC void nk_each_sum_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *r) {
|
|
1241
|
+
#if NK_TARGET_SKYLAKE
|
|
1242
|
+
nk_each_sum_f64_skylake(a, b, n, r);
|
|
1243
|
+
#elif NK_TARGET_HASWELL
|
|
1244
|
+
nk_each_sum_f64_haswell(a, b, n, r);
|
|
1245
|
+
#elif NK_TARGET_NEON
|
|
1246
|
+
nk_each_sum_f64_neon(a, b, n, r);
|
|
1247
|
+
#elif NK_TARGET_RVV
|
|
1248
|
+
nk_each_sum_f64_rvv(a, b, n, r);
|
|
1249
|
+
#else
|
|
1250
|
+
nk_each_sum_f64_serial(a, b, n, r);
|
|
1251
|
+
#endif
|
|
1252
|
+
}
|
|
1253
|
+
|
|
1254
|
+
NK_PUBLIC void nk_each_sum_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *r) {
|
|
1255
|
+
#if NK_TARGET_SKYLAKE
|
|
1256
|
+
nk_each_sum_f32_skylake(a, b, n, r);
|
|
1257
|
+
#elif NK_TARGET_HASWELL
|
|
1258
|
+
nk_each_sum_f32_haswell(a, b, n, r);
|
|
1259
|
+
#elif NK_TARGET_NEON
|
|
1260
|
+
nk_each_sum_f32_neon(a, b, n, r);
|
|
1261
|
+
#elif NK_TARGET_RVV
|
|
1262
|
+
nk_each_sum_f32_rvv(a, b, n, r);
|
|
1263
|
+
#else
|
|
1264
|
+
nk_each_sum_f32_serial(a, b, n, r);
|
|
1265
|
+
#endif
|
|
1266
|
+
}
|
|
1267
|
+
|
|
1268
|
+
NK_PUBLIC void nk_each_sum_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_bf16_t *r) {
|
|
1269
|
+
#if NK_TARGET_SKYLAKE
|
|
1270
|
+
nk_each_sum_bf16_skylake(a, b, n, r);
|
|
1271
|
+
#elif NK_TARGET_HASWELL
|
|
1272
|
+
nk_each_sum_bf16_haswell(a, b, n, r);
|
|
1273
|
+
#elif NK_TARGET_NEONBFDOT
|
|
1274
|
+
nk_each_sum_bf16_neonbfdot(a, b, n, r);
|
|
1275
|
+
#elif NK_TARGET_RVV
|
|
1276
|
+
nk_each_sum_bf16_rvv(a, b, n, r);
|
|
1277
|
+
#else
|
|
1278
|
+
nk_each_sum_bf16_serial(a, b, n, r);
|
|
1279
|
+
#endif
|
|
1280
|
+
}
|
|
1281
|
+
|
|
1282
|
+
NK_PUBLIC void nk_each_sum_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f16_t *r) {
|
|
1283
|
+
#if NK_TARGET_SAPPHIRE
|
|
1284
|
+
nk_each_sum_f16_sapphire(a, b, n, r);
|
|
1285
|
+
#elif NK_TARGET_HASWELL
|
|
1286
|
+
nk_each_sum_f16_haswell(a, b, n, r);
|
|
1287
|
+
#elif NK_TARGET_NEONHALF
|
|
1288
|
+
nk_each_sum_f16_neonhalf(a, b, n, r);
|
|
1289
|
+
#elif NK_TARGET_RVV
|
|
1290
|
+
nk_each_sum_f16_rvv(a, b, n, r);
|
|
1291
|
+
#else
|
|
1292
|
+
nk_each_sum_f16_serial(a, b, n, r);
|
|
1293
|
+
#endif
|
|
1294
|
+
}
|
|
1295
|
+
|
|
1296
|
+
NK_PUBLIC void nk_each_sum_i8(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_i8_t *r) {
|
|
1297
|
+
#if NK_TARGET_ICELAKE
|
|
1298
|
+
nk_each_sum_i8_icelake(a, b, n, r);
|
|
1299
|
+
#elif NK_TARGET_HASWELL
|
|
1300
|
+
nk_each_sum_i8_haswell(a, b, n, r);
|
|
1301
|
+
#elif NK_TARGET_NEONHALF
|
|
1302
|
+
nk_each_sum_i8_neonhalf(a, b, n, r);
|
|
1303
|
+
#elif NK_TARGET_RVV
|
|
1304
|
+
nk_each_sum_i8_rvv(a, b, n, r);
|
|
1305
|
+
#else
|
|
1306
|
+
nk_each_sum_i8_serial(a, b, n, r);
|
|
1307
|
+
#endif
|
|
1308
|
+
}
|
|
1309
|
+
|
|
1310
|
+
NK_PUBLIC void nk_each_sum_u8(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u8_t *r) {
|
|
1311
|
+
#if NK_TARGET_ICELAKE
|
|
1312
|
+
nk_each_sum_u8_icelake(a, b, n, r);
|
|
1313
|
+
#elif NK_TARGET_HASWELL
|
|
1314
|
+
nk_each_sum_u8_haswell(a, b, n, r);
|
|
1315
|
+
#elif NK_TARGET_NEONHALF
|
|
1316
|
+
nk_each_sum_u8_neonhalf(a, b, n, r);
|
|
1317
|
+
#elif NK_TARGET_RVV
|
|
1318
|
+
nk_each_sum_u8_rvv(a, b, n, r);
|
|
1319
|
+
#else
|
|
1320
|
+
nk_each_sum_u8_serial(a, b, n, r);
|
|
1321
|
+
#endif
|
|
1322
|
+
}
|
|
1323
|
+
|
|
1324
|
+
NK_PUBLIC void nk_each_sum_i16(nk_i16_t const *a, nk_i16_t const *b, nk_size_t n, nk_i16_t *r) {
|
|
1325
|
+
#if NK_TARGET_ICELAKE
|
|
1326
|
+
nk_each_sum_i16_icelake(a, b, n, r);
|
|
1327
|
+
#elif NK_TARGET_HASWELL
|
|
1328
|
+
nk_each_sum_i16_haswell(a, b, n, r);
|
|
1329
|
+
#elif NK_TARGET_NEON
|
|
1330
|
+
nk_each_sum_i16_neon(a, b, n, r);
|
|
1331
|
+
#elif NK_TARGET_RVV
|
|
1332
|
+
nk_each_sum_i16_rvv(a, b, n, r);
|
|
1333
|
+
#else
|
|
1334
|
+
nk_each_sum_i16_serial(a, b, n, r);
|
|
1335
|
+
#endif
|
|
1336
|
+
}
|
|
1337
|
+
|
|
1338
|
+
NK_PUBLIC void nk_each_sum_u16(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_u16_t *r) {
|
|
1339
|
+
#if NK_TARGET_ICELAKE
|
|
1340
|
+
nk_each_sum_u16_icelake(a, b, n, r);
|
|
1341
|
+
#elif NK_TARGET_HASWELL
|
|
1342
|
+
nk_each_sum_u16_haswell(a, b, n, r);
|
|
1343
|
+
#elif NK_TARGET_NEON
|
|
1344
|
+
nk_each_sum_u16_neon(a, b, n, r);
|
|
1345
|
+
#elif NK_TARGET_RVV
|
|
1346
|
+
nk_each_sum_u16_rvv(a, b, n, r);
|
|
1347
|
+
#else
|
|
1348
|
+
nk_each_sum_u16_serial(a, b, n, r);
|
|
1349
|
+
#endif
|
|
1350
|
+
}
|
|
1351
|
+
|
|
1352
|
+
NK_PUBLIC void nk_each_sum_i32(nk_i32_t const *a, nk_i32_t const *b, nk_size_t n, nk_i32_t *r) {
|
|
1353
|
+
#if NK_TARGET_ICELAKE
|
|
1354
|
+
nk_each_sum_i32_icelake(a, b, n, r);
|
|
1355
|
+
#elif NK_TARGET_HASWELL
|
|
1356
|
+
nk_each_sum_i32_haswell(a, b, n, r);
|
|
1357
|
+
#elif NK_TARGET_NEON
|
|
1358
|
+
nk_each_sum_i32_neon(a, b, n, r);
|
|
1359
|
+
#elif NK_TARGET_RVV
|
|
1360
|
+
nk_each_sum_i32_rvv(a, b, n, r);
|
|
1361
|
+
#else
|
|
1362
|
+
nk_each_sum_i32_serial(a, b, n, r);
|
|
1363
|
+
#endif
|
|
1364
|
+
}
|
|
1365
|
+
|
|
1366
|
+
NK_PUBLIC void nk_each_sum_u32(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_u32_t *r) {
|
|
1367
|
+
#if NK_TARGET_ICELAKE
|
|
1368
|
+
nk_each_sum_u32_icelake(a, b, n, r);
|
|
1369
|
+
#elif NK_TARGET_HASWELL
|
|
1370
|
+
nk_each_sum_u32_haswell(a, b, n, r);
|
|
1371
|
+
#elif NK_TARGET_NEON
|
|
1372
|
+
nk_each_sum_u32_neon(a, b, n, r);
|
|
1373
|
+
#elif NK_TARGET_RVV
|
|
1374
|
+
nk_each_sum_u32_rvv(a, b, n, r);
|
|
1375
|
+
#else
|
|
1376
|
+
nk_each_sum_u32_serial(a, b, n, r);
|
|
1377
|
+
#endif
|
|
1378
|
+
}
|
|
1379
|
+
|
|
1380
|
+
NK_PUBLIC void nk_each_sum_i64(nk_i64_t const *a, nk_i64_t const *b, nk_size_t n, nk_i64_t *r) {
|
|
1381
|
+
#if NK_TARGET_ICELAKE
|
|
1382
|
+
nk_each_sum_i64_icelake(a, b, n, r);
|
|
1383
|
+
#elif NK_TARGET_NEON
|
|
1384
|
+
nk_each_sum_i64_neon(a, b, n, r);
|
|
1385
|
+
#elif NK_TARGET_RVV
|
|
1386
|
+
nk_each_sum_i64_rvv(a, b, n, r);
|
|
1387
|
+
#else
|
|
1388
|
+
nk_each_sum_i64_serial(a, b, n, r);
|
|
1389
|
+
#endif
|
|
1390
|
+
}
|
|
1391
|
+
|
|
1392
|
+
NK_PUBLIC void nk_each_sum_u64(nk_u64_t const *a, nk_u64_t const *b, nk_size_t n, nk_u64_t *r) {
|
|
1393
|
+
#if NK_TARGET_ICELAKE
|
|
1394
|
+
nk_each_sum_u64_icelake(a, b, n, r);
|
|
1395
|
+
#elif NK_TARGET_NEON
|
|
1396
|
+
nk_each_sum_u64_neon(a, b, n, r);
|
|
1397
|
+
#elif NK_TARGET_RVV
|
|
1398
|
+
nk_each_sum_u64_rvv(a, b, n, r);
|
|
1399
|
+
#else
|
|
1400
|
+
nk_each_sum_u64_serial(a, b, n, r);
|
|
1401
|
+
#endif
|
|
1402
|
+
}
|
|
1403
|
+
|
|
1404
|
+
NK_PUBLIC void nk_each_scale_f64(nk_f64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
1405
|
+
nk_f64_t *r) {
|
|
1406
|
+
#if NK_TARGET_SKYLAKE
|
|
1407
|
+
nk_each_scale_f64_skylake(a, n, alpha, beta, r);
|
|
1408
|
+
#elif NK_TARGET_HASWELL
|
|
1409
|
+
nk_each_scale_f64_haswell(a, n, alpha, beta, r);
|
|
1410
|
+
#elif NK_TARGET_NEON
|
|
1411
|
+
nk_each_scale_f64_neon(a, n, alpha, beta, r);
|
|
1412
|
+
#elif NK_TARGET_RVV
|
|
1413
|
+
nk_each_scale_f64_rvv(a, n, alpha, beta, r);
|
|
1414
|
+
#else
|
|
1415
|
+
nk_each_scale_f64_serial(a, n, alpha, beta, r);
|
|
1416
|
+
#endif
|
|
1417
|
+
}
|
|
1418
|
+
|
|
1419
|
+
NK_PUBLIC void nk_each_scale_f32(nk_f32_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1420
|
+
nk_f32_t *r) {
|
|
1421
|
+
#if NK_TARGET_SKYLAKE
|
|
1422
|
+
nk_each_scale_f32_skylake(a, n, alpha, beta, r);
|
|
1423
|
+
#elif NK_TARGET_HASWELL
|
|
1424
|
+
nk_each_scale_f32_haswell(a, n, alpha, beta, r);
|
|
1425
|
+
#elif NK_TARGET_NEON
|
|
1426
|
+
nk_each_scale_f32_neon(a, n, alpha, beta, r);
|
|
1427
|
+
#elif NK_TARGET_RVV
|
|
1428
|
+
nk_each_scale_f32_rvv(a, n, alpha, beta, r);
|
|
1429
|
+
#else
|
|
1430
|
+
nk_each_scale_f32_serial(a, n, alpha, beta, r);
|
|
1431
|
+
#endif
|
|
1432
|
+
}
|
|
1433
|
+
|
|
1434
|
+
NK_PUBLIC void nk_each_scale_bf16(nk_bf16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1435
|
+
nk_bf16_t *r) {
|
|
1436
|
+
#if NK_TARGET_SKYLAKE
|
|
1437
|
+
nk_each_scale_bf16_skylake(a, n, alpha, beta, r);
|
|
1438
|
+
#elif NK_TARGET_HASWELL
|
|
1439
|
+
nk_each_scale_bf16_haswell(a, n, alpha, beta, r);
|
|
1440
|
+
#elif NK_TARGET_NEONBFDOT
|
|
1441
|
+
nk_each_scale_bf16_neonbfdot(a, n, alpha, beta, r);
|
|
1442
|
+
#elif NK_TARGET_RVV
|
|
1443
|
+
nk_each_scale_bf16_rvv(a, n, alpha, beta, r);
|
|
1444
|
+
#else
|
|
1445
|
+
nk_each_scale_bf16_serial(a, n, alpha, beta, r);
|
|
1446
|
+
#endif
|
|
1447
|
+
}
|
|
1448
|
+
|
|
1449
|
+
NK_PUBLIC void nk_each_scale_f16(nk_f16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1450
|
+
nk_f16_t *r) {
|
|
1451
|
+
#if NK_TARGET_SKYLAKE
|
|
1452
|
+
nk_each_scale_f16_skylake(a, n, alpha, beta, r);
|
|
1453
|
+
#elif NK_TARGET_HASWELL
|
|
1454
|
+
nk_each_scale_f16_haswell(a, n, alpha, beta, r);
|
|
1455
|
+
#elif NK_TARGET_NEONHALF
|
|
1456
|
+
nk_each_scale_f16_neonhalf(a, n, alpha, beta, r);
|
|
1457
|
+
#elif NK_TARGET_RVV
|
|
1458
|
+
nk_each_scale_f16_rvv(a, n, alpha, beta, r);
|
|
1459
|
+
#else
|
|
1460
|
+
nk_each_scale_f16_serial(a, n, alpha, beta, r);
|
|
1461
|
+
#endif
|
|
1462
|
+
}
|
|
1463
|
+
|
|
1464
|
+
NK_PUBLIC void nk_each_scale_i8(nk_i8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1465
|
+
nk_i8_t *r) {
|
|
1466
|
+
#if NK_TARGET_SAPPHIRE
|
|
1467
|
+
nk_each_scale_i8_sapphire(a, n, alpha, beta, r);
|
|
1468
|
+
#elif NK_TARGET_SKYLAKE
|
|
1469
|
+
nk_each_scale_i8_skylake(a, n, alpha, beta, r);
|
|
1470
|
+
#elif NK_TARGET_HASWELL
|
|
1471
|
+
nk_each_scale_i8_haswell(a, n, alpha, beta, r);
|
|
1472
|
+
#elif NK_TARGET_NEONHALF
|
|
1473
|
+
nk_each_scale_i8_neonhalf(a, n, alpha, beta, r);
|
|
1474
|
+
#elif NK_TARGET_RVV
|
|
1475
|
+
nk_each_scale_i8_rvv(a, n, alpha, beta, r);
|
|
1476
|
+
#else
|
|
1477
|
+
nk_each_scale_i8_serial(a, n, alpha, beta, r);
|
|
1478
|
+
#endif
|
|
1479
|
+
}
|
|
1480
|
+
|
|
1481
|
+
NK_PUBLIC void nk_each_scale_u8(nk_u8_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1482
|
+
nk_u8_t *r) {
|
|
1483
|
+
#if NK_TARGET_SAPPHIRE
|
|
1484
|
+
nk_each_scale_u8_sapphire(a, n, alpha, beta, r);
|
|
1485
|
+
#elif NK_TARGET_SKYLAKE
|
|
1486
|
+
nk_each_scale_u8_skylake(a, n, alpha, beta, r);
|
|
1487
|
+
#elif NK_TARGET_HASWELL
|
|
1488
|
+
nk_each_scale_u8_haswell(a, n, alpha, beta, r);
|
|
1489
|
+
#elif NK_TARGET_NEONHALF
|
|
1490
|
+
nk_each_scale_u8_neonhalf(a, n, alpha, beta, r);
|
|
1491
|
+
#elif NK_TARGET_RVV
|
|
1492
|
+
nk_each_scale_u8_rvv(a, n, alpha, beta, r);
|
|
1493
|
+
#else
|
|
1494
|
+
nk_each_scale_u8_serial(a, n, alpha, beta, r);
|
|
1495
|
+
#endif
|
|
1496
|
+
}
|
|
1497
|
+
|
|
1498
|
+
NK_PUBLIC void nk_each_scale_i16(nk_i16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1499
|
+
nk_i16_t *r) {
|
|
1500
|
+
#if NK_TARGET_SKYLAKE
|
|
1501
|
+
nk_each_scale_i16_skylake(a, n, alpha, beta, r);
|
|
1502
|
+
#elif NK_TARGET_HASWELL
|
|
1503
|
+
nk_each_scale_i16_haswell(a, n, alpha, beta, r);
|
|
1504
|
+
#elif NK_TARGET_NEON
|
|
1505
|
+
nk_each_scale_i16_neon(a, n, alpha, beta, r);
|
|
1506
|
+
#elif NK_TARGET_RVV
|
|
1507
|
+
nk_each_scale_i16_rvv(a, n, alpha, beta, r);
|
|
1508
|
+
#else
|
|
1509
|
+
nk_each_scale_i16_serial(a, n, alpha, beta, r);
|
|
1510
|
+
#endif
|
|
1511
|
+
}
|
|
1512
|
+
|
|
1513
|
+
NK_PUBLIC void nk_each_scale_u16(nk_u16_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1514
|
+
nk_u16_t *r) {
|
|
1515
|
+
#if NK_TARGET_SKYLAKE
|
|
1516
|
+
nk_each_scale_u16_skylake(a, n, alpha, beta, r);
|
|
1517
|
+
#elif NK_TARGET_HASWELL
|
|
1518
|
+
nk_each_scale_u16_haswell(a, n, alpha, beta, r);
|
|
1519
|
+
#elif NK_TARGET_NEON
|
|
1520
|
+
nk_each_scale_u16_neon(a, n, alpha, beta, r);
|
|
1521
|
+
#elif NK_TARGET_RVV
|
|
1522
|
+
nk_each_scale_u16_rvv(a, n, alpha, beta, r);
|
|
1523
|
+
#else
|
|
1524
|
+
nk_each_scale_u16_serial(a, n, alpha, beta, r);
|
|
1525
|
+
#endif
|
|
1526
|
+
}
|
|
1527
|
+
|
|
1528
|
+
NK_PUBLIC void nk_each_scale_i32(nk_i32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
1529
|
+
nk_i32_t *r) {
|
|
1530
|
+
#if NK_TARGET_SKYLAKE
|
|
1531
|
+
nk_each_scale_i32_skylake(a, n, alpha, beta, r);
|
|
1532
|
+
#elif NK_TARGET_HASWELL
|
|
1533
|
+
nk_each_scale_i32_haswell(a, n, alpha, beta, r);
|
|
1534
|
+
#elif NK_TARGET_NEON
|
|
1535
|
+
nk_each_scale_i32_neon(a, n, alpha, beta, r);
|
|
1536
|
+
#elif NK_TARGET_RVV
|
|
1537
|
+
nk_each_scale_i32_rvv(a, n, alpha, beta, r);
|
|
1538
|
+
#else
|
|
1539
|
+
nk_each_scale_i32_serial(a, n, alpha, beta, r);
|
|
1540
|
+
#endif
|
|
1541
|
+
}
|
|
1542
|
+
|
|
1543
|
+
NK_PUBLIC void nk_each_scale_u32(nk_u32_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
1544
|
+
nk_u32_t *r) {
|
|
1545
|
+
#if NK_TARGET_SKYLAKE
|
|
1546
|
+
nk_each_scale_u32_skylake(a, n, alpha, beta, r);
|
|
1547
|
+
#elif NK_TARGET_HASWELL
|
|
1548
|
+
nk_each_scale_u32_haswell(a, n, alpha, beta, r);
|
|
1549
|
+
#elif NK_TARGET_NEON
|
|
1550
|
+
nk_each_scale_u32_neon(a, n, alpha, beta, r);
|
|
1551
|
+
#elif NK_TARGET_RVV
|
|
1552
|
+
nk_each_scale_u32_rvv(a, n, alpha, beta, r);
|
|
1553
|
+
#else
|
|
1554
|
+
nk_each_scale_u32_serial(a, n, alpha, beta, r);
|
|
1555
|
+
#endif
|
|
1556
|
+
}
|
|
1557
|
+
|
|
1558
|
+
NK_PUBLIC void nk_each_scale_i64(nk_i64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
1559
|
+
nk_i64_t *r) {
|
|
1560
|
+
#if NK_TARGET_SKYLAKE
|
|
1561
|
+
nk_each_scale_i64_skylake(a, n, alpha, beta, r);
|
|
1562
|
+
#elif NK_TARGET_NEON
|
|
1563
|
+
nk_each_scale_i64_neon(a, n, alpha, beta, r);
|
|
1564
|
+
#elif NK_TARGET_RVV
|
|
1565
|
+
nk_each_scale_i64_rvv(a, n, alpha, beta, r);
|
|
1566
|
+
#else
|
|
1567
|
+
nk_each_scale_i64_serial(a, n, alpha, beta, r);
|
|
1568
|
+
#endif
|
|
1569
|
+
}
|
|
1570
|
+
|
|
1571
|
+
NK_PUBLIC void nk_each_scale_u64(nk_u64_t const *a, nk_size_t n, nk_f64_t const *alpha, nk_f64_t const *beta,
|
|
1572
|
+
nk_u64_t *r) {
|
|
1573
|
+
#if NK_TARGET_SKYLAKE
|
|
1574
|
+
nk_each_scale_u64_skylake(a, n, alpha, beta, r);
|
|
1575
|
+
#elif NK_TARGET_NEON
|
|
1576
|
+
nk_each_scale_u64_neon(a, n, alpha, beta, r);
|
|
1577
|
+
#elif NK_TARGET_RVV
|
|
1578
|
+
nk_each_scale_u64_rvv(a, n, alpha, beta, r);
|
|
1579
|
+
#else
|
|
1580
|
+
nk_each_scale_u64_serial(a, n, alpha, beta, r);
|
|
1581
|
+
#endif
|
|
1582
|
+
}
|
|
1583
|
+
|
|
1584
|
+
NK_PUBLIC void nk_each_blend_f64(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t const *alpha,
|
|
1585
|
+
nk_f64_t const *beta, nk_f64_t *r) {
|
|
1586
|
+
#if NK_TARGET_SKYLAKE
|
|
1587
|
+
nk_each_blend_f64_skylake(a, b, n, alpha, beta, r);
|
|
1588
|
+
#elif NK_TARGET_HASWELL
|
|
1589
|
+
nk_each_blend_f64_haswell(a, b, n, alpha, beta, r);
|
|
1590
|
+
#elif NK_TARGET_NEON
|
|
1591
|
+
nk_each_blend_f64_neon(a, b, n, alpha, beta, r);
|
|
1592
|
+
#elif NK_TARGET_RVV
|
|
1593
|
+
nk_each_blend_f64_rvv(a, b, n, alpha, beta, r);
|
|
1594
|
+
#else
|
|
1595
|
+
nk_each_blend_f64_serial(a, b, n, alpha, beta, r);
|
|
1596
|
+
#endif
|
|
1597
|
+
}
|
|
1598
|
+
|
|
1599
|
+
NK_PUBLIC void nk_each_blend_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
1600
|
+
nk_f32_t const *beta, nk_f32_t *r) {
|
|
1601
|
+
#if NK_TARGET_SKYLAKE
|
|
1602
|
+
nk_each_blend_f32_skylake(a, b, n, alpha, beta, r);
|
|
1603
|
+
#elif NK_TARGET_HASWELL
|
|
1604
|
+
nk_each_blend_f32_haswell(a, b, n, alpha, beta, r);
|
|
1605
|
+
#elif NK_TARGET_NEON
|
|
1606
|
+
nk_each_blend_f32_neon(a, b, n, alpha, beta, r);
|
|
1607
|
+
#elif NK_TARGET_RVV
|
|
1608
|
+
nk_each_blend_f32_rvv(a, b, n, alpha, beta, r);
|
|
1609
|
+
#else
|
|
1610
|
+
nk_each_blend_f32_serial(a, b, n, alpha, beta, r);
|
|
1611
|
+
#endif
|
|
1612
|
+
}
|
|
1613
|
+
|
|
1614
|
+
NK_PUBLIC void nk_each_blend_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
1615
|
+
nk_f32_t const *beta, nk_bf16_t *r) {
|
|
1616
|
+
#if NK_TARGET_SKYLAKE
|
|
1617
|
+
nk_each_blend_bf16_skylake(a, b, n, alpha, beta, r);
|
|
1618
|
+
#elif NK_TARGET_HASWELL
|
|
1619
|
+
nk_each_blend_bf16_haswell(a, b, n, alpha, beta, r);
|
|
1620
|
+
#elif NK_TARGET_NEONBFDOT
|
|
1621
|
+
nk_each_blend_bf16_neonbfdot(a, b, n, alpha, beta, r);
|
|
1622
|
+
#elif NK_TARGET_RVV
|
|
1623
|
+
nk_each_blend_bf16_rvv(a, b, n, alpha, beta, r);
|
|
1624
|
+
#else
|
|
1625
|
+
nk_each_blend_bf16_serial(a, b, n, alpha, beta, r);
|
|
1626
|
+
#endif
|
|
1627
|
+
}
|
|
1628
|
+
|
|
1629
|
+
NK_PUBLIC void nk_each_blend_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
1630
|
+
nk_f32_t const *beta, nk_f16_t *r) {
|
|
1631
|
+
#if NK_TARGET_SKYLAKE
|
|
1632
|
+
nk_each_blend_f16_skylake(a, b, n, alpha, beta, r);
|
|
1633
|
+
#elif NK_TARGET_HASWELL
|
|
1634
|
+
nk_each_blend_f16_haswell(a, b, n, alpha, beta, r);
|
|
1635
|
+
#elif NK_TARGET_NEONHALF
|
|
1636
|
+
nk_each_blend_f16_neonhalf(a, b, n, alpha, beta, r);
|
|
1637
|
+
#elif NK_TARGET_RVV
|
|
1638
|
+
nk_each_blend_f16_rvv(a, b, n, alpha, beta, r);
|
|
1639
|
+
#else
|
|
1640
|
+
nk_each_blend_f16_serial(a, b, n, alpha, beta, r);
|
|
1641
|
+
#endif
|
|
1642
|
+
}
|
|
1643
|
+
|
|
1644
|
+
NK_PUBLIC void nk_each_blend_i8(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
1645
|
+
nk_f32_t const *beta, nk_i8_t *r) {
|
|
1646
|
+
#if NK_TARGET_SAPPHIRE
|
|
1647
|
+
nk_each_blend_i8_sapphire(a, b, n, alpha, beta, r);
|
|
1648
|
+
#elif NK_TARGET_HASWELL
|
|
1649
|
+
nk_each_blend_i8_haswell(a, b, n, alpha, beta, r);
|
|
1650
|
+
#elif NK_TARGET_NEONHALF
|
|
1651
|
+
nk_each_blend_i8_neonhalf(a, b, n, alpha, beta, r);
|
|
1652
|
+
#elif NK_TARGET_RVV
|
|
1653
|
+
nk_each_blend_i8_rvv(a, b, n, alpha, beta, r);
|
|
1654
|
+
#else
|
|
1655
|
+
nk_each_blend_i8_serial(a, b, n, alpha, beta, r);
|
|
1656
|
+
#endif
|
|
1657
|
+
}
|
|
1658
|
+
|
|
1659
|
+
NK_PUBLIC void nk_each_blend_u8(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
1660
|
+
nk_f32_t const *beta, nk_u8_t *r) {
|
|
1661
|
+
#if NK_TARGET_SAPPHIRE
|
|
1662
|
+
nk_each_blend_u8_sapphire(a, b, n, alpha, beta, r);
|
|
1663
|
+
#elif NK_TARGET_HASWELL
|
|
1664
|
+
nk_each_blend_u8_haswell(a, b, n, alpha, beta, r);
|
|
1665
|
+
#elif NK_TARGET_NEONHALF
|
|
1666
|
+
nk_each_blend_u8_neonhalf(a, b, n, alpha, beta, r);
|
|
1667
|
+
#elif NK_TARGET_RVV
|
|
1668
|
+
nk_each_blend_u8_rvv(a, b, n, alpha, beta, r);
|
|
1669
|
+
#else
|
|
1670
|
+
nk_each_blend_u8_serial(a, b, n, alpha, beta, r);
|
|
1671
|
+
#endif
|
|
1672
|
+
}
|
|
1673
|
+
|
|
1674
|
+
NK_PUBLIC void nk_each_blend_i16(nk_i16_t const *a, nk_i16_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
1675
|
+
nk_f32_t const *beta, nk_i16_t *r) {
|
|
1676
|
+
nk_each_blend_i16_serial(a, b, n, alpha, beta, r);
|
|
1677
|
+
}
|
|
1678
|
+
|
|
1679
|
+
NK_PUBLIC void nk_each_blend_u16(nk_u16_t const *a, nk_u16_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
1680
|
+
nk_f32_t const *beta, nk_u16_t *r) {
|
|
1681
|
+
nk_each_blend_u16_serial(a, b, n, alpha, beta, r);
|
|
1682
|
+
}
|
|
1683
|
+
|
|
1684
|
+
NK_PUBLIC void nk_each_blend_i32(nk_i32_t const *a, nk_i32_t const *b, nk_size_t n, nk_f64_t const *alpha,
|
|
1685
|
+
nk_f64_t const *beta, nk_i32_t *r) {
|
|
1686
|
+
nk_each_blend_i32_serial(a, b, n, alpha, beta, r);
|
|
1687
|
+
}
|
|
1688
|
+
|
|
1689
|
+
NK_PUBLIC void nk_each_blend_u32(nk_u32_t const *a, nk_u32_t const *b, nk_size_t n, nk_f64_t const *alpha,
|
|
1690
|
+
nk_f64_t const *beta, nk_u32_t *r) {
|
|
1691
|
+
nk_each_blend_u32_serial(a, b, n, alpha, beta, r);
|
|
1692
|
+
}
|
|
1693
|
+
|
|
1694
|
+
NK_PUBLIC void nk_each_blend_i64(nk_i64_t const *a, nk_i64_t const *b, nk_size_t n, nk_f64_t const *alpha,
|
|
1695
|
+
nk_f64_t const *beta, nk_i64_t *r) {
|
|
1696
|
+
nk_each_blend_i64_serial(a, b, n, alpha, beta, r);
|
|
1697
|
+
}
|
|
1698
|
+
|
|
1699
|
+
NK_PUBLIC void nk_each_blend_u64(nk_u64_t const *a, nk_u64_t const *b, nk_size_t n, nk_f64_t const *alpha,
|
|
1700
|
+
nk_f64_t const *beta, nk_u64_t *r) {
|
|
1701
|
+
nk_each_blend_u64_serial(a, b, n, alpha, beta, r);
|
|
1702
|
+
}
|
|
1703
|
+
|
|
1704
|
+
NK_PUBLIC void nk_each_fma_f64(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
1705
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_f64_t *r) {
|
|
1706
|
+
#if NK_TARGET_SKYLAKE
|
|
1707
|
+
nk_each_fma_f64_skylake(a, b, c, n, alpha, beta, r);
|
|
1708
|
+
#elif NK_TARGET_HASWELL
|
|
1709
|
+
nk_each_fma_f64_haswell(a, b, c, n, alpha, beta, r);
|
|
1710
|
+
#elif NK_TARGET_NEON
|
|
1711
|
+
nk_each_fma_f64_neon(a, b, c, n, alpha, beta, r);
|
|
1712
|
+
#elif NK_TARGET_RVV
|
|
1713
|
+
nk_each_fma_f64_rvv(a, b, c, n, alpha, beta, r);
|
|
1714
|
+
#else
|
|
1715
|
+
nk_each_fma_f64_serial(a, b, c, n, alpha, beta, r);
|
|
1716
|
+
#endif
|
|
1717
|
+
}
|
|
1718
|
+
|
|
1719
|
+
NK_PUBLIC void nk_each_fma_f32(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
1720
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_f32_t *r) {
|
|
1721
|
+
#if NK_TARGET_SKYLAKE
|
|
1722
|
+
nk_each_fma_f32_skylake(a, b, c, n, alpha, beta, r);
|
|
1723
|
+
#elif NK_TARGET_HASWELL
|
|
1724
|
+
nk_each_fma_f32_haswell(a, b, c, n, alpha, beta, r);
|
|
1725
|
+
#elif NK_TARGET_NEON
|
|
1726
|
+
nk_each_fma_f32_neon(a, b, c, n, alpha, beta, r);
|
|
1727
|
+
#elif NK_TARGET_RVV
|
|
1728
|
+
nk_each_fma_f32_rvv(a, b, c, n, alpha, beta, r);
|
|
1729
|
+
#else
|
|
1730
|
+
nk_each_fma_f32_serial(a, b, c, n, alpha, beta, r);
|
|
1731
|
+
#endif
|
|
1732
|
+
}
|
|
1733
|
+
|
|
1734
|
+
NK_PUBLIC void nk_each_fma_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
1735
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_bf16_t *r) {
|
|
1736
|
+
#if NK_TARGET_SKYLAKE
|
|
1737
|
+
nk_each_fma_bf16_skylake(a, b, c, n, alpha, beta, r);
|
|
1738
|
+
#elif NK_TARGET_HASWELL
|
|
1739
|
+
nk_each_fma_bf16_haswell(a, b, c, n, alpha, beta, r);
|
|
1740
|
+
#elif NK_TARGET_NEONBFDOT
|
|
1741
|
+
nk_each_fma_bf16_neonbfdot(a, b, c, n, alpha, beta, r);
|
|
1742
|
+
#elif NK_TARGET_RVV
|
|
1743
|
+
nk_each_fma_bf16_rvv(a, b, c, n, alpha, beta, r);
|
|
1744
|
+
#else
|
|
1745
|
+
nk_each_fma_bf16_serial(a, b, c, n, alpha, beta, r);
|
|
1746
|
+
#endif
|
|
1747
|
+
}
|
|
1748
|
+
|
|
1749
|
+
NK_PUBLIC void nk_each_fma_f16(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
1750
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_f16_t *r) {
|
|
1751
|
+
#if NK_TARGET_SKYLAKE
|
|
1752
|
+
nk_each_fma_f16_skylake(a, b, c, n, alpha, beta, r);
|
|
1753
|
+
#elif NK_TARGET_HASWELL
|
|
1754
|
+
nk_each_fma_f16_haswell(a, b, c, n, alpha, beta, r);
|
|
1755
|
+
#elif NK_TARGET_NEONHALF
|
|
1756
|
+
nk_each_fma_f16_neonhalf(a, b, c, n, alpha, beta, r);
|
|
1757
|
+
#elif NK_TARGET_RVV
|
|
1758
|
+
nk_each_fma_f16_rvv(a, b, c, n, alpha, beta, r);
|
|
1759
|
+
#else
|
|
1760
|
+
nk_each_fma_f16_serial(a, b, c, n, alpha, beta, r);
|
|
1761
|
+
#endif
|
|
1762
|
+
}
|
|
1763
|
+
|
|
1764
|
+
NK_PUBLIC void nk_each_fma_i8(nk_i8_t const *a, nk_i8_t const *b, nk_i8_t const *c, nk_size_t n, nk_f32_t const *alpha,
|
|
1765
|
+
nk_f32_t const *beta, nk_i8_t *r) {
|
|
1766
|
+
#if NK_TARGET_SAPPHIRE
|
|
1767
|
+
nk_each_fma_i8_sapphire(a, b, c, n, alpha, beta, r);
|
|
1768
|
+
#elif NK_TARGET_SKYLAKE
|
|
1769
|
+
nk_each_fma_i8_skylake(a, b, c, n, alpha, beta, r);
|
|
1770
|
+
#elif NK_TARGET_HASWELL
|
|
1771
|
+
nk_each_fma_i8_haswell(a, b, c, n, alpha, beta, r);
|
|
1772
|
+
#elif NK_TARGET_NEONHALF
|
|
1773
|
+
nk_each_fma_i8_neonhalf(a, b, c, n, alpha, beta, r);
|
|
1774
|
+
#elif NK_TARGET_RVV
|
|
1775
|
+
nk_each_fma_i8_rvv(a, b, c, n, alpha, beta, r);
|
|
1776
|
+
#else
|
|
1777
|
+
nk_each_fma_i8_serial(a, b, c, n, alpha, beta, r);
|
|
1778
|
+
#endif
|
|
1779
|
+
}
|
|
1780
|
+
|
|
1781
|
+
NK_PUBLIC void nk_each_fma_u8(nk_u8_t const *a, nk_u8_t const *b, nk_u8_t const *c, nk_size_t n, nk_f32_t const *alpha,
|
|
1782
|
+
nk_f32_t const *beta, nk_u8_t *r) {
|
|
1783
|
+
#if NK_TARGET_SAPPHIRE
|
|
1784
|
+
nk_each_fma_u8_sapphire(a, b, c, n, alpha, beta, r);
|
|
1785
|
+
#elif NK_TARGET_SKYLAKE
|
|
1786
|
+
nk_each_fma_u8_skylake(a, b, c, n, alpha, beta, r);
|
|
1787
|
+
#elif NK_TARGET_HASWELL
|
|
1788
|
+
nk_each_fma_u8_haswell(a, b, c, n, alpha, beta, r);
|
|
1789
|
+
#elif NK_TARGET_NEONHALF
|
|
1790
|
+
nk_each_fma_u8_neonhalf(a, b, c, n, alpha, beta, r);
|
|
1791
|
+
#elif NK_TARGET_RVV
|
|
1792
|
+
nk_each_fma_u8_rvv(a, b, c, n, alpha, beta, r);
|
|
1793
|
+
#else
|
|
1794
|
+
nk_each_fma_u8_serial(a, b, c, n, alpha, beta, r);
|
|
1795
|
+
#endif
|
|
1796
|
+
}
|
|
1797
|
+
|
|
1798
|
+
NK_PUBLIC void nk_each_fma_i16(nk_i16_t const *a, nk_i16_t const *b, nk_i16_t const *c, nk_size_t n,
|
|
1799
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_i16_t *r) {
|
|
1800
|
+
#if NK_TARGET_SKYLAKE
|
|
1801
|
+
nk_each_fma_i16_skylake(a, b, c, n, alpha, beta, r);
|
|
1802
|
+
#elif NK_TARGET_HASWELL
|
|
1803
|
+
nk_each_fma_i16_haswell(a, b, c, n, alpha, beta, r);
|
|
1804
|
+
#elif NK_TARGET_NEON
|
|
1805
|
+
nk_each_fma_i16_neon(a, b, c, n, alpha, beta, r);
|
|
1806
|
+
#elif NK_TARGET_RVV
|
|
1807
|
+
nk_each_fma_i16_rvv(a, b, c, n, alpha, beta, r);
|
|
1808
|
+
#else
|
|
1809
|
+
nk_each_fma_i16_serial(a, b, c, n, alpha, beta, r);
|
|
1810
|
+
#endif
|
|
1811
|
+
}
|
|
1812
|
+
|
|
1813
|
+
NK_PUBLIC void nk_each_fma_u16(nk_u16_t const *a, nk_u16_t const *b, nk_u16_t const *c, nk_size_t n,
|
|
1814
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_u16_t *r) {
|
|
1815
|
+
#if NK_TARGET_SKYLAKE
|
|
1816
|
+
nk_each_fma_u16_skylake(a, b, c, n, alpha, beta, r);
|
|
1817
|
+
#elif NK_TARGET_HASWELL
|
|
1818
|
+
nk_each_fma_u16_haswell(a, b, c, n, alpha, beta, r);
|
|
1819
|
+
#elif NK_TARGET_NEON
|
|
1820
|
+
nk_each_fma_u16_neon(a, b, c, n, alpha, beta, r);
|
|
1821
|
+
#elif NK_TARGET_RVV
|
|
1822
|
+
nk_each_fma_u16_rvv(a, b, c, n, alpha, beta, r);
|
|
1823
|
+
#else
|
|
1824
|
+
nk_each_fma_u16_serial(a, b, c, n, alpha, beta, r);
|
|
1825
|
+
#endif
|
|
1826
|
+
}
|
|
1827
|
+
|
|
1828
|
+
NK_PUBLIC void nk_each_fma_i32(nk_i32_t const *a, nk_i32_t const *b, nk_i32_t const *c, nk_size_t n,
|
|
1829
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_i32_t *r) {
|
|
1830
|
+
#if NK_TARGET_SKYLAKE
|
|
1831
|
+
nk_each_fma_i32_skylake(a, b, c, n, alpha, beta, r);
|
|
1832
|
+
#elif NK_TARGET_HASWELL
|
|
1833
|
+
nk_each_fma_i32_haswell(a, b, c, n, alpha, beta, r);
|
|
1834
|
+
#elif NK_TARGET_NEON
|
|
1835
|
+
nk_each_fma_i32_neon(a, b, c, n, alpha, beta, r);
|
|
1836
|
+
#elif NK_TARGET_RVV
|
|
1837
|
+
nk_each_fma_i32_rvv(a, b, c, n, alpha, beta, r);
|
|
1838
|
+
#else
|
|
1839
|
+
nk_each_fma_i32_serial(a, b, c, n, alpha, beta, r);
|
|
1840
|
+
#endif
|
|
1841
|
+
}
|
|
1842
|
+
|
|
1843
|
+
NK_PUBLIC void nk_each_fma_u32(nk_u32_t const *a, nk_u32_t const *b, nk_u32_t const *c, nk_size_t n,
|
|
1844
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_u32_t *r) {
|
|
1845
|
+
#if NK_TARGET_SKYLAKE
|
|
1846
|
+
nk_each_fma_u32_skylake(a, b, c, n, alpha, beta, r);
|
|
1847
|
+
#elif NK_TARGET_HASWELL
|
|
1848
|
+
nk_each_fma_u32_haswell(a, b, c, n, alpha, beta, r);
|
|
1849
|
+
#elif NK_TARGET_NEON
|
|
1850
|
+
nk_each_fma_u32_neon(a, b, c, n, alpha, beta, r);
|
|
1851
|
+
#elif NK_TARGET_RVV
|
|
1852
|
+
nk_each_fma_u32_rvv(a, b, c, n, alpha, beta, r);
|
|
1853
|
+
#else
|
|
1854
|
+
nk_each_fma_u32_serial(a, b, c, n, alpha, beta, r);
|
|
1855
|
+
#endif
|
|
1856
|
+
}
|
|
1857
|
+
|
|
1858
|
+
NK_PUBLIC void nk_each_fma_i64(nk_i64_t const *a, nk_i64_t const *b, nk_i64_t const *c, nk_size_t n,
|
|
1859
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_i64_t *r) {
|
|
1860
|
+
#if NK_TARGET_SKYLAKE
|
|
1861
|
+
nk_each_fma_i64_skylake(a, b, c, n, alpha, beta, r);
|
|
1862
|
+
#elif NK_TARGET_NEON
|
|
1863
|
+
nk_each_fma_i64_neon(a, b, c, n, alpha, beta, r);
|
|
1864
|
+
#elif NK_TARGET_RVV
|
|
1865
|
+
nk_each_fma_i64_rvv(a, b, c, n, alpha, beta, r);
|
|
1866
|
+
#else
|
|
1867
|
+
nk_each_fma_i64_serial(a, b, c, n, alpha, beta, r);
|
|
1868
|
+
#endif
|
|
1869
|
+
}
|
|
1870
|
+
|
|
1871
|
+
NK_PUBLIC void nk_each_fma_u64(nk_u64_t const *a, nk_u64_t const *b, nk_u64_t const *c, nk_size_t n,
|
|
1872
|
+
nk_f64_t const *alpha, nk_f64_t const *beta, nk_u64_t *r) {
|
|
1873
|
+
#if NK_TARGET_SKYLAKE
|
|
1874
|
+
nk_each_fma_u64_skylake(a, b, c, n, alpha, beta, r);
|
|
1875
|
+
#elif NK_TARGET_NEON
|
|
1876
|
+
nk_each_fma_u64_neon(a, b, c, n, alpha, beta, r);
|
|
1877
|
+
#elif NK_TARGET_RVV
|
|
1878
|
+
nk_each_fma_u64_rvv(a, b, c, n, alpha, beta, r);
|
|
1879
|
+
#else
|
|
1880
|
+
nk_each_fma_u64_serial(a, b, c, n, alpha, beta, r);
|
|
1881
|
+
#endif
|
|
1882
|
+
}
|
|
1883
|
+
|
|
1884
|
+
NK_PUBLIC void nk_each_sum_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_e4m3_t *result) {
|
|
1885
|
+
#if NK_TARGET_SAPPHIRE
|
|
1886
|
+
nk_each_sum_e4m3_sapphire(a, b, n, result);
|
|
1887
|
+
#elif NK_TARGET_SKYLAKE
|
|
1888
|
+
nk_each_sum_e4m3_skylake(a, b, n, result);
|
|
1889
|
+
#elif NK_TARGET_HASWELL
|
|
1890
|
+
nk_each_sum_e4m3_haswell(a, b, n, result);
|
|
1891
|
+
#elif NK_TARGET_NEON
|
|
1892
|
+
nk_each_sum_e4m3_neon(a, b, n, result);
|
|
1893
|
+
#elif NK_TARGET_RVV
|
|
1894
|
+
nk_each_sum_e4m3_rvv(a, b, n, result);
|
|
1895
|
+
#else
|
|
1896
|
+
nk_each_sum_e4m3_serial(a, b, n, result);
|
|
1897
|
+
#endif
|
|
1898
|
+
}
|
|
1899
|
+
|
|
1900
|
+
NK_PUBLIC void nk_each_sum_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_e5m2_t *result) {
|
|
1901
|
+
#if NK_TARGET_SKYLAKE
|
|
1902
|
+
nk_each_sum_e5m2_skylake(a, b, n, result);
|
|
1903
|
+
#elif NK_TARGET_HASWELL
|
|
1904
|
+
nk_each_sum_e5m2_haswell(a, b, n, result);
|
|
1905
|
+
#elif NK_TARGET_NEON
|
|
1906
|
+
nk_each_sum_e5m2_neon(a, b, n, result);
|
|
1907
|
+
#elif NK_TARGET_RVV
|
|
1908
|
+
nk_each_sum_e5m2_rvv(a, b, n, result);
|
|
1909
|
+
#else
|
|
1910
|
+
nk_each_sum_e5m2_serial(a, b, n, result);
|
|
1911
|
+
#endif
|
|
1912
|
+
}
|
|
1913
|
+
|
|
1914
|
+
NK_PUBLIC void nk_each_scale_e4m3(nk_e4m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1915
|
+
nk_e4m3_t *result) {
|
|
1916
|
+
#if NK_TARGET_SKYLAKE
|
|
1917
|
+
nk_each_scale_e4m3_skylake(a, n, alpha, beta, result);
|
|
1918
|
+
#elif NK_TARGET_HASWELL
|
|
1919
|
+
nk_each_scale_e4m3_haswell(a, n, alpha, beta, result);
|
|
1920
|
+
#elif NK_TARGET_NEON
|
|
1921
|
+
nk_each_scale_e4m3_neon(a, n, alpha, beta, result);
|
|
1922
|
+
#elif NK_TARGET_RVV
|
|
1923
|
+
nk_each_scale_e4m3_rvv(a, n, alpha, beta, result);
|
|
1924
|
+
#else
|
|
1925
|
+
nk_each_scale_e4m3_serial(a, n, alpha, beta, result);
|
|
1926
|
+
#endif
|
|
1927
|
+
}
|
|
1928
|
+
|
|
1929
|
+
NK_PUBLIC void nk_each_scale_e5m2(nk_e5m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
1930
|
+
nk_e5m2_t *result) {
|
|
1931
|
+
#if NK_TARGET_SKYLAKE
|
|
1932
|
+
nk_each_scale_e5m2_skylake(a, n, alpha, beta, result);
|
|
1933
|
+
#elif NK_TARGET_HASWELL
|
|
1934
|
+
nk_each_scale_e5m2_haswell(a, n, alpha, beta, result);
|
|
1935
|
+
#elif NK_TARGET_NEON
|
|
1936
|
+
nk_each_scale_e5m2_neon(a, n, alpha, beta, result);
|
|
1937
|
+
#elif NK_TARGET_RVV
|
|
1938
|
+
nk_each_scale_e5m2_rvv(a, n, alpha, beta, result);
|
|
1939
|
+
#else
|
|
1940
|
+
nk_each_scale_e5m2_serial(a, n, alpha, beta, result);
|
|
1941
|
+
#endif
|
|
1942
|
+
}
|
|
1943
|
+
|
|
1944
|
+
NK_PUBLIC void nk_each_blend_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
1945
|
+
nk_f32_t const *beta, nk_e4m3_t *result) {
|
|
1946
|
+
#if NK_TARGET_SKYLAKE
|
|
1947
|
+
nk_each_blend_e4m3_skylake(a, b, n, alpha, beta, result);
|
|
1948
|
+
#elif NK_TARGET_HASWELL
|
|
1949
|
+
nk_each_blend_e4m3_haswell(a, b, n, alpha, beta, result);
|
|
1950
|
+
#elif NK_TARGET_NEON
|
|
1951
|
+
nk_each_blend_e4m3_neon(a, b, n, alpha, beta, result);
|
|
1952
|
+
#elif NK_TARGET_RVV
|
|
1953
|
+
nk_each_blend_e4m3_rvv(a, b, n, alpha, beta, result);
|
|
1954
|
+
#else
|
|
1955
|
+
nk_each_blend_e4m3_serial(a, b, n, alpha, beta, result);
|
|
1956
|
+
#endif
|
|
1957
|
+
}
|
|
1958
|
+
|
|
1959
|
+
NK_PUBLIC void nk_each_blend_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
1960
|
+
nk_f32_t const *beta, nk_e5m2_t *result) {
|
|
1961
|
+
#if NK_TARGET_SKYLAKE
|
|
1962
|
+
nk_each_blend_e5m2_skylake(a, b, n, alpha, beta, result);
|
|
1963
|
+
#elif NK_TARGET_HASWELL
|
|
1964
|
+
nk_each_blend_e5m2_haswell(a, b, n, alpha, beta, result);
|
|
1965
|
+
#elif NK_TARGET_NEON
|
|
1966
|
+
nk_each_blend_e5m2_neon(a, b, n, alpha, beta, result);
|
|
1967
|
+
#elif NK_TARGET_RVV
|
|
1968
|
+
nk_each_blend_e5m2_rvv(a, b, n, alpha, beta, result);
|
|
1969
|
+
#else
|
|
1970
|
+
nk_each_blend_e5m2_serial(a, b, n, alpha, beta, result);
|
|
1971
|
+
#endif
|
|
1972
|
+
}
|
|
1973
|
+
|
|
1974
|
+
NK_PUBLIC void nk_each_fma_e4m3(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_e4m3_t const *c, nk_size_t n,
|
|
1975
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e4m3_t *result) {
|
|
1976
|
+
#if NK_TARGET_SKYLAKE
|
|
1977
|
+
nk_each_fma_e4m3_skylake(a, b, c, n, alpha, beta, result);
|
|
1978
|
+
#elif NK_TARGET_HASWELL
|
|
1979
|
+
nk_each_fma_e4m3_haswell(a, b, c, n, alpha, beta, result);
|
|
1980
|
+
#elif NK_TARGET_NEON
|
|
1981
|
+
nk_each_fma_e4m3_neon(a, b, c, n, alpha, beta, result);
|
|
1982
|
+
#elif NK_TARGET_RVV
|
|
1983
|
+
nk_each_fma_e4m3_rvv(a, b, c, n, alpha, beta, result);
|
|
1984
|
+
#else
|
|
1985
|
+
nk_each_fma_e4m3_serial(a, b, c, n, alpha, beta, result);
|
|
1986
|
+
#endif
|
|
1987
|
+
}
|
|
1988
|
+
|
|
1989
|
+
NK_PUBLIC void nk_each_fma_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_e5m2_t const *c, nk_size_t n,
|
|
1990
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e5m2_t *result) {
|
|
1991
|
+
#if NK_TARGET_SKYLAKE
|
|
1992
|
+
nk_each_fma_e5m2_skylake(a, b, c, n, alpha, beta, result);
|
|
1993
|
+
#elif NK_TARGET_HASWELL
|
|
1994
|
+
nk_each_fma_e5m2_haswell(a, b, c, n, alpha, beta, result);
|
|
1995
|
+
#elif NK_TARGET_NEON
|
|
1996
|
+
nk_each_fma_e5m2_neon(a, b, c, n, alpha, beta, result);
|
|
1997
|
+
#elif NK_TARGET_RVV
|
|
1998
|
+
nk_each_fma_e5m2_rvv(a, b, c, n, alpha, beta, result);
|
|
1999
|
+
#else
|
|
2000
|
+
nk_each_fma_e5m2_serial(a, b, c, n, alpha, beta, result);
|
|
2001
|
+
#endif
|
|
2002
|
+
}
|
|
2003
|
+
|
|
2004
|
+
NK_PUBLIC void nk_each_sum_e2m3(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_e2m3_t *result) {
|
|
2005
|
+
nk_each_sum_e2m3_serial(a, b, n, result);
|
|
2006
|
+
}
|
|
2007
|
+
|
|
2008
|
+
NK_PUBLIC void nk_each_sum_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_e3m2_t *result) {
|
|
2009
|
+
nk_each_sum_e3m2_serial(a, b, n, result);
|
|
2010
|
+
}
|
|
2011
|
+
|
|
2012
|
+
NK_PUBLIC void nk_each_scale_e2m3(nk_e2m3_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
2013
|
+
nk_e2m3_t *result) {
|
|
2014
|
+
nk_each_scale_e2m3_serial(a, n, alpha, beta, result);
|
|
2015
|
+
}
|
|
2016
|
+
|
|
2017
|
+
NK_PUBLIC void nk_each_scale_e3m2(nk_e3m2_t const *a, nk_size_t n, nk_f32_t const *alpha, nk_f32_t const *beta,
|
|
2018
|
+
nk_e3m2_t *result) {
|
|
2019
|
+
nk_each_scale_e3m2_serial(a, n, alpha, beta, result);
|
|
2020
|
+
}
|
|
2021
|
+
|
|
2022
|
+
NK_PUBLIC void nk_each_blend_e2m3(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
2023
|
+
nk_f32_t const *beta, nk_e2m3_t *result) {
|
|
2024
|
+
nk_each_blend_e2m3_serial(a, b, n, alpha, beta, result);
|
|
2025
|
+
}
|
|
2026
|
+
|
|
2027
|
+
NK_PUBLIC void nk_each_blend_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t const *alpha,
|
|
2028
|
+
nk_f32_t const *beta, nk_e3m2_t *result) {
|
|
2029
|
+
nk_each_blend_e3m2_serial(a, b, n, alpha, beta, result);
|
|
2030
|
+
}
|
|
2031
|
+
|
|
2032
|
+
NK_PUBLIC void nk_each_fma_e2m3(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_e2m3_t const *c, nk_size_t n,
|
|
2033
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e2m3_t *result) {
|
|
2034
|
+
nk_each_fma_e2m3_serial(a, b, c, n, alpha, beta, result);
|
|
2035
|
+
}
|
|
2036
|
+
|
|
2037
|
+
NK_PUBLIC void nk_each_fma_e3m2(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_e3m2_t const *c, nk_size_t n,
|
|
2038
|
+
nk_f32_t const *alpha, nk_f32_t const *beta, nk_e3m2_t *result) {
|
|
2039
|
+
nk_each_fma_e3m2_serial(a, b, c, n, alpha, beta, result);
|
|
2040
|
+
}
|
|
2041
|
+
|
|
2042
|
+
NK_PUBLIC void nk_each_sum_f32c(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t *r) {
|
|
2043
|
+
nk_each_sum_f32c_serial(a, b, n, r);
|
|
2044
|
+
}
|
|
2045
|
+
|
|
2046
|
+
NK_PUBLIC void nk_each_sum_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t *r) {
|
|
2047
|
+
nk_each_sum_f64c_serial(a, b, n, r);
|
|
2048
|
+
}
|
|
2049
|
+
|
|
2050
|
+
NK_PUBLIC void nk_each_scale_f32c(nk_f32c_t const *a, nk_size_t n, nk_f32c_t const *alpha, nk_f32c_t const *beta,
|
|
2051
|
+
nk_f32c_t *r) {
|
|
2052
|
+
#if NK_TARGET_SKYLAKE
|
|
2053
|
+
nk_each_scale_f32c_skylake(a, n, alpha, beta, r);
|
|
2054
|
+
#elif NK_TARGET_HASWELL
|
|
2055
|
+
nk_each_scale_f32c_haswell(a, n, alpha, beta, r);
|
|
2056
|
+
#elif NK_TARGET_NEON
|
|
2057
|
+
nk_each_scale_f32c_neon(a, n, alpha, beta, r);
|
|
2058
|
+
#elif NK_TARGET_RVV
|
|
2059
|
+
nk_each_scale_f32c_rvv(a, n, alpha, beta, r);
|
|
2060
|
+
#else
|
|
2061
|
+
nk_each_scale_f32c_serial(a, n, alpha, beta, r);
|
|
2062
|
+
#endif
|
|
2063
|
+
}
|
|
2064
|
+
|
|
2065
|
+
NK_PUBLIC void nk_each_scale_f64c(nk_f64c_t const *a, nk_size_t n, nk_f64c_t const *alpha, nk_f64c_t const *beta,
|
|
2066
|
+
nk_f64c_t *r) {
|
|
2067
|
+
#if NK_TARGET_SKYLAKE
|
|
2068
|
+
nk_each_scale_f64c_skylake(a, n, alpha, beta, r);
|
|
2069
|
+
#elif NK_TARGET_HASWELL
|
|
2070
|
+
nk_each_scale_f64c_haswell(a, n, alpha, beta, r);
|
|
2071
|
+
#elif NK_TARGET_NEON
|
|
2072
|
+
nk_each_scale_f64c_neon(a, n, alpha, beta, r);
|
|
2073
|
+
#elif NK_TARGET_RVV
|
|
2074
|
+
nk_each_scale_f64c_rvv(a, n, alpha, beta, r);
|
|
2075
|
+
#else
|
|
2076
|
+
nk_each_scale_f64c_serial(a, n, alpha, beta, r);
|
|
2077
|
+
#endif
|
|
2078
|
+
}
|
|
2079
|
+
|
|
2080
|
+
NK_PUBLIC void nk_each_blend_f32c(nk_f32c_t const *a, nk_f32c_t const *b, nk_size_t n, nk_f32c_t const *alpha,
|
|
2081
|
+
nk_f32c_t const *beta, nk_f32c_t *r) {
|
|
2082
|
+
#if NK_TARGET_SKYLAKE
|
|
2083
|
+
nk_each_blend_f32c_skylake(a, b, n, alpha, beta, r);
|
|
2084
|
+
#elif NK_TARGET_HASWELL
|
|
2085
|
+
nk_each_blend_f32c_haswell(a, b, n, alpha, beta, r);
|
|
2086
|
+
#elif NK_TARGET_NEON
|
|
2087
|
+
nk_each_blend_f32c_neon(a, b, n, alpha, beta, r);
|
|
2088
|
+
#elif NK_TARGET_RVV
|
|
2089
|
+
nk_each_blend_f32c_rvv(a, b, n, alpha, beta, r);
|
|
2090
|
+
#else
|
|
2091
|
+
nk_each_blend_f32c_serial(a, b, n, alpha, beta, r);
|
|
2092
|
+
#endif
|
|
2093
|
+
}
|
|
2094
|
+
|
|
2095
|
+
NK_PUBLIC void nk_each_blend_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_size_t n, nk_f64c_t const *alpha,
|
|
2096
|
+
nk_f64c_t const *beta, nk_f64c_t *r) {
|
|
2097
|
+
#if NK_TARGET_SKYLAKE
|
|
2098
|
+
nk_each_blend_f64c_skylake(a, b, n, alpha, beta, r);
|
|
2099
|
+
#elif NK_TARGET_HASWELL
|
|
2100
|
+
nk_each_blend_f64c_haswell(a, b, n, alpha, beta, r);
|
|
2101
|
+
#elif NK_TARGET_NEON
|
|
2102
|
+
nk_each_blend_f64c_neon(a, b, n, alpha, beta, r);
|
|
2103
|
+
#elif NK_TARGET_RVV
|
|
2104
|
+
nk_each_blend_f64c_rvv(a, b, n, alpha, beta, r);
|
|
2105
|
+
#else
|
|
2106
|
+
nk_each_blend_f64c_serial(a, b, n, alpha, beta, r);
|
|
2107
|
+
#endif
|
|
2108
|
+
}
|
|
2109
|
+
|
|
2110
|
+
NK_PUBLIC void nk_each_fma_f32c(nk_f32c_t const *a, nk_f32c_t const *b, nk_f32c_t const *c, nk_size_t n,
|
|
2111
|
+
nk_f32c_t const *alpha, nk_f32c_t const *beta, nk_f32c_t *r) {
|
|
2112
|
+
#if NK_TARGET_SKYLAKE
|
|
2113
|
+
nk_each_fma_f32c_skylake(a, b, c, n, alpha, beta, r);
|
|
2114
|
+
#elif NK_TARGET_HASWELL
|
|
2115
|
+
nk_each_fma_f32c_haswell(a, b, c, n, alpha, beta, r);
|
|
2116
|
+
#elif NK_TARGET_NEON
|
|
2117
|
+
nk_each_fma_f32c_neon(a, b, c, n, alpha, beta, r);
|
|
2118
|
+
#elif NK_TARGET_RVV
|
|
2119
|
+
nk_each_fma_f32c_rvv(a, b, c, n, alpha, beta, r);
|
|
2120
|
+
#else
|
|
2121
|
+
nk_each_fma_f32c_serial(a, b, c, n, alpha, beta, r);
|
|
2122
|
+
#endif
|
|
2123
|
+
}
|
|
2124
|
+
|
|
2125
|
+
NK_PUBLIC void nk_each_fma_f64c(nk_f64c_t const *a, nk_f64c_t const *b, nk_f64c_t const *c, nk_size_t n,
|
|
2126
|
+
nk_f64c_t const *alpha, nk_f64c_t const *beta, nk_f64c_t *r) {
|
|
2127
|
+
#if NK_TARGET_SKYLAKE
|
|
2128
|
+
nk_each_fma_f64c_skylake(a, b, c, n, alpha, beta, r);
|
|
2129
|
+
#elif NK_TARGET_HASWELL
|
|
2130
|
+
nk_each_fma_f64c_haswell(a, b, c, n, alpha, beta, r);
|
|
2131
|
+
#elif NK_TARGET_NEON
|
|
2132
|
+
nk_each_fma_f64c_neon(a, b, c, n, alpha, beta, r);
|
|
2133
|
+
#elif NK_TARGET_RVV
|
|
2134
|
+
nk_each_fma_f64c_rvv(a, b, c, n, alpha, beta, r);
|
|
2135
|
+
#else
|
|
2136
|
+
nk_each_fma_f64c_serial(a, b, c, n, alpha, beta, r);
|
|
2137
|
+
#endif
|
|
2138
|
+
}
|
|
2139
|
+
|
|
2140
|
+
#endif // !NK_DYNAMIC_DISPATCH
|
|
2141
|
+
|
|
2142
|
+
#if defined(__cplusplus)
|
|
2143
|
+
} // extern "C"
|
|
2144
|
+
#endif
|
|
2145
|
+
|
|
2146
|
+
#endif // NK_EACH_H
|