numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -10,15 +10,14 @@
|
|
|
10
10
|
*
|
|
11
11
|
* ARM NEON instructions for distance computations:
|
|
12
12
|
*
|
|
13
|
-
* Intrinsic
|
|
14
|
-
*
|
|
15
|
-
*
|
|
16
|
-
*
|
|
17
|
-
*
|
|
18
|
-
*
|
|
19
|
-
*
|
|
20
|
-
*
|
|
21
|
-
* vrecpeq_f32 FRECPE (V.4S, V.4S) 2cy 2/cy 2/cy
|
|
13
|
+
* Intrinsic Instruction A76 M5
|
|
14
|
+
* vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy @ 2p 3cy @ 4p
|
|
15
|
+
* vmulq_f32 FMUL (V.4S, V.4S, V.4S) 3cy @ 2p 3cy @ 4p
|
|
16
|
+
* vaddq_f32 FADD (V.4S, V.4S, V.4S) 2cy @ 2p 2cy @ 4p
|
|
17
|
+
* vsubq_f32 FSUB (V.4S, V.4S, V.4S) 2cy @ 2p 2cy @ 4p
|
|
18
|
+
* vrsqrteq_f32 FRSQRTE (V.4S, V.4S) 2cy @ 2p 3cy @ 1p
|
|
19
|
+
* vsqrtq_f32 FSQRT (V.4S, V.4S) 12cy @ 1p 9cy @ 1p
|
|
20
|
+
* vrecpeq_f32 FRECPE (V.4S, V.4S) 2cy @ 2p 3cy @ 1p
|
|
22
21
|
*
|
|
23
22
|
* FRSQRTE provides ~8-bit precision; two Newton-Raphson iterations via vrsqrtsq_f32 achieve
|
|
24
23
|
* ~23-bit precision, sufficient for f32. This is much faster than FSQRT (0.25/cy).
|
|
@@ -55,10 +54,10 @@ extern "C" {
|
|
|
55
54
|
* Much faster than `vsqrtq_f32` (2 cy vs 9-12 cy latency, 2/cy vs 0.25/cy throughput).
|
|
56
55
|
*/
|
|
57
56
|
NK_INTERNAL float32x4_t nk_rsqrt_f32x4_neon_(float32x4_t x) {
|
|
58
|
-
float32x4_t
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
return
|
|
57
|
+
float32x4_t rsqrt_f32x4 = vrsqrteq_f32(x);
|
|
58
|
+
rsqrt_f32x4 = vmulq_f32(rsqrt_f32x4, vrsqrtsq_f32(vmulq_f32(x, rsqrt_f32x4), rsqrt_f32x4));
|
|
59
|
+
rsqrt_f32x4 = vmulq_f32(rsqrt_f32x4, vrsqrtsq_f32(vmulq_f32(x, rsqrt_f32x4), rsqrt_f32x4));
|
|
60
|
+
return rsqrt_f32x4;
|
|
62
61
|
}
|
|
63
62
|
|
|
64
63
|
/**
|
|
@@ -70,29 +69,29 @@ NK_INTERNAL float32x4_t nk_rsqrt_f32x4_neon_(float32x4_t x) {
|
|
|
70
69
|
* prefer `vsqrtq_f64` instead.
|
|
71
70
|
*/
|
|
72
71
|
NK_INTERNAL float64x2_t nk_rsqrt_f64x2_neon_(float64x2_t x) {
|
|
73
|
-
float64x2_t
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
return
|
|
72
|
+
float64x2_t rsqrt_f64x2 = vrsqrteq_f64(x);
|
|
73
|
+
rsqrt_f64x2 = vmulq_f64(rsqrt_f64x2, vrsqrtsq_f64(vmulq_f64(x, rsqrt_f64x2), rsqrt_f64x2));
|
|
74
|
+
rsqrt_f64x2 = vmulq_f64(rsqrt_f64x2, vrsqrtsq_f64(vmulq_f64(x, rsqrt_f64x2), rsqrt_f64x2));
|
|
75
|
+
rsqrt_f64x2 = vmulq_f64(rsqrt_f64x2, vrsqrtsq_f64(vmulq_f64(x, rsqrt_f64x2), rsqrt_f64x2));
|
|
76
|
+
return rsqrt_f64x2;
|
|
78
77
|
}
|
|
79
78
|
|
|
80
79
|
NK_INTERNAL nk_f32_t nk_angular_normalize_f32_neon_(nk_f32_t ab, nk_f32_t a2, nk_f32_t b2) {
|
|
81
80
|
if (a2 == 0 && b2 == 0) return 0;
|
|
82
81
|
if (ab == 0) return 1;
|
|
83
82
|
nk_f32_t squares_arr[2] = {a2, b2};
|
|
84
|
-
float32x2_t
|
|
83
|
+
float32x2_t squares_f32x2 = vld1_f32(squares_arr);
|
|
85
84
|
// Unlike x86, Arm NEON manuals don't explicitly mention the accuracy of their `rsqrt` approximation.
|
|
86
85
|
// Third-party research suggests that it's less accurate than SSE instructions, having an error of 1.5×2⁻¹².
|
|
87
86
|
// One or two rounds of Newton-Raphson refinement are recommended to improve the accuracy.
|
|
88
87
|
// https://github.com/lighttransport/embree-aarch64/issues/24
|
|
89
88
|
// https://github.com/lighttransport/embree-aarch64/blob/3f75f8cb4e553d13dced941b5fefd4c826835a6b/common/math/math.h#L137-L145
|
|
90
|
-
float32x2_t
|
|
89
|
+
float32x2_t rsqrts_f32x2 = vrsqrte_f32(squares_f32x2);
|
|
91
90
|
// Perform two rounds of Newton-Raphson refinement:
|
|
92
91
|
// https://en.wikipedia.org/wiki/Newton%27s_method
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
vst1_f32(squares_arr,
|
|
92
|
+
rsqrts_f32x2 = vmul_f32(rsqrts_f32x2, vrsqrts_f32(vmul_f32(squares_f32x2, rsqrts_f32x2), rsqrts_f32x2));
|
|
93
|
+
rsqrts_f32x2 = vmul_f32(rsqrts_f32x2, vrsqrts_f32(vmul_f32(squares_f32x2, rsqrts_f32x2), rsqrts_f32x2));
|
|
94
|
+
vst1_f32(squares_arr, rsqrts_f32x2);
|
|
96
95
|
nk_f32_t result = 1 - ab * squares_arr[0] * squares_arr[1];
|
|
97
96
|
return result > 0 ? result : 0;
|
|
98
97
|
}
|
|
@@ -101,25 +100,25 @@ NK_INTERNAL nk_f64_t nk_angular_normalize_f64_neon_(nk_f64_t ab, nk_f64_t a2, nk
|
|
|
101
100
|
if (a2 == 0 && b2 == 0) return 0;
|
|
102
101
|
if (ab == 0) return 1;
|
|
103
102
|
nk_f64_t squares_arr[2] = {a2, b2};
|
|
104
|
-
float64x2_t
|
|
103
|
+
float64x2_t squares_f64x2 = vld1q_f64(squares_arr);
|
|
105
104
|
|
|
106
105
|
// Unlike x86, Arm NEON manuals don't explicitly mention the accuracy of their `rsqrt` approximation.
|
|
107
106
|
// Third-party research suggests that it's less accurate than SSE instructions, having an error of 1.5×2⁻¹².
|
|
108
107
|
// One or two rounds of Newton-Raphson refinement are recommended to improve the accuracy.
|
|
109
108
|
// https://github.com/lighttransport/embree-aarch64/issues/24
|
|
110
109
|
// https://github.com/lighttransport/embree-aarch64/blob/3f75f8cb4e553d13dced941b5fefd4c826835a6b/common/math/math.h#L137-L145
|
|
111
|
-
float64x2_t rsqrts_f64x2 = vrsqrteq_f64(
|
|
110
|
+
float64x2_t rsqrts_f64x2 = vrsqrteq_f64(squares_f64x2);
|
|
112
111
|
// Perform three rounds of Newton-Raphson refinement for f64 precision (~48 bits):
|
|
113
112
|
// https://en.wikipedia.org/wiki/Newton%27s_method
|
|
114
|
-
rsqrts_f64x2 = vmulq_f64(rsqrts_f64x2, vrsqrtsq_f64(vmulq_f64(
|
|
115
|
-
rsqrts_f64x2 = vmulq_f64(rsqrts_f64x2, vrsqrtsq_f64(vmulq_f64(
|
|
116
|
-
rsqrts_f64x2 = vmulq_f64(rsqrts_f64x2, vrsqrtsq_f64(vmulq_f64(
|
|
113
|
+
rsqrts_f64x2 = vmulq_f64(rsqrts_f64x2, vrsqrtsq_f64(vmulq_f64(squares_f64x2, rsqrts_f64x2), rsqrts_f64x2));
|
|
114
|
+
rsqrts_f64x2 = vmulq_f64(rsqrts_f64x2, vrsqrtsq_f64(vmulq_f64(squares_f64x2, rsqrts_f64x2), rsqrts_f64x2));
|
|
115
|
+
rsqrts_f64x2 = vmulq_f64(rsqrts_f64x2, vrsqrtsq_f64(vmulq_f64(squares_f64x2, rsqrts_f64x2), rsqrts_f64x2));
|
|
117
116
|
vst1q_f64(squares_arr, rsqrts_f64x2);
|
|
118
117
|
nk_f64_t result = 1 - ab * squares_arr[0] * squares_arr[1];
|
|
119
118
|
return result > 0 ? result : 0;
|
|
120
119
|
}
|
|
121
120
|
|
|
122
|
-
#pragma region
|
|
121
|
+
#pragma region F32 and F64 Floats
|
|
123
122
|
|
|
124
123
|
NK_PUBLIC void nk_sqeuclidean_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
125
124
|
// Accumulate in f64 for numerical stability (2 f32s per iteration, avoids slow vget_low/high)
|
|
@@ -243,8 +242,8 @@ nk_angular_f64_neon_cycle:
|
|
|
243
242
|
nk_dot_stable_sum_f64x2_neon_(ab_sum_f64x2, ab_compensation_f64x2), vaddvq_f64(a2_f64x2), vaddvq_f64(b2_f64x2));
|
|
244
243
|
}
|
|
245
244
|
|
|
246
|
-
#pragma endregion
|
|
247
|
-
#pragma region
|
|
245
|
+
#pragma endregion F32 and F64 Floats
|
|
246
|
+
#pragma region F16 and BF16 Floats
|
|
248
247
|
|
|
249
248
|
NK_PUBLIC void nk_sqeuclidean_bf16_neon(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
250
249
|
uint16x8_t a_u16x8, b_u16x8;
|
|
@@ -264,9 +263,9 @@ nk_sqeuclidean_bf16_neon_cycle:
|
|
|
264
263
|
a += 8, b += 8, n -= 8;
|
|
265
264
|
}
|
|
266
265
|
float32x4_t a_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(a_u16x8), 16));
|
|
267
|
-
float32x4_t a_high_f32x4 = vreinterpretq_f32_u32(
|
|
266
|
+
float32x4_t a_high_f32x4 = vreinterpretq_f32_u32(vshll_high_n_u16(a_u16x8, 16));
|
|
268
267
|
float32x4_t b_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b_u16x8), 16));
|
|
269
|
-
float32x4_t b_high_f32x4 = vreinterpretq_f32_u32(
|
|
268
|
+
float32x4_t b_high_f32x4 = vreinterpretq_f32_u32(vshll_high_n_u16(b_u16x8, 16));
|
|
270
269
|
float32x4_t diff_low_f32x4 = vsubq_f32(a_low_f32x4, b_low_f32x4);
|
|
271
270
|
float32x4_t diff_high_f32x4 = vsubq_f32(a_high_f32x4, b_high_f32x4);
|
|
272
271
|
sum_f32x4 = vfmaq_f32(sum_f32x4, diff_low_f32x4, diff_low_f32x4);
|
|
@@ -300,9 +299,9 @@ nk_angular_bf16_neon_cycle:
|
|
|
300
299
|
a += 8, b += 8, n -= 8;
|
|
301
300
|
}
|
|
302
301
|
float32x4_t a_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(a_u16x8), 16));
|
|
303
|
-
float32x4_t a_high_f32x4 = vreinterpretq_f32_u32(
|
|
302
|
+
float32x4_t a_high_f32x4 = vreinterpretq_f32_u32(vshll_high_n_u16(a_u16x8, 16));
|
|
304
303
|
float32x4_t b_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b_u16x8), 16));
|
|
305
|
-
float32x4_t b_high_f32x4 = vreinterpretq_f32_u32(
|
|
304
|
+
float32x4_t b_high_f32x4 = vreinterpretq_f32_u32(vshll_high_n_u16(b_u16x8, 16));
|
|
306
305
|
ab_f32x4 = vfmaq_f32(ab_f32x4, a_low_f32x4, b_low_f32x4);
|
|
307
306
|
ab_f32x4 = vfmaq_f32(ab_f32x4, a_high_f32x4, b_high_f32x4);
|
|
308
307
|
a2_f32x4 = vfmaq_f32(a2_f32x4, a_low_f32x4, a_low_f32x4);
|
|
@@ -316,6 +315,80 @@ nk_angular_bf16_neon_cycle:
|
|
|
316
315
|
*result = nk_angular_normalize_f32_neon_(ab, a2, b2);
|
|
317
316
|
}
|
|
318
317
|
|
|
318
|
+
NK_PUBLIC void nk_sqeuclidean_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
319
|
+
uint16x8_t a_u16x8, b_u16x8;
|
|
320
|
+
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
321
|
+
nk_sqeuclidean_f16_neon_cycle:
|
|
322
|
+
if (n < 8) {
|
|
323
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
324
|
+
nk_partial_load_b16x8_serial_(a, &a_vec, n);
|
|
325
|
+
nk_partial_load_b16x8_serial_(b, &b_vec, n);
|
|
326
|
+
a_u16x8 = a_vec.u16x8;
|
|
327
|
+
b_u16x8 = b_vec.u16x8;
|
|
328
|
+
n = 0;
|
|
329
|
+
}
|
|
330
|
+
else {
|
|
331
|
+
a_u16x8 = vld1q_u16((nk_u16_t const *)a);
|
|
332
|
+
b_u16x8 = vld1q_u16((nk_u16_t const *)b);
|
|
333
|
+
a += 8, b += 8, n -= 8;
|
|
334
|
+
}
|
|
335
|
+
float16x8_t a_f16x8 = vreinterpretq_f16_u16(a_u16x8);
|
|
336
|
+
float16x8_t b_f16x8 = vreinterpretq_f16_u16(b_u16x8);
|
|
337
|
+
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
338
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
339
|
+
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
340
|
+
float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
341
|
+
float32x4_t diff_low_f32x4 = vsubq_f32(a_low_f32x4, b_low_f32x4);
|
|
342
|
+
float32x4_t diff_high_f32x4 = vsubq_f32(a_high_f32x4, b_high_f32x4);
|
|
343
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, diff_low_f32x4, diff_low_f32x4);
|
|
344
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, diff_high_f32x4, diff_high_f32x4);
|
|
345
|
+
if (n) goto nk_sqeuclidean_f16_neon_cycle;
|
|
346
|
+
*result = vaddvq_f32(sum_f32x4);
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
NK_PUBLIC void nk_euclidean_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
350
|
+
nk_sqeuclidean_f16_neon(a, b, n, result);
|
|
351
|
+
*result = nk_f32_sqrt_neon(*result);
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
NK_PUBLIC void nk_angular_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
355
|
+
uint16x8_t a_u16x8, b_u16x8;
|
|
356
|
+
float32x4_t ab_f32x4 = vdupq_n_f32(0);
|
|
357
|
+
float32x4_t a2_f32x4 = vdupq_n_f32(0);
|
|
358
|
+
float32x4_t b2_f32x4 = vdupq_n_f32(0);
|
|
359
|
+
nk_angular_f16_neon_cycle:
|
|
360
|
+
if (n < 8) {
|
|
361
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
362
|
+
nk_partial_load_b16x8_serial_(a, &a_vec, n);
|
|
363
|
+
nk_partial_load_b16x8_serial_(b, &b_vec, n);
|
|
364
|
+
a_u16x8 = a_vec.u16x8;
|
|
365
|
+
b_u16x8 = b_vec.u16x8;
|
|
366
|
+
n = 0;
|
|
367
|
+
}
|
|
368
|
+
else {
|
|
369
|
+
a_u16x8 = vld1q_u16((nk_u16_t const *)a);
|
|
370
|
+
b_u16x8 = vld1q_u16((nk_u16_t const *)b);
|
|
371
|
+
a += 8, b += 8, n -= 8;
|
|
372
|
+
}
|
|
373
|
+
float16x8_t a_f16x8 = vreinterpretq_f16_u16(a_u16x8);
|
|
374
|
+
float16x8_t b_f16x8 = vreinterpretq_f16_u16(b_u16x8);
|
|
375
|
+
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
376
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
377
|
+
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
378
|
+
float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
379
|
+
ab_f32x4 = vfmaq_f32(ab_f32x4, a_low_f32x4, b_low_f32x4);
|
|
380
|
+
ab_f32x4 = vfmaq_f32(ab_f32x4, a_high_f32x4, b_high_f32x4);
|
|
381
|
+
a2_f32x4 = vfmaq_f32(a2_f32x4, a_low_f32x4, a_low_f32x4);
|
|
382
|
+
a2_f32x4 = vfmaq_f32(a2_f32x4, a_high_f32x4, a_high_f32x4);
|
|
383
|
+
b2_f32x4 = vfmaq_f32(b2_f32x4, b_low_f32x4, b_low_f32x4);
|
|
384
|
+
b2_f32x4 = vfmaq_f32(b2_f32x4, b_high_f32x4, b_high_f32x4);
|
|
385
|
+
if (n) goto nk_angular_f16_neon_cycle;
|
|
386
|
+
nk_f32_t ab = vaddvq_f32(ab_f32x4);
|
|
387
|
+
nk_f32_t a2 = vaddvq_f32(a2_f32x4);
|
|
388
|
+
nk_f32_t b2 = vaddvq_f32(b2_f32x4);
|
|
389
|
+
*result = nk_angular_normalize_f32_neon_(ab, a2, b2);
|
|
390
|
+
}
|
|
391
|
+
|
|
319
392
|
NK_PUBLIC void nk_sqeuclidean_e2m3_neon(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
320
393
|
float16x8_t a_f16x8, b_f16x8;
|
|
321
394
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
@@ -334,9 +407,9 @@ nk_sqeuclidean_e2m3_neon_cycle:
|
|
|
334
407
|
a += 8, b += 8, n -= 8;
|
|
335
408
|
}
|
|
336
409
|
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
337
|
-
float32x4_t a_high_f32x4 =
|
|
410
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
338
411
|
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
339
|
-
float32x4_t b_high_f32x4 =
|
|
412
|
+
float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
340
413
|
float32x4_t diff_low_f32x4 = vsubq_f32(a_low_f32x4, b_low_f32x4);
|
|
341
414
|
float32x4_t diff_high_f32x4 = vsubq_f32(a_high_f32x4, b_high_f32x4);
|
|
342
415
|
sum_f32x4 = vfmaq_f32(sum_f32x4, diff_low_f32x4, diff_low_f32x4);
|
|
@@ -370,9 +443,9 @@ nk_angular_e2m3_neon_cycle:
|
|
|
370
443
|
a += 8, b += 8, n -= 8;
|
|
371
444
|
}
|
|
372
445
|
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
373
|
-
float32x4_t a_high_f32x4 =
|
|
446
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
374
447
|
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
375
|
-
float32x4_t b_high_f32x4 =
|
|
448
|
+
float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
376
449
|
ab_f32x4 = vfmaq_f32(ab_f32x4, a_low_f32x4, b_low_f32x4);
|
|
377
450
|
ab_f32x4 = vfmaq_f32(ab_f32x4, a_high_f32x4, b_high_f32x4);
|
|
378
451
|
a2_f32x4 = vfmaq_f32(a2_f32x4, a_low_f32x4, a_low_f32x4);
|
|
@@ -404,9 +477,9 @@ nk_sqeuclidean_e3m2_neon_cycle:
|
|
|
404
477
|
a += 8, b += 8, n -= 8;
|
|
405
478
|
}
|
|
406
479
|
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
407
|
-
float32x4_t a_high_f32x4 =
|
|
480
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
408
481
|
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
409
|
-
float32x4_t b_high_f32x4 =
|
|
482
|
+
float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
410
483
|
float32x4_t diff_low_f32x4 = vsubq_f32(a_low_f32x4, b_low_f32x4);
|
|
411
484
|
float32x4_t diff_high_f32x4 = vsubq_f32(a_high_f32x4, b_high_f32x4);
|
|
412
485
|
sum_f32x4 = vfmaq_f32(sum_f32x4, diff_low_f32x4, diff_low_f32x4);
|
|
@@ -440,9 +513,9 @@ nk_angular_e3m2_neon_cycle:
|
|
|
440
513
|
a += 8, b += 8, n -= 8;
|
|
441
514
|
}
|
|
442
515
|
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
443
|
-
float32x4_t a_high_f32x4 =
|
|
516
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
444
517
|
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
445
|
-
float32x4_t b_high_f32x4 =
|
|
518
|
+
float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
446
519
|
ab_f32x4 = vfmaq_f32(ab_f32x4, a_low_f32x4, b_low_f32x4);
|
|
447
520
|
ab_f32x4 = vfmaq_f32(ab_f32x4, a_high_f32x4, b_high_f32x4);
|
|
448
521
|
a2_f32x4 = vfmaq_f32(a2_f32x4, a_low_f32x4, a_low_f32x4);
|
|
@@ -474,9 +547,9 @@ nk_sqeuclidean_e4m3_neon_cycle:
|
|
|
474
547
|
a += 8, b += 8, n -= 8;
|
|
475
548
|
}
|
|
476
549
|
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
477
|
-
float32x4_t a_high_f32x4 =
|
|
550
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
478
551
|
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
479
|
-
float32x4_t b_high_f32x4 =
|
|
552
|
+
float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
480
553
|
float32x4_t diff_low_f32x4 = vsubq_f32(a_low_f32x4, b_low_f32x4);
|
|
481
554
|
float32x4_t diff_high_f32x4 = vsubq_f32(a_high_f32x4, b_high_f32x4);
|
|
482
555
|
sum_f32x4 = vfmaq_f32(sum_f32x4, diff_low_f32x4, diff_low_f32x4);
|
|
@@ -510,9 +583,9 @@ nk_angular_e4m3_neon_cycle:
|
|
|
510
583
|
a += 8, b += 8, n -= 8;
|
|
511
584
|
}
|
|
512
585
|
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
513
|
-
float32x4_t a_high_f32x4 =
|
|
586
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
514
587
|
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
515
|
-
float32x4_t b_high_f32x4 =
|
|
588
|
+
float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
516
589
|
ab_f32x4 = vfmaq_f32(ab_f32x4, a_low_f32x4, b_low_f32x4);
|
|
517
590
|
ab_f32x4 = vfmaq_f32(ab_f32x4, a_high_f32x4, b_high_f32x4);
|
|
518
591
|
a2_f32x4 = vfmaq_f32(a2_f32x4, a_low_f32x4, a_low_f32x4);
|
|
@@ -544,9 +617,9 @@ nk_sqeuclidean_e5m2_neon_cycle:
|
|
|
544
617
|
a += 8, b += 8, n -= 8;
|
|
545
618
|
}
|
|
546
619
|
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
547
|
-
float32x4_t a_high_f32x4 =
|
|
620
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
548
621
|
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
549
|
-
float32x4_t b_high_f32x4 =
|
|
622
|
+
float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
550
623
|
float32x4_t diff_low_f32x4 = vsubq_f32(a_low_f32x4, b_low_f32x4);
|
|
551
624
|
float32x4_t diff_high_f32x4 = vsubq_f32(a_high_f32x4, b_high_f32x4);
|
|
552
625
|
sum_f32x4 = vfmaq_f32(sum_f32x4, diff_low_f32x4, diff_low_f32x4);
|
|
@@ -580,9 +653,9 @@ nk_angular_e5m2_neon_cycle:
|
|
|
580
653
|
a += 8, b += 8, n -= 8;
|
|
581
654
|
}
|
|
582
655
|
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
583
|
-
float32x4_t a_high_f32x4 =
|
|
656
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
584
657
|
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
585
|
-
float32x4_t b_high_f32x4 =
|
|
658
|
+
float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
586
659
|
ab_f32x4 = vfmaq_f32(ab_f32x4, a_low_f32x4, b_low_f32x4);
|
|
587
660
|
ab_f32x4 = vfmaq_f32(ab_f32x4, a_high_f32x4, b_high_f32x4);
|
|
588
661
|
a2_f32x4 = vfmaq_f32(a2_f32x4, a_low_f32x4, a_low_f32x4);
|
|
@@ -767,7 +840,7 @@ NK_INTERNAL void nk_euclidean_through_u32_from_dot_neon_(nk_b128_vec_t dots, nk_
|
|
|
767
840
|
} // extern "C"
|
|
768
841
|
#endif
|
|
769
842
|
|
|
770
|
-
#pragma endregion
|
|
843
|
+
#pragma endregion F16 and BF16 Floats
|
|
771
844
|
#endif // NK_TARGET_NEON
|
|
772
845
|
#endif // NK_TARGET_ARM_
|
|
773
846
|
#endif // NK_SPATIAL_NEON_H
|
|
@@ -8,15 +8,14 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section spatial_neonbfdot_instructions ARM NEON BF16 Instructions (ARMv8.6-BF16)
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
*
|
|
13
|
-
*
|
|
14
|
-
*
|
|
15
|
-
*
|
|
16
|
-
*
|
|
17
|
-
*
|
|
18
|
-
*
|
|
19
|
-
* vaddvq_f64 FADDP (V.2D) 3cy 1/cy 2/cy
|
|
11
|
+
* Intrinsic Instruction A76 M5
|
|
12
|
+
* vbfdotq_f32 BFDOT (V.4S, V.8H, V.8H) 3cy @ 2p 2cy @ 1p
|
|
13
|
+
* vcvt_f32_bf16 BFCVTN (V.4H, V.4S) 3cy @ 2p 3cy @ 4p
|
|
14
|
+
* vld1q_bf16 LD1 (V.8H) 4cy @ 2p 4cy @ 3p
|
|
15
|
+
* vsubq_f32 FSUB (V.4S, V.4S, V.4S) 2cy @ 2p 2cy @ 4p
|
|
16
|
+
* vfmaq_f64 FMLA (V.2D, V.2D, V.2D) 4cy @ 2p 3cy @ 4p
|
|
17
|
+
* vaddvq_f32 FADDP+FADDP (V.4S) 5cy @ 1p 8cy @ 1p
|
|
18
|
+
* vaddvq_f64 FADDP (V.2D) 3cy @ 1p 3cy @ 2p
|
|
20
19
|
*
|
|
21
20
|
* The ARMv8.6-BF16 extension provides BFDOT for accelerated dot products on BF16 data, useful for
|
|
22
21
|
* angular distance (cosine similarity) computations. BF16's larger exponent range (matching FP32)
|
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Spatial Similarity Measures for NEON FP8DOT4.
|
|
3
|
+
* @file include/numkong/spatial/neonfp8.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 23, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/spatial.h
|
|
8
|
+
*
|
|
9
|
+
* For L2 distance, we use the identity: (a−b)² = a² + b² − 2 × a × b,
|
|
10
|
+
* computing all three terms via FP8DOT4 without FP8 subtraction.
|
|
11
|
+
* Angular distance uses three DOT4 accumulators (a·b, ‖a‖², ‖b‖²) in parallel.
|
|
12
|
+
*/
|
|
13
|
+
#ifndef NK_SPATIAL_NEONFP8_H
|
|
14
|
+
#define NK_SPATIAL_NEONFP8_H
|
|
15
|
+
|
|
16
|
+
#if NK_TARGET_ARM_
|
|
17
|
+
#if NK_TARGET_NEONFP8
|
|
18
|
+
|
|
19
|
+
#include "numkong/types.h"
|
|
20
|
+
#include "numkong/dot/neonfp8.h" // `nk_e2m3x16_to_e4m3x16_neonfp8_`, `nk_e3m2x16_to_e5m2x16_neonfp8_`
|
|
21
|
+
#include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`, `nk_angular_normalize_f32_neon_`
|
|
22
|
+
|
|
23
|
+
#if defined(__cplusplus)
|
|
24
|
+
extern "C" {
|
|
25
|
+
#endif
|
|
26
|
+
|
|
27
|
+
#if defined(__clang__)
|
|
28
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8-a+simd+fp8dot4"))), apply_to = function)
|
|
29
|
+
#elif defined(__GNUC__)
|
|
30
|
+
#pragma GCC push_options
|
|
31
|
+
#pragma GCC target("arch=armv8-a+simd+fp8dot4")
|
|
32
|
+
#endif
|
|
33
|
+
|
|
34
|
+
NK_PUBLIC void nk_sqeuclidean_e4m3_neonfp8(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
35
|
+
mfloat8x16_t a_mf8x16, b_mf8x16;
|
|
36
|
+
float32x4_t a2_f32x4 = vdupq_n_f32(0), ab_f32x4 = vdupq_n_f32(0), b2_f32x4 = vdupq_n_f32(0);
|
|
37
|
+
nk_sqeuclidean_e4m3_neonfp8_cycle:
|
|
38
|
+
if (n < 16) {
|
|
39
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
40
|
+
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
41
|
+
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
42
|
+
a_mf8x16 = vreinterpretq_mf8_u8(a_vec.u8x16);
|
|
43
|
+
b_mf8x16 = vreinterpretq_mf8_u8(b_vec.u8x16);
|
|
44
|
+
n = 0;
|
|
45
|
+
}
|
|
46
|
+
else {
|
|
47
|
+
a_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)a));
|
|
48
|
+
b_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)b));
|
|
49
|
+
a += 16, b += 16, n -= 16;
|
|
50
|
+
}
|
|
51
|
+
a2_f32x4 = vdotq_f32_mf8_fpm(a2_f32x4, a_mf8x16, a_mf8x16, NK_FPM_E4M3_);
|
|
52
|
+
ab_f32x4 = vdotq_f32_mf8_fpm(ab_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E4M3_);
|
|
53
|
+
b2_f32x4 = vdotq_f32_mf8_fpm(b2_f32x4, b_mf8x16, b_mf8x16, NK_FPM_E4M3_);
|
|
54
|
+
if (n) goto nk_sqeuclidean_e4m3_neonfp8_cycle;
|
|
55
|
+
*result = vaddvq_f32(a2_f32x4) - 2 * vaddvq_f32(ab_f32x4) + vaddvq_f32(b2_f32x4);
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
NK_PUBLIC void nk_euclidean_e4m3_neonfp8(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
59
|
+
nk_sqeuclidean_e4m3_neonfp8(a, b, n, result);
|
|
60
|
+
*result = nk_f32_sqrt_neon(*result);
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
NK_PUBLIC void nk_angular_e4m3_neonfp8(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
64
|
+
mfloat8x16_t a_mf8x16, b_mf8x16;
|
|
65
|
+
float32x4_t ab_f32x4 = vdupq_n_f32(0), a2_f32x4 = vdupq_n_f32(0), b2_f32x4 = vdupq_n_f32(0);
|
|
66
|
+
nk_angular_e4m3_neonfp8_cycle:
|
|
67
|
+
if (n < 16) {
|
|
68
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
69
|
+
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
70
|
+
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
71
|
+
a_mf8x16 = vreinterpretq_mf8_u8(a_vec.u8x16);
|
|
72
|
+
b_mf8x16 = vreinterpretq_mf8_u8(b_vec.u8x16);
|
|
73
|
+
n = 0;
|
|
74
|
+
}
|
|
75
|
+
else {
|
|
76
|
+
a_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)a));
|
|
77
|
+
b_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)b));
|
|
78
|
+
a += 16, b += 16, n -= 16;
|
|
79
|
+
}
|
|
80
|
+
ab_f32x4 = vdotq_f32_mf8_fpm(ab_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E4M3_);
|
|
81
|
+
a2_f32x4 = vdotq_f32_mf8_fpm(a2_f32x4, a_mf8x16, a_mf8x16, NK_FPM_E4M3_);
|
|
82
|
+
b2_f32x4 = vdotq_f32_mf8_fpm(b2_f32x4, b_mf8x16, b_mf8x16, NK_FPM_E4M3_);
|
|
83
|
+
if (n) goto nk_angular_e4m3_neonfp8_cycle;
|
|
84
|
+
*result = nk_angular_normalize_f32_neon_(vaddvq_f32(ab_f32x4), vaddvq_f32(a2_f32x4), vaddvq_f32(b2_f32x4));
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
NK_PUBLIC void nk_sqeuclidean_e5m2_neonfp8(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
88
|
+
mfloat8x16_t a_mf8x16, b_mf8x16;
|
|
89
|
+
float32x4_t a2_f32x4 = vdupq_n_f32(0), ab_f32x4 = vdupq_n_f32(0), b2_f32x4 = vdupq_n_f32(0);
|
|
90
|
+
nk_sqeuclidean_e5m2_neonfp8_cycle:
|
|
91
|
+
if (n < 16) {
|
|
92
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
93
|
+
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
94
|
+
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
95
|
+
a_mf8x16 = vreinterpretq_mf8_u8(a_vec.u8x16);
|
|
96
|
+
b_mf8x16 = vreinterpretq_mf8_u8(b_vec.u8x16);
|
|
97
|
+
n = 0;
|
|
98
|
+
}
|
|
99
|
+
else {
|
|
100
|
+
a_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)a));
|
|
101
|
+
b_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)b));
|
|
102
|
+
a += 16, b += 16, n -= 16;
|
|
103
|
+
}
|
|
104
|
+
a2_f32x4 = vdotq_f32_mf8_fpm(a2_f32x4, a_mf8x16, a_mf8x16, NK_FPM_E5M2_);
|
|
105
|
+
ab_f32x4 = vdotq_f32_mf8_fpm(ab_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E5M2_);
|
|
106
|
+
b2_f32x4 = vdotq_f32_mf8_fpm(b2_f32x4, b_mf8x16, b_mf8x16, NK_FPM_E5M2_);
|
|
107
|
+
if (n) goto nk_sqeuclidean_e5m2_neonfp8_cycle;
|
|
108
|
+
*result = vaddvq_f32(a2_f32x4) - 2 * vaddvq_f32(ab_f32x4) + vaddvq_f32(b2_f32x4);
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
NK_PUBLIC void nk_euclidean_e5m2_neonfp8(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
112
|
+
nk_sqeuclidean_e5m2_neonfp8(a, b, n, result);
|
|
113
|
+
*result = nk_f32_sqrt_neon(*result);
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
NK_PUBLIC void nk_angular_e5m2_neonfp8(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
117
|
+
mfloat8x16_t a_mf8x16, b_mf8x16;
|
|
118
|
+
float32x4_t ab_f32x4 = vdupq_n_f32(0), a2_f32x4 = vdupq_n_f32(0), b2_f32x4 = vdupq_n_f32(0);
|
|
119
|
+
nk_angular_e5m2_neonfp8_cycle:
|
|
120
|
+
if (n < 16) {
|
|
121
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
122
|
+
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
123
|
+
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
124
|
+
a_mf8x16 = vreinterpretq_mf8_u8(a_vec.u8x16);
|
|
125
|
+
b_mf8x16 = vreinterpretq_mf8_u8(b_vec.u8x16);
|
|
126
|
+
n = 0;
|
|
127
|
+
}
|
|
128
|
+
else {
|
|
129
|
+
a_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)a));
|
|
130
|
+
b_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)b));
|
|
131
|
+
a += 16, b += 16, n -= 16;
|
|
132
|
+
}
|
|
133
|
+
ab_f32x4 = vdotq_f32_mf8_fpm(ab_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E5M2_);
|
|
134
|
+
a2_f32x4 = vdotq_f32_mf8_fpm(a2_f32x4, a_mf8x16, a_mf8x16, NK_FPM_E5M2_);
|
|
135
|
+
b2_f32x4 = vdotq_f32_mf8_fpm(b2_f32x4, b_mf8x16, b_mf8x16, NK_FPM_E5M2_);
|
|
136
|
+
if (n) goto nk_angular_e5m2_neonfp8_cycle;
|
|
137
|
+
*result = nk_angular_normalize_f32_neon_(vaddvq_f32(ab_f32x4), vaddvq_f32(a2_f32x4), vaddvq_f32(b2_f32x4));
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
NK_PUBLIC void nk_sqeuclidean_e2m3_neonfp8(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
141
|
+
mfloat8x16_t a_mf8x16, b_mf8x16;
|
|
142
|
+
float32x4_t a2_f32x4 = vdupq_n_f32(0), ab_f32x4 = vdupq_n_f32(0), b2_f32x4 = vdupq_n_f32(0);
|
|
143
|
+
nk_sqeuclidean_e2m3_neonfp8_cycle:
|
|
144
|
+
if (n < 16) {
|
|
145
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
146
|
+
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
147
|
+
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
148
|
+
a_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(a_vec.u8x16));
|
|
149
|
+
b_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(b_vec.u8x16));
|
|
150
|
+
n = 0;
|
|
151
|
+
}
|
|
152
|
+
else {
|
|
153
|
+
a_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(vld1q_u8((nk_u8_t const *)a)));
|
|
154
|
+
b_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(vld1q_u8((nk_u8_t const *)b)));
|
|
155
|
+
a += 16, b += 16, n -= 16;
|
|
156
|
+
}
|
|
157
|
+
a2_f32x4 = vdotq_f32_mf8_fpm(a2_f32x4, a_mf8x16, a_mf8x16, NK_FPM_E4M3_);
|
|
158
|
+
ab_f32x4 = vdotq_f32_mf8_fpm(ab_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E4M3_);
|
|
159
|
+
b2_f32x4 = vdotq_f32_mf8_fpm(b2_f32x4, b_mf8x16, b_mf8x16, NK_FPM_E4M3_);
|
|
160
|
+
if (n) goto nk_sqeuclidean_e2m3_neonfp8_cycle;
|
|
161
|
+
*result = vaddvq_f32(a2_f32x4) - 2 * vaddvq_f32(ab_f32x4) + vaddvq_f32(b2_f32x4);
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
NK_PUBLIC void nk_euclidean_e2m3_neonfp8(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
165
|
+
nk_sqeuclidean_e2m3_neonfp8(a, b, n, result);
|
|
166
|
+
*result = nk_f32_sqrt_neon(*result);
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
NK_PUBLIC void nk_angular_e2m3_neonfp8(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
170
|
+
mfloat8x16_t a_mf8x16, b_mf8x16;
|
|
171
|
+
float32x4_t ab_f32x4 = vdupq_n_f32(0), a2_f32x4 = vdupq_n_f32(0), b2_f32x4 = vdupq_n_f32(0);
|
|
172
|
+
nk_angular_e2m3_neonfp8_cycle:
|
|
173
|
+
if (n < 16) {
|
|
174
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
175
|
+
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
176
|
+
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
177
|
+
a_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(a_vec.u8x16));
|
|
178
|
+
b_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(b_vec.u8x16));
|
|
179
|
+
n = 0;
|
|
180
|
+
}
|
|
181
|
+
else {
|
|
182
|
+
a_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(vld1q_u8((nk_u8_t const *)a)));
|
|
183
|
+
b_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(vld1q_u8((nk_u8_t const *)b)));
|
|
184
|
+
a += 16, b += 16, n -= 16;
|
|
185
|
+
}
|
|
186
|
+
ab_f32x4 = vdotq_f32_mf8_fpm(ab_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E4M3_);
|
|
187
|
+
a2_f32x4 = vdotq_f32_mf8_fpm(a2_f32x4, a_mf8x16, a_mf8x16, NK_FPM_E4M3_);
|
|
188
|
+
b2_f32x4 = vdotq_f32_mf8_fpm(b2_f32x4, b_mf8x16, b_mf8x16, NK_FPM_E4M3_);
|
|
189
|
+
if (n) goto nk_angular_e2m3_neonfp8_cycle;
|
|
190
|
+
*result = nk_angular_normalize_f32_neon_(vaddvq_f32(ab_f32x4), vaddvq_f32(a2_f32x4), vaddvq_f32(b2_f32x4));
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
NK_PUBLIC void nk_sqeuclidean_e3m2_neonfp8(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
194
|
+
mfloat8x16_t a_mf8x16, b_mf8x16;
|
|
195
|
+
float32x4_t a2_f32x4 = vdupq_n_f32(0), ab_f32x4 = vdupq_n_f32(0), b2_f32x4 = vdupq_n_f32(0);
|
|
196
|
+
nk_sqeuclidean_e3m2_neonfp8_cycle:
|
|
197
|
+
if (n < 16) {
|
|
198
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
199
|
+
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
200
|
+
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
201
|
+
a_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(a_vec.u8x16));
|
|
202
|
+
b_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(b_vec.u8x16));
|
|
203
|
+
n = 0;
|
|
204
|
+
}
|
|
205
|
+
else {
|
|
206
|
+
a_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(vld1q_u8((nk_u8_t const *)a)));
|
|
207
|
+
b_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(vld1q_u8((nk_u8_t const *)b)));
|
|
208
|
+
a += 16, b += 16, n -= 16;
|
|
209
|
+
}
|
|
210
|
+
a2_f32x4 = vdotq_f32_mf8_fpm(a2_f32x4, a_mf8x16, a_mf8x16, NK_FPM_E5M2_);
|
|
211
|
+
ab_f32x4 = vdotq_f32_mf8_fpm(ab_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E5M2_);
|
|
212
|
+
b2_f32x4 = vdotq_f32_mf8_fpm(b2_f32x4, b_mf8x16, b_mf8x16, NK_FPM_E5M2_);
|
|
213
|
+
if (n) goto nk_sqeuclidean_e3m2_neonfp8_cycle;
|
|
214
|
+
*result = vaddvq_f32(a2_f32x4) - 2 * vaddvq_f32(ab_f32x4) + vaddvq_f32(b2_f32x4);
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
NK_PUBLIC void nk_euclidean_e3m2_neonfp8(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
218
|
+
nk_sqeuclidean_e3m2_neonfp8(a, b, n, result);
|
|
219
|
+
*result = nk_f32_sqrt_neon(*result);
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
NK_PUBLIC void nk_angular_e3m2_neonfp8(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
223
|
+
mfloat8x16_t a_mf8x16, b_mf8x16;
|
|
224
|
+
float32x4_t ab_f32x4 = vdupq_n_f32(0), a2_f32x4 = vdupq_n_f32(0), b2_f32x4 = vdupq_n_f32(0);
|
|
225
|
+
nk_angular_e3m2_neonfp8_cycle:
|
|
226
|
+
if (n < 16) {
|
|
227
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
228
|
+
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
229
|
+
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
230
|
+
a_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(a_vec.u8x16));
|
|
231
|
+
b_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(b_vec.u8x16));
|
|
232
|
+
n = 0;
|
|
233
|
+
}
|
|
234
|
+
else {
|
|
235
|
+
a_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(vld1q_u8((nk_u8_t const *)a)));
|
|
236
|
+
b_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(vld1q_u8((nk_u8_t const *)b)));
|
|
237
|
+
a += 16, b += 16, n -= 16;
|
|
238
|
+
}
|
|
239
|
+
ab_f32x4 = vdotq_f32_mf8_fpm(ab_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E5M2_);
|
|
240
|
+
a2_f32x4 = vdotq_f32_mf8_fpm(a2_f32x4, a_mf8x16, a_mf8x16, NK_FPM_E5M2_);
|
|
241
|
+
b2_f32x4 = vdotq_f32_mf8_fpm(b2_f32x4, b_mf8x16, b_mf8x16, NK_FPM_E5M2_);
|
|
242
|
+
if (n) goto nk_angular_e3m2_neonfp8_cycle;
|
|
243
|
+
*result = nk_angular_normalize_f32_neon_(vaddvq_f32(ab_f32x4), vaddvq_f32(a2_f32x4), vaddvq_f32(b2_f32x4));
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
#if defined(__clang__)
|
|
247
|
+
#pragma clang attribute pop
|
|
248
|
+
#elif defined(__GNUC__)
|
|
249
|
+
#pragma GCC pop_options
|
|
250
|
+
#endif
|
|
251
|
+
|
|
252
|
+
#if defined(__cplusplus)
|
|
253
|
+
} // extern "C"
|
|
254
|
+
#endif
|
|
255
|
+
|
|
256
|
+
#endif // NK_TARGET_NEONFP8
|
|
257
|
+
#endif // NK_TARGET_ARM_
|
|
258
|
+
#endif // NK_SPATIAL_NEONFP8_H
|