numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
package/include/numkong/mesh.h
CHANGED
|
@@ -82,17 +82,17 @@
|
|
|
82
82
|
*
|
|
83
83
|
* The SIMD kernels are dominated by FMA, permutes, and gathers:
|
|
84
84
|
*
|
|
85
|
-
* Intrinsic
|
|
86
|
-
* _mm256_fmadd_ps/pd
|
|
87
|
-
* _mm256_i32gather_ps
|
|
88
|
-
* _mm512_permutex2var_ps/pd
|
|
89
|
-
* _mm512_reduce_add_ps/pd
|
|
85
|
+
* Intrinsic Instruction Notes
|
|
86
|
+
* _mm256_fmadd_ps/pd VFMADD* FMA on FP ports (Haswell/Skylake: ports 0/1)
|
|
87
|
+
* _mm256_i32gather_ps VGATHERDPS High-latency; memory-bound
|
|
88
|
+
* _mm512_permutex2var_ps/pd VPERMT2* Shuffle-heavy; can bottleneck on shuffle ports
|
|
89
|
+
* _mm512_reduce_add_ps/pd (sequence) Implemented via shuffles + adds
|
|
90
90
|
*
|
|
91
91
|
* Gather-heavy tails are intentionally isolated to keep the steady-state loop on contiguous loads.
|
|
92
92
|
*
|
|
93
93
|
* @section references References
|
|
94
94
|
*
|
|
95
|
-
* - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
|
|
95
|
+
* - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html
|
|
96
96
|
* - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
|
|
97
97
|
*
|
|
98
98
|
*/
|
|
@@ -245,6 +245,25 @@ NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_si
|
|
|
245
245
|
/** @copydoc nk_umeyama_f64 */
|
|
246
246
|
NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
|
|
247
247
|
nk_f64_t *b_centroid, nk_f64_t *rotation, nk_f64_t *scale, nk_f64_t *result);
|
|
248
|
+
|
|
249
|
+
/** @copydoc nk_rmsd_f16 */
|
|
250
|
+
NK_PUBLIC void nk_rmsd_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
251
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
252
|
+
/** @copydoc nk_kabsch_f16 */
|
|
253
|
+
NK_PUBLIC void nk_kabsch_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
254
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
255
|
+
/** @copydoc nk_umeyama_f16 */
|
|
256
|
+
NK_PUBLIC void nk_umeyama_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
257
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
258
|
+
/** @copydoc nk_rmsd_bf16 */
|
|
259
|
+
NK_PUBLIC void nk_rmsd_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
260
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
261
|
+
/** @copydoc nk_kabsch_bf16 */
|
|
262
|
+
NK_PUBLIC void nk_kabsch_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
263
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
264
|
+
/** @copydoc nk_umeyama_bf16 */
|
|
265
|
+
NK_PUBLIC void nk_umeyama_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
266
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
248
267
|
#endif // NK_TARGET_SKYLAKE
|
|
249
268
|
|
|
250
269
|
/* SIMD-powered backends for AVX2 CPUs of Haswell generation and newer.
|
|
@@ -313,21 +332,16 @@ NK_PUBLIC void nk_kabsch_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_
|
|
|
313
332
|
/** @copydoc nk_umeyama_f64 */
|
|
314
333
|
NK_PUBLIC void nk_umeyama_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
|
|
315
334
|
nk_f64_t *b_centroid, nk_f64_t *rotation, nk_f64_t *scale, nk_f64_t *result);
|
|
316
|
-
#endif // NK_TARGET_NEON
|
|
317
|
-
|
|
318
|
-
/* SIMD-powered backends for Arm NEON FP16 CPUs.
|
|
319
|
-
*/
|
|
320
|
-
#if NK_TARGET_NEONHALF
|
|
321
335
|
/** @copydoc nk_rmsd_f16 */
|
|
322
|
-
NK_PUBLIC void
|
|
323
|
-
|
|
336
|
+
NK_PUBLIC void nk_rmsd_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
337
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
324
338
|
/** @copydoc nk_kabsch_f16 */
|
|
325
|
-
NK_PUBLIC void
|
|
326
|
-
|
|
339
|
+
NK_PUBLIC void nk_kabsch_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
340
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
327
341
|
/** @copydoc nk_umeyama_f16 */
|
|
328
|
-
NK_PUBLIC void
|
|
329
|
-
|
|
330
|
-
#endif //
|
|
342
|
+
NK_PUBLIC void nk_umeyama_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
343
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
344
|
+
#endif // NK_TARGET_NEON
|
|
331
345
|
|
|
332
346
|
/* SIMD-powered backends for Arm NEON BF16 CPUs.
|
|
333
347
|
*/
|
|
@@ -406,22 +420,10 @@ NK_PUBLIC void nk_umeyama_f64_v128relaxed(nk_f64_t const *a, nk_f64_t const *b,
|
|
|
406
420
|
#endif // NK_TARGET_V128RELAXED
|
|
407
421
|
|
|
408
422
|
/**
|
|
409
|
-
* @brief Returns the output dtype for
|
|
410
|
-
|
|
411
|
-
NK_INTERNAL nk_dtype_t nk_rmsd_output_dtype(nk_dtype_t dtype) {
|
|
412
|
-
switch (dtype) {
|
|
413
|
-
case nk_f64_k: return nk_f64_k;
|
|
414
|
-
case nk_f32_k: return nk_f64_k;
|
|
415
|
-
case nk_f16_k: return nk_f32_k;
|
|
416
|
-
case nk_bf16_k: return nk_f32_k;
|
|
417
|
-
default: return nk_dtype_unknown_k;
|
|
418
|
-
}
|
|
419
|
-
}
|
|
420
|
-
|
|
421
|
-
/**
|
|
422
|
-
* @brief Returns the output dtype for Kabsch alignment.
|
|
423
|
+
* @brief Returns the metric output dtype for mesh alignment operations.
|
|
424
|
+
* Matches the C++ `mesh_metric_t` alias in types.hpp.
|
|
423
425
|
*/
|
|
424
|
-
NK_INTERNAL nk_dtype_t
|
|
426
|
+
NK_INTERNAL nk_dtype_t nk_mesh_metric_dtype(nk_dtype_t dtype) {
|
|
425
427
|
switch (dtype) {
|
|
426
428
|
case nk_f64_k: return nk_f64_k;
|
|
427
429
|
case nk_f32_k: return nk_f64_k;
|
|
@@ -432,12 +434,13 @@ NK_INTERNAL nk_dtype_t nk_kabsch_output_dtype(nk_dtype_t dtype) {
|
|
|
432
434
|
}
|
|
433
435
|
|
|
434
436
|
/**
|
|
435
|
-
* @brief Returns the output dtype for
|
|
437
|
+
* @brief Returns the transform output dtype for mesh alignment operations.
|
|
438
|
+
* Matches the C++ `mesh_transform_t` alias in types.hpp.
|
|
436
439
|
*/
|
|
437
|
-
NK_INTERNAL nk_dtype_t
|
|
440
|
+
NK_INTERNAL nk_dtype_t nk_mesh_transform_dtype(nk_dtype_t dtype) {
|
|
438
441
|
switch (dtype) {
|
|
439
442
|
case nk_f64_k: return nk_f64_k;
|
|
440
|
-
case nk_f32_k: return
|
|
443
|
+
case nk_f32_k: return nk_f32_k;
|
|
441
444
|
case nk_f16_k: return nk_f32_k;
|
|
442
445
|
case nk_bf16_k: return nk_f32_k;
|
|
443
446
|
default: return nk_dtype_unknown_k;
|
|
@@ -450,7 +453,6 @@ NK_INTERNAL nk_dtype_t nk_umeyama_output_dtype(nk_dtype_t dtype) {
|
|
|
450
453
|
|
|
451
454
|
#include "numkong/mesh/serial.h"
|
|
452
455
|
#include "numkong/mesh/neon.h"
|
|
453
|
-
#include "numkong/mesh/neonhalf.h"
|
|
454
456
|
#include "numkong/mesh/neonbfdot.h"
|
|
455
457
|
#include "numkong/mesh/haswell.h"
|
|
456
458
|
#include "numkong/mesh/skylake.h"
|
|
@@ -499,10 +501,12 @@ NK_PUBLIC void nk_rmsd_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk
|
|
|
499
501
|
|
|
500
502
|
NK_PUBLIC void nk_rmsd_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
501
503
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
502
|
-
#if
|
|
504
|
+
#if NK_TARGET_SKYLAKE
|
|
505
|
+
nk_rmsd_f16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
506
|
+
#elif NK_TARGET_HASWELL
|
|
503
507
|
nk_rmsd_f16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
504
|
-
#elif
|
|
505
|
-
|
|
508
|
+
#elif NK_TARGET_NEON
|
|
509
|
+
nk_rmsd_f16_neon(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
506
510
|
#elif NK_TARGET_RVV
|
|
507
511
|
nk_rmsd_f16_rvv(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
508
512
|
#else
|
|
@@ -512,7 +516,9 @@ NK_PUBLIC void nk_rmsd_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk
|
|
|
512
516
|
|
|
513
517
|
NK_PUBLIC void nk_rmsd_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
514
518
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
515
|
-
#if
|
|
519
|
+
#if NK_TARGET_SKYLAKE
|
|
520
|
+
nk_rmsd_bf16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
521
|
+
#elif NK_TARGET_HASWELL
|
|
516
522
|
nk_rmsd_bf16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
517
523
|
#elif NK_TARGET_NEONBFDOT
|
|
518
524
|
nk_rmsd_bf16_neonbfdot(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
@@ -559,10 +565,12 @@ NK_PUBLIC void nk_kabsch_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n,
|
|
|
559
565
|
|
|
560
566
|
NK_PUBLIC void nk_kabsch_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
561
567
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
562
|
-
#if
|
|
568
|
+
#if NK_TARGET_SKYLAKE
|
|
569
|
+
nk_kabsch_f16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
570
|
+
#elif NK_TARGET_HASWELL
|
|
563
571
|
nk_kabsch_f16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
564
|
-
#elif
|
|
565
|
-
|
|
572
|
+
#elif NK_TARGET_NEON
|
|
573
|
+
nk_kabsch_f16_neon(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
566
574
|
#elif NK_TARGET_RVV
|
|
567
575
|
nk_kabsch_f16_rvv(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
568
576
|
#else
|
|
@@ -572,7 +580,9 @@ NK_PUBLIC void nk_kabsch_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
|
|
|
572
580
|
|
|
573
581
|
NK_PUBLIC void nk_kabsch_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
574
582
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
575
|
-
#if
|
|
583
|
+
#if NK_TARGET_SKYLAKE
|
|
584
|
+
nk_kabsch_bf16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
585
|
+
#elif NK_TARGET_HASWELL
|
|
576
586
|
nk_kabsch_bf16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
577
587
|
#elif NK_TARGET_NEONBFDOT
|
|
578
588
|
nk_kabsch_bf16_neonbfdot(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
@@ -619,10 +629,12 @@ NK_PUBLIC void nk_umeyama_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n,
|
|
|
619
629
|
|
|
620
630
|
NK_PUBLIC void nk_umeyama_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
621
631
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
622
|
-
#if
|
|
632
|
+
#if NK_TARGET_SKYLAKE
|
|
633
|
+
nk_umeyama_f16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
634
|
+
#elif NK_TARGET_HASWELL
|
|
623
635
|
nk_umeyama_f16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
624
|
-
#elif
|
|
625
|
-
|
|
636
|
+
#elif NK_TARGET_NEON
|
|
637
|
+
nk_umeyama_f16_neon(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
626
638
|
#elif NK_TARGET_RVV
|
|
627
639
|
nk_umeyama_f16_rvv(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
628
640
|
#else
|
|
@@ -632,7 +644,9 @@ NK_PUBLIC void nk_umeyama_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
|
|
|
632
644
|
|
|
633
645
|
NK_PUBLIC void nk_umeyama_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
634
646
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
635
|
-
#if
|
|
647
|
+
#if NK_TARGET_SKYLAKE
|
|
648
|
+
nk_umeyama_bf16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
649
|
+
#elif NK_TARGET_HASWELL
|
|
636
650
|
nk_umeyama_bf16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
637
651
|
#elif NK_TARGET_NEONBFDOT
|
|
638
652
|
nk_umeyama_bf16_neonbfdot(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
package/include/numkong/mesh.hpp
CHANGED
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
|
|
18
18
|
namespace ashvardanian::numkong {
|
|
19
19
|
|
|
20
|
-
#pragma region
|
|
20
|
+
#pragma region SVD Helpers for Scalar Fallbacks
|
|
21
21
|
|
|
22
22
|
/** @brief 3x3 matrix determinant. */
|
|
23
23
|
template <typename scalar_type_>
|
|
@@ -313,9 +313,9 @@ void svd3x3_(scalar_type_ const *a, scalar_type_ *svd_u, scalar_type_ *svd_s, sc
|
|
|
313
313
|
svd_s[8] = s3_sq.sqrt();
|
|
314
314
|
}
|
|
315
315
|
|
|
316
|
-
#pragma endregion
|
|
316
|
+
#pragma endregion SVD Helpers for Scalar Fallbacks
|
|
317
317
|
|
|
318
|
-
#pragma region
|
|
318
|
+
#pragma region Mesh Alignment Kernels
|
|
319
319
|
|
|
320
320
|
/**
|
|
321
321
|
* @brief Root Mean Square Deviation between two 3D point clouds (no alignment)
|
|
@@ -755,7 +755,7 @@ void umeyama(in_type_ const *a, in_type_ const *b, std::size_t n, transform_type
|
|
|
755
755
|
}
|
|
756
756
|
}
|
|
757
757
|
|
|
758
|
-
#pragma endregion
|
|
758
|
+
#pragma endregion Mesh Alignment Kernels
|
|
759
759
|
|
|
760
760
|
} // namespace ashvardanian::numkong
|
|
761
761
|
|
|
@@ -62,9 +62,9 @@ NK_PUBLIC nk_dtype_t nk_kernel_output_dtype(nk_kernel_kind_t kind, nk_dtype_t in
|
|
|
62
62
|
case nk_kernel_vincenty_k: return nk_vincenty_output_dtype(input);
|
|
63
63
|
case nk_kernel_kld_k:
|
|
64
64
|
case nk_kernel_jsd_k: return nk_probability_output_dtype(input);
|
|
65
|
-
case nk_kernel_rmsd_k:
|
|
66
|
-
case nk_kernel_kabsch_k:
|
|
67
|
-
case nk_kernel_umeyama_k: return
|
|
65
|
+
case nk_kernel_rmsd_k:
|
|
66
|
+
case nk_kernel_kabsch_k:
|
|
67
|
+
case nk_kernel_umeyama_k: return nk_mesh_metric_dtype(input);
|
|
68
68
|
case nk_kernel_sparse_dot_k: return nk_sparse_dot_output_dtype(input);
|
|
69
69
|
case nk_kernel_maxsim_packed_k: return nk_maxsim_output_dtype(input);
|
|
70
70
|
default: return nk_dtype_unknown_k;
|
|
@@ -5,17 +5,21 @@ These are used in variational inference, topic modeling, and distribution compar
|
|
|
5
5
|
|
|
6
6
|
Kullback-Leibler divergence from $P$ to $Q$:
|
|
7
7
|
|
|
8
|
-
|
|
8
|
+
$$
|
|
9
9
|
\text{KLD}(P \| Q) = \sum_{i=0}^{n-1} P(i) \log_2 \frac{P(i)}{Q(i)}
|
|
10
|
-
|
|
10
|
+
$$
|
|
11
11
|
|
|
12
12
|
Jensen-Shannon distance is the square root of the symmetrized KLD through a mixture:
|
|
13
13
|
|
|
14
|
-
|
|
14
|
+
$$
|
|
15
|
+
\text{JSD}(P, Q) = \frac{1}{2} \text{KLD}(P \| M) + \frac{1}{2} \text{KLD}(Q \| M)
|
|
16
|
+
$$
|
|
15
17
|
|
|
16
18
|
where $M = \frac{P + Q}{2}$, yielding the distance:
|
|
17
19
|
|
|
18
|
-
$$
|
|
20
|
+
$$
|
|
21
|
+
d_{JS}(P, Q) = \sqrt{\text{JSD}(P, Q)}
|
|
22
|
+
$$
|
|
19
23
|
|
|
20
24
|
Unlike the raw divergence, $d_{JS}$ is a true metric satisfying the triangle inequality.
|
|
21
25
|
|
|
@@ -35,9 +39,9 @@ def jsd(p: np.ndarray, q: np.ndarray) -> float:
|
|
|
35
39
|
|
|
36
40
|
## Use Cases
|
|
37
41
|
|
|
38
|
-
__Kullback-Leibler divergence__ is
|
|
42
|
+
__Kullback-Leibler divergence__ is widely used in variational inference (ELBO objective), knowledge distillation between neural networks, information gain in decision trees, and measuring fit between a model and observed data.
|
|
39
43
|
|
|
40
|
-
__Jensen-Shannon distance__
|
|
44
|
+
__Jensen-Shannon distance__ is commonly used in microbiome community comparison (enterotyping), where its metric property enables clustering with standard algorithms. It also appears in distribution drift detection, topic model evaluation, and as the theoretical foundation of the original GAN objective — though in practice GAN training uses proxy losses rather than computing JSD directly.
|
|
41
45
|
|
|
42
46
|
## Input & Output Types
|
|
43
47
|
|
|
@@ -149,25 +153,25 @@ Measured with Wasmtime v42 (Cranelift backend).
|
|
|
149
153
|
| `nk_kld_f16_serial` | 0.118 gb/s, 1.04K ulp | 0.127 gb/s, 4.53K ulp | 0.111 gb/s, 18.3K ulp |
|
|
150
154
|
| `nk_jsd_f16_serial` | 0.0748 gb/s, 1.4 ulp | 0.0681 gb/s, 2.6 ulp | 0.0857 gb/s, 9.7 ulp |
|
|
151
155
|
|
|
152
|
-
### Apple
|
|
156
|
+
### Apple M5
|
|
153
157
|
|
|
154
158
|
#### Native
|
|
155
159
|
|
|
156
160
|
| Kernel | 256 | 1024 | 4096 |
|
|
157
161
|
| :-------------------- | -----------------------: | -----------------------: | -----------------------: |
|
|
158
162
|
| __f64__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
159
|
-
| `nk_kld_f64_serial` |
|
|
160
|
-
| `nk_jsd_f64_serial` |
|
|
163
|
+
| `nk_kld_f64_serial` | 3.22 gb/s, 5.6K ulp | 3.36 gb/s, 25K ulp | 3.32 gb/s, 99K ulp |
|
|
164
|
+
| `nk_jsd_f64_serial` | 2.06 gb/s, 0.4 ulp | 2.17 gb/s, 0.4 ulp | 2.17 gb/s, 0.5 ulp |
|
|
161
165
|
| __f32__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
162
|
-
| `nk_kld_f32_serial` |
|
|
163
|
-
| `nk_jsd_f32_serial` |
|
|
164
|
-
| `nk_kld_f32_neon` |
|
|
165
|
-
| `nk_jsd_f32_neon` |
|
|
166
|
+
| `nk_kld_f32_serial` | 9.26 gb/s, 1.0K ulp | 8.73 gb/s, 4.5K ulp | 9.10 gb/s, 18K ulp |
|
|
167
|
+
| `nk_jsd_f32_serial` | 2.08 gb/s, 0.4 ulp | 2.16 gb/s, 0.4 ulp | 2.13 gb/s, 4.6 ulp |
|
|
168
|
+
| `nk_kld_f32_neon` | 19.0 gb/s, 1.0K ulp | 17.4 gb/s, 4.5K ulp | 18.1 gb/s, 18K ulp |
|
|
169
|
+
| `nk_jsd_f32_neon` | 9.75 gb/s, 15 ulp | 9.32 gb/s, 14 ulp | 9.62 gb/s, 9.9 ulp |
|
|
166
170
|
| __bf16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
167
|
-
| `nk_kld_bf16_serial` |
|
|
168
|
-
| `nk_jsd_bf16_serial` |
|
|
171
|
+
| `nk_kld_bf16_serial` | 4.58 gb/s, 1.0K ulp | 4.47 gb/s, 4.5K ulp | 4.65 gb/s, 18K ulp |
|
|
172
|
+
| `nk_jsd_bf16_serial` | 1.08 gb/s, 1.4 ulp | 1.07 gb/s, 2.9 ulp | 1.09 gb/s, 9.7 ulp |
|
|
169
173
|
| __f16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
|
|
170
|
-
| `nk_kld_f16_serial` |
|
|
171
|
-
| `nk_jsd_f16_serial` |
|
|
172
|
-
| `nk_kld_f16_neonhalf` |
|
|
173
|
-
| `nk_jsd_f16_neonhalf` |
|
|
174
|
+
| `nk_kld_f16_serial` | 4.63 gb/s, 1.0K ulp | 4.45 gb/s, 4.5K ulp | 4.55 gb/s, 18K ulp |
|
|
175
|
+
| `nk_jsd_f16_serial` | 1.03 gb/s, 1.4 ulp | 0.962 gb/s, 2.7 ulp | 0.976 gb/s, 8.7 ulp |
|
|
176
|
+
| `nk_kld_f16_neonhalf` | 10.2 gb/s, 1.0K ulp | 9.67 gb/s, 4.5K ulp | 9.99 gb/s, 18K ulp |
|
|
177
|
+
| `nk_jsd_f16_neonhalf` | 5.00 gb/s, 15 ulp | 4.79 gb/s, 14 ulp | 4.94 gb/s, 9.9 ulp |
|
|
@@ -57,8 +57,8 @@ NK_PUBLIC float32x4_t nk_log2_f32x4_neon_(float32x4_t x) {
|
|
|
57
57
|
NK_PUBLIC void nk_kld_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
58
58
|
nk_f32_t epsilon = NK_F32_DIVISION_EPSILON;
|
|
59
59
|
float32x4_t epsilon_f32x4 = vdupq_n_f32(epsilon);
|
|
60
|
-
float64x2_t
|
|
61
|
-
float64x2_t
|
|
60
|
+
float64x2_t sum_low_f64x2 = vdupq_n_f64(0.0);
|
|
61
|
+
float64x2_t sum_high_f64x2 = vdupq_n_f64(0.0);
|
|
62
62
|
float32x4_t a_f32x4, b_f32x4;
|
|
63
63
|
|
|
64
64
|
nk_kld_f32_neon_cycle:
|
|
@@ -79,20 +79,20 @@ nk_kld_f32_neon_cycle:
|
|
|
79
79
|
float32x4_t ratio_f32x4 = vdivq_f32(vaddq_f32(a_f32x4, epsilon_f32x4), vaddq_f32(b_f32x4, epsilon_f32x4));
|
|
80
80
|
float32x4_t log_ratio_f32x4 = nk_log2_f32x4_neon_(ratio_f32x4);
|
|
81
81
|
float32x4_t contribution_f32x4 = vmulq_f32(a_f32x4, log_ratio_f32x4);
|
|
82
|
-
|
|
83
|
-
|
|
82
|
+
sum_low_f64x2 = vaddq_f64(sum_low_f64x2, vcvt_f64_f32(vget_low_f32(contribution_f32x4)));
|
|
83
|
+
sum_high_f64x2 = vaddq_f64(sum_high_f64x2, vcvt_high_f64_f32(contribution_f32x4));
|
|
84
84
|
if (n != 0) goto nk_kld_f32_neon_cycle;
|
|
85
85
|
|
|
86
86
|
nk_f64_t log2_normalizer = 0.6931471805599453;
|
|
87
|
-
nk_f64_t sum = vaddvq_f64(vaddq_f64(
|
|
87
|
+
nk_f64_t sum = vaddvq_f64(vaddq_f64(sum_low_f64x2, sum_high_f64x2)) * log2_normalizer;
|
|
88
88
|
*result = sum;
|
|
89
89
|
}
|
|
90
90
|
|
|
91
91
|
NK_PUBLIC void nk_jsd_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
|
|
92
92
|
nk_f32_t epsilon = NK_F32_DIVISION_EPSILON;
|
|
93
93
|
float32x4_t epsilon_f32x4 = vdupq_n_f32(epsilon);
|
|
94
|
-
float64x2_t
|
|
95
|
-
float64x2_t
|
|
94
|
+
float64x2_t sum_low_f64x2 = vdupq_n_f64(0.0);
|
|
95
|
+
float64x2_t sum_high_f64x2 = vdupq_n_f64(0.0);
|
|
96
96
|
float32x4_t a_f32x4, b_f32x4;
|
|
97
97
|
|
|
98
98
|
nk_jsd_f32_neon_cycle:
|
|
@@ -118,12 +118,12 @@ nk_jsd_f32_neon_cycle:
|
|
|
118
118
|
float32x4_t contribution_a_f32x4 = vmulq_f32(a_f32x4, log_ratio_a_f32x4);
|
|
119
119
|
float32x4_t contribution_b_f32x4 = vmulq_f32(b_f32x4, log_ratio_b_f32x4);
|
|
120
120
|
float32x4_t contribution_f32x4 = vaddq_f32(contribution_a_f32x4, contribution_b_f32x4);
|
|
121
|
-
|
|
122
|
-
|
|
121
|
+
sum_low_f64x2 = vaddq_f64(sum_low_f64x2, vcvt_f64_f32(vget_low_f32(contribution_f32x4)));
|
|
122
|
+
sum_high_f64x2 = vaddq_f64(sum_high_f64x2, vcvt_high_f64_f32(contribution_f32x4));
|
|
123
123
|
if (n != 0) goto nk_jsd_f32_neon_cycle;
|
|
124
124
|
|
|
125
125
|
nk_f64_t log2_normalizer = 0.6931471805599453;
|
|
126
|
-
nk_f64_t sum = vaddvq_f64(vaddq_f64(
|
|
126
|
+
nk_f64_t sum = vaddvq_f64(vaddq_f64(sum_low_f64x2, sum_high_f64x2)) * log2_normalizer / 2.0;
|
|
127
127
|
*result = sum > 0 ? nk_f64_sqrt_neon(sum) : 0;
|
|
128
128
|
}
|
|
129
129
|
|
|
@@ -134,76 +134,106 @@ nk_jsd_f32_neon_cycle:
|
|
|
134
134
|
#endif
|
|
135
135
|
#endif // NK_TARGET_NEON
|
|
136
136
|
|
|
137
|
-
#if
|
|
137
|
+
#if NK_TARGET_NEON
|
|
138
138
|
#if defined(__clang__)
|
|
139
|
-
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd
|
|
139
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function)
|
|
140
140
|
#elif defined(__GNUC__)
|
|
141
141
|
#pragma GCC push_options
|
|
142
|
-
#pragma GCC target("arch=armv8.2-a+simd
|
|
142
|
+
#pragma GCC target("arch=armv8.2-a+simd")
|
|
143
143
|
#endif
|
|
144
144
|
|
|
145
|
-
NK_PUBLIC void
|
|
145
|
+
NK_PUBLIC void nk_kld_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
146
146
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
147
147
|
nk_f32_t epsilon = NK_F32_DIVISION_EPSILON;
|
|
148
148
|
float32x4_t epsilon_f32x4 = vdupq_n_f32(epsilon);
|
|
149
|
-
float32x4_t
|
|
149
|
+
float32x4_t a_low_f32x4, a_high_f32x4, b_low_f32x4, b_high_f32x4;
|
|
150
150
|
|
|
151
|
-
|
|
152
|
-
if (n <
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
151
|
+
nk_kld_f16_neon_cycle:
|
|
152
|
+
if (n < 8) {
|
|
153
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
154
|
+
nk_partial_load_b16x8_serial_(a, &a_vec, n);
|
|
155
|
+
nk_partial_load_b16x8_serial_(b, &b_vec, n);
|
|
156
|
+
float16x8_t a_f16x8 = vreinterpretq_f16_u16(a_vec.u16x8);
|
|
157
|
+
float16x8_t b_f16x8 = vreinterpretq_f16_u16(b_vec.u16x8);
|
|
158
|
+
a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
159
|
+
a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
160
|
+
b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
161
|
+
b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
158
162
|
n = 0;
|
|
159
163
|
}
|
|
160
164
|
else {
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
165
|
+
float16x8_t a_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)a));
|
|
166
|
+
float16x8_t b_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)b));
|
|
167
|
+
a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
168
|
+
a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
169
|
+
b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
170
|
+
b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
171
|
+
n -= 8, a += 8, b += 8;
|
|
164
172
|
}
|
|
165
173
|
|
|
166
|
-
float32x4_t
|
|
167
|
-
|
|
168
|
-
float32x4_t
|
|
169
|
-
|
|
170
|
-
|
|
174
|
+
float32x4_t ratio_low_f32x4 = vdivq_f32(vaddq_f32(a_low_f32x4, epsilon_f32x4),
|
|
175
|
+
vaddq_f32(b_low_f32x4, epsilon_f32x4));
|
|
176
|
+
float32x4_t ratio_high_f32x4 = vdivq_f32(vaddq_f32(a_high_f32x4, epsilon_f32x4),
|
|
177
|
+
vaddq_f32(b_high_f32x4, epsilon_f32x4));
|
|
178
|
+
float32x4_t log_ratio_low_f32x4 = nk_log2_f32x4_neon_(ratio_low_f32x4);
|
|
179
|
+
float32x4_t log_ratio_high_f32x4 = nk_log2_f32x4_neon_(ratio_high_f32x4);
|
|
180
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, a_low_f32x4, log_ratio_low_f32x4);
|
|
181
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, a_high_f32x4, log_ratio_high_f32x4);
|
|
182
|
+
if (n) goto nk_kld_f16_neon_cycle;
|
|
171
183
|
|
|
172
184
|
nk_f32_t log2_normalizer = 0.693147181f;
|
|
173
185
|
nk_f32_t sum = vaddvq_f32(sum_f32x4) * log2_normalizer;
|
|
174
186
|
*result = sum;
|
|
175
187
|
}
|
|
176
188
|
|
|
177
|
-
NK_PUBLIC void
|
|
189
|
+
NK_PUBLIC void nk_jsd_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
178
190
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
179
191
|
nk_f32_t epsilon = NK_F32_DIVISION_EPSILON;
|
|
180
192
|
float32x4_t epsilon_f32x4 = vdupq_n_f32(epsilon);
|
|
181
|
-
float32x4_t
|
|
193
|
+
float32x4_t a_low_f32x4, a_high_f32x4, b_low_f32x4, b_high_f32x4;
|
|
182
194
|
|
|
183
|
-
|
|
184
|
-
if (n <
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
195
|
+
nk_jsd_f16_neon_cycle:
|
|
196
|
+
if (n < 8) {
|
|
197
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
198
|
+
nk_partial_load_b16x8_serial_(a, &a_vec, n);
|
|
199
|
+
nk_partial_load_b16x8_serial_(b, &b_vec, n);
|
|
200
|
+
float16x8_t a_f16x8 = vreinterpretq_f16_u16(a_vec.u16x8);
|
|
201
|
+
float16x8_t b_f16x8 = vreinterpretq_f16_u16(b_vec.u16x8);
|
|
202
|
+
a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
203
|
+
a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
204
|
+
b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
205
|
+
b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
190
206
|
n = 0;
|
|
191
207
|
}
|
|
192
208
|
else {
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
209
|
+
float16x8_t a_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)a));
|
|
210
|
+
float16x8_t b_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)b));
|
|
211
|
+
a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
|
|
212
|
+
a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
|
|
213
|
+
b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
|
|
214
|
+
b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
|
|
215
|
+
n -= 8, a += 8, b += 8;
|
|
196
216
|
}
|
|
197
217
|
|
|
198
|
-
float32x4_t
|
|
199
|
-
float32x4_t
|
|
200
|
-
float32x4_t
|
|
201
|
-
|
|
202
|
-
float32x4_t
|
|
203
|
-
|
|
204
|
-
float32x4_t
|
|
205
|
-
|
|
206
|
-
|
|
218
|
+
float32x4_t mean_low_f32x4 = vmulq_n_f32(vaddq_f32(a_low_f32x4, b_low_f32x4), 0.5f);
|
|
219
|
+
float32x4_t mean_high_f32x4 = vmulq_n_f32(vaddq_f32(a_high_f32x4, b_high_f32x4), 0.5f);
|
|
220
|
+
float32x4_t ratio_a_low_f32x4 = vdivq_f32(vaddq_f32(a_low_f32x4, epsilon_f32x4),
|
|
221
|
+
vaddq_f32(mean_low_f32x4, epsilon_f32x4));
|
|
222
|
+
float32x4_t ratio_a_high_f32x4 = vdivq_f32(vaddq_f32(a_high_f32x4, epsilon_f32x4),
|
|
223
|
+
vaddq_f32(mean_high_f32x4, epsilon_f32x4));
|
|
224
|
+
float32x4_t ratio_b_low_f32x4 = vdivq_f32(vaddq_f32(b_low_f32x4, epsilon_f32x4),
|
|
225
|
+
vaddq_f32(mean_low_f32x4, epsilon_f32x4));
|
|
226
|
+
float32x4_t ratio_b_high_f32x4 = vdivq_f32(vaddq_f32(b_high_f32x4, epsilon_f32x4),
|
|
227
|
+
vaddq_f32(mean_high_f32x4, epsilon_f32x4));
|
|
228
|
+
float32x4_t log_ratio_a_low_f32x4 = nk_log2_f32x4_neon_(ratio_a_low_f32x4);
|
|
229
|
+
float32x4_t log_ratio_a_high_f32x4 = nk_log2_f32x4_neon_(ratio_a_high_f32x4);
|
|
230
|
+
float32x4_t log_ratio_b_low_f32x4 = nk_log2_f32x4_neon_(ratio_b_low_f32x4);
|
|
231
|
+
float32x4_t log_ratio_b_high_f32x4 = nk_log2_f32x4_neon_(ratio_b_high_f32x4);
|
|
232
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, a_low_f32x4, log_ratio_a_low_f32x4);
|
|
233
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, a_high_f32x4, log_ratio_a_high_f32x4);
|
|
234
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, b_low_f32x4, log_ratio_b_low_f32x4);
|
|
235
|
+
sum_f32x4 = vfmaq_f32(sum_f32x4, b_high_f32x4, log_ratio_b_high_f32x4);
|
|
236
|
+
if (n) goto nk_jsd_f16_neon_cycle;
|
|
207
237
|
|
|
208
238
|
nk_f32_t log2_normalizer = 0.693147181f;
|
|
209
239
|
nk_f32_t sum = vaddvq_f32(sum_f32x4) * log2_normalizer / 2;
|
|
@@ -215,7 +245,7 @@ nk_jsd_f16_neonhalf_cycle:
|
|
|
215
245
|
#elif defined(__GNUC__)
|
|
216
246
|
#pragma GCC pop_options
|
|
217
247
|
#endif
|
|
218
|
-
#endif //
|
|
248
|
+
#endif // NK_TARGET_NEON
|
|
219
249
|
|
|
220
250
|
#if defined(__cplusplus)
|
|
221
251
|
} // extern "C"
|