numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
package/include/numkong/cast.h
CHANGED
|
@@ -150,6 +150,20 @@ NK_PUBLIC void nk_f32_to_f16_sapphire(nk_f32_t const *src, nk_f16_t *dest);
|
|
|
150
150
|
NK_PUBLIC void nk_cast_rvv(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type);
|
|
151
151
|
#endif // NK_TARGET_RVV
|
|
152
152
|
|
|
153
|
+
#if NK_TARGET_POWERVSX
|
|
154
|
+
/** @copydoc nk_cast */
|
|
155
|
+
NK_PUBLIC void nk_cast_powervsx(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type);
|
|
156
|
+
/** @copydoc nk_f16_to_f32 */
|
|
157
|
+
NK_PUBLIC void nk_f16_to_f32_powervsx(nk_f16_t const *src, nk_f32_t *dest);
|
|
158
|
+
/** @copydoc nk_f32_to_f16 */
|
|
159
|
+
NK_PUBLIC void nk_f32_to_f16_powervsx(nk_f32_t const *src, nk_f16_t *dest);
|
|
160
|
+
#endif // NK_TARGET_POWERVSX
|
|
161
|
+
|
|
162
|
+
#if NK_TARGET_V128RELAXED
|
|
163
|
+
/** @copydoc nk_cast */
|
|
164
|
+
NK_PUBLIC void nk_cast_v128relaxed(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type);
|
|
165
|
+
#endif // NK_TARGET_V128RELAXED
|
|
166
|
+
|
|
153
167
|
#if defined(__cplusplus)
|
|
154
168
|
} // extern "C"
|
|
155
169
|
#endif
|
|
@@ -161,6 +175,8 @@ NK_PUBLIC void nk_cast_rvv(void const *from, nk_dtype_t from_type, nk_size_t n,
|
|
|
161
175
|
#include "numkong/cast/icelake.h"
|
|
162
176
|
#include "numkong/cast/sapphire.h"
|
|
163
177
|
#include "numkong/cast/rvv.h"
|
|
178
|
+
#include "numkong/cast/powervsx.h"
|
|
179
|
+
#include "numkong/cast/loongsonasx.h"
|
|
164
180
|
|
|
165
181
|
#if defined(__cplusplus)
|
|
166
182
|
extern "C" {
|
|
@@ -177,10 +193,14 @@ NK_PUBLIC void nk_cast(void const *from, nk_dtype_t from_type, nk_size_t n, void
|
|
|
177
193
|
nk_cast_skylake(from, from_type, n, to, to_type);
|
|
178
194
|
#elif NK_TARGET_HASWELL
|
|
179
195
|
nk_cast_haswell(from, from_type, n, to, to_type);
|
|
196
|
+
#elif NK_TARGET_POWERVSX
|
|
197
|
+
nk_cast_powervsx(from, from_type, n, to, to_type);
|
|
180
198
|
#elif NK_TARGET_RVV
|
|
181
199
|
nk_cast_rvv(from, from_type, n, to, to_type);
|
|
182
200
|
#elif NK_TARGET_NEON
|
|
183
201
|
nk_cast_neon(from, from_type, n, to, to_type);
|
|
202
|
+
#elif NK_TARGET_V128RELAXED
|
|
203
|
+
nk_cast_v128relaxed(from, from_type, n, to, to_type);
|
|
184
204
|
#else
|
|
185
205
|
nk_cast_serial(from, from_type, n, to, to_type);
|
|
186
206
|
#endif
|
|
@@ -191,6 +211,8 @@ NK_PUBLIC void nk_f16_to_f32(nk_f16_t const *src, nk_f32_t *dest) {
|
|
|
191
211
|
nk_f16_to_f32_sapphire(src, dest);
|
|
192
212
|
#elif NK_TARGET_HASWELL
|
|
193
213
|
nk_f16_to_f32_haswell(src, dest);
|
|
214
|
+
#elif NK_TARGET_POWERVSX
|
|
215
|
+
nk_f16_to_f32_powervsx(src, dest);
|
|
194
216
|
#elif NK_TARGET_NEON
|
|
195
217
|
nk_f16_to_f32_neon(src, dest);
|
|
196
218
|
#else
|
|
@@ -203,6 +225,8 @@ NK_PUBLIC void nk_f32_to_f16(nk_f32_t const *src, nk_f16_t *dest) {
|
|
|
203
225
|
nk_f32_to_f16_sapphire(src, dest);
|
|
204
226
|
#elif NK_TARGET_HASWELL
|
|
205
227
|
nk_f32_to_f16_haswell(src, dest);
|
|
228
|
+
#elif NK_TARGET_POWERVSX
|
|
229
|
+
nk_f32_to_f16_powervsx(src, dest);
|
|
206
230
|
#elif NK_TARGET_NEON
|
|
207
231
|
nk_f32_to_f16_neon(src, dest);
|
|
208
232
|
#else
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief C++ wrappers for SIMD-accelerated type casting.
|
|
3
|
+
* @file include/numkong/cast.hpp
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 20, 2026
|
|
6
|
+
*/
|
|
7
|
+
#ifndef NK_CAST_HPP
|
|
8
|
+
#define NK_CAST_HPP
|
|
9
|
+
|
|
10
|
+
#include <cstddef> // `std::size_t`
|
|
11
|
+
|
|
12
|
+
#include "numkong/cast.h"
|
|
13
|
+
|
|
14
|
+
#include "numkong/types.hpp"
|
|
15
|
+
#include "numkong/vector.hpp"
|
|
16
|
+
|
|
17
|
+
namespace ashvardanian::numkong {
|
|
18
|
+
|
|
19
|
+
/**
|
|
20
|
+
* @brief Elementwise type-cast from one numeric type to another.
|
|
21
|
+
* @param[in] from Input array of `n` elements.
|
|
22
|
+
* @param[in] n Number of elements.
|
|
23
|
+
* @param[out] to Output array of `n` elements.
|
|
24
|
+
*
|
|
25
|
+
* @tparam from_type_ Source element type.
|
|
26
|
+
* @tparam to_type_ Destination element type.
|
|
27
|
+
* @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`.
|
|
28
|
+
*/
|
|
29
|
+
template <numeric_dtype from_type_, numeric_dtype to_type_, allow_simd_t allow_simd_ = prefer_simd_k>
|
|
30
|
+
void cast(from_type_ const *from, std::size_t n, to_type_ *to) noexcept {
|
|
31
|
+
if constexpr (allow_simd_ == prefer_simd_k) nk_cast(from, from_type_::dtype(), n, to, to_type_::dtype());
|
|
32
|
+
else nk_cast_serial(from, from_type_::dtype(), n, to, to_type_::dtype());
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
/** @brief Elementwise type-cast between vector views. Sizes must match. */
|
|
36
|
+
template <numeric_dtype from_type_, numeric_dtype to_type_, allow_simd_t allow_simd_ = prefer_simd_k>
|
|
37
|
+
void cast(vector_view<from_type_> from, vector_span<to_type_> to) noexcept {
|
|
38
|
+
std::size_t n = from.size() < to.size() ? from.size() : to.size();
|
|
39
|
+
cast<from_type_, to_type_, allow_simd_>(from.data(), n, to.data());
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
} // namespace ashvardanian::numkong
|
|
43
|
+
|
|
44
|
+
#endif // NK_CAST_HPP
|
|
@@ -6,21 +6,21 @@ These operations are central to Gaussian process inference, metric learning, and
|
|
|
6
6
|
|
|
7
7
|
The bilinear form for real vectors is:
|
|
8
8
|
|
|
9
|
-
|
|
9
|
+
$$
|
|
10
10
|
\text{bilinear}(a, b, C) = a^T C b = \sum_{i=0}^{n-1} \sum_{j=0}^{n-1} a_i \cdot c_{ij} \cdot b_j
|
|
11
|
-
|
|
11
|
+
$$
|
|
12
12
|
|
|
13
13
|
The Mahalanobis distance is:
|
|
14
14
|
|
|
15
|
-
|
|
15
|
+
$$
|
|
16
16
|
\text{mahalanobis}(a, b, C) = \sqrt{(a - b)^T C (a - b)}
|
|
17
|
-
|
|
17
|
+
$$
|
|
18
18
|
|
|
19
19
|
For complex vectors, the bilinear form uses the conjugate transpose:
|
|
20
20
|
|
|
21
|
-
|
|
21
|
+
$$
|
|
22
22
|
\text{bilinear}(a, b, C) = a^H C b = \sum_{i=0}^{n-1} \sum_{j=0}^{n-1} \bar{a_i} \cdot c_{ij} \cdot b_j
|
|
23
|
-
|
|
23
|
+
$$
|
|
24
24
|
|
|
25
25
|
Reformulating as Python pseudocode:
|
|
26
26
|
|
|
@@ -72,8 +72,8 @@ This nested structure gives $O(n)$ cache-friendly sequential access to the $n \t
|
|
|
72
72
|
|
|
73
73
|
`nk_bilinear_f32_smef64`, `nk_bilinear_f64_smef64`, `nk_bilinear_f32c_smef64`, `nk_bilinear_f64c_smef64`, `nk_mahalanobis_f32_smef64`, `nk_mahalanobis_f64_smef64` use the Scalable Matrix Extension to compute the bilinear form as an outer-product accumulation.
|
|
74
74
|
Each `FMOPA` instruction performs a rank-1 update $a_i \cdot b^T$ into the SME ZA tile array, and the matrix $C$ is streamed row-by-row and multiplied into the accumulator.
|
|
75
|
-
This
|
|
76
|
-
For dimensions that align to the tile size, this approach
|
|
75
|
+
This differs from the row-major dot approach — it reformulates $a^T C b$ as a matrix-multiply problem where SME's 2D tile registers use the matrix engine's throughput.
|
|
76
|
+
For dimensions that align to the tile size, this approach has high throughput; dimensions that do not align fall back to NEON for cleanup of the residual elements.
|
|
77
77
|
|
|
78
78
|
### Complex Bilinear Decomposition
|
|
79
79
|
|
|
@@ -201,23 +201,23 @@ Measured with Wasmtime v42 (Cranelift backend).
|
|
|
201
201
|
|
|
202
202
|
#### WASM
|
|
203
203
|
|
|
204
|
-
Measured with Wasmtime
|
|
204
|
+
Measured with Wasmtime v43 (Cranelift backend).
|
|
205
205
|
|
|
206
206
|
| Kernel | 256² | 1024² | 4096² |
|
|
207
207
|
| :------------------------- | -----------------------: | -----------------------: | -----------------------: |
|
|
208
208
|
| __f64c__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
209
|
-
| `nk_bilinear_f64c_serial` |
|
|
209
|
+
| `nk_bilinear_f64c_serial` | 0.445 gso/s, ? ulp | 0.445 gso/s, ? ulp | 0.445 gso/s, ? ulp |
|
|
210
210
|
| __f32c__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
211
|
-
| `nk_bilinear_f32c_serial` |
|
|
211
|
+
| `nk_bilinear_f32c_serial` | 2.83 gso/s, ? ulp | 2.83 gso/s, ? ulp | 2.84 gso/s, ? ulp |
|
|
212
212
|
| __bf16c__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
213
|
-
| `nk_bilinear_bf16c_serial` |
|
|
213
|
+
| `nk_bilinear_bf16c_serial` | 3.05 gso/s, ? ulp | 3.02 gso/s, ? ulp | 3.03 gso/s, ? ulp |
|
|
214
214
|
| __f16c__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
215
|
-
| `nk_bilinear_f16c_serial` |
|
|
215
|
+
| `nk_bilinear_f16c_serial` | 0.984 gso/s, ? ulp | 0.992 gso/s, ? ulp | 0.995 gso/s, ? ulp |
|
|
216
216
|
| __f64__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
217
|
-
| `nk_bilinear_f64_serial` |
|
|
217
|
+
| `nk_bilinear_f64_serial` | 0.998 gso/s, ? ulp | 0.999 gso/s, ? ulp | 0.999 gso/s, ? ulp |
|
|
218
218
|
| __f32__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
219
|
-
| `nk_bilinear_f32_serial` |
|
|
219
|
+
| `nk_bilinear_f32_serial` | 5.00 gso/s, ? ulp | 3.73 gso/s, ? ulp | 3.49 gso/s, ? ulp |
|
|
220
220
|
| __bf16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
221
|
-
| `nk_bilinear_bf16_serial` |
|
|
221
|
+
| `nk_bilinear_bf16_serial` | 4.84 gso/s, ? ulp | 3.83 gso/s, ? ulp | 3.60 gso/s, ? ulp |
|
|
222
222
|
| __f16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
223
|
-
| `nk_bilinear_f16_serial` |
|
|
223
|
+
| `nk_bilinear_f16_serial` | 1.90 gso/s, ? ulp | 1.75 gso/s, ? ulp | 1.93 gso/s, ? ulp |
|
|
@@ -11,13 +11,12 @@
|
|
|
11
11
|
*
|
|
12
12
|
* @section neon_curved_instructions Key NEON Instructions
|
|
13
13
|
*
|
|
14
|
-
* Intrinsic
|
|
15
|
-
*
|
|
16
|
-
*
|
|
17
|
-
*
|
|
18
|
-
*
|
|
19
|
-
*
|
|
20
|
-
* vld2_f32 LD2 ({Vt.2S, Vt2.2S}, [Xn]) 4cy 1/cy 1/cy
|
|
14
|
+
* Intrinsic Instruction A76 M5
|
|
15
|
+
* vfmaq_f64 FMLA (V.2D, V.2D, V.2D) 4cy @ 2p 3cy @ 4p
|
|
16
|
+
* vcvt_f64_f32 FCVTL (V.2D, V.2S) 3cy @ 2p 3cy @ 4p
|
|
17
|
+
* vaddvq_f64 FADDP (V.2D to scalar) 3cy @ 1p 3cy @ 2p
|
|
18
|
+
* vld1_f32 LD1 ({Vt.2S}, [Xn]) 4cy @ 2p 4cy @ 3p
|
|
19
|
+
* vld2_f32 LD2 ({Vt.2S, Vt2.2S}, [Xn]) 4cy @ 1p 4cy @ 1p
|
|
21
20
|
*
|
|
22
21
|
* For f32 bilinear and Mahalanobis, we upcast to f64 for accumulation to preserve
|
|
23
22
|
* precision and avoid catastrophic cancellation in large-magnitude sums.
|
|
@@ -190,6 +189,131 @@ NK_PUBLIC void nk_bilinear_f32c_neon(nk_f32c_t const *a_pairs, nk_f32c_t const *
|
|
|
190
189
|
results->imag = outer_sum_imag_f64;
|
|
191
190
|
}
|
|
192
191
|
|
|
192
|
+
NK_PUBLIC void nk_bilinear_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
193
|
+
nk_f32_t *result) {
|
|
194
|
+
nk_f32_t outer_sum = 0;
|
|
195
|
+
for (nk_size_t row = 0; row != n; ++row) {
|
|
196
|
+
nk_f16_t const *c_row = c + row * n;
|
|
197
|
+
nk_f32_t a_row;
|
|
198
|
+
nk_f16_to_f32_serial(a + row, &a_row);
|
|
199
|
+
float32x4_t inner_sum_f32x4 = vdupq_n_f32(0);
|
|
200
|
+
nk_size_t column = 0;
|
|
201
|
+
for (; column + 8 <= n; column += 8) {
|
|
202
|
+
float16x8_t b_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)(b + column)));
|
|
203
|
+
float16x8_t c_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)(c_row + column)));
|
|
204
|
+
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
205
|
+
float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
206
|
+
float32x4_t c_low_f32x4 = vcvt_f32_f16(vget_low_f16(c_f16x8));
|
|
207
|
+
float32x4_t c_high_f32x4 = vcvt_high_f32_f16(c_f16x8);
|
|
208
|
+
inner_sum_f32x4 = vfmaq_f32(inner_sum_f32x4, c_low_f32x4, b_low_f32x4);
|
|
209
|
+
inner_sum_f32x4 = vfmaq_f32(inner_sum_f32x4, c_high_f32x4, b_high_f32x4);
|
|
210
|
+
}
|
|
211
|
+
nk_f32_t inner_sum = vaddvq_f32(inner_sum_f32x4);
|
|
212
|
+
for (; column < n; ++column) {
|
|
213
|
+
nk_f32_t b_val, c_val;
|
|
214
|
+
nk_f16_to_f32_serial(b + column, &b_val);
|
|
215
|
+
nk_f16_to_f32_serial(c_row + column, &c_val);
|
|
216
|
+
inner_sum += c_val * b_val;
|
|
217
|
+
}
|
|
218
|
+
outer_sum += a_row * inner_sum;
|
|
219
|
+
}
|
|
220
|
+
*result = outer_sum;
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
NK_PUBLIC void nk_mahalanobis_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
224
|
+
nk_f32_t *result) {
|
|
225
|
+
nk_f32_t outer_sum = 0;
|
|
226
|
+
for (nk_size_t row = 0; row != n; ++row) {
|
|
227
|
+
nk_f16_t const *c_row = c + row * n;
|
|
228
|
+
nk_f32_t a_row, b_row;
|
|
229
|
+
nk_f16_to_f32_serial(a + row, &a_row);
|
|
230
|
+
nk_f16_to_f32_serial(b + row, &b_row);
|
|
231
|
+
nk_f32_t diff_row = a_row - b_row;
|
|
232
|
+
float32x4_t inner_sum_f32x4 = vdupq_n_f32(0);
|
|
233
|
+
nk_size_t column = 0;
|
|
234
|
+
for (; column + 8 <= n; column += 8) {
|
|
235
|
+
float16x8_t a_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)(a + column)));
|
|
236
|
+
float16x8_t b_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)(b + column)));
|
|
237
|
+
float16x8_t c_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)(c_row + column)));
|
|
238
|
+
float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
239
|
+
float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
240
|
+
float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
241
|
+
float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
242
|
+
float32x4_t c_low_f32x4 = vcvt_f32_f16(vget_low_f16(c_f16x8));
|
|
243
|
+
float32x4_t c_high_f32x4 = vcvt_high_f32_f16(c_f16x8);
|
|
244
|
+
float32x4_t diff_low_f32x4 = vsubq_f32(a_low_f32x4, b_low_f32x4);
|
|
245
|
+
float32x4_t diff_high_f32x4 = vsubq_f32(a_high_f32x4, b_high_f32x4);
|
|
246
|
+
inner_sum_f32x4 = vfmaq_f32(inner_sum_f32x4, c_low_f32x4, diff_low_f32x4);
|
|
247
|
+
inner_sum_f32x4 = vfmaq_f32(inner_sum_f32x4, c_high_f32x4, diff_high_f32x4);
|
|
248
|
+
}
|
|
249
|
+
nk_f32_t inner_sum = vaddvq_f32(inner_sum_f32x4);
|
|
250
|
+
for (; column < n; ++column) {
|
|
251
|
+
nk_f32_t a_val, b_val, c_val;
|
|
252
|
+
nk_f16_to_f32_serial(a + column, &a_val);
|
|
253
|
+
nk_f16_to_f32_serial(b + column, &b_val);
|
|
254
|
+
nk_f16_to_f32_serial(c_row + column, &c_val);
|
|
255
|
+
inner_sum += c_val * (a_val - b_val);
|
|
256
|
+
}
|
|
257
|
+
outer_sum += diff_row * inner_sum;
|
|
258
|
+
}
|
|
259
|
+
nk_f32_t quadratic = outer_sum;
|
|
260
|
+
*result = nk_f32_sqrt_neon(quadratic > 0 ? quadratic : 0);
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
NK_PUBLIC void nk_bilinear_f16c_neon(nk_f16c_t const *a_pairs, nk_f16c_t const *b_pairs, nk_f16c_t const *c_pairs,
|
|
264
|
+
nk_size_t n, nk_f32c_t *results) {
|
|
265
|
+
nk_f32_t outer_sum_real = 0;
|
|
266
|
+
nk_f32_t outer_sum_imag = 0;
|
|
267
|
+
for (nk_size_t row = 0; row != n; ++row) {
|
|
268
|
+
nk_f16c_t const *c_row = c_pairs + row * n;
|
|
269
|
+
nk_f32_t a_real, a_imag;
|
|
270
|
+
nk_f16_to_f32_serial(&(a_pairs + row)->real, &a_real);
|
|
271
|
+
nk_f16_to_f32_serial(&(a_pairs + row)->imag, &a_imag);
|
|
272
|
+
float32x4_t inner_sum_real_f32x4 = vdupq_n_f32(0);
|
|
273
|
+
float32x4_t inner_sum_imag_f32x4 = vdupq_n_f32(0);
|
|
274
|
+
nk_size_t column = 0;
|
|
275
|
+
for (; column + 8 <= n; column += 8) {
|
|
276
|
+
int16x8x2_t b_i16x8x2 = vld2q_s16((short const *)(b_pairs + column));
|
|
277
|
+
int16x8x2_t c_i16x8x2 = vld2q_s16((short const *)(c_row + column));
|
|
278
|
+
float16x8_t b_real_f16x8 = vreinterpretq_f16_s16(b_i16x8x2.val[0]);
|
|
279
|
+
float16x8_t b_imag_f16x8 = vreinterpretq_f16_s16(b_i16x8x2.val[1]);
|
|
280
|
+
float16x8_t c_real_f16x8 = vreinterpretq_f16_s16(c_i16x8x2.val[0]);
|
|
281
|
+
float16x8_t c_imag_f16x8 = vreinterpretq_f16_s16(c_i16x8x2.val[1]);
|
|
282
|
+
float32x4_t b_real_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_real_f16x8));
|
|
283
|
+
float32x4_t b_real_high_f32x4 = vcvt_high_f32_f16(b_real_f16x8);
|
|
284
|
+
float32x4_t b_imag_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_imag_f16x8));
|
|
285
|
+
float32x4_t b_imag_high_f32x4 = vcvt_high_f32_f16(b_imag_f16x8);
|
|
286
|
+
float32x4_t c_real_low_f32x4 = vcvt_f32_f16(vget_low_f16(c_real_f16x8));
|
|
287
|
+
float32x4_t c_real_high_f32x4 = vcvt_high_f32_f16(c_real_f16x8);
|
|
288
|
+
float32x4_t c_imag_low_f32x4 = vcvt_f32_f16(vget_low_f16(c_imag_f16x8));
|
|
289
|
+
float32x4_t c_imag_high_f32x4 = vcvt_high_f32_f16(c_imag_f16x8);
|
|
290
|
+
inner_sum_real_f32x4 = vfmaq_f32(inner_sum_real_f32x4, c_real_low_f32x4, b_real_low_f32x4);
|
|
291
|
+
inner_sum_real_f32x4 = vfmsq_f32(inner_sum_real_f32x4, c_imag_low_f32x4, b_imag_low_f32x4);
|
|
292
|
+
inner_sum_real_f32x4 = vfmaq_f32(inner_sum_real_f32x4, c_real_high_f32x4, b_real_high_f32x4);
|
|
293
|
+
inner_sum_real_f32x4 = vfmsq_f32(inner_sum_real_f32x4, c_imag_high_f32x4, b_imag_high_f32x4);
|
|
294
|
+
inner_sum_imag_f32x4 = vfmaq_f32(inner_sum_imag_f32x4, c_real_low_f32x4, b_imag_low_f32x4);
|
|
295
|
+
inner_sum_imag_f32x4 = vfmaq_f32(inner_sum_imag_f32x4, c_imag_low_f32x4, b_real_low_f32x4);
|
|
296
|
+
inner_sum_imag_f32x4 = vfmaq_f32(inner_sum_imag_f32x4, c_real_high_f32x4, b_imag_high_f32x4);
|
|
297
|
+
inner_sum_imag_f32x4 = vfmaq_f32(inner_sum_imag_f32x4, c_imag_high_f32x4, b_real_high_f32x4);
|
|
298
|
+
}
|
|
299
|
+
nk_f32_t inner_sum_real = vaddvq_f32(inner_sum_real_f32x4);
|
|
300
|
+
nk_f32_t inner_sum_imag = vaddvq_f32(inner_sum_imag_f32x4);
|
|
301
|
+
for (; column < n; ++column) {
|
|
302
|
+
nk_f32_t b_real, b_imag, c_real, c_imag;
|
|
303
|
+
nk_f16_to_f32_serial(&(b_pairs + column)->real, &b_real);
|
|
304
|
+
nk_f16_to_f32_serial(&(b_pairs + column)->imag, &b_imag);
|
|
305
|
+
nk_f16_to_f32_serial(&(c_row + column)->real, &c_real);
|
|
306
|
+
nk_f16_to_f32_serial(&(c_row + column)->imag, &c_imag);
|
|
307
|
+
inner_sum_real += c_real * b_real - c_imag * b_imag;
|
|
308
|
+
inner_sum_imag += c_real * b_imag + c_imag * b_real;
|
|
309
|
+
}
|
|
310
|
+
outer_sum_real += a_real * inner_sum_real - a_imag * inner_sum_imag;
|
|
311
|
+
outer_sum_imag += a_real * inner_sum_imag + a_imag * inner_sum_real;
|
|
312
|
+
}
|
|
313
|
+
results->real = outer_sum_real;
|
|
314
|
+
results->imag = outer_sum_imag;
|
|
315
|
+
}
|
|
316
|
+
|
|
193
317
|
#if defined(__clang__)
|
|
194
318
|
#pragma clang attribute pop
|
|
195
319
|
#elif defined(__GNUC__)
|
|
@@ -10,13 +10,12 @@
|
|
|
10
10
|
*
|
|
11
11
|
* @section curved_neonbfdot_instructions ARM NEON BF16 Instructions (ARMv8.6-BF16)
|
|
12
12
|
*
|
|
13
|
-
* Intrinsic
|
|
14
|
-
*
|
|
15
|
-
*
|
|
16
|
-
*
|
|
17
|
-
*
|
|
18
|
-
*
|
|
19
|
-
* vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy 2/cy 4/cy
|
|
13
|
+
* Intrinsic Instruction A76 M5
|
|
14
|
+
* vbfdotq_f32 BFDOT (V.4S, V.8H, V.8H) 3cy @ 2p 2cy @ 1p
|
|
15
|
+
* vcvt_f32_bf16 BFCVTN (V.4H, V.4S) 3cy @ 2p 3cy @ 4p
|
|
16
|
+
* vld1q_bf16 LD1 (V.8H) 4cy @ 2p 4cy @ 3p
|
|
17
|
+
* vaddvq_f32 FADDP+FADDP (V.4S) 5cy @ 1p 8cy @ 1p
|
|
18
|
+
* vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy @ 2p 3cy @ 4p
|
|
20
19
|
*
|
|
21
20
|
* For bilinear forms, BFDOT enables efficient inner-product computation by processing 8 bf16
|
|
22
21
|
* pairs into 4 f32 results per instruction. For Mahalanobis distance, bf16 inputs are converted
|
|
@@ -36,10 +36,10 @@ extern "C" {
|
|
|
36
36
|
|
|
37
37
|
NK_PUBLIC void nk_bilinear_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
38
38
|
nk_f64_t *result) {
|
|
39
|
-
nk_size_t
|
|
39
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
40
40
|
nk_f64_t outer_sum = 0;
|
|
41
41
|
for (nk_size_t i = 0; i < n; ++i) {
|
|
42
|
-
vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
42
|
+
vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
43
43
|
nk_f32_t const *c_row = c + i * n;
|
|
44
44
|
nk_size_t remaining = n;
|
|
45
45
|
for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
|
|
@@ -50,7 +50,7 @@ NK_PUBLIC void nk_bilinear_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_f32_
|
|
|
50
50
|
}
|
|
51
51
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
52
52
|
nk_f64_t inner_val = __riscv_vfmv_f_s_f64m1_f64(
|
|
53
|
-
__riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1,
|
|
53
|
+
__riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1, max_vector_length));
|
|
54
54
|
outer_sum += (nk_f64_t)a[i] * inner_val;
|
|
55
55
|
}
|
|
56
56
|
*result = outer_sum;
|
|
@@ -58,12 +58,12 @@ NK_PUBLIC void nk_bilinear_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_f32_
|
|
|
58
58
|
|
|
59
59
|
NK_PUBLIC void nk_bilinear_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
60
60
|
nk_f64_t *result) {
|
|
61
|
-
nk_size_t
|
|
61
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
62
62
|
vfloat64m1_t sum_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
63
63
|
nk_f64_t outer_compensation = 0;
|
|
64
64
|
for (nk_size_t i = 0; i < n; ++i) {
|
|
65
|
-
vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
66
|
-
vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
65
|
+
vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
66
|
+
vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
67
67
|
nk_f64_t const *c_row = c + i * n;
|
|
68
68
|
nk_size_t remaining = n;
|
|
69
69
|
for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
|
|
@@ -82,7 +82,7 @@ NK_PUBLIC void nk_bilinear_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_f64_
|
|
|
82
82
|
}
|
|
83
83
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
84
84
|
nk_f64_t inner_val = __riscv_vfmv_f_s_f64m1_f64(
|
|
85
|
-
__riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1,
|
|
85
|
+
__riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1, max_vector_length));
|
|
86
86
|
nk_f64_t product_outer = a[i] * inner_val;
|
|
87
87
|
nk_f64_t old_sum = __riscv_vfmv_f_s_f64m1_f64(sum_f64m1);
|
|
88
88
|
nk_f64_t new_sum = old_sum + product_outer;
|
|
@@ -96,14 +96,14 @@ NK_PUBLIC void nk_bilinear_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_f64_
|
|
|
96
96
|
|
|
97
97
|
NK_PUBLIC void nk_bilinear_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
98
98
|
nk_f32_t *result) {
|
|
99
|
-
nk_size_t
|
|
99
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
100
100
|
vfloat32m1_t sum_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
101
101
|
for (nk_size_t i = 0; i < n; ++i) {
|
|
102
102
|
// Convert a[i] from f16 to f32
|
|
103
103
|
nk_f32_t a_i;
|
|
104
104
|
nk_f16_to_f32_serial(a + i, &a_i);
|
|
105
105
|
|
|
106
|
-
vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
106
|
+
vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
107
107
|
nk_f16_t const *c_row = c + i * n;
|
|
108
108
|
nk_size_t remaining = n;
|
|
109
109
|
for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
|
|
@@ -117,7 +117,7 @@ NK_PUBLIC void nk_bilinear_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f16_
|
|
|
117
117
|
}
|
|
118
118
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
119
119
|
nk_f32_t inner_val = __riscv_vfmv_f_s_f32m1_f32(
|
|
120
|
-
__riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1,
|
|
120
|
+
__riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1, max_vector_length));
|
|
121
121
|
sum_f32m1 = __riscv_vfmv_v_f_f32m1(__riscv_vfmv_f_s_f32m1_f32(sum_f32m1) + a_i * inner_val, 1);
|
|
122
122
|
}
|
|
123
123
|
*result = __riscv_vfmv_f_s_f32m1_f32(sum_f32m1);
|
|
@@ -125,14 +125,14 @@ NK_PUBLIC void nk_bilinear_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f16_
|
|
|
125
125
|
|
|
126
126
|
NK_PUBLIC void nk_bilinear_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
127
127
|
nk_f32_t *result) {
|
|
128
|
-
nk_size_t
|
|
128
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
129
129
|
vfloat32m1_t sum_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
130
130
|
for (nk_size_t i = 0; i < n; ++i) {
|
|
131
131
|
// Convert a[i] from bf16 to f32
|
|
132
132
|
nk_f32_t a_i;
|
|
133
133
|
nk_bf16_to_f32_serial(a + i, &a_i);
|
|
134
134
|
|
|
135
|
-
vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
135
|
+
vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
136
136
|
nk_bf16_t const *c_row = c + i * n;
|
|
137
137
|
nk_size_t remaining = n;
|
|
138
138
|
for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
|
|
@@ -146,7 +146,7 @@ NK_PUBLIC void nk_bilinear_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_b
|
|
|
146
146
|
}
|
|
147
147
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
148
148
|
nk_f32_t inner_val = __riscv_vfmv_f_s_f32m1_f32(
|
|
149
|
-
__riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1,
|
|
149
|
+
__riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1, max_vector_length));
|
|
150
150
|
sum_f32m1 = __riscv_vfmv_v_f_f32m1(__riscv_vfmv_f_s_f32m1_f32(sum_f32m1) + a_i * inner_val, 1);
|
|
151
151
|
}
|
|
152
152
|
*result = __riscv_vfmv_f_s_f32m1_f32(sum_f32m1);
|
|
@@ -154,11 +154,11 @@ NK_PUBLIC void nk_bilinear_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_b
|
|
|
154
154
|
|
|
155
155
|
NK_PUBLIC void nk_mahalanobis_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t n,
|
|
156
156
|
nk_f64_t *result) {
|
|
157
|
-
nk_size_t
|
|
157
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
158
158
|
nk_f64_t outer_sum = 0;
|
|
159
159
|
for (nk_size_t i = 0; i < n; ++i) {
|
|
160
160
|
nk_f64_t diff_i = (nk_f64_t)a[i] - (nk_f64_t)b[i];
|
|
161
|
-
vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
161
|
+
vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
162
162
|
nk_f32_t const *c_row = c + i * n;
|
|
163
163
|
nk_size_t remaining = n;
|
|
164
164
|
for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
|
|
@@ -173,7 +173,7 @@ NK_PUBLIC void nk_mahalanobis_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_f
|
|
|
173
173
|
}
|
|
174
174
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
175
175
|
nk_f64_t inner_val = __riscv_vfmv_f_s_f64m1_f64(
|
|
176
|
-
__riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1,
|
|
176
|
+
__riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1, max_vector_length));
|
|
177
177
|
outer_sum += diff_i * inner_val;
|
|
178
178
|
}
|
|
179
179
|
*result = nk_f64_sqrt_rvv(outer_sum > 0 ? outer_sum : 0);
|
|
@@ -181,13 +181,13 @@ NK_PUBLIC void nk_mahalanobis_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_f
|
|
|
181
181
|
|
|
182
182
|
NK_PUBLIC void nk_mahalanobis_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t n,
|
|
183
183
|
nk_f64_t *result) {
|
|
184
|
-
nk_size_t
|
|
184
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e64m4();
|
|
185
185
|
vfloat64m1_t sum_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
186
186
|
nk_f64_t outer_compensation = 0;
|
|
187
187
|
for (nk_size_t i = 0; i < n; ++i) {
|
|
188
188
|
nk_f64_t diff_i = a[i] - b[i];
|
|
189
|
-
vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
190
|
-
vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0,
|
|
189
|
+
vfloat64m4_t inner_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
190
|
+
vfloat64m4_t compensation_f64m4 = __riscv_vfmv_v_f_f64m4(0.0, max_vector_length);
|
|
191
191
|
nk_f64_t const *c_row = c + i * n;
|
|
192
192
|
nk_size_t remaining = n;
|
|
193
193
|
for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
|
|
@@ -209,7 +209,7 @@ NK_PUBLIC void nk_mahalanobis_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_f
|
|
|
209
209
|
}
|
|
210
210
|
vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
|
|
211
211
|
nk_f64_t inner_val = __riscv_vfmv_f_s_f64m1_f64(
|
|
212
|
-
__riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1,
|
|
212
|
+
__riscv_vfredusum_vs_f64m4_f64m1(inner_f64m4, zero_f64m1, max_vector_length));
|
|
213
213
|
nk_f64_t product_outer = diff_i * inner_val;
|
|
214
214
|
nk_f64_t old_sum = __riscv_vfmv_f_s_f64m1_f64(sum_f64m1);
|
|
215
215
|
nk_f64_t new_sum = old_sum + product_outer;
|
|
@@ -224,7 +224,7 @@ NK_PUBLIC void nk_mahalanobis_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_f
|
|
|
224
224
|
|
|
225
225
|
NK_PUBLIC void nk_mahalanobis_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f16_t const *c, nk_size_t n,
|
|
226
226
|
nk_f32_t *result) {
|
|
227
|
-
nk_size_t
|
|
227
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
228
228
|
vfloat32m1_t sum_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
229
229
|
for (nk_size_t i = 0; i < n; ++i) {
|
|
230
230
|
nk_f32_t a_i, b_i;
|
|
@@ -232,7 +232,7 @@ NK_PUBLIC void nk_mahalanobis_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f
|
|
|
232
232
|
nk_f16_to_f32_serial(b + i, &b_i);
|
|
233
233
|
nk_f32_t diff_i = a_i - b_i;
|
|
234
234
|
|
|
235
|
-
vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
235
|
+
vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
236
236
|
nk_f16_t const *c_row = c + i * n;
|
|
237
237
|
nk_size_t remaining = n;
|
|
238
238
|
for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
|
|
@@ -249,7 +249,7 @@ NK_PUBLIC void nk_mahalanobis_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f
|
|
|
249
249
|
}
|
|
250
250
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
251
251
|
nk_f32_t inner_val = __riscv_vfmv_f_s_f32m1_f32(
|
|
252
|
-
__riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1,
|
|
252
|
+
__riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1, max_vector_length));
|
|
253
253
|
sum_f32m1 = __riscv_vfmv_v_f_f32m1(__riscv_vfmv_f_s_f32m1_f32(sum_f32m1) + diff_i * inner_val, 1);
|
|
254
254
|
}
|
|
255
255
|
nk_f32_t quadratic_f16 = __riscv_vfmv_f_s_f32m1_f32(sum_f32m1);
|
|
@@ -258,7 +258,7 @@ NK_PUBLIC void nk_mahalanobis_f16_rvv(nk_f16_t const *a, nk_f16_t const *b, nk_f
|
|
|
258
258
|
|
|
259
259
|
NK_PUBLIC void nk_mahalanobis_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, nk_bf16_t const *c, nk_size_t n,
|
|
260
260
|
nk_f32_t *result) {
|
|
261
|
-
nk_size_t
|
|
261
|
+
nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
|
|
262
262
|
vfloat32m1_t sum_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
263
263
|
for (nk_size_t i = 0; i < n; ++i) {
|
|
264
264
|
nk_f32_t a_i, b_i;
|
|
@@ -266,7 +266,7 @@ NK_PUBLIC void nk_mahalanobis_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, n
|
|
|
266
266
|
nk_bf16_to_f32_serial(b + i, &b_i);
|
|
267
267
|
nk_f32_t diff_i = a_i - b_i;
|
|
268
268
|
|
|
269
|
-
vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f,
|
|
269
|
+
vfloat32m2_t inner_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
|
|
270
270
|
nk_bf16_t const *c_row = c + i * n;
|
|
271
271
|
nk_size_t remaining = n;
|
|
272
272
|
for (nk_size_t vector_length; remaining > 0; remaining -= vector_length, c_row += vector_length) {
|
|
@@ -283,7 +283,7 @@ NK_PUBLIC void nk_mahalanobis_bf16_rvv(nk_bf16_t const *a, nk_bf16_t const *b, n
|
|
|
283
283
|
}
|
|
284
284
|
vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
|
|
285
285
|
nk_f32_t inner_val = __riscv_vfmv_f_s_f32m1_f32(
|
|
286
|
-
__riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1,
|
|
286
|
+
__riscv_vfredusum_vs_f32m2_f32m1(inner_f32m2, zero_f32m1, max_vector_length));
|
|
287
287
|
sum_f32m1 = __riscv_vfmv_v_f_f32m1(__riscv_vfmv_f_s_f32m1_f32(sum_f32m1) + diff_i * inner_val, 1);
|
|
288
288
|
}
|
|
289
289
|
nk_f32_t quadratic_bf16 = __riscv_vfmv_f_s_f32m1_f32(sum_f32m1);
|