numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -8,34 +8,34 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section neon_cast_instructions ARM NEON Conversion Instructions
|
|
10
10
|
*
|
|
11
|
-
* Float ↔ integer conversions
|
|
11
|
+
* Float ↔ integer conversions:
|
|
12
12
|
*
|
|
13
|
-
* Intrinsic
|
|
14
|
-
* vcvtq_f32_s32
|
|
15
|
-
* vcvtq_f32_u32
|
|
16
|
-
* vcvtq_s32_f32
|
|
17
|
-
* vcvtq_u32_f32
|
|
13
|
+
* Intrinsic Instruction A76 M5
|
|
14
|
+
* vcvtq_f32_s32 SCVTF (V.4S, V.4S) 3cy @ 2p 3cy @ 4p
|
|
15
|
+
* vcvtq_f32_u32 UCVTF (V.4S, V.4S) 3cy @ 2p 3cy @ 4p
|
|
16
|
+
* vcvtq_s32_f32 FCVTZS (V.4S, V.4S) 3cy @ 2p 3cy @ 4p
|
|
17
|
+
* vcvtq_u32_f32 FCVTZU (V.4S, V.4S) 3cy @ 2p 3cy @ 4p
|
|
18
18
|
*
|
|
19
19
|
* Float precision conversions:
|
|
20
20
|
*
|
|
21
|
-
* Intrinsic
|
|
22
|
-
* vcvt_f32_f16
|
|
23
|
-
* vcvt_f16_f32
|
|
24
|
-
* vcvt_f64_f32
|
|
25
|
-
* vcvt_f32_f64
|
|
21
|
+
* Intrinsic Instruction A76 M5
|
|
22
|
+
* vcvt_f32_f16 FCVTL (V.4S, V.4H) 3cy @ 2p 3cy @ 4p
|
|
23
|
+
* vcvt_f16_f32 FCVTN (V.4H, V.4S) 3cy @ 2p 3cy @ 4p
|
|
24
|
+
* vcvt_f64_f32 FCVTL (V.2D, V.2S) 3cy @ 2p 3cy @ 4p
|
|
25
|
+
* vcvt_f32_f64 FCVTN (V.2S, V.2D) 3cy @ 2p 3cy @ 4p
|
|
26
26
|
*
|
|
27
27
|
* Integer narrowing with saturation:
|
|
28
28
|
*
|
|
29
|
-
* Intrinsic
|
|
30
|
-
* vqmovn_s32
|
|
31
|
-
* vqmovn_u32
|
|
32
|
-
* vqmovun_s32
|
|
29
|
+
* Intrinsic Instruction A76 M5
|
|
30
|
+
* vqmovn_s32 SQXTN (V.4H, V.4S) 3cy @ 2p 3cy @ 4p
|
|
31
|
+
* vqmovn_u32 UQXTN (V.4H, V.4S) 3cy @ 2p 3cy @ 4p
|
|
32
|
+
* vqmovun_s32 SQXTUN (V.4H, V.4S) 3cy @ 2p 3cy @ 4p
|
|
33
33
|
*
|
|
34
34
|
* BF16 support (ARMv8.6-A+):
|
|
35
35
|
*
|
|
36
|
-
* Intrinsic
|
|
37
|
-
* vcvtq_low_bf16_f32
|
|
38
|
-
* vcvtq_high_bf16_f32
|
|
36
|
+
* Intrinsic Instruction A76 M5
|
|
37
|
+
* vcvtq_low_bf16_f32 BFCVTN (V.4H, V.4S) 3cy @ 2p 3cy @ 4p
|
|
38
|
+
* vcvtq_high_bf16_f32 BFCVTN2 (V.8H, V.4S) 3cy @ 2p 3cy @ 4p
|
|
39
39
|
*
|
|
40
40
|
* BF16 conversions on baseline NEON (emulated via bit shifts):
|
|
41
41
|
* - bf16 → f32: vmovl_u16 + vshlq_n_u32 by 16
|
|
@@ -68,18 +68,18 @@ extern "C" {
|
|
|
68
68
|
#endif
|
|
69
69
|
|
|
70
70
|
NK_PUBLIC void nk_f16_to_f32_neon(nk_f16_t const *src, nk_f32_t *dest) {
|
|
71
|
-
float16x4_t
|
|
72
|
-
float32x4_t
|
|
73
|
-
*dest = vgetq_lane_f32(
|
|
71
|
+
float16x4_t f16_f16x4 = vreinterpret_f16_u16(vld1_dup_u16((nk_u16_t const *)src));
|
|
72
|
+
float32x4_t f32_f32x4 = vcvt_f32_f16(f16_f16x4);
|
|
73
|
+
*dest = vgetq_lane_f32(f32_f32x4, 0);
|
|
74
74
|
}
|
|
75
75
|
|
|
76
76
|
NK_PUBLIC void nk_f32_to_f16_neon(nk_f32_t const *src, nk_f16_t *dest) {
|
|
77
|
-
float32x4_t
|
|
78
|
-
float16x4_t
|
|
79
|
-
vst1_lane_u16((nk_u16_t *)dest, vreinterpret_u16_f16(
|
|
77
|
+
float32x4_t f32_f32x4 = vdupq_n_f32(*src);
|
|
78
|
+
float16x4_t f16_f16x4 = vcvt_f16_f32(f32_f32x4);
|
|
79
|
+
vst1_lane_u16((nk_u16_t *)dest, vreinterpret_u16_f16(f16_f16x4), 0);
|
|
80
80
|
}
|
|
81
81
|
|
|
82
|
-
#pragma region
|
|
82
|
+
#pragma region Type Punned Loads and Stores
|
|
83
83
|
|
|
84
84
|
/** @brief Type-agnostic 128-bit full load (NEON). */
|
|
85
85
|
NK_INTERNAL void nk_load_b128_neon_(void const *src, nk_b128_vec_t *dst) {
|
|
@@ -104,73 +104,64 @@ NK_INTERNAL void nk_store_b256_neon_(nk_b256_vec_t const *src, void *dst) {
|
|
|
104
104
|
/** @brief Type-agnostic 64-bit full load (NEON). */
|
|
105
105
|
NK_INTERNAL void nk_load_b64_neon_(void const *src, nk_b64_vec_t *dst) { dst->u8x8 = vld1_u8((nk_u8_t const *)src); }
|
|
106
106
|
|
|
107
|
-
#pragma endregion
|
|
107
|
+
#pragma endregion Type Punned Loads and Stores
|
|
108
108
|
|
|
109
|
-
#pragma region
|
|
109
|
+
#pragma region Vectorized Conversions
|
|
110
110
|
|
|
111
|
-
/** @brief Convert 4x e4m3 → f32x4 via
|
|
112
|
-
*
|
|
113
|
-
*
|
|
111
|
+
/** @brief Convert 4x e4m3 → f32x4 via Giesen magic-multiply (NEON).
|
|
112
|
+
* Reinterprets magnitude bits as a tiny f32, then multiplies by 2^(127-bias) to rebias.
|
|
113
|
+
* Handles zero, subnormals, and normals in a single VMUL. NaN fixup for magnitude 0x7F.
|
|
114
|
+
* https://fgiesen.wordpress.com/2012/03/28/half-to-float-done-quic/ */
|
|
114
115
|
NK_INTERNAL float32x4_t nk_e4m3x4_to_f32x4_neon_(nk_b32_vec_t src) {
|
|
115
116
|
uint8x8_t e4m3_u8x8 = vcreate_u8(src.u32);
|
|
116
117
|
uint16x8_t e4m3_u16x8 = vmovl_u8(e4m3_u8x8);
|
|
117
118
|
uint32x4_t e4m3_u32x4 = vmovl_u16(vget_low_u16(e4m3_u16x8));
|
|
118
|
-
uint32x4_t sign_u32x4 = vshlq_n_u32(vandq_u32(e4m3_u32x4, vdupq_n_u32(0x80)), 24);
|
|
119
|
-
uint32x4_t exp_u32x4 = vandq_u32(vshrq_n_u32(e4m3_u32x4, 3), vdupq_n_u32(0x0F));
|
|
120
|
-
uint32x4_t mant_u32x4 = vandq_u32(e4m3_u32x4, vdupq_n_u32(0x07));
|
|
121
119
|
|
|
122
|
-
//
|
|
123
|
-
uint32x4_t
|
|
124
|
-
|
|
125
|
-
uint32x4_t
|
|
120
|
+
// Extract sign: (raw & 0x80) << 24 → f32 sign bit
|
|
121
|
+
uint32x4_t sign_u32x4 = vshlq_n_u32(vandq_u32(e4m3_u32x4, vdupq_n_u32(0x80)), 24);
|
|
122
|
+
// Strip sign to get 7-bit magnitude, shift left by 20 so E4M3 exponent overlaps f32 exponent
|
|
123
|
+
uint32x4_t nonsign_u32x4 = vandq_u32(e4m3_u32x4, vdupq_n_u32(0x7F));
|
|
124
|
+
uint32x4_t shifted_u32x4 = vshlq_n_u32(nonsign_u32x4, 20);
|
|
126
125
|
|
|
127
|
-
//
|
|
128
|
-
float32x4_t
|
|
129
|
-
|
|
126
|
+
// Magic multiply: reinterpret as f32 × 2^120 rebiases from E4M3 (bias=7) to f32 (bias=127).
|
|
127
|
+
float32x4_t result_f32x4 = vmulq_f32(vreinterpretq_f32_u32(shifted_u32x4),
|
|
128
|
+
vreinterpretq_f32_u32(vdupq_n_u32(0x7B800000))); // 2^120
|
|
130
129
|
|
|
131
|
-
// NaN
|
|
130
|
+
// NaN fixup: E4M3FN NaN only at magnitude 0x7F → force to f32 quiet NaN
|
|
131
|
+
uint32x4_t is_nan_mask_u32x4 = vceqq_u32(nonsign_u32x4, vdupq_n_u32(0x7F));
|
|
132
132
|
uint32x4_t nan_u32x4 = vorrq_u32(sign_u32x4, vdupq_n_u32(0x7FC00000));
|
|
133
|
-
uint32x4_t
|
|
133
|
+
uint32x4_t result_u32x4 = vbslq_u32(is_nan_mask_u32x4, nan_u32x4, vreinterpretq_u32_f32(result_f32x4));
|
|
134
134
|
|
|
135
|
-
//
|
|
136
|
-
|
|
137
|
-
uint32x4_t result_u32x4 = vbslq_u32(exp_zero_mask, subnormal_u32x4, normal_u32x4);
|
|
138
|
-
result_u32x4 = vbslq_u32(is_nan_mask, nan_u32x4, result_u32x4);
|
|
139
|
-
return vreinterpretq_f32_u32(result_u32x4);
|
|
135
|
+
// Restore sign
|
|
136
|
+
return vreinterpretq_f32_u32(vorrq_u32(result_u32x4, sign_u32x4));
|
|
140
137
|
}
|
|
141
138
|
|
|
142
|
-
/** @brief Convert 4x e5m2 → f32x4 via
|
|
143
|
-
*
|
|
144
|
-
* Handles subnormals
|
|
139
|
+
/** @brief Convert 4x e5m2 → f32x4 via Giesen magic-multiply (NEON).
|
|
140
|
+
* Reinterprets magnitude bits as a tiny f32, then multiplies by 2^(127-bias) to rebias.
|
|
141
|
+
* Handles zero, subnormals, and normals in a single VMUL. Inf/NaN fixup for exp=31.
|
|
142
|
+
* https://fgiesen.wordpress.com/2012/03/28/half-to-float-done-quic/ */
|
|
145
143
|
NK_INTERNAL float32x4_t nk_e5m2x4_to_f32x4_neon_(nk_b32_vec_t src) {
|
|
146
144
|
uint8x8_t e5m2_u8x8 = vcreate_u8(src.u32);
|
|
147
145
|
uint16x8_t e5m2_u16x8 = vmovl_u8(e5m2_u8x8);
|
|
148
146
|
uint32x4_t e5m2_u32x4 = vmovl_u16(vget_low_u16(e5m2_u16x8));
|
|
149
|
-
uint32x4_t sign_u32x4 = vshlq_n_u32(vandq_u32(e5m2_u32x4, vdupq_n_u32(0x80)), 24);
|
|
150
|
-
uint32x4_t exp_u32x4 = vandq_u32(vshrq_n_u32(e5m2_u32x4, 2), vdupq_n_u32(0x1F));
|
|
151
|
-
uint32x4_t mant_u32x4 = vandq_u32(e5m2_u32x4, vdupq_n_u32(0x03));
|
|
152
147
|
|
|
153
|
-
//
|
|
154
|
-
uint32x4_t
|
|
155
|
-
|
|
156
|
-
uint32x4_t
|
|
148
|
+
// Extract sign: (raw & 0x80) << 24 → f32 sign bit
|
|
149
|
+
uint32x4_t sign_u32x4 = vshlq_n_u32(vandq_u32(e5m2_u32x4, vdupq_n_u32(0x80)), 24);
|
|
150
|
+
// Strip sign to get 7-bit magnitude, shift left by 21 so E5M2 exponent overlaps f32 exponent
|
|
151
|
+
uint32x4_t nonsign_u32x4 = vandq_u32(e5m2_u32x4, vdupq_n_u32(0x7F));
|
|
152
|
+
uint32x4_t shifted_u32x4 = vshlq_n_u32(nonsign_u32x4, 21);
|
|
157
153
|
|
|
158
|
-
//
|
|
159
|
-
float32x4_t
|
|
160
|
-
|
|
154
|
+
// Magic multiply: reinterpret as f32 × 2^112 rebiases from E5M2 (bias=15) to f32 (bias=127).
|
|
155
|
+
float32x4_t result_f32x4 = vmulq_f32(vreinterpretq_f32_u32(shifted_u32x4),
|
|
156
|
+
vreinterpretq_f32_u32(vdupq_n_u32(0x77800000))); // 2^112
|
|
161
157
|
|
|
162
|
-
//
|
|
163
|
-
uint32x4_t
|
|
164
|
-
uint32x4_t
|
|
165
|
-
|
|
166
|
-
uint32x4_t special_u32x4 = vbslq_u32(mant_zero_mask, infinity_u32x4, nan_u32x4);
|
|
158
|
+
// Inf/NaN fixup: nonsign > 123 means exp=31 → force f32 exponent to 255
|
|
159
|
+
uint32x4_t is_infnan_u32x4 = vcgtq_u32(nonsign_u32x4, vdupq_n_u32(123));
|
|
160
|
+
uint32x4_t result_u32x4 = vorrq_u32(vreinterpretq_u32_f32(result_f32x4),
|
|
161
|
+
vandq_u32(is_infnan_u32x4, vdupq_n_u32(0x7F800000)));
|
|
167
162
|
|
|
168
|
-
//
|
|
169
|
-
|
|
170
|
-
uint32x4_t exp_max_mask = vceqq_u32(exp_u32x4, vdupq_n_u32(31));
|
|
171
|
-
uint32x4_t result_u32x4 = vbslq_u32(exp_zero_mask, subnormal_u32x4, normal_u32x4);
|
|
172
|
-
result_u32x4 = vbslq_u32(exp_max_mask, special_u32x4, result_u32x4);
|
|
173
|
-
return vreinterpretq_f32_u32(result_u32x4);
|
|
163
|
+
// Restore sign
|
|
164
|
+
return vreinterpretq_f32_u32(vorrq_u32(result_u32x4, sign_u32x4));
|
|
174
165
|
}
|
|
175
166
|
|
|
176
167
|
/** @brief Convert 8x e4m3 → f16x8 via bit manipulation (NEON).
|
|
@@ -190,19 +181,20 @@ NK_INTERNAL float16x8_t nk_e4m3x8_to_f16x8_neon_(uint8x8_t e4m3_u8x8) {
|
|
|
190
181
|
// Subnormal path (exp=0, mant ≠ 0): E4M3 subnormal value = mant × 2⁻⁹ = mant ÷ 512
|
|
191
182
|
// Compute arithmetically: mant → f32 → multiply → f16
|
|
192
183
|
float32x4_t subnormal_low_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(mant_u16x8))), 1.0f / 512.0f);
|
|
193
|
-
float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(
|
|
184
|
+
float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_high_u16(mant_u16x8)), 1.0f / 512.0f);
|
|
194
185
|
uint16x8_t subnormal_abs_u16x8 = vreinterpretq_u16_f16(
|
|
195
186
|
vcombine_f16(vcvt_f16_f32(subnormal_low_f32x4), vcvt_f16_f32(subnormal_high_f32x4)));
|
|
196
187
|
uint16x8_t subnormal_u16x8 = vorrq_u16(subnormal_abs_u16x8, sign_u16x8);
|
|
197
188
|
|
|
198
189
|
// NaN path: E4M3FN only has NaN when exp=15 AND mant=7 (0x7F or 0xFF)
|
|
199
190
|
uint16x8_t nan_u16x8 = vorrq_u16(sign_u16x8, vdupq_n_u16(0x7E00)); // F16 quiet NaN
|
|
200
|
-
uint16x8_t
|
|
191
|
+
uint16x8_t is_nan_mask_u16x8 = vandq_u16(vceqq_u16(exp_u16x8, vdupq_n_u16(15)),
|
|
192
|
+
vceqq_u16(mant_u16x8, vdupq_n_u16(7)));
|
|
201
193
|
|
|
202
194
|
// Blend paths: subnormal when exp=0, NaN when exp=15 && mant=7, else normal
|
|
203
|
-
uint16x8_t
|
|
204
|
-
uint16x8_t result_u16x8 = vbslq_u16(
|
|
205
|
-
result_u16x8 = vbslq_u16(
|
|
195
|
+
uint16x8_t exp_zero_mask_u16x8 = vceqq_u16(exp_u16x8, vdupq_n_u16(0));
|
|
196
|
+
uint16x8_t result_u16x8 = vbslq_u16(exp_zero_mask_u16x8, subnormal_u16x8, normal_u16x8);
|
|
197
|
+
result_u16x8 = vbslq_u16(is_nan_mask_u16x8, nan_u16x8, result_u16x8);
|
|
206
198
|
return vreinterpretq_f16_u16(result_u16x8);
|
|
207
199
|
}
|
|
208
200
|
|
|
@@ -232,8 +224,8 @@ NK_INTERNAL void nk_e4m3x16_to_f16x8x2_neon_(uint8x16_t input_u8x16, float16x8_t
|
|
|
232
224
|
0x58, 0x58, 0x59, 0x59, 0x5A, 0x5A, 0x5B, 0x5B, 0x5C, 0x5C, 0x5D, 0x5D, 0x5E, 0x5E, 0x5F, 0x7E,
|
|
233
225
|
};
|
|
234
226
|
|
|
235
|
-
uint8x16x4_t
|
|
236
|
-
uint8x16x4_t
|
|
227
|
+
uint8x16x4_t lut_q0_u8x16x4 = vld1q_u8_x4(table_q0_u8x64);
|
|
228
|
+
uint8x16x4_t lut_q1_u8x16x4 = vld1q_u8_x4(table_q1_u8x64);
|
|
237
229
|
|
|
238
230
|
// Strip sign bit, work with 7-bit absolute value
|
|
239
231
|
uint8x16_t sign_u8x16 = vandq_u8(input_u8x16, vdupq_n_u8(0x80));
|
|
@@ -241,9 +233,9 @@ NK_INTERNAL void nk_e4m3x16_to_f16x8x2_neon_(uint8x16_t input_u8x16, float16x8_t
|
|
|
241
233
|
|
|
242
234
|
// High byte via 2× VQTBL4 on unsigned index, then OR sign back.
|
|
243
235
|
// VQTBL4 returns 0 for out-of-range indices (>= 64), so results OR together cleanly.
|
|
244
|
-
uint8x16_t high_q0_u8x16 = vqtbl4q_u8(
|
|
236
|
+
uint8x16_t high_q0_u8x16 = vqtbl4q_u8(lut_q0_u8x16x4, abs_u8x16);
|
|
245
237
|
uint8x16_t offset_q1_u8x16 = vsubq_u8(abs_u8x16, vdupq_n_u8(64));
|
|
246
|
-
uint8x16_t high_q1_u8x16 = vqtbl4q_u8(
|
|
238
|
+
uint8x16_t high_q1_u8x16 = vqtbl4q_u8(lut_q1_u8x16x4, offset_q1_u8x16);
|
|
247
239
|
uint8x16_t high_bytes_u8x16 = vorrq_u8(vorrq_u8(high_q0_u8x16, high_q1_u8x16), sign_u8x16);
|
|
248
240
|
|
|
249
241
|
// Low byte: (lsb << 7), masked to 0 for subnormals (exp=0) and NaN (exp=15, mant=7)
|
|
@@ -290,14 +282,14 @@ NK_INTERNAL float16x8_t nk_e2m3x8_to_f16x8_neon_(uint8x8_t e2m3_u8x8) {
|
|
|
290
282
|
// Subnormal path (exp=0): E2M3 subnormal = mant / 8
|
|
291
283
|
// Compute via f32: mant → f32 → multiply → f16
|
|
292
284
|
float32x4_t subnormal_low_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(mant_u16x8))), 1.0f / 8.0f);
|
|
293
|
-
float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(
|
|
285
|
+
float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_high_u16(mant_u16x8)), 1.0f / 8.0f);
|
|
294
286
|
uint16x8_t subnormal_abs_u16x8 = vreinterpretq_u16_f16(
|
|
295
287
|
vcombine_f16(vcvt_f16_f32(subnormal_low_f32x4), vcvt_f16_f32(subnormal_high_f32x4)));
|
|
296
288
|
uint16x8_t subnormal_u16x8 = vorrq_u16(subnormal_abs_u16x8, sign_u16x8);
|
|
297
289
|
|
|
298
290
|
// Blend: use subnormal result when exp=0, else normal
|
|
299
|
-
uint16x8_t
|
|
300
|
-
uint16x8_t result_u16x8 = vbslq_u16(
|
|
291
|
+
uint16x8_t exp_zero_mask_u16x8 = vceqq_u16(exp_u16x8, vdupq_n_u16(0));
|
|
292
|
+
uint16x8_t result_u16x8 = vbslq_u16(exp_zero_mask_u16x8, subnormal_u16x8, normal_u16x8);
|
|
301
293
|
|
|
302
294
|
return vreinterpretq_f16_u16(result_u16x8);
|
|
303
295
|
}
|
|
@@ -323,14 +315,14 @@ NK_INTERNAL float16x8_t nk_e3m2x8_to_f16x8_neon_(uint8x8_t e3m2_u8x8) {
|
|
|
323
315
|
// Subnormal path (exp=0): E3M2 subnormal = mant × 2^(-2) × (1/4) = mant / 16
|
|
324
316
|
// Compute via f32: mant → f32 → multiply → f16
|
|
325
317
|
float32x4_t subnormal_low_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(mant_u16x8))), 1.0f / 16.0f);
|
|
326
|
-
float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(
|
|
318
|
+
float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_high_u16(mant_u16x8)), 1.0f / 16.0f);
|
|
327
319
|
uint16x8_t subnormal_abs_u16x8 = vreinterpretq_u16_f16(
|
|
328
320
|
vcombine_f16(vcvt_f16_f32(subnormal_low_f32x4), vcvt_f16_f32(subnormal_high_f32x4)));
|
|
329
321
|
uint16x8_t subnormal_u16x8 = vorrq_u16(subnormal_abs_u16x8, sign_u16x8);
|
|
330
322
|
|
|
331
323
|
// Blend: use subnormal result when exp=0, else normal
|
|
332
|
-
uint16x8_t
|
|
333
|
-
uint16x8_t result_u16x8 = vbslq_u16(
|
|
324
|
+
uint16x8_t exp_zero_mask_u16x8 = vceqq_u16(exp_u16x8, vdupq_n_u16(0));
|
|
325
|
+
uint16x8_t result_u16x8 = vbslq_u16(exp_zero_mask_u16x8, subnormal_u16x8, normal_u16x8);
|
|
334
326
|
|
|
335
327
|
return vreinterpretq_f16_u16(result_u16x8);
|
|
336
328
|
}
|
|
@@ -442,43 +434,43 @@ NK_INTERNAL uint8x8_t nk_f16x8_to_e4m3x8_neon_(float16x8_t f16x8) {
|
|
|
442
434
|
uint16x8_t f16_mant_u16x8 = vandq_u16(bits_u16x8, vdupq_n_u16(0x03FF));
|
|
443
435
|
|
|
444
436
|
// Rebias exponent: F16 bias=15 → E4M3 bias=7, subtract 8
|
|
445
|
-
int16x8_t
|
|
437
|
+
int16x8_t e4m3_exp_i16x8 = vsubq_s16(vreinterpretq_s16_u16(f16_exp_u16x8), vdupq_n_s16(8));
|
|
446
438
|
|
|
447
439
|
// Detect special cases
|
|
448
|
-
uint16x8_t
|
|
449
|
-
uint16x8_t
|
|
450
|
-
uint16x8_t
|
|
451
|
-
uint16x8_t
|
|
452
|
-
uint16x8_t
|
|
440
|
+
uint16x8_t is_f16_zero_u16x8 = vceqq_u16(vandq_u16(bits_u16x8, vdupq_n_u16(0x7FFF)), vdupq_n_u16(0));
|
|
441
|
+
uint16x8_t is_f16_special_u16x8 = vceqq_u16(f16_exp_u16x8, vdupq_n_u16(31)); // inf or nan
|
|
442
|
+
uint16x8_t is_f16_nan_u16x8 = vandq_u16(is_f16_special_u16x8, vcgtq_u16(f16_mant_u16x8, vdupq_n_u16(0)));
|
|
443
|
+
uint16x8_t is_underflow_u16x8 = vcltq_s16(e4m3_exp_i16x8, vdupq_n_s16(1)); // exp < 1 → subnormal/zero
|
|
444
|
+
uint16x8_t is_overflow_u16x8 = vcgtq_s16(e4m3_exp_i16x8, vdupq_n_s16(15)); // exp > 15 → overflow
|
|
453
445
|
|
|
454
446
|
// Normal path with RNE rounding: round mantissa from 10 to 3 bits
|
|
455
447
|
// RNE: add (0x3F + lsb) where lsb = bit 7 of mantissa
|
|
456
448
|
uint16x8_t lsb_u16x8 = vandq_u16(vshrq_n_u16(f16_mant_u16x8, 7), vdupq_n_u16(1));
|
|
457
449
|
uint16x8_t rounded_mant_u16x8 = vaddq_u16(f16_mant_u16x8, vaddq_u16(vdupq_n_u16(0x3F), lsb_u16x8));
|
|
458
450
|
uint16x8_t carry_u16x8 = vshrq_n_u16(rounded_mant_u16x8, 10); // Mantissa overflow → carry to exponent
|
|
459
|
-
|
|
451
|
+
e4m3_exp_i16x8 = vaddq_s16(e4m3_exp_i16x8, vreinterpretq_s16_u16(carry_u16x8));
|
|
460
452
|
uint16x8_t e4m3_mant_u16x8 = vandq_u16(vshrq_n_u16(rounded_mant_u16x8, 7), vdupq_n_u16(0x07));
|
|
461
453
|
e4m3_mant_u16x8 = vbicq_u16(e4m3_mant_u16x8, vceqq_u16(carry_u16x8, vdupq_n_u16(1))); // Clear mant if carry
|
|
462
454
|
|
|
463
455
|
// Recheck overflow after rounding (carry might have pushed us over)
|
|
464
|
-
|
|
456
|
+
is_overflow_u16x8 = vorrq_u16(is_overflow_u16x8, vcgtq_s16(e4m3_exp_i16x8, vdupq_n_s16(15)));
|
|
465
457
|
|
|
466
458
|
// Clamp exponent to [1, 15] for normal values
|
|
467
|
-
int16x8_t
|
|
468
|
-
|
|
459
|
+
int16x8_t clamped_exp_i16x8 = vmaxq_s16(e4m3_exp_i16x8, vdupq_n_s16(1));
|
|
460
|
+
clamped_exp_i16x8 = vminq_s16(clamped_exp_i16x8, vdupq_n_s16(15));
|
|
469
461
|
|
|
470
462
|
// E4M3FN quirk: exp=15, mant=7 is NaN, so clamp mantissa to 6 when exp=15
|
|
471
|
-
uint16x8_t
|
|
472
|
-
e4m3_mant_u16x8 = vbslq_u16(
|
|
463
|
+
uint16x8_t is_max_exp_u16x8 = vceqq_s16(clamped_exp_i16x8, vdupq_n_s16(15));
|
|
464
|
+
e4m3_mant_u16x8 = vbslq_u16(is_max_exp_u16x8, vminq_u16(e4m3_mant_u16x8, vdupq_n_u16(6)), e4m3_mant_u16x8);
|
|
473
465
|
|
|
474
466
|
// Assemble normal result
|
|
475
467
|
uint16x8_t normal_result_u16x8 = vorrq_u16(
|
|
476
|
-
sign_byte_u16x8, vorrq_u16(vshlq_n_u16(vreinterpretq_u16_s16(
|
|
468
|
+
sign_byte_u16x8, vorrq_u16(vshlq_n_u16(vreinterpretq_u16_s16(clamped_exp_i16x8), 3), e4m3_mant_u16x8));
|
|
477
469
|
|
|
478
470
|
// Subnormal path: E4M3 subnormal = mant × 2⁻⁹
|
|
479
471
|
// Use float conversion for correctness: abs(f16) × 512, round to int, clamp to [0,7]
|
|
480
472
|
float32x4_t abs_low_f32x4 = vabsq_f32(vcvt_f32_f16(vget_low_f16(f16x8)));
|
|
481
|
-
float32x4_t abs_high_f32x4 = vabsq_f32(
|
|
473
|
+
float32x4_t abs_high_f32x4 = vabsq_f32(vcvt_high_f32_f16(f16x8));
|
|
482
474
|
float32x4_t scaled_low_f32x4 = vmulq_n_f32(abs_low_f32x4, 512.0f);
|
|
483
475
|
float32x4_t scaled_high_f32x4 = vmulq_n_f32(abs_high_f32x4, 512.0f);
|
|
484
476
|
int32x4_t subnormal_mantissa_low_i32x4 = vcvtnq_s32_f32(scaled_low_f32x4); // Round to nearest even
|
|
@@ -492,17 +484,18 @@ NK_INTERNAL uint8x8_t nk_f16x8_to_e4m3x8_neon_(float16x8_t f16x8) {
|
|
|
492
484
|
uint16x8_t subnormal_result_u16x8 = vorrq_u16(sign_byte_u16x8, subnormal_mant_u16x8);
|
|
493
485
|
|
|
494
486
|
// Special values: E4M3FN has no ∞, max normal = 0x7E (exp=15, mant=6 = 448)
|
|
495
|
-
uint16x8_t
|
|
496
|
-
uint16x8_t
|
|
497
|
-
uint16x8_t
|
|
487
|
+
uint16x8_t e4m3_max_u16x8 = vorrq_u16(sign_byte_u16x8, vdupq_n_u16(0x7E)); // ±448 (exp=15, mant=6)
|
|
488
|
+
uint16x8_t e4m3_nan_u16x8 = vorrq_u16(sign_byte_u16x8, vdupq_n_u16(0x7F)); // ±NaN (exp=15, mant=7)
|
|
489
|
+
uint16x8_t e4m3_zero_u16x8 = sign_byte_u16x8; // ±0
|
|
498
490
|
|
|
499
491
|
// Blend results (order matters: later conditions override earlier)
|
|
500
492
|
uint16x8_t result_u16x8 = normal_result_u16x8;
|
|
501
|
-
result_u16x8 = vbslq_u16(
|
|
502
|
-
result_u16x8 = vbslq_u16(
|
|
503
|
-
result_u16x8 = vbslq_u16(
|
|
504
|
-
|
|
505
|
-
result_u16x8 = vbslq_u16(
|
|
493
|
+
result_u16x8 = vbslq_u16(is_underflow_u16x8, subnormal_result_u16x8, result_u16x8);
|
|
494
|
+
result_u16x8 = vbslq_u16(is_overflow_u16x8, e4m3_max_u16x8, result_u16x8);
|
|
495
|
+
result_u16x8 = vbslq_u16(is_f16_special_u16x8, e4m3_max_u16x8,
|
|
496
|
+
result_u16x8); // F16 inf → E4M3 max (no inf in E4M3FN)
|
|
497
|
+
result_u16x8 = vbslq_u16(is_f16_nan_u16x8, e4m3_nan_u16x8, result_u16x8); // F16 nan → E4M3 nan
|
|
498
|
+
result_u16x8 = vbslq_u16(is_f16_zero_u16x8, e4m3_zero_u16x8, result_u16x8); // Preserve ±0
|
|
506
499
|
|
|
507
500
|
return vmovn_u16(result_u16x8);
|
|
508
501
|
}
|
|
@@ -515,7 +508,7 @@ NK_INTERNAL uint8x8_t nk_f16x8_to_e5m2x8_neon_(float16x8_t f16x8) {
|
|
|
515
508
|
|
|
516
509
|
// Detect inf/nan (exp=31) - these should not be rounded, just truncated
|
|
517
510
|
uint16x8_t exp_u16x8 = vandq_u16(vshrq_n_u16(bits_u16x8, 10), vdupq_n_u16(0x1F));
|
|
518
|
-
uint16x8_t
|
|
511
|
+
uint16x8_t is_special_mask_u16x8 = vceqq_u16(exp_u16x8, vdupq_n_u16(31));
|
|
519
512
|
|
|
520
513
|
// RNE rounding: add (0x7F + lsb) where lsb = bit 8 of F16
|
|
521
514
|
// This rounds the lower 8 bits correctly and may carry into exponent
|
|
@@ -524,7 +517,7 @@ NK_INTERNAL uint8x8_t nk_f16x8_to_e5m2x8_neon_(float16x8_t f16x8) {
|
|
|
524
517
|
uint16x8_t rounded_bits_u16x8 = vaddq_u16(bits_u16x8, rounding_bias_u16x8);
|
|
525
518
|
|
|
526
519
|
// For special values (inf/nan), use original bits without rounding
|
|
527
|
-
uint16x8_t final_bits_u16x8 = vbslq_u16(
|
|
520
|
+
uint16x8_t final_bits_u16x8 = vbslq_u16(is_special_mask_u16x8, bits_u16x8, rounded_bits_u16x8);
|
|
528
521
|
|
|
529
522
|
// Shift right by 8 to get E5M2 format
|
|
530
523
|
uint16x8_t e5m2_u16x8 = vshrq_n_u16(final_bits_u16x8, 8);
|
|
@@ -539,32 +532,6 @@ NK_INTERNAL float32x4_t nk_bf16x4_to_f32x4_neon_(uint16x4_t bf16_u16x4) {
|
|
|
539
532
|
return vreinterpretq_f32_u32(bits_u32x4);
|
|
540
533
|
}
|
|
541
534
|
|
|
542
|
-
/** @brief Convert 4x f16 (as u16 bits) → f32x4 via integer bit manipulation (NEON).
|
|
543
|
-
* F16 format: S EEEEE MMMMMMMMMM (bias=15, 5-bit exponent, 10-bit mantissa).
|
|
544
|
-
* Works on ARMv8.0 without the FP16 arithmetic extension. Treats denormals as zero. */
|
|
545
|
-
NK_INTERNAL float32x4_t nk_f16x4_to_f32x4_neon_(uint16x4_t half_u16x4) {
|
|
546
|
-
// Widen u16 to u32
|
|
547
|
-
uint32x4_t bits_u32x4 = vmovl_u16(half_u16x4);
|
|
548
|
-
// Extract sign, exponent, mantissa
|
|
549
|
-
uint32x4_t sign_u32x4 = vshlq_n_u32(vandq_u32(bits_u32x4, vdupq_n_u32(0x8000)), 16);
|
|
550
|
-
uint32x4_t exponent_u32x4 = vandq_u32(bits_u32x4, vdupq_n_u32(0x7C00));
|
|
551
|
-
uint32x4_t mantissa_u32x4 = vandq_u32(bits_u32x4, vdupq_n_u32(0x03FF));
|
|
552
|
-
// Normal path: ((exponent + mantissa) << 13) + rebias(112 << 23 = 0x38000000)
|
|
553
|
-
uint32x4_t exponent_mantissa_u32x4 = vandq_u32(bits_u32x4, vdupq_n_u32(0x7FFF));
|
|
554
|
-
uint32x4_t normal_u32x4 = vaddq_u32(vshlq_n_u32(exponent_mantissa_u32x4, 13), vdupq_n_u32(0x38000000));
|
|
555
|
-
// Inf/NaN path (exponent == 0x7C00): 0x7F800000 | (mantissa << 13)
|
|
556
|
-
uint32x4_t inf_nan_u32x4 = vorrq_u32(vdupq_n_u32(0x7F800000), vshlq_n_u32(mantissa_u32x4, 13));
|
|
557
|
-
// Select inf/NaN where exponent == 31 (0x7C00)
|
|
558
|
-
uint32x4_t is_inf_nan_u32x4 = vceqq_u32(exponent_u32x4, vdupq_n_u32(0x7C00));
|
|
559
|
-
uint32x4_t result_u32x4 = vbslq_u32(is_inf_nan_u32x4, inf_nan_u32x4, normal_u32x4);
|
|
560
|
-
// Zero path (exponent == 0): treat denormals as zero for simplicity
|
|
561
|
-
uint32x4_t is_zero_u32x4 = vceqq_u32(exponent_u32x4, vdupq_n_u32(0));
|
|
562
|
-
result_u32x4 = vbslq_u32(is_zero_u32x4, vdupq_n_u32(0), result_u32x4);
|
|
563
|
-
// OR sign back
|
|
564
|
-
result_u32x4 = vorrq_u32(result_u32x4, sign_u32x4);
|
|
565
|
-
return vreinterpretq_f32_u32(result_u32x4);
|
|
566
|
-
}
|
|
567
|
-
|
|
568
535
|
/** @brief Convert f32x4 → 4x bf16 with RNE rounding (NEON).
|
|
569
536
|
* Round-to-nearest-even: add (0x7FFF + lsb) before truncation. */
|
|
570
537
|
NK_INTERNAL uint16x4_t nk_f32x4_to_bf16x4_neon_(float32x4_t f32x4) {
|
|
@@ -592,19 +559,20 @@ NK_INTERNAL uint16x8_t nk_e4m3x8_to_bf16x8_neon_(uint8x8_t e4m3_u8x8) {
|
|
|
592
559
|
// Subnormal path (exp=0): E4M3 subnormal = mant × 2⁻⁹ = mant ÷ 512 → BF16
|
|
593
560
|
// Compute via f32: mant → f32 → multiply → truncate to bf16
|
|
594
561
|
float32x4_t subnormal_low_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(mant_u16x8))), 1.0f / 512.0f);
|
|
595
|
-
float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(
|
|
562
|
+
float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_high_u16(mant_u16x8)), 1.0f / 512.0f);
|
|
596
563
|
uint16x8_t subnormal_abs_u16x8 = vcombine_u16(nk_f32x4_to_bf16x4_neon_(subnormal_low_f32x4),
|
|
597
564
|
nk_f32x4_to_bf16x4_neon_(subnormal_high_f32x4));
|
|
598
565
|
uint16x8_t subnormal_u16x8 = vorrq_u16(subnormal_abs_u16x8, sign_u16x8);
|
|
599
566
|
|
|
600
567
|
// NaN path: E4M3FN only has NaN when exp=15 AND mant=7 (0x7F or 0xFF)
|
|
601
568
|
uint16x8_t nan_u16x8 = vorrq_u16(sign_u16x8, vdupq_n_u16(0x7FC0)); // BF16 quiet NaN
|
|
602
|
-
uint16x8_t
|
|
569
|
+
uint16x8_t is_nan_mask_u16x8 = vandq_u16(vceqq_u16(exp_u16x8, vdupq_n_u16(15)),
|
|
570
|
+
vceqq_u16(mant_u16x8, vdupq_n_u16(7)));
|
|
603
571
|
|
|
604
572
|
// Blend paths: subnormal when exp=0, NaN when exp=15 && mant=7, else normal
|
|
605
|
-
uint16x8_t
|
|
606
|
-
uint16x8_t result_u16x8 = vbslq_u16(
|
|
607
|
-
result_u16x8 = vbslq_u16(
|
|
573
|
+
uint16x8_t exp_zero_mask_u16x8 = vceqq_u16(exp_u16x8, vdupq_n_u16(0));
|
|
574
|
+
uint16x8_t result_u16x8 = vbslq_u16(exp_zero_mask_u16x8, subnormal_u16x8, normal_u16x8);
|
|
575
|
+
result_u16x8 = vbslq_u16(is_nan_mask_u16x8, nan_u16x8, result_u16x8);
|
|
608
576
|
return result_u16x8;
|
|
609
577
|
}
|
|
610
578
|
|
|
@@ -625,8 +593,7 @@ NK_INTERNAL uint16x8_t nk_e5m2x8_to_bf16x8_neon_(uint8x8_t e5m2_u8x8) {
|
|
|
625
593
|
// Subnormal path (exp=0): E5M2 subnormal = mant × 2⁻¹⁶ = mant ÷ 65536 → BF16
|
|
626
594
|
// Compute via f32: mant → f32 → multiply → truncate to bf16
|
|
627
595
|
float32x4_t subnormal_low_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(mant_u16x8))), 1.0f / 65536.0f);
|
|
628
|
-
float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(
|
|
629
|
-
1.0f / 65536.0f);
|
|
596
|
+
float32x4_t subnormal_high_f32x4 = vmulq_n_f32(vcvtq_f32_u32(vmovl_high_u16(mant_u16x8)), 1.0f / 65536.0f);
|
|
630
597
|
uint16x8_t subnormal_abs_u16x8 = vcombine_u16(nk_f32x4_to_bf16x4_neon_(subnormal_low_f32x4),
|
|
631
598
|
nk_f32x4_to_bf16x4_neon_(subnormal_high_f32x4));
|
|
632
599
|
uint16x8_t subnormal_u16x8 = vorrq_u16(subnormal_abs_u16x8, sign_u16x8);
|
|
@@ -634,14 +601,14 @@ NK_INTERNAL uint16x8_t nk_e5m2x8_to_bf16x8_neon_(uint8x8_t e5m2_u8x8) {
|
|
|
634
601
|
// Special path (exp=31): inf (mant=0) or nan (mant≠0)
|
|
635
602
|
uint16x8_t infinity_u16x8 = vorrq_u16(sign_u16x8, vdupq_n_u16(0x7F80));
|
|
636
603
|
uint16x8_t nan_u16x8 = vorrq_u16(sign_u16x8, vdupq_n_u16(0x7FC0));
|
|
637
|
-
uint16x8_t
|
|
638
|
-
uint16x8_t special_u16x8 = vbslq_u16(
|
|
604
|
+
uint16x8_t mant_zero_mask_u16x8 = vceqq_u16(mant_u16x8, vdupq_n_u16(0));
|
|
605
|
+
uint16x8_t special_u16x8 = vbslq_u16(mant_zero_mask_u16x8, infinity_u16x8, nan_u16x8);
|
|
639
606
|
|
|
640
607
|
// Blend paths based on exponent value
|
|
641
|
-
uint16x8_t
|
|
642
|
-
uint16x8_t
|
|
643
|
-
uint16x8_t result_u16x8 = vbslq_u16(
|
|
644
|
-
result_u16x8 = vbslq_u16(
|
|
608
|
+
uint16x8_t exp_zero_mask_u16x8 = vceqq_u16(exp_u16x8, vdupq_n_u16(0));
|
|
609
|
+
uint16x8_t exp_max_mask_u16x8 = vceqq_u16(exp_u16x8, vdupq_n_u16(31));
|
|
610
|
+
uint16x8_t result_u16x8 = vbslq_u16(exp_zero_mask_u16x8, subnormal_u16x8, normal_u16x8);
|
|
611
|
+
result_u16x8 = vbslq_u16(exp_max_mask_u16x8, special_u16x8, result_u16x8);
|
|
645
612
|
return result_u16x8;
|
|
646
613
|
}
|
|
647
614
|
|
|
@@ -678,21 +645,23 @@ NK_INTERNAL uint16x4_t nk_f32x4_to_u16x4_neon_(float32x4_t f32x4) {
|
|
|
678
645
|
}
|
|
679
646
|
|
|
680
647
|
/** @brief Convert f32x4 → 4x i8 with saturation (NEON). Convert to i32, narrow twice. */
|
|
681
|
-
NK_INTERNAL
|
|
648
|
+
NK_INTERNAL nk_b32_vec_t nk_f32x4_to_i8x4_neon_(float32x4_t f32x4) {
|
|
682
649
|
int32x4_t i32x4 = vcvtnq_s32_f32(f32x4);
|
|
683
650
|
int16x4_t i16x4 = vqmovn_s32(i32x4);
|
|
684
651
|
int8x8_t i8x8 = vqmovn_s16(vcombine_s16(i16x4, i16x4));
|
|
685
|
-
|
|
686
|
-
|
|
652
|
+
nk_b32_vec_t result_vec;
|
|
653
|
+
result_vec.u32 = vget_lane_u32(vreinterpret_u32_s8(i8x8), 0);
|
|
654
|
+
return result_vec;
|
|
687
655
|
}
|
|
688
656
|
|
|
689
657
|
/** @brief Convert f32x4 → 4x u8 with saturation (NEON). Convert to u32, narrow twice. */
|
|
690
|
-
NK_INTERNAL
|
|
658
|
+
NK_INTERNAL nk_b32_vec_t nk_f32x4_to_u8x4_neon_(float32x4_t f32x4) {
|
|
691
659
|
uint32x4_t u32x4 = vcvtnq_u32_f32(f32x4);
|
|
692
660
|
uint16x4_t u16x4 = vqmovn_u32(u32x4);
|
|
693
661
|
uint8x8_t u8x8 = vqmovn_u16(vcombine_u16(u16x4, u16x4));
|
|
694
|
-
|
|
695
|
-
|
|
662
|
+
nk_b32_vec_t result_vec;
|
|
663
|
+
result_vec.u32 = vget_lane_u32(vreinterpret_u32_u8(u8x8), 0);
|
|
664
|
+
return result_vec;
|
|
696
665
|
}
|
|
697
666
|
|
|
698
667
|
/** @brief Convert f32x4 → 4x e4m3 via bit manipulation (NEON).
|
|
@@ -830,6 +799,8 @@ NK_INTERNAL float32x4_t nk_e2m3x4_to_f32x4_neon_(nk_b32_vec_t src) {
|
|
|
830
799
|
uint8x8_t e2m3_u8x8 = vcreate_u8(src.u32);
|
|
831
800
|
uint16x8_t e2m3_u16x8 = vmovl_u8(e2m3_u8x8);
|
|
832
801
|
uint32x4_t e2m3_u32x4 = vmovl_u16(vget_low_u16(e2m3_u16x8));
|
|
802
|
+
|
|
803
|
+
// Extract sign: bit 5 → bit 31
|
|
833
804
|
uint32x4_t sign_u32x4 = vshlq_n_u32(vandq_u32(e2m3_u32x4, vdupq_n_u32(0x20)), 26);
|
|
834
805
|
uint32x4_t exp_u32x4 = vandq_u32(vshrq_n_u32(e2m3_u32x4, 3), vdupq_n_u32(0x03));
|
|
835
806
|
uint32x4_t mant_u32x4 = vandq_u32(e2m3_u32x4, vdupq_n_u32(0x07));
|
|
@@ -844,8 +815,8 @@ NK_INTERNAL float32x4_t nk_e2m3x4_to_f32x4_neon_(nk_b32_vec_t src) {
|
|
|
844
815
|
uint32x4_t subnormal_u32x4 = vorrq_u32(vreinterpretq_u32_f32(subnormal_f32x4), sign_u32x4);
|
|
845
816
|
|
|
846
817
|
// Blend paths: subnormal when exp=0, else normal
|
|
847
|
-
uint32x4_t
|
|
848
|
-
uint32x4_t result_u32x4 = vbslq_u32(
|
|
818
|
+
uint32x4_t exp_zero_mask_u32x4 = vceqq_u32(exp_u32x4, vdupq_n_u32(0));
|
|
819
|
+
uint32x4_t result_u32x4 = vbslq_u32(exp_zero_mask_u32x4, subnormal_u32x4, normal_u32x4);
|
|
849
820
|
return vreinterpretq_f32_u32(result_u32x4);
|
|
850
821
|
}
|
|
851
822
|
|
|
@@ -856,6 +827,8 @@ NK_INTERNAL float32x4_t nk_e3m2x4_to_f32x4_neon_(nk_b32_vec_t src) {
|
|
|
856
827
|
uint8x8_t e3m2_u8x8 = vcreate_u8(src.u32);
|
|
857
828
|
uint16x8_t e3m2_u16x8 = vmovl_u8(e3m2_u8x8);
|
|
858
829
|
uint32x4_t e3m2_u32x4 = vmovl_u16(vget_low_u16(e3m2_u16x8));
|
|
830
|
+
|
|
831
|
+
// Extract sign: bit 5 → bit 31
|
|
859
832
|
uint32x4_t sign_u32x4 = vshlq_n_u32(vandq_u32(e3m2_u32x4, vdupq_n_u32(0x20)), 26);
|
|
860
833
|
uint32x4_t exp_u32x4 = vandq_u32(vshrq_n_u32(e3m2_u32x4, 2), vdupq_n_u32(0x07));
|
|
861
834
|
uint32x4_t mant_u32x4 = vandq_u32(e3m2_u32x4, vdupq_n_u32(0x03));
|
|
@@ -870,8 +843,8 @@ NK_INTERNAL float32x4_t nk_e3m2x4_to_f32x4_neon_(nk_b32_vec_t src) {
|
|
|
870
843
|
uint32x4_t subnormal_u32x4 = vorrq_u32(vreinterpretq_u32_f32(subnormal_f32x4), sign_u32x4);
|
|
871
844
|
|
|
872
845
|
// Blend paths: subnormal when exp=0, else normal
|
|
873
|
-
uint32x4_t
|
|
874
|
-
uint32x4_t result_u32x4 = vbslq_u32(
|
|
846
|
+
uint32x4_t exp_zero_mask_u32x4 = vceqq_u32(exp_u32x4, vdupq_n_u32(0));
|
|
847
|
+
uint32x4_t result_u32x4 = vbslq_u32(exp_zero_mask_u32x4, subnormal_u32x4, normal_u32x4);
|
|
875
848
|
return vreinterpretq_f32_u32(result_u32x4);
|
|
876
849
|
}
|
|
877
850
|
|
|
@@ -997,9 +970,9 @@ NK_INTERNAL nk_b32_vec_t nk_f32x4_to_e3m2x4_neon_(float32x4_t f32x4) {
|
|
|
997
970
|
return result;
|
|
998
971
|
}
|
|
999
972
|
|
|
1000
|
-
#pragma endregion
|
|
973
|
+
#pragma endregion Vectorized Conversions
|
|
1001
974
|
|
|
1002
|
-
#pragma region
|
|
975
|
+
#pragma region Public API
|
|
1003
976
|
|
|
1004
977
|
NK_PUBLIC void nk_cast_neon(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type) {
|
|
1005
978
|
// Same-type fast path
|
|
@@ -1044,38 +1017,37 @@ NK_PUBLIC void nk_cast_neon(void const *from, nk_dtype_t from_type, nk_size_t n,
|
|
|
1044
1017
|
nk_u8_t *to_ptr = (nk_u8_t *)to;
|
|
1045
1018
|
|
|
1046
1019
|
for (nk_size_t idx = 0; idx < batches; ++idx, from_ptr += from_step, to_ptr += to_step) {
|
|
1020
|
+
nk_b128_vec_t hub_vec;
|
|
1021
|
+
|
|
1047
1022
|
// Upcast to f16x8 hub
|
|
1048
|
-
float16x8_t hub_f16x8;
|
|
1049
1023
|
switch (from_type) {
|
|
1050
|
-
case nk_e4m3_k:
|
|
1051
|
-
case nk_e5m2_k:
|
|
1052
|
-
case nk_e2m3_k:
|
|
1053
|
-
case nk_e3m2_k:
|
|
1054
|
-
case nk_f16_k:
|
|
1024
|
+
case nk_e4m3_k: hub_vec.u16x8 = vreinterpretq_u16_f16(nk_e4m3x8_to_f16x8_neon_(vld1_u8(from_ptr))); break;
|
|
1025
|
+
case nk_e5m2_k: hub_vec.u16x8 = vreinterpretq_u16_f16(nk_e5m2x8_to_f16x8_neon_(vld1_u8(from_ptr))); break;
|
|
1026
|
+
case nk_e2m3_k: hub_vec.u16x8 = vreinterpretq_u16_f16(nk_e2m3x8_to_f16x8_neon_(vld1_u8(from_ptr))); break;
|
|
1027
|
+
case nk_e3m2_k: hub_vec.u16x8 = vreinterpretq_u16_f16(nk_e3m2x8_to_f16x8_neon_(vld1_u8(from_ptr))); break;
|
|
1028
|
+
case nk_f16_k: hub_vec.u16x8 = vld1q_u16((nk_u16_t const *)from_ptr); break;
|
|
1055
1029
|
case nk_bf16_k: {
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
float32x4_t ieee_high_f32x4 = nk_bf16x4_to_f32x4_neon_(brain_high_u16x4);
|
|
1060
|
-
hub_f16x8 = vcombine_f16(vcvt_f16_f32(ieee_low_f32x4), vcvt_f16_f32(ieee_high_f32x4));
|
|
1030
|
+
float32x4_t low_f32x4 = nk_bf16x4_to_f32x4_neon_(vld1_u16((nk_u16_t const *)from_ptr));
|
|
1031
|
+
float32x4_t high_f32x4 = nk_bf16x4_to_f32x4_neon_(vld1_u16((nk_u16_t const *)(from_ptr + 8)));
|
|
1032
|
+
hub_vec.u16x8 = vreinterpretq_u16_f16(vcombine_f16(vcvt_f16_f32(low_f32x4), vcvt_f16_f32(high_f32x4)));
|
|
1061
1033
|
} break;
|
|
1062
|
-
default:
|
|
1034
|
+
default: hub_vec.u16x8 = vdupq_n_u16(0); break;
|
|
1063
1035
|
}
|
|
1064
1036
|
|
|
1065
1037
|
// Downcast from f16x8 hub
|
|
1066
1038
|
switch (to_type) {
|
|
1067
|
-
case nk_e4m3_k: vst1_u8(to_ptr, nk_f16x8_to_e4m3x8_neon_(
|
|
1068
|
-
case nk_e5m2_k: vst1_u8(to_ptr, nk_f16x8_to_e5m2x8_neon_(
|
|
1069
|
-
case nk_f16_k: vst1q_u16((nk_u16_t *)to_ptr,
|
|
1039
|
+
case nk_e4m3_k: vst1_u8(to_ptr, nk_f16x8_to_e4m3x8_neon_(vreinterpretq_f16_u16(hub_vec.u16x8))); break;
|
|
1040
|
+
case nk_e5m2_k: vst1_u8(to_ptr, nk_f16x8_to_e5m2x8_neon_(vreinterpretq_f16_u16(hub_vec.u16x8))); break;
|
|
1041
|
+
case nk_f16_k: vst1q_u16((nk_u16_t *)to_ptr, hub_vec.u16x8); break;
|
|
1070
1042
|
case nk_bf16_k: {
|
|
1071
|
-
float32x4_t
|
|
1072
|
-
float32x4_t
|
|
1073
|
-
vst1_u16((nk_u16_t *)to_ptr, nk_f32x4_to_bf16x4_neon_(
|
|
1074
|
-
vst1_u16((nk_u16_t *)(to_ptr + 8), nk_f32x4_to_bf16x4_neon_(
|
|
1043
|
+
float32x4_t low_f32x4 = vcvt_f32_f16(vget_low_f16(vreinterpretq_f16_u16(hub_vec.u16x8)));
|
|
1044
|
+
float32x4_t high_f32x4 = vcvt_high_f32_f16(vreinterpretq_f16_u16(hub_vec.u16x8));
|
|
1045
|
+
vst1_u16((nk_u16_t *)to_ptr, nk_f32x4_to_bf16x4_neon_(low_f32x4));
|
|
1046
|
+
vst1_u16((nk_u16_t *)(to_ptr + 8), nk_f32x4_to_bf16x4_neon_(high_f32x4));
|
|
1075
1047
|
} break;
|
|
1076
1048
|
case nk_f32_k: {
|
|
1077
|
-
vst1q_f32((nk_f32_t *)to_ptr, vcvt_f32_f16(vget_low_f16(
|
|
1078
|
-
vst1q_f32((nk_f32_t *)(to_ptr + 16),
|
|
1049
|
+
vst1q_f32((nk_f32_t *)to_ptr, vcvt_f32_f16(vget_low_f16(vreinterpretq_f16_u16(hub_vec.u16x8))));
|
|
1050
|
+
vst1q_f32((nk_f32_t *)(to_ptr + 16), vcvt_high_f32_f16(vreinterpretq_f16_u16(hub_vec.u16x8)));
|
|
1079
1051
|
} break;
|
|
1080
1052
|
default: break;
|
|
1081
1053
|
}
|
|
@@ -1097,76 +1069,71 @@ NK_PUBLIC void nk_cast_neon(void const *from, nk_dtype_t from_type, nk_size_t n,
|
|
|
1097
1069
|
nk_u8_t *to_ptr = (nk_u8_t *)to;
|
|
1098
1070
|
|
|
1099
1071
|
for (nk_size_t idx = 0; idx < batches; ++idx, from_ptr += from_step, to_ptr += to_step) {
|
|
1100
|
-
|
|
1101
|
-
|
|
1072
|
+
nk_b128_vec_t hub_vec;
|
|
1073
|
+
|
|
1074
|
+
// Upcast to f32x4 hub
|
|
1102
1075
|
switch (from_type) {
|
|
1103
|
-
case nk_f32_k:
|
|
1104
|
-
case nk_f16_k:
|
|
1105
|
-
case nk_bf16_k:
|
|
1106
|
-
case nk_e4m3_k:
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
nk_b32_vec_t
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
case
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
case
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
case nk_u8_k: {
|
|
1136
|
-
nk_b32_vec_t in_vec;
|
|
1137
|
-
nk_load_b32_serial_(from_ptr, &in_vec);
|
|
1138
|
-
hub_f32x4 = nk_u8x4_to_f32x4_neon_(in_vec);
|
|
1139
|
-
} break;
|
|
1140
|
-
default: hub_f32x4 = vdupq_n_f32(0); break;
|
|
1076
|
+
case nk_f32_k: hub_vec.f32x4 = vld1q_f32((nk_f32_t const *)from_ptr); break;
|
|
1077
|
+
case nk_f16_k: hub_vec.f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16((nk_u16_t const *)from_ptr))); break;
|
|
1078
|
+
case nk_bf16_k: hub_vec.f32x4 = nk_bf16x4_to_f32x4_neon_(vld1_u16((nk_u16_t const *)from_ptr)); break;
|
|
1079
|
+
case nk_e4m3_k:
|
|
1080
|
+
hub_vec.u32s[0] = *(nk_u32_t const *)from_ptr;
|
|
1081
|
+
hub_vec.f32x4 = nk_e4m3x4_to_f32x4_neon_(*(nk_b32_vec_t *)&hub_vec);
|
|
1082
|
+
break;
|
|
1083
|
+
case nk_e5m2_k:
|
|
1084
|
+
hub_vec.u32s[0] = *(nk_u32_t const *)from_ptr;
|
|
1085
|
+
hub_vec.f32x4 = nk_e5m2x4_to_f32x4_neon_(*(nk_b32_vec_t *)&hub_vec);
|
|
1086
|
+
break;
|
|
1087
|
+
case nk_e2m3_k:
|
|
1088
|
+
hub_vec.u32s[0] = *(nk_u32_t const *)from_ptr;
|
|
1089
|
+
hub_vec.f32x4 = nk_e2m3x4_to_f32x4_neon_(*(nk_b32_vec_t *)&hub_vec);
|
|
1090
|
+
break;
|
|
1091
|
+
case nk_e3m2_k:
|
|
1092
|
+
hub_vec.u32s[0] = *(nk_u32_t const *)from_ptr;
|
|
1093
|
+
hub_vec.f32x4 = nk_e3m2x4_to_f32x4_neon_(*(nk_b32_vec_t *)&hub_vec);
|
|
1094
|
+
break;
|
|
1095
|
+
case nk_i32_k: hub_vec.f32x4 = vcvtq_f32_s32(vld1q_s32((nk_i32_t const *)from_ptr)); break;
|
|
1096
|
+
case nk_u32_k: hub_vec.f32x4 = vcvtq_f32_u32(vld1q_u32((nk_u32_t const *)from_ptr)); break;
|
|
1097
|
+
case nk_i16_k: hub_vec.f32x4 = nk_i16x4_to_f32x4_neon_(vld1_s16((nk_i16_t const *)from_ptr)); break;
|
|
1098
|
+
case nk_u16_k: hub_vec.f32x4 = nk_u16x4_to_f32x4_neon_(vld1_u16((nk_u16_t const *)from_ptr)); break;
|
|
1099
|
+
case nk_i8_k:
|
|
1100
|
+
hub_vec.u32s[0] = *(nk_u32_t const *)from_ptr;
|
|
1101
|
+
hub_vec.f32x4 = nk_i8x4_to_f32x4_neon_(*(nk_b32_vec_t *)&hub_vec);
|
|
1102
|
+
break;
|
|
1103
|
+
case nk_u8_k:
|
|
1104
|
+
hub_vec.u32s[0] = *(nk_u32_t const *)from_ptr;
|
|
1105
|
+
hub_vec.f32x4 = nk_u8x4_to_f32x4_neon_(*(nk_b32_vec_t *)&hub_vec);
|
|
1106
|
+
break;
|
|
1107
|
+
default: hub_vec.f32x4 = vdupq_n_f32(0); break;
|
|
1141
1108
|
}
|
|
1142
1109
|
|
|
1143
|
-
// Downcast from f32x4 and store
|
|
1110
|
+
// Downcast from f32x4 hub and store
|
|
1144
1111
|
switch (to_type) {
|
|
1145
|
-
case nk_f32_k: vst1q_f32((nk_f32_t *)to_ptr,
|
|
1146
|
-
case nk_f16_k: vst1_u16((nk_u16_t *)to_ptr, vreinterpret_u16_f16(vcvt_f16_f32(
|
|
1147
|
-
case nk_bf16_k: vst1_u16((nk_u16_t *)to_ptr, nk_f32x4_to_bf16x4_neon_(
|
|
1148
|
-
case nk_e4m3_k:
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
case
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
case
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
case
|
|
1168
|
-
|
|
1169
|
-
|
|
1112
|
+
case nk_f32_k: vst1q_f32((nk_f32_t *)to_ptr, hub_vec.f32x4); break;
|
|
1113
|
+
case nk_f16_k: vst1_u16((nk_u16_t *)to_ptr, vreinterpret_u16_f16(vcvt_f16_f32(hub_vec.f32x4))); break;
|
|
1114
|
+
case nk_bf16_k: vst1_u16((nk_u16_t *)to_ptr, nk_f32x4_to_bf16x4_neon_(hub_vec.f32x4)); break;
|
|
1115
|
+
case nk_e4m3_k:
|
|
1116
|
+
vst1_lane_u32((nk_u32_t *)to_ptr, vcreate_u32(nk_f32x4_to_e4m3x4_neon_(hub_vec.f32x4).u32), 0);
|
|
1117
|
+
break;
|
|
1118
|
+
case nk_e5m2_k:
|
|
1119
|
+
vst1_lane_u32((nk_u32_t *)to_ptr, vcreate_u32(nk_f32x4_to_e5m2x4_neon_(hub_vec.f32x4).u32), 0);
|
|
1120
|
+
break;
|
|
1121
|
+
case nk_e2m3_k:
|
|
1122
|
+
vst1_lane_u32((nk_u32_t *)to_ptr, vcreate_u32(nk_f32x4_to_e2m3x4_neon_(hub_vec.f32x4).u32), 0);
|
|
1123
|
+
break;
|
|
1124
|
+
case nk_e3m2_k:
|
|
1125
|
+
vst1_lane_u32((nk_u32_t *)to_ptr, vcreate_u32(nk_f32x4_to_e3m2x4_neon_(hub_vec.f32x4).u32), 0);
|
|
1126
|
+
break;
|
|
1127
|
+
case nk_i32_k: vst1q_s32((nk_i32_t *)to_ptr, vcvtnq_s32_f32(hub_vec.f32x4)); break;
|
|
1128
|
+
case nk_u32_k: vst1q_u32((nk_u32_t *)to_ptr, vcvtnq_u32_f32(hub_vec.f32x4)); break;
|
|
1129
|
+
case nk_i16_k: vst1_s16((nk_i16_t *)to_ptr, nk_f32x4_to_i16x4_neon_(hub_vec.f32x4)); break;
|
|
1130
|
+
case nk_u16_k: vst1_u16((nk_u16_t *)to_ptr, nk_f32x4_to_u16x4_neon_(hub_vec.f32x4)); break;
|
|
1131
|
+
case nk_i8_k:
|
|
1132
|
+
vst1_lane_u32((nk_u32_t *)to_ptr, vcreate_u32(nk_f32x4_to_i8x4_neon_(hub_vec.f32x4).u32), 0);
|
|
1133
|
+
break;
|
|
1134
|
+
case nk_u8_k:
|
|
1135
|
+
vst1_lane_u32((nk_u32_t *)to_ptr, vcreate_u32(nk_f32x4_to_u8x4_neon_(hub_vec.f32x4).u32), 0);
|
|
1136
|
+
break;
|
|
1170
1137
|
default: break;
|
|
1171
1138
|
}
|
|
1172
1139
|
}
|
|
@@ -1175,7 +1142,7 @@ NK_PUBLIC void nk_cast_neon(void const *from, nk_dtype_t from_type, nk_size_t n,
|
|
|
1175
1142
|
if (tail) nk_cast_serial(from_ptr, from_type, tail, to_ptr, to_type);
|
|
1176
1143
|
}
|
|
1177
1144
|
|
|
1178
|
-
#pragma endregion
|
|
1145
|
+
#pragma endregion Public API
|
|
1179
1146
|
|
|
1180
1147
|
#if defined(__clang__)
|
|
1181
1148
|
#pragma clang attribute pop
|