numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -1,157 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* @brief NEON FP16 implementations for the redesigned reduction API (moments + minmax).
|
|
3
|
-
* @file include/numkong/reduce/neonhalf.h
|
|
4
|
-
* @author Ash Vardanian
|
|
5
|
-
* @date February 13, 2026
|
|
6
|
-
*
|
|
7
|
-
* @sa include/numkong/reduce.h
|
|
8
|
-
*
|
|
9
|
-
* @section reduce_neonhalf_new_design Design Notes
|
|
10
|
-
*
|
|
11
|
-
* Moments (sum + sum-of-squares) accumulate in f32 via vcvt_f32_f16 widening, giving
|
|
12
|
-
* full f32 precision. The contiguous path processes 8 f16 elements per iteration, widening
|
|
13
|
-
* to two f32x4 halves and using vfmaq_f32 for fused multiply-accumulate of squares.
|
|
14
|
-
*
|
|
15
|
-
* Minmax tracks min/max values as native f16x8 with u16x8 iteration counters (same width
|
|
16
|
-
* as f16). The u16 counters wrap at 65536, so the dispatcher splits arrays larger than
|
|
17
|
-
* 65536 * 8 = 524288 elements via recursive halving.
|
|
18
|
-
*/
|
|
19
|
-
#ifndef NK_REDUCE_NEONHALF_H
|
|
20
|
-
#define NK_REDUCE_NEONHALF_H
|
|
21
|
-
|
|
22
|
-
#if NK_TARGET_ARM_
|
|
23
|
-
#if NK_TARGET_NEONHALF
|
|
24
|
-
|
|
25
|
-
#include "numkong/types.h"
|
|
26
|
-
#include "numkong/cast/neon.h"
|
|
27
|
-
#include "numkong/cast/serial.h"
|
|
28
|
-
#include "numkong/reduce/serial.h"
|
|
29
|
-
|
|
30
|
-
#if defined(__cplusplus)
|
|
31
|
-
extern "C" {
|
|
32
|
-
#endif
|
|
33
|
-
|
|
34
|
-
#if defined(__clang__)
|
|
35
|
-
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)
|
|
36
|
-
#elif defined(__GNUC__)
|
|
37
|
-
#pragma GCC push_options
|
|
38
|
-
#pragma GCC target("arch=armv8.2-a+simd+fp16")
|
|
39
|
-
#endif
|
|
40
|
-
|
|
41
|
-
NK_INTERNAL void nk_reduce_moments_f16_neonhalf_contiguous_( //
|
|
42
|
-
nk_f16_t const *data_ptr, nk_size_t count, //
|
|
43
|
-
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
44
|
-
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
45
|
-
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
46
|
-
nk_size_t idx = 0;
|
|
47
|
-
|
|
48
|
-
for (; idx + 8 <= count; idx += 8) {
|
|
49
|
-
float16x8_t data_f16x8 = vld1q_f16((nk_f16_for_arm_simd_t const *)(data_ptr + idx));
|
|
50
|
-
float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(data_f16x8));
|
|
51
|
-
float32x4_t high_f32x4 = vcvt_f32_f16(vget_high_f16(data_f16x8));
|
|
52
|
-
sum_f32x4 = vaddq_f32(sum_f32x4, low_f32x4);
|
|
53
|
-
sum_f32x4 = vaddq_f32(sum_f32x4, high_f32x4);
|
|
54
|
-
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4);
|
|
55
|
-
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, high_f32x4, high_f32x4);
|
|
56
|
-
}
|
|
57
|
-
|
|
58
|
-
// Scalar tail
|
|
59
|
-
nk_f32_t sum = vaddvq_f32(sum_f32x4);
|
|
60
|
-
nk_f32_t sumsq = vaddvq_f32(sumsq_f32x4);
|
|
61
|
-
for (; idx < count; ++idx) {
|
|
62
|
-
nk_f32_t value_f32;
|
|
63
|
-
nk_f16_to_f32_serial(data_ptr + idx, &value_f32);
|
|
64
|
-
sum += value_f32, sumsq += value_f32 * value_f32;
|
|
65
|
-
}
|
|
66
|
-
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
67
|
-
}
|
|
68
|
-
|
|
69
|
-
NK_INTERNAL void nk_reduce_moments_f16_neonhalf_strided_( //
|
|
70
|
-
nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
71
|
-
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
72
|
-
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
73
|
-
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
74
|
-
nk_size_t idx = 0;
|
|
75
|
-
|
|
76
|
-
if (stride_elements == 2) {
|
|
77
|
-
for (; idx + 8 <= count; idx += 8) {
|
|
78
|
-
uint16x8x2_t loaded_u16x8x2 = vld2q_u16((uint16_t const *)(data_ptr + idx * 2));
|
|
79
|
-
float16x8_t data_f16x8 = vreinterpretq_f16_u16(loaded_u16x8x2.val[0]);
|
|
80
|
-
float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(data_f16x8));
|
|
81
|
-
float32x4_t high_f32x4 = vcvt_f32_f16(vget_high_f16(data_f16x8));
|
|
82
|
-
sum_f32x4 = vaddq_f32(sum_f32x4, low_f32x4);
|
|
83
|
-
sum_f32x4 = vaddq_f32(sum_f32x4, high_f32x4);
|
|
84
|
-
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4);
|
|
85
|
-
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, high_f32x4, high_f32x4);
|
|
86
|
-
}
|
|
87
|
-
}
|
|
88
|
-
else if (stride_elements == 3) {
|
|
89
|
-
for (; idx + 8 <= count; idx += 8) {
|
|
90
|
-
uint16x8x3_t loaded_u16x8x3 = vld3q_u16((uint16_t const *)(data_ptr + idx * 3));
|
|
91
|
-
float16x8_t data_f16x8 = vreinterpretq_f16_u16(loaded_u16x8x3.val[0]);
|
|
92
|
-
float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(data_f16x8));
|
|
93
|
-
float32x4_t high_f32x4 = vcvt_f32_f16(vget_high_f16(data_f16x8));
|
|
94
|
-
sum_f32x4 = vaddq_f32(sum_f32x4, low_f32x4);
|
|
95
|
-
sum_f32x4 = vaddq_f32(sum_f32x4, high_f32x4);
|
|
96
|
-
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4);
|
|
97
|
-
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, high_f32x4, high_f32x4);
|
|
98
|
-
}
|
|
99
|
-
}
|
|
100
|
-
else if (stride_elements == 4) {
|
|
101
|
-
for (; idx + 8 <= count; idx += 8) {
|
|
102
|
-
uint16x8x4_t loaded_u16x8x4 = vld4q_u16((uint16_t const *)(data_ptr + idx * 4));
|
|
103
|
-
float16x8_t data_f16x8 = vreinterpretq_f16_u16(loaded_u16x8x4.val[0]);
|
|
104
|
-
float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(data_f16x8));
|
|
105
|
-
float32x4_t high_f32x4 = vcvt_f32_f16(vget_high_f16(data_f16x8));
|
|
106
|
-
sum_f32x4 = vaddq_f32(sum_f32x4, low_f32x4);
|
|
107
|
-
sum_f32x4 = vaddq_f32(sum_f32x4, high_f32x4);
|
|
108
|
-
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, low_f32x4, low_f32x4);
|
|
109
|
-
sumsq_f32x4 = vfmaq_f32(sumsq_f32x4, high_f32x4, high_f32x4);
|
|
110
|
-
}
|
|
111
|
-
}
|
|
112
|
-
|
|
113
|
-
// Scalar tail for remaining elements
|
|
114
|
-
nk_f32_t sum = vaddvq_f32(sum_f32x4);
|
|
115
|
-
nk_f32_t sumsq = vaddvq_f32(sumsq_f32x4);
|
|
116
|
-
for (; idx < count; ++idx) {
|
|
117
|
-
nk_f32_t value_f32;
|
|
118
|
-
nk_f16_to_f32_serial((nk_f16_t const *)(data_ptr + idx * stride_elements), &value_f32);
|
|
119
|
-
sum += value_f32, sumsq += value_f32 * value_f32;
|
|
120
|
-
}
|
|
121
|
-
*sum_ptr = sum, *sumsq_ptr = sumsq;
|
|
122
|
-
}
|
|
123
|
-
|
|
124
|
-
NK_PUBLIC void nk_reduce_moments_f16_neonhalf( //
|
|
125
|
-
nk_f16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
126
|
-
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
127
|
-
nk_size_t stride_elements = stride_bytes / sizeof(nk_f16_t);
|
|
128
|
-
int aligned = (stride_bytes % sizeof(nk_f16_t) == 0);
|
|
129
|
-
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
130
|
-
else if (!aligned) nk_reduce_moments_f16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
131
|
-
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 8) {
|
|
132
|
-
nk_size_t left_count = count / 2;
|
|
133
|
-
nk_f32_t left_sum_value, left_sumsq_value, right_sum_value, right_sumsq_value;
|
|
134
|
-
nk_reduce_moments_f16_neonhalf(data_ptr, left_count, stride_bytes, &left_sum_value, &left_sumsq_value);
|
|
135
|
-
nk_reduce_moments_f16_neonhalf(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
136
|
-
&right_sum_value, &right_sumsq_value);
|
|
137
|
-
*sum_ptr = left_sum_value + right_sum_value, *sumsq_ptr = left_sumsq_value + right_sumsq_value;
|
|
138
|
-
}
|
|
139
|
-
else if (stride_elements == 1) nk_reduce_moments_f16_neonhalf_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
140
|
-
else if (stride_elements <= 4)
|
|
141
|
-
nk_reduce_moments_f16_neonhalf_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
142
|
-
else nk_reduce_moments_f16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
143
|
-
}
|
|
144
|
-
|
|
145
|
-
#if defined(__clang__)
|
|
146
|
-
#pragma clang attribute pop
|
|
147
|
-
#elif defined(__GNUC__)
|
|
148
|
-
#pragma GCC pop_options
|
|
149
|
-
#endif
|
|
150
|
-
|
|
151
|
-
#if defined(__cplusplus)
|
|
152
|
-
} // extern "C"
|
|
153
|
-
#endif
|
|
154
|
-
|
|
155
|
-
#endif // NK_TARGET_NEONHALF
|
|
156
|
-
#endif // NK_TARGET_ARM_
|
|
157
|
-
#endif // NK_REDUCE_NEONHALF_H
|
|
@@ -1,118 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* @brief SIMD-accelerated Spatial Similarity Measures for NEON FP16.
|
|
3
|
-
* @file include/numkong/spatial/neonhalf.h
|
|
4
|
-
* @author Ash Vardanian
|
|
5
|
-
* @date December 27, 2025
|
|
6
|
-
*
|
|
7
|
-
* @sa include/numkong/spatial.h
|
|
8
|
-
*
|
|
9
|
-
* @section spatial_neonhalf_instructions ARM NEON FP16 Instructions (ARMv8.2-FP16)
|
|
10
|
-
*
|
|
11
|
-
* Intrinsic Instruction Latency Throughput
|
|
12
|
-
* A76 M4+/V1+/Oryon
|
|
13
|
-
* vfmaq_f16 FMLA (V.8H, V.8H, V.8H) 4cy 2/cy 4/cy
|
|
14
|
-
* vcvt_f32_f16 FCVTL (V.4S, V.4H) 3cy 2/cy 4/cy
|
|
15
|
-
* vld1q_f16 LD1 (V.8H) 4cy 2/cy 3/cy
|
|
16
|
-
* vsubq_f16 FSUB (V.8H, V.8H, V.8H) 2cy 2/cy 4/cy
|
|
17
|
-
* vaddvq_f32 FADDP+FADDP (V.4S) 4cy 1/cy 2/cy
|
|
18
|
-
*
|
|
19
|
-
* The ARMv8.2-FP16 extension enables native half-precision arithmetic, doubling the element count
|
|
20
|
-
* per vector register (8x F16 vs 4x F32). For spatial distance computations like L2 and angular
|
|
21
|
-
* distance, this halves memory bandwidth requirements.
|
|
22
|
-
*
|
|
23
|
-
* Inputs are widened from F16 to F32 for accumulation via FCVTL to preserve numerical precision
|
|
24
|
-
* during the squared difference summation. The subtraction and FMA operations use F32 precision
|
|
25
|
-
* in the accumulator to avoid catastrophic cancellation in distance computations.
|
|
26
|
-
*/
|
|
27
|
-
#ifndef NK_SPATIAL_NEONHALF_H
|
|
28
|
-
#define NK_SPATIAL_NEONHALF_H
|
|
29
|
-
|
|
30
|
-
#if NK_TARGET_ARM_
|
|
31
|
-
#if NK_TARGET_NEONHALF
|
|
32
|
-
|
|
33
|
-
#include "numkong/types.h"
|
|
34
|
-
#include "numkong/cast/serial.h" // `nk_partial_load_b16x4_serial_`
|
|
35
|
-
#include "numkong/spatial/neon.h" // `nk_angular_normalize_f32_neon_`, `nk_f32_sqrt_neon`
|
|
36
|
-
|
|
37
|
-
#if defined(__cplusplus)
|
|
38
|
-
extern "C" {
|
|
39
|
-
#endif
|
|
40
|
-
|
|
41
|
-
#if defined(__clang__)
|
|
42
|
-
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)
|
|
43
|
-
#elif defined(__GNUC__)
|
|
44
|
-
#pragma GCC push_options
|
|
45
|
-
#pragma GCC target("arch=armv8.2-a+simd+fp16")
|
|
46
|
-
#endif
|
|
47
|
-
|
|
48
|
-
NK_PUBLIC void nk_sqeuclidean_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
49
|
-
float32x4_t a_f32x4, b_f32x4;
|
|
50
|
-
float32x4_t distance_sq_f32x4 = vdupq_n_f32(0);
|
|
51
|
-
|
|
52
|
-
nk_sqeuclidean_f16_neonhalf_cycle:
|
|
53
|
-
if (n < 4) {
|
|
54
|
-
nk_b64_vec_t a_vec, b_vec;
|
|
55
|
-
nk_partial_load_b16x4_serial_(a, &a_vec, n);
|
|
56
|
-
nk_partial_load_b16x4_serial_(b, &b_vec, n);
|
|
57
|
-
a_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(a_vec.u16x4));
|
|
58
|
-
b_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(b_vec.u16x4));
|
|
59
|
-
n = 0;
|
|
60
|
-
}
|
|
61
|
-
else {
|
|
62
|
-
a_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)a));
|
|
63
|
-
b_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)b));
|
|
64
|
-
n -= 4, a += 4, b += 4;
|
|
65
|
-
}
|
|
66
|
-
float32x4_t diff_f32x4 = vsubq_f32(a_f32x4, b_f32x4);
|
|
67
|
-
distance_sq_f32x4 = vfmaq_f32(distance_sq_f32x4, diff_f32x4, diff_f32x4);
|
|
68
|
-
if (n) goto nk_sqeuclidean_f16_neonhalf_cycle;
|
|
69
|
-
|
|
70
|
-
*result = vaddvq_f32(distance_sq_f32x4);
|
|
71
|
-
}
|
|
72
|
-
NK_PUBLIC void nk_euclidean_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
73
|
-
nk_sqeuclidean_f16_neonhalf(a, b, n, result);
|
|
74
|
-
*result = nk_f32_sqrt_neon(*result);
|
|
75
|
-
}
|
|
76
|
-
|
|
77
|
-
NK_PUBLIC void nk_angular_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
78
|
-
float32x4_t dot_product_f32x4 = vdupq_n_f32(0), a_norm_sq_f32x4 = vdupq_n_f32(0), b_norm_sq_f32x4 = vdupq_n_f32(0);
|
|
79
|
-
float32x4_t a_f32x4, b_f32x4;
|
|
80
|
-
|
|
81
|
-
nk_angular_f16_neonhalf_cycle:
|
|
82
|
-
if (n < 4) {
|
|
83
|
-
nk_b64_vec_t a_vec, b_vec;
|
|
84
|
-
nk_partial_load_b16x4_serial_(a, &a_vec, n);
|
|
85
|
-
nk_partial_load_b16x4_serial_(b, &b_vec, n);
|
|
86
|
-
a_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(a_vec.u16x4));
|
|
87
|
-
b_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(b_vec.u16x4));
|
|
88
|
-
n = 0;
|
|
89
|
-
}
|
|
90
|
-
else {
|
|
91
|
-
a_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)a));
|
|
92
|
-
b_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)b));
|
|
93
|
-
n -= 4, a += 4, b += 4;
|
|
94
|
-
}
|
|
95
|
-
dot_product_f32x4 = vfmaq_f32(dot_product_f32x4, a_f32x4, b_f32x4);
|
|
96
|
-
a_norm_sq_f32x4 = vfmaq_f32(a_norm_sq_f32x4, a_f32x4, a_f32x4);
|
|
97
|
-
b_norm_sq_f32x4 = vfmaq_f32(b_norm_sq_f32x4, b_f32x4, b_f32x4);
|
|
98
|
-
if (n) goto nk_angular_f16_neonhalf_cycle;
|
|
99
|
-
|
|
100
|
-
nk_f32_t dot_product_f32 = vaddvq_f32(dot_product_f32x4);
|
|
101
|
-
nk_f32_t a_norm_sq_f32 = vaddvq_f32(a_norm_sq_f32x4);
|
|
102
|
-
nk_f32_t b_norm_sq_f32 = vaddvq_f32(b_norm_sq_f32x4);
|
|
103
|
-
*result = nk_angular_normalize_f32_neon_(dot_product_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
104
|
-
}
|
|
105
|
-
|
|
106
|
-
#if defined(__clang__)
|
|
107
|
-
#pragma clang attribute pop
|
|
108
|
-
#elif defined(__GNUC__)
|
|
109
|
-
#pragma GCC pop_options
|
|
110
|
-
#endif
|
|
111
|
-
|
|
112
|
-
#if defined(__cplusplus)
|
|
113
|
-
} // extern "C"
|
|
114
|
-
#endif
|
|
115
|
-
|
|
116
|
-
#endif // NK_TARGET_NEONHALF
|
|
117
|
-
#endif // NK_TARGET_ARM_
|
|
118
|
-
#endif // NK_SPATIAL_NEONHALF_H
|
|
@@ -1,343 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* @brief SIMD-accelerated Spatial Similarity Measures for Sapphire Rapids.
|
|
3
|
-
* @file include/numkong/spatial/sapphire.h
|
|
4
|
-
* @author Ash Vardanian
|
|
5
|
-
* @date December 27, 2025
|
|
6
|
-
*
|
|
7
|
-
* @sa include/numkong/spatial.h
|
|
8
|
-
*
|
|
9
|
-
* Sapphire Rapids adds native FP16 support via AVX-512 FP16 extension.
|
|
10
|
-
* For e4m3 L2 distance, we can leverage F16 for the subtraction step:
|
|
11
|
-
* - e4m3 differences fit in F16 (max |a−b| = 896 < 65504)
|
|
12
|
-
* - But squared differences overflow F16 (896² = 802816 > 65504)
|
|
13
|
-
* - So: subtract in F16, convert to F32, then square and accumulate
|
|
14
|
-
*
|
|
15
|
-
* For e2m3/e3m2 L2 distance, squared differences fit in FP16:
|
|
16
|
-
* - E2M3: max |a−b| = 15, max (a−b)² = 225 < 65504, flush cadence = 4 (conservative for uniformity)
|
|
17
|
-
* - E3M2: max |a−b| = 56, max (a−b)² = 3136 < 65504, flush cadence = 4
|
|
18
|
-
* So the entire sub+square+accumulate stays in FP16 with periodic F32 flush.
|
|
19
|
-
*
|
|
20
|
-
* @section spatial_sapphire_instructions Relevant Instructions
|
|
21
|
-
*
|
|
22
|
-
* Intrinsic Instruction Sapphire Genoa
|
|
23
|
-
* _mm256_sub_ph VSUBPH (YMM, YMM, YMM) 4cy @ p05 3cy @ p01
|
|
24
|
-
* _mm512_cvtph_ps VCVTPH2PS (ZMM, YMM) 5cy @ p05 5cy @ p01
|
|
25
|
-
* _mm512_fmadd_ps VFMADD (ZMM, ZMM, ZMM) 4cy @ p05 4cy @ p01
|
|
26
|
-
* _mm512_reduce_add_ps (pseudo: VHADDPS chain) ~8cy ~8cy
|
|
27
|
-
* _mm_maskz_loadu_epi8 VMOVDQU8 (XMM {K}, M128) 7cy @ p23 7cy @ p23
|
|
28
|
-
*/
|
|
29
|
-
#ifndef NK_SPATIAL_SAPPHIRE_H
|
|
30
|
-
#define NK_SPATIAL_SAPPHIRE_H
|
|
31
|
-
|
|
32
|
-
#if NK_TARGET_X86_
|
|
33
|
-
#if NK_TARGET_SAPPHIRE
|
|
34
|
-
|
|
35
|
-
#include "numkong/types.h"
|
|
36
|
-
#include "numkong/cast/sapphire.h" // `nk_e4m3x16_to_f16x16_sapphire_`
|
|
37
|
-
#include "numkong/dot/sapphire.h" // `nk_e2m3x32_to_f16x32_sapphire_`, `nk_flush_f16_to_f32_sapphire_`
|
|
38
|
-
#include "numkong/spatial/haswell.h" // `nk_angular_normalize_f32_haswell_`, `nk_f32_sqrt_haswell`
|
|
39
|
-
|
|
40
|
-
#if defined(__cplusplus)
|
|
41
|
-
extern "C" {
|
|
42
|
-
#endif
|
|
43
|
-
|
|
44
|
-
#if defined(__clang__)
|
|
45
|
-
#pragma clang attribute push( \
|
|
46
|
-
__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,f16c,fma,bmi,bmi2"))), \
|
|
47
|
-
apply_to = function)
|
|
48
|
-
#elif defined(__GNUC__)
|
|
49
|
-
#pragma GCC push_options
|
|
50
|
-
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "f16c", "fma", "bmi", "bmi2")
|
|
51
|
-
#endif
|
|
52
|
-
|
|
53
|
-
NK_PUBLIC void nk_sqeuclidean_e4m3_sapphire(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars,
|
|
54
|
-
nk_size_t count_scalars, nk_f32_t *result) {
|
|
55
|
-
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
56
|
-
|
|
57
|
-
while (count_scalars > 0) {
|
|
58
|
-
nk_size_t const n = count_scalars < 16 ? count_scalars : 16;
|
|
59
|
-
__mmask16 const mask = (__mmask16)_bzhi_u32(0xFFFF, n);
|
|
60
|
-
__m128i a_e4m3x16 = _mm_maskz_loadu_epi8(mask, a_scalars);
|
|
61
|
-
__m128i b_e4m3x16 = _mm_maskz_loadu_epi8(mask, b_scalars);
|
|
62
|
-
|
|
63
|
-
// Convert e4m3 → f16
|
|
64
|
-
__m256h a_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(a_e4m3x16);
|
|
65
|
-
__m256h b_f16x16 = nk_e4m3x16_to_f16x16_sapphire_(b_e4m3x16);
|
|
66
|
-
|
|
67
|
-
// Subtract in F16 − differences fit (max 896 < 65504)
|
|
68
|
-
__m256h diff_f16x16 = _mm256_sub_ph(a_f16x16, b_f16x16);
|
|
69
|
-
|
|
70
|
-
// Convert to F32 before squaring (896² = 802816 overflows F16!)
|
|
71
|
-
__m512 diff_f32x16 = _mm512_cvtph_ps(_mm256_castph_si256(diff_f16x16));
|
|
72
|
-
|
|
73
|
-
// Square and accumulate in F32
|
|
74
|
-
sum_f32x16 = _mm512_fmadd_ps(diff_f32x16, diff_f32x16, sum_f32x16);
|
|
75
|
-
a_scalars += n, b_scalars += n, count_scalars -= n;
|
|
76
|
-
}
|
|
77
|
-
|
|
78
|
-
*result = _mm512_reduce_add_ps(sum_f32x16);
|
|
79
|
-
}
|
|
80
|
-
|
|
81
|
-
NK_PUBLIC void nk_euclidean_e4m3_sapphire(nk_e4m3_t const *a_scalars, nk_e4m3_t const *b_scalars,
|
|
82
|
-
nk_size_t count_scalars, nk_f32_t *result) {
|
|
83
|
-
nk_sqeuclidean_e4m3_sapphire(a_scalars, b_scalars, count_scalars, result);
|
|
84
|
-
*result = nk_f32_sqrt_haswell(*result);
|
|
85
|
-
}
|
|
86
|
-
|
|
87
|
-
NK_PUBLIC void nk_sqeuclidean_e2m3_sapphire(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars,
|
|
88
|
-
nk_size_t count_scalars, nk_f32_t *result) {
|
|
89
|
-
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
90
|
-
|
|
91
|
-
// Main loop: 4-way unrolled, 128 elements per flush
|
|
92
|
-
while (count_scalars >= 128) {
|
|
93
|
-
__m512h acc_f16x32 = _mm512_setzero_ph();
|
|
94
|
-
__m512h a_f16x32, b_f16x32, diff_f16x32;
|
|
95
|
-
// Iteration 1
|
|
96
|
-
a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars));
|
|
97
|
-
b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars));
|
|
98
|
-
diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
|
|
99
|
-
acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
|
|
100
|
-
// Iteration 2
|
|
101
|
-
a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 32));
|
|
102
|
-
b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 32));
|
|
103
|
-
diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
|
|
104
|
-
acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
|
|
105
|
-
// Iteration 3
|
|
106
|
-
a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 64));
|
|
107
|
-
b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 64));
|
|
108
|
-
diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
|
|
109
|
-
acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
|
|
110
|
-
// Iteration 4
|
|
111
|
-
a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 96));
|
|
112
|
-
b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 96));
|
|
113
|
-
diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
|
|
114
|
-
acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
|
|
115
|
-
// Flush to F32
|
|
116
|
-
sum_f32x16 = nk_flush_f16_to_f32_sapphire_(acc_f16x32, sum_f32x16);
|
|
117
|
-
a_scalars += 128, b_scalars += 128, count_scalars -= 128;
|
|
118
|
-
}
|
|
119
|
-
|
|
120
|
-
// Tail: remaining 0–127 elements, 32 at a time via masked loads
|
|
121
|
-
__m512h acc_f16x32 = _mm512_setzero_ph();
|
|
122
|
-
while (count_scalars > 0) {
|
|
123
|
-
nk_size_t const n = count_scalars < 32 ? count_scalars : 32;
|
|
124
|
-
__mmask32 const mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
|
|
125
|
-
__m512h a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, a_scalars));
|
|
126
|
-
__m512h b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, b_scalars));
|
|
127
|
-
__m512h diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
|
|
128
|
-
acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
|
|
129
|
-
a_scalars += n, b_scalars += n, count_scalars -= n;
|
|
130
|
-
}
|
|
131
|
-
sum_f32x16 = nk_flush_f16_to_f32_sapphire_(acc_f16x32, sum_f32x16);
|
|
132
|
-
|
|
133
|
-
*result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
134
|
-
}
|
|
135
|
-
|
|
136
|
-
NK_PUBLIC void nk_sqeuclidean_e3m2_sapphire(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars,
|
|
137
|
-
nk_size_t count_scalars, nk_f32_t *result) {
|
|
138
|
-
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
139
|
-
|
|
140
|
-
// Main loop: 4-way unrolled, 128 elements per flush
|
|
141
|
-
while (count_scalars >= 128) {
|
|
142
|
-
__m512h acc_f16x32 = _mm512_setzero_ph();
|
|
143
|
-
__m512h a_f16x32, b_f16x32, diff_f16x32;
|
|
144
|
-
// Iteration 1
|
|
145
|
-
a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars));
|
|
146
|
-
b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars));
|
|
147
|
-
diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
|
|
148
|
-
acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
|
|
149
|
-
// Iteration 2
|
|
150
|
-
a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 32));
|
|
151
|
-
b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 32));
|
|
152
|
-
diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
|
|
153
|
-
acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
|
|
154
|
-
// Iteration 3
|
|
155
|
-
a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 64));
|
|
156
|
-
b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 64));
|
|
157
|
-
diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
|
|
158
|
-
acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
|
|
159
|
-
// Iteration 4
|
|
160
|
-
a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 96));
|
|
161
|
-
b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 96));
|
|
162
|
-
diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
|
|
163
|
-
acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
|
|
164
|
-
// Flush to F32
|
|
165
|
-
sum_f32x16 = nk_flush_f16_to_f32_sapphire_(acc_f16x32, sum_f32x16);
|
|
166
|
-
a_scalars += 128, b_scalars += 128, count_scalars -= 128;
|
|
167
|
-
}
|
|
168
|
-
|
|
169
|
-
// Tail: remaining 0–127 elements, 32 at a time via masked loads
|
|
170
|
-
__m512h acc_f16x32 = _mm512_setzero_ph();
|
|
171
|
-
while (count_scalars > 0) {
|
|
172
|
-
nk_size_t const n = count_scalars < 32 ? count_scalars : 32;
|
|
173
|
-
__mmask32 const mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
|
|
174
|
-
__m512h a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, a_scalars));
|
|
175
|
-
__m512h b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, b_scalars));
|
|
176
|
-
__m512h diff_f16x32 = _mm512_sub_ph(a_f16x32, b_f16x32);
|
|
177
|
-
acc_f16x32 = _mm512_fmadd_ph(diff_f16x32, diff_f16x32, acc_f16x32);
|
|
178
|
-
a_scalars += n, b_scalars += n, count_scalars -= n;
|
|
179
|
-
}
|
|
180
|
-
sum_f32x16 = nk_flush_f16_to_f32_sapphire_(acc_f16x32, sum_f32x16);
|
|
181
|
-
|
|
182
|
-
*result = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
183
|
-
}
|
|
184
|
-
|
|
185
|
-
NK_PUBLIC void nk_euclidean_e2m3_sapphire(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars,
|
|
186
|
-
nk_size_t count_scalars, nk_f32_t *result) {
|
|
187
|
-
nk_sqeuclidean_e2m3_sapphire(a_scalars, b_scalars, count_scalars, result);
|
|
188
|
-
*result = nk_f32_sqrt_haswell(*result);
|
|
189
|
-
}
|
|
190
|
-
|
|
191
|
-
NK_PUBLIC void nk_euclidean_e3m2_sapphire(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars,
|
|
192
|
-
nk_size_t count_scalars, nk_f32_t *result) {
|
|
193
|
-
nk_sqeuclidean_e3m2_sapphire(a_scalars, b_scalars, count_scalars, result);
|
|
194
|
-
*result = nk_f32_sqrt_haswell(*result);
|
|
195
|
-
}
|
|
196
|
-
|
|
197
|
-
NK_PUBLIC void nk_angular_e2m3_sapphire(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
|
|
198
|
-
nk_f32_t *result) {
|
|
199
|
-
__m512 sum_dot_f32x16 = _mm512_setzero_ps();
|
|
200
|
-
__m512 sum_a_f32x16 = _mm512_setzero_ps();
|
|
201
|
-
__m512 sum_b_f32x16 = _mm512_setzero_ps();
|
|
202
|
-
|
|
203
|
-
// Main loop: 4-way unrolled, 128 elements per flush
|
|
204
|
-
while (count_scalars >= 128) {
|
|
205
|
-
__m512h dot_acc = _mm512_setzero_ph();
|
|
206
|
-
__m512h a_norm_acc = _mm512_setzero_ph();
|
|
207
|
-
__m512h b_norm_acc = _mm512_setzero_ph();
|
|
208
|
-
__m512h a_f16x32, b_f16x32;
|
|
209
|
-
// Iteration 1
|
|
210
|
-
a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars));
|
|
211
|
-
b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars));
|
|
212
|
-
dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
|
|
213
|
-
a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
|
|
214
|
-
b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
|
|
215
|
-
// Iteration 2
|
|
216
|
-
a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 32));
|
|
217
|
-
b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 32));
|
|
218
|
-
dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
|
|
219
|
-
a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
|
|
220
|
-
b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
|
|
221
|
-
// Iteration 3
|
|
222
|
-
a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 64));
|
|
223
|
-
b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 64));
|
|
224
|
-
dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
|
|
225
|
-
a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
|
|
226
|
-
b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
|
|
227
|
-
// Iteration 4
|
|
228
|
-
a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 96));
|
|
229
|
-
b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 96));
|
|
230
|
-
dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
|
|
231
|
-
a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
|
|
232
|
-
b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
|
|
233
|
-
// Flush to F32
|
|
234
|
-
sum_dot_f32x16 = nk_flush_f16_to_f32_sapphire_(dot_acc, sum_dot_f32x16);
|
|
235
|
-
sum_a_f32x16 = nk_flush_f16_to_f32_sapphire_(a_norm_acc, sum_a_f32x16);
|
|
236
|
-
sum_b_f32x16 = nk_flush_f16_to_f32_sapphire_(b_norm_acc, sum_b_f32x16);
|
|
237
|
-
a_scalars += 128, b_scalars += 128, count_scalars -= 128;
|
|
238
|
-
}
|
|
239
|
-
|
|
240
|
-
// Tail: remaining 0–127 elements, 32 at a time via masked loads
|
|
241
|
-
__m512h dot_acc = _mm512_setzero_ph();
|
|
242
|
-
__m512h a_norm_acc = _mm512_setzero_ph();
|
|
243
|
-
__m512h b_norm_acc = _mm512_setzero_ph();
|
|
244
|
-
while (count_scalars > 0) {
|
|
245
|
-
nk_size_t const n = count_scalars < 32 ? count_scalars : 32;
|
|
246
|
-
__mmask32 const mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
|
|
247
|
-
__m512h a_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, a_scalars));
|
|
248
|
-
__m512h b_f16x32 = nk_e2m3x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, b_scalars));
|
|
249
|
-
dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
|
|
250
|
-
a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
|
|
251
|
-
b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
|
|
252
|
-
a_scalars += n, b_scalars += n, count_scalars -= n;
|
|
253
|
-
}
|
|
254
|
-
sum_dot_f32x16 = nk_flush_f16_to_f32_sapphire_(dot_acc, sum_dot_f32x16);
|
|
255
|
-
sum_a_f32x16 = nk_flush_f16_to_f32_sapphire_(a_norm_acc, sum_a_f32x16);
|
|
256
|
-
sum_b_f32x16 = nk_flush_f16_to_f32_sapphire_(b_norm_acc, sum_b_f32x16);
|
|
257
|
-
|
|
258
|
-
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(sum_dot_f32x16);
|
|
259
|
-
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(sum_a_f32x16);
|
|
260
|
-
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(sum_b_f32x16);
|
|
261
|
-
*result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
262
|
-
}
|
|
263
|
-
|
|
264
|
-
NK_PUBLIC void nk_angular_e3m2_sapphire(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars, nk_size_t count_scalars,
|
|
265
|
-
nk_f32_t *result) {
|
|
266
|
-
__m512 sum_dot_f32x16 = _mm512_setzero_ps();
|
|
267
|
-
__m512 sum_a_f32x16 = _mm512_setzero_ps();
|
|
268
|
-
__m512 sum_b_f32x16 = _mm512_setzero_ps();
|
|
269
|
-
|
|
270
|
-
// Main loop: 4-way unrolled, 128 elements per flush
|
|
271
|
-
while (count_scalars >= 128) {
|
|
272
|
-
__m512h dot_acc = _mm512_setzero_ph();
|
|
273
|
-
__m512h a_norm_acc = _mm512_setzero_ph();
|
|
274
|
-
__m512h b_norm_acc = _mm512_setzero_ph();
|
|
275
|
-
__m512h a_f16x32, b_f16x32;
|
|
276
|
-
// Iteration 1
|
|
277
|
-
a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars));
|
|
278
|
-
b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars));
|
|
279
|
-
dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
|
|
280
|
-
a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
|
|
281
|
-
b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
|
|
282
|
-
// Iteration 2
|
|
283
|
-
a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 32));
|
|
284
|
-
b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 32));
|
|
285
|
-
dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
|
|
286
|
-
a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
|
|
287
|
-
b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
|
|
288
|
-
// Iteration 3
|
|
289
|
-
a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 64));
|
|
290
|
-
b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 64));
|
|
291
|
-
dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
|
|
292
|
-
a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
|
|
293
|
-
b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
|
|
294
|
-
// Iteration 4
|
|
295
|
-
a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(a_scalars + 96));
|
|
296
|
-
b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_loadu_epi8(b_scalars + 96));
|
|
297
|
-
dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
|
|
298
|
-
a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
|
|
299
|
-
b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
|
|
300
|
-
// Flush to F32
|
|
301
|
-
sum_dot_f32x16 = nk_flush_f16_to_f32_sapphire_(dot_acc, sum_dot_f32x16);
|
|
302
|
-
sum_a_f32x16 = nk_flush_f16_to_f32_sapphire_(a_norm_acc, sum_a_f32x16);
|
|
303
|
-
sum_b_f32x16 = nk_flush_f16_to_f32_sapphire_(b_norm_acc, sum_b_f32x16);
|
|
304
|
-
a_scalars += 128, b_scalars += 128, count_scalars -= 128;
|
|
305
|
-
}
|
|
306
|
-
|
|
307
|
-
// Tail: remaining 0–127 elements, 32 at a time via masked loads
|
|
308
|
-
__m512h dot_acc = _mm512_setzero_ph();
|
|
309
|
-
__m512h a_norm_acc = _mm512_setzero_ph();
|
|
310
|
-
__m512h b_norm_acc = _mm512_setzero_ph();
|
|
311
|
-
while (count_scalars > 0) {
|
|
312
|
-
nk_size_t const n = count_scalars < 32 ? count_scalars : 32;
|
|
313
|
-
__mmask32 const mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
|
|
314
|
-
__m512h a_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, a_scalars));
|
|
315
|
-
__m512h b_f16x32 = nk_e3m2x32_to_f16x32_sapphire_(_mm256_maskz_loadu_epi8(mask, b_scalars));
|
|
316
|
-
dot_acc = _mm512_fmadd_ph(a_f16x32, b_f16x32, dot_acc);
|
|
317
|
-
a_norm_acc = _mm512_fmadd_ph(a_f16x32, a_f16x32, a_norm_acc);
|
|
318
|
-
b_norm_acc = _mm512_fmadd_ph(b_f16x32, b_f16x32, b_norm_acc);
|
|
319
|
-
a_scalars += n, b_scalars += n, count_scalars -= n;
|
|
320
|
-
}
|
|
321
|
-
sum_dot_f32x16 = nk_flush_f16_to_f32_sapphire_(dot_acc, sum_dot_f32x16);
|
|
322
|
-
sum_a_f32x16 = nk_flush_f16_to_f32_sapphire_(a_norm_acc, sum_a_f32x16);
|
|
323
|
-
sum_b_f32x16 = nk_flush_f16_to_f32_sapphire_(b_norm_acc, sum_b_f32x16);
|
|
324
|
-
|
|
325
|
-
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(sum_dot_f32x16);
|
|
326
|
-
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(sum_a_f32x16);
|
|
327
|
-
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(sum_b_f32x16);
|
|
328
|
-
*result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
329
|
-
}
|
|
330
|
-
|
|
331
|
-
#if defined(__clang__)
|
|
332
|
-
#pragma clang attribute pop
|
|
333
|
-
#elif defined(__GNUC__)
|
|
334
|
-
#pragma GCC pop_options
|
|
335
|
-
#endif
|
|
336
|
-
|
|
337
|
-
#if defined(__cplusplus)
|
|
338
|
-
} // extern "C"
|
|
339
|
-
#endif
|
|
340
|
-
|
|
341
|
-
#endif // NK_TARGET_SAPPHIRE
|
|
342
|
-
#endif // NK_TARGET_X86_
|
|
343
|
-
#endif // NK_SPATIAL_SAPPHIRE_H
|
|
@@ -1,58 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* @brief Batched Spatial Distances for NEON FP16 (Half-Precision).
|
|
3
|
-
* @file include/numkong/spatials/neonhalf.h
|
|
4
|
-
* @author Ash Vardanian
|
|
5
|
-
* @date February 23, 2026
|
|
6
|
-
*
|
|
7
|
-
* @sa include/numkong/spatials.h
|
|
8
|
-
*/
|
|
9
|
-
#ifndef NK_SPATIALS_NEONHALF_H
|
|
10
|
-
#define NK_SPATIALS_NEONHALF_H
|
|
11
|
-
|
|
12
|
-
#if NK_TARGET_ARM_
|
|
13
|
-
#if NK_TARGET_NEONHALF
|
|
14
|
-
|
|
15
|
-
#include "numkong/spatial/neon.h"
|
|
16
|
-
#include "numkong/dots/neonhalf.h"
|
|
17
|
-
|
|
18
|
-
#if defined(__cplusplus)
|
|
19
|
-
extern "C" {
|
|
20
|
-
#endif
|
|
21
|
-
|
|
22
|
-
#if defined(__clang__)
|
|
23
|
-
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)
|
|
24
|
-
#elif defined(__GNUC__)
|
|
25
|
-
#pragma GCC push_options
|
|
26
|
-
#pragma GCC target("arch=armv8.2-a+simd+fp16")
|
|
27
|
-
#endif
|
|
28
|
-
|
|
29
|
-
nk_define_cross_normalized_packed_(angular, f16, neonhalf, f16, f16, f32, /*norm_value_type=*/f32, f32, nk_b128_vec_t,
|
|
30
|
-
nk_dots_packed_f16_neonhalf, nk_angular_through_f32_from_dot_neon_,
|
|
31
|
-
nk_dots_reduce_sumsq_f16_, nk_load_b128_neon_, nk_partial_load_b32x4_serial_,
|
|
32
|
-
nk_store_b128_neon_, nk_partial_store_b32x4_serial_, 1)
|
|
33
|
-
nk_define_cross_normalized_packed_(euclidean, f16, neonhalf, f16, f16, f32, /*norm_value_type=*/f32, f32, nk_b128_vec_t,
|
|
34
|
-
nk_dots_packed_f16_neonhalf, nk_euclidean_through_f32_from_dot_neon_,
|
|
35
|
-
nk_dots_reduce_sumsq_f16_, nk_load_b128_neon_, nk_partial_load_b32x4_serial_,
|
|
36
|
-
nk_store_b128_neon_, nk_partial_store_b32x4_serial_, 1)
|
|
37
|
-
nk_define_cross_normalized_symmetric_(angular, f16, neonhalf, f16, f32, /*norm_value_type=*/f32, f32, nk_b128_vec_t,
|
|
38
|
-
nk_dots_symmetric_f16_neonhalf, nk_angular_through_f32_from_dot_neon_,
|
|
39
|
-
nk_dots_reduce_sumsq_f16_, nk_load_b128_neon_, nk_partial_load_b32x4_serial_,
|
|
40
|
-
nk_store_b128_neon_, nk_partial_store_b32x4_serial_, 1)
|
|
41
|
-
nk_define_cross_normalized_symmetric_(euclidean, f16, neonhalf, f16, f32, /*norm_value_type=*/f32, f32, nk_b128_vec_t,
|
|
42
|
-
nk_dots_symmetric_f16_neonhalf, nk_euclidean_through_f32_from_dot_neon_,
|
|
43
|
-
nk_dots_reduce_sumsq_f16_, nk_load_b128_neon_, nk_partial_load_b32x4_serial_,
|
|
44
|
-
nk_store_b128_neon_, nk_partial_store_b32x4_serial_, 1)
|
|
45
|
-
|
|
46
|
-
#if defined(__clang__)
|
|
47
|
-
#pragma clang attribute pop
|
|
48
|
-
#elif defined(__GNUC__)
|
|
49
|
-
#pragma GCC pop_options
|
|
50
|
-
#endif
|
|
51
|
-
|
|
52
|
-
#if defined(__cplusplus)
|
|
53
|
-
} // extern "C"
|
|
54
|
-
#endif
|
|
55
|
-
|
|
56
|
-
#endif // NK_TARGET_NEONHALF
|
|
57
|
-
#endif // NK_TARGET_ARM_
|
|
58
|
-
#endif // NK_SPATIALS_NEONHALF_H
|