numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
Binary file
|
|
@@ -1,212 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* @brief SIMD-accelerated Curved Space Similarity for NEON FP16.
|
|
3
|
-
* @file include/numkong/curved/neonhalf.h
|
|
4
|
-
* @author Ash Vardanian
|
|
5
|
-
* @date January 14, 2026
|
|
6
|
-
*
|
|
7
|
-
* @sa include/numkong/curved.h
|
|
8
|
-
*
|
|
9
|
-
* Implements f16 bilinear forms and Mahalanobis distance using ARM NEON with FP16 extensions.
|
|
10
|
-
*
|
|
11
|
-
* @section curved_neonhalf_instructions ARM NEON FP16 Instructions (ARMv8.2-FP16)
|
|
12
|
-
*
|
|
13
|
-
* Intrinsic Instruction Latency Throughput
|
|
14
|
-
* A76 M4+/V1+/Oryon
|
|
15
|
-
* vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy 2/cy 4/cy
|
|
16
|
-
* vcvt_f32_f16 FCVTL (V.4S, V.4H) 3cy 2/cy 4/cy
|
|
17
|
-
* vld1_f16 LD1 (V.4H) 4cy 2/cy 3/cy
|
|
18
|
-
* vsubq_f32 FSUB (V.4S, V.4S, V.4S) 2cy 2/cy 4/cy
|
|
19
|
-
* vaddvq_f32 FADDP+FADDP (V.4S) 4cy 1/cy 2/cy
|
|
20
|
-
*
|
|
21
|
-
* Bilinear forms involve nested summation O(n^2) operations. For numerical stability,
|
|
22
|
-
* f16 inputs are widened to f32 for accumulation. The matrix C is accessed row-by-row
|
|
23
|
-
* to maintain cache locality.
|
|
24
|
-
*
|
|
25
|
-
* Mathematical definitions:
|
|
26
|
-
* - Bilinear: result = ∑ᵢ ∑ⱼ aᵢ × cᵢⱼ × bⱼ
|
|
27
|
-
* - Mahalanobis: result = √((a - b)ᵀ × C × (a - b))
|
|
28
|
-
*/
|
|
29
|
-
#ifndef NK_CURVED_NEONHALF_H
|
|
30
|
-
#define NK_CURVED_NEONHALF_H
|
|
31
|
-
|
|
32
|
-
#if NK_TARGET_ARM_
|
|
33
|
-
#if NK_TARGET_NEONHALF
|
|
34
|
-
|
|
35
|
-
#include "numkong/types.h"
|
|
36
|
-
#include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`
|
|
37
|
-
#include "numkong/cast/serial.h" // `nk_f16_to_f32_serial`
|
|
38
|
-
|
|
39
|
-
#if defined(__cplusplus)
|
|
40
|
-
extern "C" {
|
|
41
|
-
#endif
|
|
42
|
-
|
|
43
|
-
#if defined(__clang__)
|
|
44
|
-
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)
|
|
45
|
-
#elif defined(__GNUC__)
|
|
46
|
-
#pragma GCC push_options
|
|
47
|
-
#pragma GCC target("arch=armv8.2-a+simd+fp16")
|
|
48
|
-
#endif
|
|
49
|
-
|
|
50
|
-
NK_PUBLIC void nk_bilinear_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
51
|
-
nk_f32_t *result) {
|
|
52
|
-
nk_f32_t outer_sum = 0;
|
|
53
|
-
|
|
54
|
-
// Process rows of the matrix
|
|
55
|
-
for (nk_size_t row = 0; row != n; ++row) {
|
|
56
|
-
nk_f16_t const *c_row = c + row * n;
|
|
57
|
-
|
|
58
|
-
// Load a[row] as f32
|
|
59
|
-
nk_f32_t a_row;
|
|
60
|
-
nk_f16_to_f32_serial(a + row, &a_row);
|
|
61
|
-
|
|
62
|
-
// Compute inner sum
|
|
63
|
-
float32x4_t inner_sum_f32x4 = vdupq_n_f32(0);
|
|
64
|
-
nk_size_t column = 0;
|
|
65
|
-
|
|
66
|
-
// Process 4 elements at a time
|
|
67
|
-
for (; column + 4 <= n; column += 4) {
|
|
68
|
-
float32x4_t b_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)(b + column)));
|
|
69
|
-
float32x4_t c_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)(c_row + column)));
|
|
70
|
-
inner_sum_f32x4 = vfmaq_f32(inner_sum_f32x4, c_f32x4, b_f32x4);
|
|
71
|
-
}
|
|
72
|
-
|
|
73
|
-
// Reduce SIMD accumulator
|
|
74
|
-
nk_f32_t inner_sum = vaddvq_f32(inner_sum_f32x4);
|
|
75
|
-
|
|
76
|
-
// Handle tail elements with scalar code
|
|
77
|
-
for (; column < n; ++column) {
|
|
78
|
-
nk_f32_t b_val, c_val;
|
|
79
|
-
nk_f16_to_f32_serial(b + column, &b_val);
|
|
80
|
-
nk_f16_to_f32_serial(c_row + column, &c_val);
|
|
81
|
-
inner_sum += c_val * b_val;
|
|
82
|
-
}
|
|
83
|
-
|
|
84
|
-
// Multiply by a[row] and accumulate
|
|
85
|
-
outer_sum += a_row * inner_sum;
|
|
86
|
-
}
|
|
87
|
-
|
|
88
|
-
*result = outer_sum;
|
|
89
|
-
}
|
|
90
|
-
|
|
91
|
-
NK_PUBLIC void nk_mahalanobis_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
92
|
-
nk_f32_t *result) {
|
|
93
|
-
nk_f32_t outer_sum = 0;
|
|
94
|
-
|
|
95
|
-
// Process rows of the matrix
|
|
96
|
-
for (nk_size_t row = 0; row != n; ++row) {
|
|
97
|
-
nk_f16_t const *c_row = c + row * n;
|
|
98
|
-
|
|
99
|
-
// Compute diff_row = a[row] - b[row] in f32
|
|
100
|
-
nk_f32_t a_row, b_row;
|
|
101
|
-
nk_f16_to_f32_serial(a + row, &a_row);
|
|
102
|
-
nk_f16_to_f32_serial(b + row, &b_row);
|
|
103
|
-
nk_f32_t diff_row = a_row - b_row;
|
|
104
|
-
|
|
105
|
-
// Compute inner sum
|
|
106
|
-
float32x4_t inner_sum_f32x4 = vdupq_n_f32(0);
|
|
107
|
-
nk_size_t column = 0;
|
|
108
|
-
|
|
109
|
-
// Process 4 elements at a time
|
|
110
|
-
for (; column + 4 <= n; column += 4) {
|
|
111
|
-
float32x4_t a_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)(a + column)));
|
|
112
|
-
float32x4_t b_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)(b + column)));
|
|
113
|
-
float32x4_t c_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)(c_row + column)));
|
|
114
|
-
float32x4_t diff_column_f32x4 = vsubq_f32(a_f32x4, b_f32x4);
|
|
115
|
-
inner_sum_f32x4 = vfmaq_f32(inner_sum_f32x4, c_f32x4, diff_column_f32x4);
|
|
116
|
-
}
|
|
117
|
-
|
|
118
|
-
// Reduce SIMD accumulator
|
|
119
|
-
nk_f32_t inner_sum = vaddvq_f32(inner_sum_f32x4);
|
|
120
|
-
|
|
121
|
-
// Handle tail elements with scalar code
|
|
122
|
-
for (; column < n; ++column) {
|
|
123
|
-
nk_f32_t a_val, b_val, c_val;
|
|
124
|
-
nk_f16_to_f32_serial(a + column, &a_val);
|
|
125
|
-
nk_f16_to_f32_serial(b + column, &b_val);
|
|
126
|
-
nk_f16_to_f32_serial(c_row + column, &c_val);
|
|
127
|
-
inner_sum += c_val * (a_val - b_val);
|
|
128
|
-
}
|
|
129
|
-
|
|
130
|
-
// Multiply by diff_row and accumulate
|
|
131
|
-
outer_sum += diff_row * inner_sum;
|
|
132
|
-
}
|
|
133
|
-
|
|
134
|
-
nk_f32_t quadratic = outer_sum;
|
|
135
|
-
*result = nk_f32_sqrt_neon(quadratic > 0 ? quadratic : 0);
|
|
136
|
-
}
|
|
137
|
-
|
|
138
|
-
NK_PUBLIC void nk_bilinear_f16c_neonhalf(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_f16c_t const *c_pairs,
|
|
139
|
-
nk_size_t n, nk_f32c_t *results) {
|
|
140
|
-
nk_f32_t outer_sum_real = 0;
|
|
141
|
-
nk_f32_t outer_sum_imag = 0;
|
|
142
|
-
|
|
143
|
-
// Process rows of the matrix
|
|
144
|
-
for (nk_size_t row = 0; row != n; ++row) {
|
|
145
|
-
nk_f16c_t const *c_row = c_pairs + row * n;
|
|
146
|
-
|
|
147
|
-
// Load a[row] complex value
|
|
148
|
-
nk_f32_t a_real, a_imag;
|
|
149
|
-
nk_f16_to_f32_serial(&(a_pairs + row)->real, &a_real);
|
|
150
|
-
nk_f16_to_f32_serial(&(a_pairs + row)->imag, &a_imag);
|
|
151
|
-
|
|
152
|
-
// Compute inner sum
|
|
153
|
-
float32x4_t inner_sum_real_f32x4 = vdupq_n_f32(0);
|
|
154
|
-
float32x4_t inner_sum_imag_f32x4 = vdupq_n_f32(0);
|
|
155
|
-
nk_size_t column = 0;
|
|
156
|
-
|
|
157
|
-
// Process 4 complex pairs at a time using deinterleaved loads
|
|
158
|
-
for (; column + 4 <= n; column += 4) {
|
|
159
|
-
// Deinterleave real/imaginary using vld2_s16 pattern from dot/neonhalf.h
|
|
160
|
-
int16x4x2_t b_i16x4x2 = vld2_s16((short const *)(b_pairs + column));
|
|
161
|
-
int16x4x2_t c_i16x4x2 = vld2_s16((short const *)(c_row + column));
|
|
162
|
-
float32x4_t b_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[0]));
|
|
163
|
-
float32x4_t b_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[1]));
|
|
164
|
-
float32x4_t c_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(c_i16x4x2.val[0]));
|
|
165
|
-
float32x4_t c_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(c_i16x4x2.val[1]));
|
|
166
|
-
|
|
167
|
-
// Complex multiply
|
|
168
|
-
inner_sum_real_f32x4 = vfmaq_f32(inner_sum_real_f32x4, c_real_f32x4, b_real_f32x4);
|
|
169
|
-
inner_sum_real_f32x4 = vfmsq_f32(inner_sum_real_f32x4, c_imag_f32x4, b_imag_f32x4);
|
|
170
|
-
inner_sum_imag_f32x4 = vfmaq_f32(inner_sum_imag_f32x4, c_real_f32x4, b_imag_f32x4);
|
|
171
|
-
inner_sum_imag_f32x4 = vfmaq_f32(inner_sum_imag_f32x4, c_imag_f32x4, b_real_f32x4);
|
|
172
|
-
}
|
|
173
|
-
|
|
174
|
-
// Reduce SIMD accumulators
|
|
175
|
-
nk_f32_t inner_sum_real = vaddvq_f32(inner_sum_real_f32x4);
|
|
176
|
-
nk_f32_t inner_sum_imag = vaddvq_f32(inner_sum_imag_f32x4);
|
|
177
|
-
|
|
178
|
-
// Handle tail elements with scalar code
|
|
179
|
-
for (; column < n; ++column) {
|
|
180
|
-
nk_f32_t b_real, b_imag, c_real, c_imag;
|
|
181
|
-
nk_f16_to_f32_serial(&(b_pairs + column)->real, &b_real);
|
|
182
|
-
nk_f16_to_f32_serial(&(b_pairs + column)->imag, &b_imag);
|
|
183
|
-
nk_f16_to_f32_serial(&(c_row + column)->real, &c_real);
|
|
184
|
-
nk_f16_to_f32_serial(&(c_row + column)->imag, &c_imag);
|
|
185
|
-
|
|
186
|
-
// Complex multiply
|
|
187
|
-
inner_sum_real += c_real * b_real - c_imag * b_imag;
|
|
188
|
-
inner_sum_imag += c_real * b_imag + c_imag * b_real;
|
|
189
|
-
}
|
|
190
|
-
|
|
191
|
-
// Complex multiply
|
|
192
|
-
outer_sum_real += a_real * inner_sum_real - a_imag * inner_sum_imag;
|
|
193
|
-
outer_sum_imag += a_real * inner_sum_imag + a_imag * inner_sum_real;
|
|
194
|
-
}
|
|
195
|
-
|
|
196
|
-
results->real = outer_sum_real;
|
|
197
|
-
results->imag = outer_sum_imag;
|
|
198
|
-
}
|
|
199
|
-
|
|
200
|
-
#if defined(__clang__)
|
|
201
|
-
#pragma clang attribute pop
|
|
202
|
-
#elif defined(__GNUC__)
|
|
203
|
-
#pragma GCC pop_options
|
|
204
|
-
#endif
|
|
205
|
-
|
|
206
|
-
#if defined(__cplusplus)
|
|
207
|
-
} // extern "C"
|
|
208
|
-
#endif
|
|
209
|
-
|
|
210
|
-
#endif // NK_TARGET_NEONHALF
|
|
211
|
-
#endif // NK_TARGET_ARM_
|
|
212
|
-
#endif // NK_CURVED_NEONHALF_H
|
|
@@ -1,198 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* @brief SIMD-accelerated Dot Products for NEON FP16.
|
|
3
|
-
* @file include/numkong/dot/neonhalf.h
|
|
4
|
-
* @author Ash Vardanian
|
|
5
|
-
* @date December 27, 2025
|
|
6
|
-
*
|
|
7
|
-
* @sa include/numkong/dot.h
|
|
8
|
-
*
|
|
9
|
-
* @section dot_neonhalf_instructions ARM NEON FP16 Instructions (ARMv8.2-FP16)
|
|
10
|
-
*
|
|
11
|
-
* Intrinsic Instruction Latency Throughput
|
|
12
|
-
* A76 M4+/V1+/Oryon
|
|
13
|
-
* vfmaq_f16 FMLA (V.8H, V.8H, V.8H) 4cy 2/cy 4/cy
|
|
14
|
-
* vcvt_f32_f16 FCVTL (V.4S, V.4H) 3cy 2/cy 4/cy
|
|
15
|
-
* vld1q_f16 LD1 (V.8H) 4cy 2/cy 3/cy
|
|
16
|
-
* vaddvq_f32 FADDP+FADDP (V.4S) 4cy 1/cy 2/cy
|
|
17
|
-
* vfmsq_f16 FMLS (V.8H, V.8H, V.8H) 4cy 2/cy 4/cy
|
|
18
|
-
*
|
|
19
|
-
* The ARMv8.2-FP16 extension enables native half-precision arithmetic, doubling the element count
|
|
20
|
-
* per vector register (8x F16 vs 4x F32). This doubles theoretical throughput for bandwidth-bound
|
|
21
|
-
* workloads while halving memory footprint.
|
|
22
|
-
*
|
|
23
|
-
* For dot products, inputs are widened from F16 to F32 for accumulation to preserve numerical
|
|
24
|
-
* precision. The FCVTL instruction handles this widening, allowing the FMA operations
|
|
25
|
-
* to maintain full F32 precision in the accumulator.
|
|
26
|
-
*
|
|
27
|
-
* @section dot_neonhalf_stateful Stateful Streaming Logic
|
|
28
|
-
*
|
|
29
|
-
* To build memory-optimal tiled algorithms, this file defines following structures and force-inlined
|
|
30
|
-
* `NK_INTERNAL` functions:
|
|
31
|
-
*
|
|
32
|
-
* - nk_dot_f16x4 state with f16 inputs widened to f32 for accumulation.
|
|
33
|
-
*
|
|
34
|
-
* @code{c}
|
|
35
|
-
* nk_dot_f16x4_state_neonhalf_t state_first, state_second, state_third, state_fourth;
|
|
36
|
-
* float16x4_t query_f16x4, target_first_f16x4, target_second_f16x4, target_third_f16x4, target_fourth_f16x4;
|
|
37
|
-
* nk_dot_f16x4_init_neonhalf(&state_first);
|
|
38
|
-
* nk_dot_f16x4_init_neonhalf(&state_second);
|
|
39
|
-
* nk_dot_f16x4_init_neonhalf(&state_third);
|
|
40
|
-
* nk_dot_f16x4_init_neonhalf(&state_fourth);
|
|
41
|
-
* for (nk_size_t idx = 0; idx + 4 <= depth; idx += 4) {
|
|
42
|
-
* query_f16x4 = vld1_f16(query_ptr + idx);
|
|
43
|
-
* target_first_f16x4 = vld1_f16(target_first_ptr + idx);
|
|
44
|
-
* target_second_f16x4 = vld1_f16(target_second_ptr + idx);
|
|
45
|
-
* target_third_f16x4 = vld1_f16(target_third_ptr + idx);
|
|
46
|
-
* target_fourth_f16x4 = vld1_f16(target_fourth_ptr + idx);
|
|
47
|
-
* nk_dot_f16x4_update_neonhalf(&state_first, query_f16x4, target_first_f16x4, idx, 4);
|
|
48
|
-
* nk_dot_f16x4_update_neonhalf(&state_second, query_f16x4, target_second_f16x4, idx, 4);
|
|
49
|
-
* nk_dot_f16x4_update_neonhalf(&state_third, query_f16x4, target_third_f16x4, idx, 4);
|
|
50
|
-
* nk_dot_f16x4_update_neonhalf(&state_fourth, query_f16x4, target_fourth_f16x4, idx, 4);
|
|
51
|
-
* }
|
|
52
|
-
* float32x4_t results_f32x4;
|
|
53
|
-
* nk_dot_f16x4_finalize_neonhalf(&state_first, &state_second, &state_third, &state_fourth, depth, &results_f32x4);
|
|
54
|
-
* @endcode
|
|
55
|
-
*/
|
|
56
|
-
#ifndef NK_DOT_NEONHALF_H
|
|
57
|
-
#define NK_DOT_NEONHALF_H
|
|
58
|
-
|
|
59
|
-
#if NK_TARGET_ARM_
|
|
60
|
-
#if NK_TARGET_NEONHALF
|
|
61
|
-
|
|
62
|
-
#include "numkong/types.h"
|
|
63
|
-
#include "numkong/cast/serial.h" // `nk_partial_load_b16x4_serial_`
|
|
64
|
-
|
|
65
|
-
#if defined(__cplusplus)
|
|
66
|
-
extern "C" {
|
|
67
|
-
#endif
|
|
68
|
-
|
|
69
|
-
#if defined(__clang__)
|
|
70
|
-
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)
|
|
71
|
-
#elif defined(__GNUC__)
|
|
72
|
-
#pragma GCC push_options
|
|
73
|
-
#pragma GCC target("arch=armv8.2-a+simd+fp16")
|
|
74
|
-
#endif
|
|
75
|
-
|
|
76
|
-
NK_PUBLIC void nk_dot_f16_neonhalf(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
|
|
77
|
-
nk_f32_t *result) {
|
|
78
|
-
float32x4_t a_f32x4, b_f32x4;
|
|
79
|
-
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
80
|
-
nk_dot_f16_neonhalf_cycle:
|
|
81
|
-
if (count_scalars < 4) {
|
|
82
|
-
nk_b64_vec_t a_vec, b_vec;
|
|
83
|
-
nk_partial_load_b16x4_serial_(a_scalars, &a_vec, count_scalars);
|
|
84
|
-
nk_partial_load_b16x4_serial_(b_scalars, &b_vec, count_scalars);
|
|
85
|
-
a_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(a_vec.u16x4));
|
|
86
|
-
b_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(b_vec.u16x4));
|
|
87
|
-
count_scalars = 0;
|
|
88
|
-
}
|
|
89
|
-
else {
|
|
90
|
-
a_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)a_scalars));
|
|
91
|
-
b_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)b_scalars));
|
|
92
|
-
a_scalars += 4, b_scalars += 4, count_scalars -= 4;
|
|
93
|
-
}
|
|
94
|
-
sum_f32x4 = vfmaq_f32(sum_f32x4, a_f32x4, b_f32x4);
|
|
95
|
-
if (count_scalars) goto nk_dot_f16_neonhalf_cycle;
|
|
96
|
-
*result = vaddvq_f32(sum_f32x4);
|
|
97
|
-
}
|
|
98
|
-
|
|
99
|
-
NK_PUBLIC void nk_dot_f16c_neonhalf(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
|
|
100
|
-
nk_f32c_t *result) {
|
|
101
|
-
float32x4_t sum_real_f32x4 = vdupq_n_f32(0);
|
|
102
|
-
float32x4_t sum_imag_f32x4 = vdupq_n_f32(0);
|
|
103
|
-
while (count_pairs >= 4) {
|
|
104
|
-
// Unpack the input arrays into real and imaginary parts.
|
|
105
|
-
// MSVC sadly doesn't recognize the `vld2_f16`, so we load the data as signed
|
|
106
|
-
// integers of the same size and reinterpret with `vreinterpret_f16_s16` afterwards.
|
|
107
|
-
int16x4x2_t a_i16x4x2 = vld2_s16((short *)a_pairs);
|
|
108
|
-
int16x4x2_t b_i16x4x2 = vld2_s16((short *)b_pairs);
|
|
109
|
-
float32x4_t a_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(a_i16x4x2.val[0]));
|
|
110
|
-
float32x4_t a_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(a_i16x4x2.val[1]));
|
|
111
|
-
float32x4_t b_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[0]));
|
|
112
|
-
float32x4_t b_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[1]));
|
|
113
|
-
sum_real_f32x4 = vfmaq_f32(sum_real_f32x4, a_real_f32x4, b_real_f32x4);
|
|
114
|
-
sum_real_f32x4 = vfmsq_f32(sum_real_f32x4, a_imag_f32x4, b_imag_f32x4);
|
|
115
|
-
sum_imag_f32x4 = vfmaq_f32(sum_imag_f32x4, a_real_f32x4, b_imag_f32x4);
|
|
116
|
-
sum_imag_f32x4 = vfmaq_f32(sum_imag_f32x4, a_imag_f32x4, b_real_f32x4);
|
|
117
|
-
count_pairs -= 4, a_pairs += 4, b_pairs += 4;
|
|
118
|
-
}
|
|
119
|
-
// Reduce horizontal sums and aggregate with the tail:
|
|
120
|
-
nk_f32c_t tail_result;
|
|
121
|
-
nk_dot_f16c_serial(a_pairs, b_pairs, count_pairs, &tail_result);
|
|
122
|
-
result->real = tail_result.real + vaddvq_f32(sum_real_f32x4);
|
|
123
|
-
result->imag = tail_result.imag + vaddvq_f32(sum_imag_f32x4);
|
|
124
|
-
}
|
|
125
|
-
|
|
126
|
-
NK_PUBLIC void nk_vdot_f16c_neonhalf(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_size_t count_pairs,
|
|
127
|
-
nk_f32c_t *result) {
|
|
128
|
-
float32x4_t sum_real_f32x4 = vdupq_n_f32(0);
|
|
129
|
-
float32x4_t sum_imag_f32x4 = vdupq_n_f32(0);
|
|
130
|
-
while (count_pairs >= 4) {
|
|
131
|
-
// Unpack the input arrays into real and imaginary parts.
|
|
132
|
-
// MSVC sadly doesn't recognize the `vld2_f16`, so we load the data as signed
|
|
133
|
-
// integers of the same size and reinterpret with `vreinterpret_f16_s16` afterwards.
|
|
134
|
-
int16x4x2_t a_i16x4x2 = vld2_s16((short *)a_pairs);
|
|
135
|
-
int16x4x2_t b_i16x4x2 = vld2_s16((short *)b_pairs);
|
|
136
|
-
float32x4_t a_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(a_i16x4x2.val[0]));
|
|
137
|
-
float32x4_t a_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(a_i16x4x2.val[1]));
|
|
138
|
-
float32x4_t b_real_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[0]));
|
|
139
|
-
float32x4_t b_imag_f32x4 = vcvt_f32_f16(vreinterpret_f16_s16(b_i16x4x2.val[1]));
|
|
140
|
-
sum_real_f32x4 = vfmaq_f32(sum_real_f32x4, a_real_f32x4, b_real_f32x4);
|
|
141
|
-
sum_real_f32x4 = vfmaq_f32(sum_real_f32x4, a_imag_f32x4, b_imag_f32x4);
|
|
142
|
-
sum_imag_f32x4 = vfmaq_f32(sum_imag_f32x4, a_real_f32x4, b_imag_f32x4);
|
|
143
|
-
sum_imag_f32x4 = vfmsq_f32(sum_imag_f32x4, a_imag_f32x4, b_real_f32x4);
|
|
144
|
-
count_pairs -= 4, a_pairs += 4, b_pairs += 4;
|
|
145
|
-
}
|
|
146
|
-
// Reduce horizontal sums and aggregate with the tail:
|
|
147
|
-
nk_f32c_t tail_result;
|
|
148
|
-
nk_vdot_f16c_serial(a_pairs, b_pairs, count_pairs, &tail_result);
|
|
149
|
-
result->real = tail_result.real + vaddvq_f32(sum_real_f32x4);
|
|
150
|
-
result->imag = tail_result.imag + vaddvq_f32(sum_imag_f32x4);
|
|
151
|
-
}
|
|
152
|
-
|
|
153
|
-
/**
|
|
154
|
-
* @brief Running state for 64-bit dot accumulation over f16 scalars on NEON with FP16 extension.
|
|
155
|
-
*
|
|
156
|
-
* Processes 4 f16 values at a time (64 bits), converting directly to f32 without
|
|
157
|
-
* the overhead of vget_low/vget_high operations on 128-bit vectors.
|
|
158
|
-
*/
|
|
159
|
-
typedef struct nk_dot_f16x4_state_neonhalf_t {
|
|
160
|
-
float32x4_t sum_f32x4;
|
|
161
|
-
} nk_dot_f16x4_state_neonhalf_t;
|
|
162
|
-
|
|
163
|
-
NK_INTERNAL void nk_dot_f16x4_init_neonhalf(nk_dot_f16x4_state_neonhalf_t *state) { state->sum_f32x4 = vdupq_n_f32(0); }
|
|
164
|
-
|
|
165
|
-
NK_INTERNAL void nk_dot_f16x4_update_neonhalf(nk_dot_f16x4_state_neonhalf_t *state, nk_b64_vec_t a, nk_b64_vec_t b,
|
|
166
|
-
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
167
|
-
nk_unused_(depth_offset);
|
|
168
|
-
nk_unused_(active_dimensions);
|
|
169
|
-
// 4 f16s = 64 bits, direct conversion without low/high split
|
|
170
|
-
float16x4_t a_f16x4 = vreinterpret_f16_u16(a.u16x4);
|
|
171
|
-
float16x4_t b_f16x4 = vreinterpret_f16_u16(b.u16x4);
|
|
172
|
-
state->sum_f32x4 = vfmaq_f32(state->sum_f32x4, vcvt_f32_f16(a_f16x4), vcvt_f32_f16(b_f16x4));
|
|
173
|
-
}
|
|
174
|
-
|
|
175
|
-
NK_INTERNAL void nk_dot_f16x4_finalize_neonhalf( //
|
|
176
|
-
nk_dot_f16x4_state_neonhalf_t const *state_a, nk_dot_f16x4_state_neonhalf_t const *state_b, //
|
|
177
|
-
nk_dot_f16x4_state_neonhalf_t const *state_c, nk_dot_f16x4_state_neonhalf_t const *state_d, //
|
|
178
|
-
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
179
|
-
nk_unused_(total_dimensions);
|
|
180
|
-
result->f32s[0] = vaddvq_f32(state_a->sum_f32x4);
|
|
181
|
-
result->f32s[1] = vaddvq_f32(state_b->sum_f32x4);
|
|
182
|
-
result->f32s[2] = vaddvq_f32(state_c->sum_f32x4);
|
|
183
|
-
result->f32s[3] = vaddvq_f32(state_d->sum_f32x4);
|
|
184
|
-
}
|
|
185
|
-
|
|
186
|
-
#if defined(__clang__)
|
|
187
|
-
#pragma clang attribute pop
|
|
188
|
-
#elif defined(__GNUC__)
|
|
189
|
-
#pragma GCC pop_options
|
|
190
|
-
#endif
|
|
191
|
-
|
|
192
|
-
#if defined(__cplusplus)
|
|
193
|
-
} // extern "C"
|
|
194
|
-
#endif
|
|
195
|
-
|
|
196
|
-
#endif // NK_TARGET_NEONHALF
|
|
197
|
-
#endif // NK_TARGET_ARM_
|
|
198
|
-
#endif // NK_DOT_NEONHALF_H
|
|
@@ -1,57 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* @brief SIMD-accelerated Batched Dot Products for NEON FP16.
|
|
3
|
-
* @file include/numkong/dots/neonhalf.h
|
|
4
|
-
* @author Ash Vardanian
|
|
5
|
-
* @date December 27, 2025
|
|
6
|
-
*
|
|
7
|
-
* @sa include/numkong/dots.h
|
|
8
|
-
*/
|
|
9
|
-
#ifndef NK_DOTS_NEONHALF_H
|
|
10
|
-
#define NK_DOTS_NEONHALF_H
|
|
11
|
-
|
|
12
|
-
#if NK_TARGET_ARM_
|
|
13
|
-
#if NK_TARGET_NEONHALF
|
|
14
|
-
|
|
15
|
-
#include "numkong/dot/neonhalf.h"
|
|
16
|
-
|
|
17
|
-
#if defined(__cplusplus)
|
|
18
|
-
extern "C" {
|
|
19
|
-
#endif
|
|
20
|
-
|
|
21
|
-
#if defined(__clang__)
|
|
22
|
-
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)
|
|
23
|
-
#elif defined(__GNUC__)
|
|
24
|
-
#pragma GCC push_options
|
|
25
|
-
#pragma GCC target("arch=armv8.2-a+simd+fp16")
|
|
26
|
-
#endif
|
|
27
|
-
|
|
28
|
-
/* F16 GEMM: depth_simd_dimensions=4 (4 f16s = 8 bytes = 64-bit input for direct f32 conversion) */
|
|
29
|
-
nk_define_cross_pack_size_(dots, f16, neonhalf, f16, f16, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/4,
|
|
30
|
-
/*dimensions_per_value=*/1)
|
|
31
|
-
nk_define_cross_pack_(dots, f16, neonhalf, f16, f16, nk_assign_from_to_, /*norm_value_type=*/f32,
|
|
32
|
-
nk_dots_reduce_sumsq_f16_, /*depth_simd_dimensions=*/4,
|
|
33
|
-
/*dimensions_per_value=*/1)
|
|
34
|
-
nk_define_cross_symmetric_(dots, f16, neonhalf, f16, f32, nk_b64_vec_t, nk_dot_f16x4_state_neonhalf_t, nk_b128_vec_t,
|
|
35
|
-
nk_dot_f16x4_init_neonhalf, nk_load_b64_neon_, nk_partial_load_b16x4_serial_,
|
|
36
|
-
nk_dot_f16x4_update_neonhalf, nk_dot_f16x4_finalize_neonhalf, nk_store_b128_neon_,
|
|
37
|
-
nk_partial_store_b32x4_serial_,
|
|
38
|
-
/*depth_simd_dimensions=*/4, /*dimensions_per_value=*/1)
|
|
39
|
-
nk_define_cross_packed_(dots, f16, neonhalf, f16, f16, f32, nk_b64_vec_t, nk_dot_f16x4_state_neonhalf_t, nk_b128_vec_t,
|
|
40
|
-
nk_dot_f16x4_init_neonhalf, nk_load_b64_neon_, nk_partial_load_b16x4_serial_, nk_load_b64_neon_,
|
|
41
|
-
nk_partial_load_b16x4_serial_, nk_dot_f16x4_update_neonhalf, nk_dot_f16x4_finalize_neonhalf,
|
|
42
|
-
nk_store_b128_neon_, nk_partial_store_b32x4_serial_,
|
|
43
|
-
/*depth_simd_dimensions=*/4, /*dimensions_per_value=*/1)
|
|
44
|
-
|
|
45
|
-
#if defined(__clang__)
|
|
46
|
-
#pragma clang attribute pop
|
|
47
|
-
#elif defined(__GNUC__)
|
|
48
|
-
#pragma GCC pop_options
|
|
49
|
-
#endif
|
|
50
|
-
|
|
51
|
-
#if defined(__cplusplus)
|
|
52
|
-
} // extern "C"
|
|
53
|
-
#endif
|
|
54
|
-
|
|
55
|
-
#endif // NK_TARGET_NEONHALF
|
|
56
|
-
#endif // NK_TARGET_ARM_
|
|
57
|
-
#endif // NK_DOTS_NEONHALF_H
|