numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Spatial Similarity Measures for SVE.
|
|
3
|
+
* @file include/numkong/spatial/sve.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/spatial.h
|
|
8
|
+
*
|
|
9
|
+
* @section spatial_sve_instructions ARM SVE Instructions
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Latency Throughput
|
|
12
|
+
* svld1_f32 LD1W (Z.S, P/Z, [Xn]) 4-6cy 2/cy
|
|
13
|
+
* svsub_f32_x FSUB (Z.S, P/M, Z.S, Z.S) 3cy 2/cy
|
|
14
|
+
* svmla_f32_x FMLA (Z.S, P/M, Z.S, Z.S) 4cy 2/cy
|
|
15
|
+
* svaddv_f32 FADDV (S, P, Z.S) 6cy 1/cy
|
|
16
|
+
* svdupq_n_f32 DUP (Z.S, #imm) 1cy 2/cy
|
|
17
|
+
* svwhilelt_b32 WHILELT (P.S, Xn, Xm) 2cy 1/cy
|
|
18
|
+
* svptrue_b32 PTRUE (P.S, pattern) 1cy 2/cy
|
|
19
|
+
* svcntw CNTW (Xd) 1cy 2/cy
|
|
20
|
+
* svld1_f64 LD1D (Z.D, P/Z, [Xn]) 4-6cy 2/cy
|
|
21
|
+
* svsub_f64_x FSUB (Z.D, P/M, Z.D, Z.D) 3cy 2/cy
|
|
22
|
+
* svmla_f64_x FMLA (Z.D, P/M, Z.D, Z.D) 4cy 2/cy
|
|
23
|
+
* svaddv_f64 FADDV (D, P, Z.D) 6cy 1/cy
|
|
24
|
+
*
|
|
25
|
+
* SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
|
|
26
|
+
* and Apple M4+ use 128-bit. Code using svcntb() adapts automatically, but wider vectors
|
|
27
|
+
* process more elements per iteration with identical latencies.
|
|
28
|
+
*
|
|
29
|
+
* Spatial operations like L2 distance and angular similarity benefit from SVE's fused
|
|
30
|
+
* multiply-add instructions. The FADDV reduction dominates the critical path for short vectors.
|
|
31
|
+
*/
|
|
32
|
+
#ifndef NK_SPATIAL_SVE_H
|
|
33
|
+
#define NK_SPATIAL_SVE_H
|
|
34
|
+
|
|
35
|
+
#if NK_TARGET_ARM_
|
|
36
|
+
#if NK_TARGET_SVE
|
|
37
|
+
|
|
38
|
+
#include "numkong/types.h"
|
|
39
|
+
#include "numkong/spatial/neon.h" // `nk_f64_sqrt_neon`
|
|
40
|
+
#include "numkong/dot/sve.h" // `nk_dot_stable_sum_f64_sve_`
|
|
41
|
+
|
|
42
|
+
#if defined(__cplusplus)
|
|
43
|
+
extern "C" {
|
|
44
|
+
#endif
|
|
45
|
+
|
|
46
|
+
#if defined(__clang__)
|
|
47
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve"))), apply_to = function)
|
|
48
|
+
#elif defined(__GNUC__)
|
|
49
|
+
#pragma GCC push_options
|
|
50
|
+
#pragma GCC target("arch=armv8.2-a+sve")
|
|
51
|
+
#endif
|
|
52
|
+
|
|
53
|
+
/** @brief Reciprocal square root of an f32 SVE vector via estimate + 2 Newton-Raphson steps.
|
|
54
|
+
*
|
|
55
|
+
* Computes 1/sqrt(x) for each active lane. The initial estimate from `svrsqrte_f32`
|
|
56
|
+
* has ~8 bits of precision; each Newton-Raphson iteration via `svrsqrts_f32` roughly
|
|
57
|
+
* doubles the mantissa bits, giving ~23 bits (~full f32 precision) after 2 steps.
|
|
58
|
+
*
|
|
59
|
+
* Marked `__arm_streaming_compatible` so the helper is callable from both streaming
|
|
60
|
+
* (SME) and non-streaming (SVE) contexts without mode transitions.
|
|
61
|
+
*
|
|
62
|
+
* @param predicate Active-lane mask
|
|
63
|
+
* @param x Input vector (must be positive for meaningful results)
|
|
64
|
+
* @return Approximate 1/sqrt(x) with ~23-bit mantissa accuracy
|
|
65
|
+
*/
|
|
66
|
+
NK_INTERNAL svfloat32_t nk_rsqrt_f32x_sve_(svbool_t predicate, svfloat32_t x) NK_STREAMING_COMPATIBLE_ {
|
|
67
|
+
svfloat32_t r = svrsqrte_f32(x);
|
|
68
|
+
r = svmul_f32_x(predicate, r, svrsqrts_f32(svmul_f32_x(predicate, x, r), r));
|
|
69
|
+
r = svmul_f32_x(predicate, r, svrsqrts_f32(svmul_f32_x(predicate, x, r), r));
|
|
70
|
+
return r;
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
/** @brief Reciprocal square root of an f64 SVE vector via estimate + 3 Newton-Raphson steps.
|
|
74
|
+
*
|
|
75
|
+
* Computes 1/sqrt(x) for each active lane. The initial estimate from `svrsqrte_f64`
|
|
76
|
+
* has ~8 bits of precision; three Newton-Raphson iterations via `svrsqrts_f64` yield
|
|
77
|
+
* ~52-bit mantissa accuracy (full f64 precision).
|
|
78
|
+
*
|
|
79
|
+
* Marked `__arm_streaming_compatible` so the helper is callable from both streaming
|
|
80
|
+
* (SME) and non-streaming (SVE) contexts without mode transitions.
|
|
81
|
+
*
|
|
82
|
+
* @param predicate Active-lane mask
|
|
83
|
+
* @param x Input vector (must be positive for meaningful results)
|
|
84
|
+
* @return Approximate 1/sqrt(x) with ~52-bit mantissa accuracy
|
|
85
|
+
*/
|
|
86
|
+
NK_INTERNAL svfloat64_t nk_rsqrt_f64x_sve_(svbool_t predicate, svfloat64_t x) NK_STREAMING_COMPATIBLE_ {
|
|
87
|
+
svfloat64_t r = svrsqrte_f64(x);
|
|
88
|
+
r = svmul_f64_x(predicate, r, svrsqrts_f64(svmul_f64_x(predicate, x, r), r));
|
|
89
|
+
r = svmul_f64_x(predicate, r, svrsqrts_f64(svmul_f64_x(predicate, x, r), r));
|
|
90
|
+
r = svmul_f64_x(predicate, r, svrsqrts_f64(svmul_f64_x(predicate, x, r), r));
|
|
91
|
+
return r;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
NK_PUBLIC void nk_sqeuclidean_f32_sve(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
95
|
+
nk_size_t i = 0;
|
|
96
|
+
nk_size_t const vector_length = svcntd();
|
|
97
|
+
svfloat64_t dist_sq_f64x = svdupq_n_f64(0.0, 0.0);
|
|
98
|
+
for (; i < n; i += vector_length) {
|
|
99
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(i, n);
|
|
100
|
+
svbool_t predicate_f64x = svwhilelt_b64_u64(i, n);
|
|
101
|
+
svfloat64_t a_f64x = svcvt_f64_f32_x(predicate_f64x, svld1_f32(predicate_f32x, a + i));
|
|
102
|
+
svfloat64_t b_f64x = svcvt_f64_f32_x(predicate_f64x, svld1_f32(predicate_f32x, b + i));
|
|
103
|
+
svfloat64_t diff_f64x = svsub_f64_x(predicate_f64x, a_f64x, b_f64x);
|
|
104
|
+
dist_sq_f64x = svmla_f64_x(predicate_f64x, dist_sq_f64x, diff_f64x, diff_f64x);
|
|
105
|
+
}
|
|
106
|
+
nk_f64_t dist_sq_f64 = svaddv_f64(svptrue_b64(), dist_sq_f64x);
|
|
107
|
+
*result = dist_sq_f64;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
NK_PUBLIC void nk_euclidean_f32_sve(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
111
|
+
nk_sqeuclidean_f32_sve(a, b, n, result);
|
|
112
|
+
*result = nk_f64_sqrt_neon(*result);
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
NK_PUBLIC void nk_angular_f32_sve(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
116
|
+
nk_size_t i = 0;
|
|
117
|
+
nk_size_t const vector_length = svcntd();
|
|
118
|
+
svfloat64_t ab_f64x = svdupq_n_f64(0.0, 0.0);
|
|
119
|
+
svfloat64_t a2_f64x = svdupq_n_f64(0.0, 0.0);
|
|
120
|
+
svfloat64_t b2_f64x = svdupq_n_f64(0.0, 0.0);
|
|
121
|
+
for (; i < n; i += vector_length) {
|
|
122
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(i, n);
|
|
123
|
+
svbool_t predicate_f64x = svwhilelt_b64_u64(i, n);
|
|
124
|
+
svfloat64_t a_f64x = svcvt_f64_f32_x(predicate_f64x, svld1_f32(predicate_f32x, a + i));
|
|
125
|
+
svfloat64_t b_f64x = svcvt_f64_f32_x(predicate_f64x, svld1_f32(predicate_f32x, b + i));
|
|
126
|
+
ab_f64x = svmla_f64_x(predicate_f64x, ab_f64x, a_f64x, b_f64x);
|
|
127
|
+
a2_f64x = svmla_f64_x(predicate_f64x, a2_f64x, a_f64x, a_f64x);
|
|
128
|
+
b2_f64x = svmla_f64_x(predicate_f64x, b2_f64x, b_f64x, b_f64x);
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
nk_f64_t ab_f64 = svaddv_f64(svptrue_b64(), ab_f64x);
|
|
132
|
+
nk_f64_t a2_f64 = svaddv_f64(svptrue_b64(), a2_f64x);
|
|
133
|
+
nk_f64_t b2_f64 = svaddv_f64(svptrue_b64(), b2_f64x);
|
|
134
|
+
*result = nk_angular_normalize_f64_neon_(ab_f64, a2_f64, b2_f64);
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
NK_PUBLIC void nk_sqeuclidean_f64_sve(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
138
|
+
// Neumaier compensated summation for numerical stability
|
|
139
|
+
nk_size_t i = 0;
|
|
140
|
+
svfloat64_t sum_f64x = svdupq_n_f64(0.0, 0.0);
|
|
141
|
+
svfloat64_t compensation_f64x = svdupq_n_f64(0.0, 0.0);
|
|
142
|
+
svbool_t predicate_all_f64x = svptrue_b64();
|
|
143
|
+
do {
|
|
144
|
+
svbool_t predicate_f64x = svwhilelt_b64_u64(i, n);
|
|
145
|
+
svfloat64_t a_f64x = svld1_f64(predicate_f64x, a + i);
|
|
146
|
+
svfloat64_t b_f64x = svld1_f64(predicate_f64x, b + i);
|
|
147
|
+
svfloat64_t diff_f64x = svsub_f64_x(predicate_f64x, a_f64x, b_f64x);
|
|
148
|
+
svfloat64_t diff_sq_f64x = svmul_f64_x(predicate_f64x, diff_f64x, diff_f64x);
|
|
149
|
+
// Neumaier: t = sum + x
|
|
150
|
+
svfloat64_t t_f64x = svadd_f64_x(predicate_f64x, sum_f64x, diff_sq_f64x);
|
|
151
|
+
svfloat64_t abs_sum_f64x = svabs_f64_x(predicate_f64x, sum_f64x);
|
|
152
|
+
// diff_sq is already non-negative (it's a square), so svabs is unnecessary
|
|
153
|
+
svbool_t sum_ge_x_f64x = svcmpge_f64(predicate_f64x, abs_sum_f64x, diff_sq_f64x);
|
|
154
|
+
// When |sum| >= |x|: comp += (sum - t) + x; when |x| > |sum|: comp += (x - t) + sum
|
|
155
|
+
svfloat64_t comp_sum_large_f64x = svadd_f64_x(predicate_f64x, svsub_f64_x(predicate_f64x, sum_f64x, t_f64x),
|
|
156
|
+
diff_sq_f64x);
|
|
157
|
+
svfloat64_t comp_x_large_f64x = svadd_f64_x(predicate_f64x, svsub_f64_x(predicate_f64x, diff_sq_f64x, t_f64x),
|
|
158
|
+
sum_f64x);
|
|
159
|
+
svfloat64_t comp_update_f64x = svsel_f64(sum_ge_x_f64x, comp_sum_large_f64x, comp_x_large_f64x);
|
|
160
|
+
compensation_f64x = svadd_f64_x(predicate_f64x, compensation_f64x, comp_update_f64x);
|
|
161
|
+
sum_f64x = t_f64x;
|
|
162
|
+
i += svcntd();
|
|
163
|
+
} while (i < n);
|
|
164
|
+
*result = nk_dot_stable_sum_f64_sve_(predicate_all_f64x, sum_f64x, compensation_f64x);
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
NK_PUBLIC void nk_euclidean_f64_sve(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
168
|
+
nk_sqeuclidean_f64_sve(a, b, n, result);
|
|
169
|
+
*result = nk_f64_sqrt_neon(*result);
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
NK_PUBLIC void nk_angular_f64_sve(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
173
|
+
// Dot2 (Ogita-Rump-Oishi) for cross-product ab (may have cancellation),
|
|
174
|
+
// simple FMA for self-products a2/b2 (all positive, no cancellation)
|
|
175
|
+
nk_size_t i = 0;
|
|
176
|
+
svfloat64_t ab_sum_f64x = svdupq_n_f64(0.0, 0.0);
|
|
177
|
+
svfloat64_t ab_compensation_f64x = svdupq_n_f64(0.0, 0.0);
|
|
178
|
+
svfloat64_t a2_f64x = svdupq_n_f64(0.0, 0.0);
|
|
179
|
+
svfloat64_t b2_f64x = svdupq_n_f64(0.0, 0.0);
|
|
180
|
+
svbool_t predicate_all_f64x = svptrue_b64();
|
|
181
|
+
do {
|
|
182
|
+
svbool_t predicate_f64x = svwhilelt_b64_u64(i, n);
|
|
183
|
+
svfloat64_t a_f64x = svld1_f64(predicate_f64x, a + i);
|
|
184
|
+
svfloat64_t b_f64x = svld1_f64(predicate_f64x, b + i);
|
|
185
|
+
// TwoProd for ab: product = a*b, error = fma(a,b,-product) = -(product - a*b)
|
|
186
|
+
svfloat64_t product_f64x = svmul_f64_x(predicate_f64x, a_f64x, b_f64x);
|
|
187
|
+
svfloat64_t product_error_f64x = svneg_f64_x(predicate_f64x,
|
|
188
|
+
svnmls_f64_x(predicate_f64x, product_f64x, a_f64x, b_f64x));
|
|
189
|
+
// TwoSum: (tentative_sum, sum_error) = TwoSum(sum, product)
|
|
190
|
+
svfloat64_t tentative_sum_f64x = svadd_f64_x(predicate_f64x, ab_sum_f64x, product_f64x);
|
|
191
|
+
svfloat64_t virtual_addend_f64x = svsub_f64_x(predicate_f64x, tentative_sum_f64x, ab_sum_f64x);
|
|
192
|
+
svfloat64_t sum_error_f64x = svadd_f64_x(
|
|
193
|
+
predicate_f64x,
|
|
194
|
+
svsub_f64_x(predicate_f64x, ab_sum_f64x,
|
|
195
|
+
svsub_f64_x(predicate_f64x, tentative_sum_f64x, virtual_addend_f64x)),
|
|
196
|
+
svsub_f64_x(predicate_f64x, product_f64x, virtual_addend_f64x));
|
|
197
|
+
ab_sum_f64x = tentative_sum_f64x;
|
|
198
|
+
ab_compensation_f64x = svadd_f64_x(predicate_f64x, ab_compensation_f64x,
|
|
199
|
+
svadd_f64_x(predicate_f64x, sum_error_f64x, product_error_f64x));
|
|
200
|
+
// Simple FMA for self-products (no cancellation)
|
|
201
|
+
a2_f64x = svmla_f64_x(predicate_f64x, a2_f64x, a_f64x, a_f64x);
|
|
202
|
+
b2_f64x = svmla_f64_x(predicate_f64x, b2_f64x, b_f64x, b_f64x);
|
|
203
|
+
i += svcntd();
|
|
204
|
+
} while (i < n);
|
|
205
|
+
|
|
206
|
+
nk_f64_t ab_f64 = nk_dot_stable_sum_f64_sve_(predicate_all_f64x, ab_sum_f64x, ab_compensation_f64x);
|
|
207
|
+
nk_f64_t a2_f64 = svaddv_f64(predicate_all_f64x, a2_f64x);
|
|
208
|
+
nk_f64_t b2_f64 = svaddv_f64(predicate_all_f64x, b2_f64x);
|
|
209
|
+
*result = nk_angular_normalize_f64_neon_(ab_f64, a2_f64, b2_f64);
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
#if defined(__clang__)
|
|
213
|
+
#pragma clang attribute pop
|
|
214
|
+
#elif defined(__GNUC__)
|
|
215
|
+
#pragma GCC pop_options
|
|
216
|
+
#endif
|
|
217
|
+
|
|
218
|
+
#if defined(__cplusplus)
|
|
219
|
+
} // extern "C"
|
|
220
|
+
#endif
|
|
221
|
+
|
|
222
|
+
#endif // NK_TARGET_SVE
|
|
223
|
+
#endif // NK_TARGET_ARM_
|
|
224
|
+
#endif // NK_SPATIAL_SVE_H
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Spatial Similarity Measures for SVE BF16.
|
|
3
|
+
* @file include/numkong/spatial/sve.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/spatial.h
|
|
8
|
+
*
|
|
9
|
+
* @section spatial_svebfdot_instructions ARM SVE+BF16 Instructions
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Latency Throughput
|
|
12
|
+
* svld1_bf16 LD1H (Z.H, P/Z, [Xn]) 4-6cy 2/cy
|
|
13
|
+
* svld1_u16 LD1H (Z.H, P/Z, [Xn]) 4-6cy 2/cy
|
|
14
|
+
* svbfdot_f32 BFDOT (Z.S, Z.H, Z.H) 4cy 2/cy
|
|
15
|
+
* svmla_f32_x FMLA (Z.S, P/M, Z.S, Z.S) 4cy 2/cy
|
|
16
|
+
* svsub_f32_x FSUB (Z.S, P/M, Z.S, Z.S) 3cy 2/cy
|
|
17
|
+
* svaddv_f32 FADDV (S, P, Z.S) 6cy 1/cy
|
|
18
|
+
* svunpklo_u32 UUNPKLO (Z.S, Z.H) 2cy 2/cy
|
|
19
|
+
* svunpkhi_u32 UUNPKHI (Z.S, Z.H) 2cy 2/cy
|
|
20
|
+
* svlsl_n_u32_x LSL (Z.S, P/M, Z.S, #imm) 2cy 2/cy
|
|
21
|
+
* svwhilelt_b16 WHILELT (P.H, Xn, Xm) 2cy 1/cy
|
|
22
|
+
* svwhilelt_b32 WHILELT (P.S, Xn, Xm) 2cy 1/cy
|
|
23
|
+
* svcnth CNTH (Xd) 1cy 2/cy
|
|
24
|
+
*
|
|
25
|
+
* SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
|
|
26
|
+
* and Apple M4+ use 128-bit. Code using svcntb() adapts automatically, but wider vectors
|
|
27
|
+
* process more elements per iteration with identical latencies.
|
|
28
|
+
*
|
|
29
|
+
* The BFDOT instruction fuses two BF16 multiplications with FP32 accumulation, providing
|
|
30
|
+
* efficient BF16 dot products without explicit conversion overhead.
|
|
31
|
+
*/
|
|
32
|
+
#ifndef NK_SPATIAL_SVEBFDOT_H
|
|
33
|
+
#define NK_SPATIAL_SVEBFDOT_H
|
|
34
|
+
|
|
35
|
+
#if NK_TARGET_ARM_
|
|
36
|
+
#if NK_TARGET_SVEBFDOT
|
|
37
|
+
|
|
38
|
+
#include "numkong/types.h"
|
|
39
|
+
#include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`
|
|
40
|
+
|
|
41
|
+
#if defined(__cplusplus)
|
|
42
|
+
extern "C" {
|
|
43
|
+
#endif
|
|
44
|
+
|
|
45
|
+
#if defined(__clang__)
|
|
46
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve+bf16"))), apply_to = function)
|
|
47
|
+
#elif defined(__GNUC__)
|
|
48
|
+
#pragma GCC push_options
|
|
49
|
+
#pragma GCC target("arch=armv8.2-a+sve+bf16")
|
|
50
|
+
#endif
|
|
51
|
+
|
|
52
|
+
NK_PUBLIC void nk_sqeuclidean_bf16_svebfdot(nk_bf16_t const *a_enum, nk_bf16_t const *b_enum, nk_size_t n,
|
|
53
|
+
nk_f32_t *result) {
|
|
54
|
+
nk_size_t i = 0;
|
|
55
|
+
svfloat32_t d2_low_f32x = svdupq_n_f32(0.f, 0.f, 0.f, 0.f);
|
|
56
|
+
svfloat32_t d2_high_f32x = svdupq_n_f32(0.f, 0.f, 0.f, 0.f);
|
|
57
|
+
nk_u16_t const *a = (nk_u16_t const *)(a_enum);
|
|
58
|
+
nk_u16_t const *b = (nk_u16_t const *)(b_enum);
|
|
59
|
+
do {
|
|
60
|
+
svbool_t predicate_bf16x = svwhilelt_b16_u64(i, n);
|
|
61
|
+
svuint16_t a_u16x = svld1_u16(predicate_bf16x, a + i);
|
|
62
|
+
svuint16_t b_u16x = svld1_u16(predicate_bf16x, b + i);
|
|
63
|
+
|
|
64
|
+
// There is no `bf16` subtraction in SVE, so we need to convert to `u32` and shift.
|
|
65
|
+
svbool_t predicate_low_f32x = svwhilelt_b32_u64(i, n);
|
|
66
|
+
svbool_t predicate_high_f32x = svwhilelt_b32_u64(i + svcnth() / 2, n);
|
|
67
|
+
svfloat32_t a_low_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(predicate_low_f32x, svunpklo_u32(a_u16x), 16));
|
|
68
|
+
svfloat32_t a_high_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(predicate_high_f32x, svunpkhi_u32(a_u16x), 16));
|
|
69
|
+
svfloat32_t b_low_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(predicate_low_f32x, svunpklo_u32(b_u16x), 16));
|
|
70
|
+
svfloat32_t b_high_f32x = svreinterpret_f32_u32(svlsl_n_u32_x(predicate_high_f32x, svunpkhi_u32(b_u16x), 16));
|
|
71
|
+
|
|
72
|
+
svfloat32_t a_minus_b_low_f32x = svsub_f32_x(predicate_low_f32x, a_low_f32x, b_low_f32x);
|
|
73
|
+
svfloat32_t a_minus_b_high_f32x = svsub_f32_x(predicate_high_f32x, a_high_f32x, b_high_f32x);
|
|
74
|
+
d2_low_f32x = svmla_f32_x(predicate_bf16x, d2_low_f32x, a_minus_b_low_f32x, a_minus_b_low_f32x);
|
|
75
|
+
d2_high_f32x = svmla_f32_x(predicate_bf16x, d2_high_f32x, a_minus_b_high_f32x, a_minus_b_high_f32x);
|
|
76
|
+
i += svcnth();
|
|
77
|
+
} while (i < n);
|
|
78
|
+
nk_f32_t d2 = svaddv_f32(svptrue_b32(), d2_low_f32x) + svaddv_f32(svptrue_b32(), d2_high_f32x);
|
|
79
|
+
*result = d2;
|
|
80
|
+
}
|
|
81
|
+
NK_PUBLIC void nk_euclidean_bf16_svebfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
82
|
+
nk_sqeuclidean_bf16_svebfdot(a, b, n, result);
|
|
83
|
+
*result = nk_f32_sqrt_neon(*result);
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
NK_PUBLIC void nk_angular_bf16_svebfdot(nk_bf16_t const *a_enum, nk_bf16_t const *b_enum, nk_size_t n,
|
|
87
|
+
nk_f32_t *result) {
|
|
88
|
+
nk_size_t i = 0;
|
|
89
|
+
svfloat32_t ab_f32x = svdupq_n_f32(0.f, 0.f, 0.f, 0.f);
|
|
90
|
+
svfloat32_t a2_f32x = svdupq_n_f32(0.f, 0.f, 0.f, 0.f);
|
|
91
|
+
svfloat32_t b2_f32x = svdupq_n_f32(0.f, 0.f, 0.f, 0.f);
|
|
92
|
+
nk_bf16_for_arm_simd_t const *a = (nk_bf16_for_arm_simd_t const *)(a_enum);
|
|
93
|
+
nk_bf16_for_arm_simd_t const *b = (nk_bf16_for_arm_simd_t const *)(b_enum);
|
|
94
|
+
do {
|
|
95
|
+
svbool_t predicate_bf16x = svwhilelt_b16_u64(i, n);
|
|
96
|
+
svbfloat16_t a_bf16x = svld1_bf16(predicate_bf16x, a + i);
|
|
97
|
+
svbfloat16_t b_bf16x = svld1_bf16(predicate_bf16x, b + i);
|
|
98
|
+
ab_f32x = svbfdot_f32(ab_f32x, a_bf16x, b_bf16x);
|
|
99
|
+
a2_f32x = svbfdot_f32(a2_f32x, a_bf16x, a_bf16x);
|
|
100
|
+
b2_f32x = svbfdot_f32(b2_f32x, b_bf16x, b_bf16x);
|
|
101
|
+
i += svcnth();
|
|
102
|
+
} while (i < n);
|
|
103
|
+
|
|
104
|
+
nk_f32_t ab = svaddv_f32(svptrue_b32(), ab_f32x);
|
|
105
|
+
nk_f32_t a2 = svaddv_f32(svptrue_b32(), a2_f32x);
|
|
106
|
+
nk_f32_t b2 = svaddv_f32(svptrue_b32(), b2_f32x);
|
|
107
|
+
*result = nk_angular_normalize_f32_neon_(ab, a2, b2);
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
#if defined(__clang__)
|
|
111
|
+
#pragma clang attribute pop
|
|
112
|
+
#elif defined(__GNUC__)
|
|
113
|
+
#pragma GCC pop_options
|
|
114
|
+
#endif
|
|
115
|
+
|
|
116
|
+
#if defined(__cplusplus)
|
|
117
|
+
} // extern "C"
|
|
118
|
+
#endif
|
|
119
|
+
|
|
120
|
+
#endif // NK_TARGET_SVEBFDOT
|
|
121
|
+
#endif // NK_TARGET_ARM_
|
|
122
|
+
#endif // NK_SPATIAL_SVEBFDOT_H
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Spatial Similarity Measures for SVE FP16.
|
|
3
|
+
* @file include/numkong/spatial/sve.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/spatial.h
|
|
8
|
+
*
|
|
9
|
+
* @section spatial_svehalf_instructions ARM SVE+FP16 Instructions
|
|
10
|
+
*
|
|
11
|
+
* Intrinsic Instruction Latency Throughput
|
|
12
|
+
* svld1_f16 LD1H (Z.H, P/Z, [Xn]) 4-6cy 2/cy
|
|
13
|
+
* svsub_f16_x FSUB (Z.H, P/M, Z.H, Z.H) 3cy 2/cy
|
|
14
|
+
* svmla_f16_x FMLA (Z.H, P/M, Z.H, Z.H) 4cy 2/cy
|
|
15
|
+
* svaddv_f16 FADDV (H, P, Z.H) 6cy 1/cy
|
|
16
|
+
* svdupq_n_f16 DUP (Z.H, #imm) 1cy 2/cy
|
|
17
|
+
* svwhilelt_b16 WHILELT (P.H, Xn, Xm) 2cy 1/cy
|
|
18
|
+
* svptrue_b16 PTRUE (P.H, pattern) 1cy 2/cy
|
|
19
|
+
* svcnth CNTH (Xd) 1cy 2/cy
|
|
20
|
+
*
|
|
21
|
+
* SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
|
|
22
|
+
* and Apple M4+ use 128-bit. Code using svcntb() adapts automatically, but wider vectors
|
|
23
|
+
* process more elements per iteration with identical latencies.
|
|
24
|
+
*
|
|
25
|
+
* FP16 spatial operations trade precision for throughput, processing twice as many elements
|
|
26
|
+
* per cycle. This is particularly effective for embedding similarity in ML applications.
|
|
27
|
+
*/
|
|
28
|
+
#ifndef NK_SPATIAL_SVEHALF_H
|
|
29
|
+
#define NK_SPATIAL_SVEHALF_H
|
|
30
|
+
|
|
31
|
+
#if NK_TARGET_ARM_
|
|
32
|
+
#if NK_TARGET_SVEHALF
|
|
33
|
+
|
|
34
|
+
#include "numkong/types.h"
|
|
35
|
+
#include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`
|
|
36
|
+
|
|
37
|
+
#if defined(__cplusplus)
|
|
38
|
+
extern "C" {
|
|
39
|
+
#endif
|
|
40
|
+
|
|
41
|
+
#if defined(__clang__)
|
|
42
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve+fp16"))), apply_to = function)
|
|
43
|
+
#elif defined(__GNUC__)
|
|
44
|
+
#pragma GCC push_options
|
|
45
|
+
#pragma GCC target("arch=armv8.2-a+sve+fp16")
|
|
46
|
+
#endif
|
|
47
|
+
|
|
48
|
+
NK_PUBLIC void nk_sqeuclidean_f16_svehalf(nk_f16_t const *a_enum, nk_f16_t const *b_enum, nk_size_t n,
|
|
49
|
+
nk_f32_t *result) {
|
|
50
|
+
nk_size_t i = 0;
|
|
51
|
+
svfloat32_t d2_f32x = svdup_n_f32(0.0f);
|
|
52
|
+
nk_f16_for_arm_simd_t const *a = (nk_f16_for_arm_simd_t const *)(a_enum);
|
|
53
|
+
nk_f16_for_arm_simd_t const *b = (nk_f16_for_arm_simd_t const *)(b_enum);
|
|
54
|
+
do {
|
|
55
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(i, n);
|
|
56
|
+
svfloat16_t a_f16x = svld1_f16(svwhilelt_b16_u64(i, n), a + i);
|
|
57
|
+
svfloat16_t b_f16x = svld1_f16(svwhilelt_b16_u64(i, n), b + i);
|
|
58
|
+
svfloat32_t a_f32x = svcvt_f32_f16_x(predicate_f32x, a_f16x);
|
|
59
|
+
svfloat32_t b_f32x = svcvt_f32_f16_x(predicate_f32x, b_f16x);
|
|
60
|
+
svfloat32_t diff_f32x = svsub_f32_x(predicate_f32x, a_f32x, b_f32x);
|
|
61
|
+
d2_f32x = svmla_f32_x(predicate_f32x, d2_f32x, diff_f32x, diff_f32x);
|
|
62
|
+
i += svcntw();
|
|
63
|
+
} while (i < n);
|
|
64
|
+
*result = svaddv_f32(svptrue_b32(), d2_f32x);
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
NK_PUBLIC void nk_euclidean_f16_svehalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
68
|
+
nk_sqeuclidean_f16_svehalf(a, b, n, result);
|
|
69
|
+
*result = nk_f32_sqrt_neon(*result);
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
NK_PUBLIC void nk_angular_f16_svehalf(nk_f16_t const *a_enum, nk_f16_t const *b_enum, nk_size_t n, nk_f32_t *result) {
|
|
73
|
+
nk_size_t i = 0;
|
|
74
|
+
svfloat32_t ab_f32x = svdup_n_f32(0.0f);
|
|
75
|
+
svfloat32_t a2_f32x = svdup_n_f32(0.0f);
|
|
76
|
+
svfloat32_t b2_f32x = svdup_n_f32(0.0f);
|
|
77
|
+
nk_f16_for_arm_simd_t const *a = (nk_f16_for_arm_simd_t const *)(a_enum);
|
|
78
|
+
nk_f16_for_arm_simd_t const *b = (nk_f16_for_arm_simd_t const *)(b_enum);
|
|
79
|
+
do {
|
|
80
|
+
svbool_t predicate_f32x = svwhilelt_b32_u64(i, n);
|
|
81
|
+
svfloat16_t a_f16x = svld1_f16(svwhilelt_b16_u64(i, n), a + i);
|
|
82
|
+
svfloat16_t b_f16x = svld1_f16(svwhilelt_b16_u64(i, n), b + i);
|
|
83
|
+
svfloat32_t a_f32x = svcvt_f32_f16_x(predicate_f32x, a_f16x);
|
|
84
|
+
svfloat32_t b_f32x = svcvt_f32_f16_x(predicate_f32x, b_f16x);
|
|
85
|
+
ab_f32x = svmla_f32_x(predicate_f32x, ab_f32x, a_f32x, b_f32x);
|
|
86
|
+
a2_f32x = svmla_f32_x(predicate_f32x, a2_f32x, a_f32x, a_f32x);
|
|
87
|
+
b2_f32x = svmla_f32_x(predicate_f32x, b2_f32x, b_f32x, b_f32x);
|
|
88
|
+
i += svcntw();
|
|
89
|
+
} while (i < n);
|
|
90
|
+
|
|
91
|
+
nk_f32_t ab_f32 = svaddv_f32(svptrue_b32(), ab_f32x);
|
|
92
|
+
nk_f32_t a2_f32 = svaddv_f32(svptrue_b32(), a2_f32x);
|
|
93
|
+
nk_f32_t b2_f32 = svaddv_f32(svptrue_b32(), b2_f32x);
|
|
94
|
+
*result = nk_angular_normalize_f32_neon_(ab_f32, a2_f32, b2_f32);
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
#if defined(__clang__)
|
|
98
|
+
#pragma clang attribute pop
|
|
99
|
+
#elif defined(__GNUC__)
|
|
100
|
+
#pragma GCC pop_options
|
|
101
|
+
#endif
|
|
102
|
+
|
|
103
|
+
#if defined(__cplusplus)
|
|
104
|
+
} // extern "C"
|
|
105
|
+
#endif
|
|
106
|
+
|
|
107
|
+
#endif // NK_TARGET_SVEHALF
|
|
108
|
+
#endif // NK_TARGET_ARM_
|
|
109
|
+
#endif // NK_SPATIAL_SVEHALF_H
|