numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -13,14 +13,32 @@
|
|
|
13
13
|
extern "C" {
|
|
14
14
|
#endif
|
|
15
15
|
|
|
16
|
-
#pragma region
|
|
16
|
+
#pragma region Type Punned Loads and Stores
|
|
17
17
|
|
|
18
18
|
/** @brief Type-agnostic 32-bit full load (scalar). */
|
|
19
19
|
NK_INTERNAL void nk_load_b32_serial_(void const *src, nk_b32_vec_t *dst) { dst->u32 = *(nk_u32_t const *)src; }
|
|
20
20
|
|
|
21
|
+
/** @brief Type-agnostic 64-bit full load. */
|
|
22
|
+
NK_INTERNAL void nk_load_b64_serial_(void const *src, nk_b64_vec_t *dst) { dst->u64 = *(nk_u64_t const *)src; }
|
|
23
|
+
|
|
24
|
+
/** @brief Type-agnostic 128-bit full load. */
|
|
25
|
+
NK_INTERNAL void nk_load_b128_serial_(void const *src, nk_b128_vec_t *dst) {
|
|
26
|
+
nk_u64_t const *s = (nk_u64_t const *)src;
|
|
27
|
+
dst->u64s[0] = s[0], dst->u64s[1] = s[1];
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
/** @brief Type-agnostic 256-bit full load. */
|
|
31
|
+
NK_INTERNAL void nk_load_b256_serial_(void const *src, nk_b256_vec_t *dst) {
|
|
32
|
+
nk_u64_t const *s = (nk_u64_t const *)src;
|
|
33
|
+
dst->u64s[0] = s[0], dst->u64s[1] = s[1], dst->u64s[2] = s[2], dst->u64s[3] = s[3];
|
|
34
|
+
}
|
|
35
|
+
|
|
21
36
|
/** @brief Type-agnostic 32-bit full store (scalar). */
|
|
22
37
|
NK_INTERNAL void nk_store_b32_serial_(nk_b32_vec_t const *src, void *dst) { *(nk_u32_t *)dst = src->u32; }
|
|
23
38
|
|
|
39
|
+
/** @brief Type-agnostic 64-bit full store (scalar). */
|
|
40
|
+
NK_INTERNAL void nk_store_b64_serial_(nk_b64_vec_t const *src, void *dst) { *(nk_u64_t *)dst = src->u64; }
|
|
41
|
+
|
|
24
42
|
/** @brief Type-agnostic 128-bit store (serial, word-by-word). */
|
|
25
43
|
NK_INTERNAL void nk_store_b128_serial_(nk_b128_vec_t const *src, void *dst) {
|
|
26
44
|
nk_u64_t *d = (nk_u64_t *)dst;
|
|
@@ -37,164 +55,681 @@ NK_INTERNAL void nk_store_b256_serial_(nk_b256_vec_t const *src, void *dst) {
|
|
|
37
55
|
d[3] = src->u64s[3];
|
|
38
56
|
}
|
|
39
57
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
* -inf 0xFC00 0xFF800000 Negative infinity
|
|
52
|
-
* NaN 0x7E00 0x7FC00000 Quiet NaN (payload preserved)
|
|
53
|
-
* Min normal 0x0400 0x38800000 2⁻¹⁴
|
|
54
|
-
* Max normal 0x7BFF 0x477FE000 65504
|
|
55
|
-
* Min denorm 0x0001 0x33800000 2⁻²⁴
|
|
56
|
-
* Max denorm 0x03FF 0x387FC000 2⁻¹⁴ - 2⁻²⁴
|
|
57
|
-
*
|
|
58
|
-
* https://stackoverflow.com/a/60047308
|
|
59
|
-
* https://gist.github.com/milhidaka/95863906fe828198f47991c813dbe233
|
|
60
|
-
* https://github.com/OpenCyphal/libcanard/blob/636795f4bc395f56af8d2c61d3757b5e762bb9e5/canard.c#L811-L834
|
|
61
|
-
*/
|
|
62
|
-
NK_PUBLIC void nk_f16_to_f32_serial(nk_f16_t const *src, nk_f32_t *dest) {
|
|
63
|
-
#if NK_NATIVE_F16
|
|
64
|
-
*dest = (nk_f32_t)(*src);
|
|
65
|
-
#else
|
|
66
|
-
unsigned short x;
|
|
67
|
-
nk_copy_bytes_(&x, src, 2);
|
|
68
|
-
|
|
69
|
-
unsigned int sign = (x >> 15) & 1;
|
|
70
|
-
unsigned int exponent = (x >> 10) & 0x1F;
|
|
71
|
-
unsigned int mantissa = x & 0x03FF;
|
|
72
|
-
|
|
73
|
-
nk_fui32_t conv;
|
|
74
|
-
|
|
75
|
-
if (exponent == 0) {
|
|
76
|
-
if (mantissa == 0) {
|
|
77
|
-
// Zero (preserve sign)
|
|
78
|
-
conv.u = sign << 31;
|
|
79
|
-
}
|
|
80
|
-
else {
|
|
81
|
-
// Denormal: value = mantissa × 2⁻²⁴
|
|
82
|
-
// Use FPU normalization, then subtract 24 from exponent
|
|
83
|
-
nk_fui32_t temp;
|
|
84
|
-
temp.f = (float)mantissa;
|
|
85
|
-
conv.u = (sign << 31) | (temp.u - 0x0C000000);
|
|
86
|
-
}
|
|
87
|
-
}
|
|
88
|
-
else if (exponent == 31) {
|
|
89
|
-
// Infinity (mantissa=0) or NaN (mantissa!=0)
|
|
90
|
-
conv.u = (sign << 31) | 0x7F800000 | (mantissa << 13);
|
|
91
|
-
}
|
|
92
|
-
else {
|
|
93
|
-
// Normal: rebias exponent (127-15=112), shift mantissa
|
|
94
|
-
conv.u = (sign << 31) | ((exponent + 112) << 23) | (mantissa << 13);
|
|
58
|
+
/** @brief Type-agnostic partial load for 64-bit elements (4 elements max) into 256-bit vector. */
|
|
59
|
+
NK_INTERNAL void nk_partial_load_b64x4_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
60
|
+
nk_u64_t const *s = (nk_u64_t const *)src;
|
|
61
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
|
|
62
|
+
switch (n) {
|
|
63
|
+
default:
|
|
64
|
+
case 4: dst->u64s[3] = s[3]; // fallthrough
|
|
65
|
+
case 3: dst->u64s[2] = s[2]; // fallthrough
|
|
66
|
+
case 2: dst->u64s[1] = s[1]; // fallthrough
|
|
67
|
+
case 1: dst->u64s[0] = s[0]; // fallthrough
|
|
68
|
+
case 0: break;
|
|
95
69
|
}
|
|
96
|
-
|
|
97
|
-
*dest = conv.f;
|
|
98
|
-
#endif
|
|
99
70
|
}
|
|
100
71
|
|
|
101
|
-
/**
|
|
102
|
-
|
|
103
|
-
*
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
* NaN 0x7FC00000 0x7E00 Quiet NaN (payload truncated)
|
|
112
|
-
* 1.0 0x3F800000 0x3C00 Normal number
|
|
113
|
-
* 65504 0x477FE000 0x7BFF Max f16 normal
|
|
114
|
-
* 65520+ >0x477FE000 0x7C00 Overflow → infinity
|
|
115
|
-
* 2⁻¹⁴ 0x38800000 0x0400 Min f16 normal
|
|
116
|
-
* 2⁻²⁴ 0x33800000 0x0001 Min f16 denormal
|
|
117
|
-
* <2⁻²⁵ <0x33000000 0x0000 Underflow → zero
|
|
118
|
-
*
|
|
119
|
-
* https://stackoverflow.com/a/60047308
|
|
120
|
-
* https://gist.github.com/milhidaka/95863906fe828198f47991c813dbe233
|
|
121
|
-
* https://github.com/OpenCyphal/libcanard/blob/636795f4bc395f56af8d2c61d3757b5e762bb9e5/canard.c#L811-L834
|
|
122
|
-
*/
|
|
123
|
-
NK_PUBLIC void nk_f32_to_f16_serial(nk_f32_t const *src, nk_f16_t *dest) {
|
|
124
|
-
#if NK_NATIVE_F16
|
|
125
|
-
*dest = (nk_f16_t)(*src);
|
|
126
|
-
#else
|
|
127
|
-
nk_fui32_t conv;
|
|
128
|
-
conv.f = *src;
|
|
129
|
-
|
|
130
|
-
unsigned int sign = (conv.u >> 31) & 1;
|
|
131
|
-
unsigned int exponent = (conv.u >> 23) & 0xFF;
|
|
132
|
-
unsigned int mantissa = conv.u & 0x007FFFFF;
|
|
133
|
-
|
|
134
|
-
unsigned short result;
|
|
135
|
-
|
|
136
|
-
if (exponent == 0) {
|
|
137
|
-
// Zero or f32 denormal → f16 zero
|
|
138
|
-
result = (unsigned short)(sign << 15);
|
|
139
|
-
}
|
|
140
|
-
else if (exponent == 255) {
|
|
141
|
-
// Infinity or NaN
|
|
142
|
-
unsigned short payload = (unsigned short)(mantissa >> 13);
|
|
143
|
-
if (mantissa != 0 && payload == 0) payload = 1; // Preserve NaN-ness
|
|
144
|
-
result = (unsigned short)((sign << 15) | 0x7C00 | payload);
|
|
145
|
-
}
|
|
146
|
-
else if (exponent <= 102) {
|
|
147
|
-
// Below or at f16 denormal threshold
|
|
148
|
-
// exp=102 with mant=0 is exactly 2^-25 (tie point, rounds to 0 per round-to-even)
|
|
149
|
-
// exp=102 with mant>0 is above tie point (rounds to smallest denormal 0x0001)
|
|
150
|
-
if (exponent == 102 && mantissa > 0) result = (unsigned short)((sign << 15) | 0x0001);
|
|
151
|
-
else result = (unsigned short)(sign << 15);
|
|
72
|
+
/** @brief Type-agnostic partial store for 64-bit elements (4 elements max) from 256-bit vector. */
|
|
73
|
+
NK_INTERNAL void nk_partial_store_b64x4_serial_(nk_b256_vec_t const *src, void *dst, nk_size_t n) {
|
|
74
|
+
nk_u64_t *d = (nk_u64_t *)dst;
|
|
75
|
+
switch (n) {
|
|
76
|
+
default:
|
|
77
|
+
case 4: d[3] = src->u64s[3]; // fallthrough
|
|
78
|
+
case 3: d[2] = src->u64s[2]; // fallthrough
|
|
79
|
+
case 2: d[1] = src->u64s[1]; // fallthrough
|
|
80
|
+
case 1: d[0] = src->u64s[0]; // fallthrough
|
|
81
|
+
case 0: break;
|
|
152
82
|
}
|
|
153
|
-
|
|
154
|
-
// F16 denormal range (exp 103-112) with IEEE 754 round-to-nearest-even
|
|
155
|
-
unsigned int shift = 113 - exponent;
|
|
156
|
-
unsigned int shift_amount = shift + 13;
|
|
157
|
-
unsigned long long full_mant = 0x00800000ULL | mantissa;
|
|
158
|
-
|
|
159
|
-
// Extract result before rounding
|
|
160
|
-
unsigned int mant = (unsigned int)(full_mant >> shift_amount);
|
|
161
|
-
|
|
162
|
-
// IEEE 754 round-to-nearest-even: round up if round_bit is set AND
|
|
163
|
-
// (sticky_bits are nonzero OR result is odd)
|
|
164
|
-
unsigned int round_bit = (full_mant >> (shift_amount - 1)) & 1;
|
|
165
|
-
unsigned long long sticky_bits = full_mant & ((1ULL << (shift_amount - 1)) - 1);
|
|
166
|
-
|
|
167
|
-
if (round_bit && (sticky_bits || (mant & 1))) mant++;
|
|
83
|
+
}
|
|
168
84
|
|
|
169
|
-
|
|
85
|
+
NK_INTERNAL void nk_partial_load_b64x2_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
86
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0;
|
|
87
|
+
nk_u64_t const *s = (nk_u64_t const *)src;
|
|
88
|
+
switch (n) {
|
|
89
|
+
default:
|
|
90
|
+
case 2: dst->u64s[1] = s[1]; // fallthrough
|
|
91
|
+
case 1: dst->u64s[0] = s[0]; // fallthrough
|
|
92
|
+
case 0: break;
|
|
170
93
|
}
|
|
171
|
-
|
|
172
|
-
// Normal f16 range with IEEE 754 round-to-nearest-even
|
|
173
|
-
unsigned int f16_exp = exponent - 112;
|
|
174
|
-
unsigned int f16_mant = mantissa >> 13;
|
|
94
|
+
}
|
|
175
95
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
96
|
+
/** @brief Type-agnostic partial store for 64-bit elements (2 elements max) from 128-bit vector. */
|
|
97
|
+
NK_INTERNAL void nk_partial_store_b64x2_serial_(nk_b128_vec_t const *src, void *dst, nk_size_t n) {
|
|
98
|
+
nk_u64_t *d = (nk_u64_t *)dst;
|
|
99
|
+
switch (n) {
|
|
100
|
+
default:
|
|
101
|
+
case 2: d[1] = src->u64s[1]; // fallthrough
|
|
102
|
+
case 1: d[0] = src->u64s[0]; // fallthrough
|
|
103
|
+
case 0: break;
|
|
104
|
+
}
|
|
105
|
+
}
|
|
179
106
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
107
|
+
/** @brief Type-agnostic partial load for 32-bit elements (8 elements max) into 256-bit vector. */
|
|
108
|
+
NK_INTERNAL void nk_partial_load_b32x8_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
109
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
|
|
110
|
+
nk_u32_t const *s = (nk_u32_t const *)src;
|
|
111
|
+
switch (n) {
|
|
112
|
+
default:
|
|
113
|
+
case 8: dst->u32s[7] = s[7]; // fallthrough
|
|
114
|
+
case 7: dst->u32s[6] = s[6]; // fallthrough
|
|
115
|
+
case 6: dst->u32s[5] = s[5]; // fallthrough
|
|
116
|
+
case 5: dst->u32s[4] = s[4]; // fallthrough
|
|
117
|
+
case 4: dst->u32s[3] = s[3]; // fallthrough
|
|
118
|
+
case 3: dst->u32s[2] = s[2]; // fallthrough
|
|
119
|
+
case 2: dst->u32s[1] = s[1]; // fallthrough
|
|
120
|
+
case 1: dst->u32s[0] = s[0]; // fallthrough
|
|
121
|
+
case 0: break;
|
|
122
|
+
}
|
|
123
|
+
}
|
|
184
124
|
|
|
185
|
-
|
|
186
|
-
|
|
125
|
+
/** @brief Type-agnostic partial store for 32-bit elements (8 elements max) from 256-bit vector. */
|
|
126
|
+
NK_INTERNAL void nk_partial_store_b32x8_serial_(nk_b256_vec_t const *src, void *dst, nk_size_t n) {
|
|
127
|
+
nk_u32_t *d = (nk_u32_t *)dst;
|
|
128
|
+
switch (n) {
|
|
129
|
+
default:
|
|
130
|
+
case 8: d[7] = src->u32s[7]; // fallthrough
|
|
131
|
+
case 7: d[6] = src->u32s[6]; // fallthrough
|
|
132
|
+
case 6: d[5] = src->u32s[5]; // fallthrough
|
|
133
|
+
case 5: d[4] = src->u32s[4]; // fallthrough
|
|
134
|
+
case 4: d[3] = src->u32s[3]; // fallthrough
|
|
135
|
+
case 3: d[2] = src->u32s[2]; // fallthrough
|
|
136
|
+
case 2: d[1] = src->u32s[1]; // fallthrough
|
|
137
|
+
case 1: d[0] = src->u32s[0]; // fallthrough
|
|
138
|
+
case 0: break;
|
|
187
139
|
}
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
/** @brief Type-agnostic partial load for 32-bit elements (4 elements max) into 128-bit vector. */
|
|
143
|
+
NK_INTERNAL void nk_partial_load_b32x4_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
144
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0;
|
|
145
|
+
nk_u32_t const *s = (nk_u32_t const *)src;
|
|
146
|
+
switch (n) {
|
|
147
|
+
default:
|
|
148
|
+
case 4: dst->u32s[3] = s[3]; // fallthrough
|
|
149
|
+
case 3: dst->u32s[2] = s[2]; // fallthrough
|
|
150
|
+
case 2: dst->u32s[1] = s[1]; // fallthrough
|
|
151
|
+
case 1: dst->u32s[0] = s[0]; // fallthrough
|
|
152
|
+
case 0: break;
|
|
191
153
|
}
|
|
154
|
+
}
|
|
192
155
|
|
|
193
|
-
|
|
194
|
-
|
|
156
|
+
/** @brief Type-agnostic partial store for 32-bit elements (4 elements max) from 128-bit vector. */
|
|
157
|
+
NK_INTERNAL void nk_partial_store_b32x4_serial_(nk_b128_vec_t const *src, void *dst, nk_size_t n) {
|
|
158
|
+
nk_u32_t *d = (nk_u32_t *)dst;
|
|
159
|
+
switch (n) {
|
|
160
|
+
default:
|
|
161
|
+
case 4: d[3] = src->u32s[3]; // fallthrough
|
|
162
|
+
case 3: d[2] = src->u32s[2]; // fallthrough
|
|
163
|
+
case 2: d[1] = src->u32s[1]; // fallthrough
|
|
164
|
+
case 1: d[0] = src->u32s[0]; // fallthrough
|
|
165
|
+
case 0: break;
|
|
166
|
+
}
|
|
195
167
|
}
|
|
196
168
|
|
|
197
|
-
/**
|
|
169
|
+
/** @brief Type-agnostic partial load for 32-bit elements (2 elements max) into 64-bit vector. */
|
|
170
|
+
NK_INTERNAL void nk_partial_load_b32x2_serial_(void const *src, nk_b64_vec_t *dst, nk_size_t n) {
|
|
171
|
+
dst->u64 = 0;
|
|
172
|
+
nk_u32_t const *s = (nk_u32_t const *)src;
|
|
173
|
+
switch (n) {
|
|
174
|
+
default:
|
|
175
|
+
case 2: dst->u32s[1] = s[1]; // fallthrough
|
|
176
|
+
case 1: dst->u32s[0] = s[0]; // fallthrough
|
|
177
|
+
case 0: break;
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
/** @brief Type-agnostic partial load for 16-bit elements (8 elements max) into 128-bit vector. */
|
|
182
|
+
NK_INTERNAL void nk_partial_load_b16x8_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
183
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0;
|
|
184
|
+
nk_u16_t const *s = (nk_u16_t const *)src;
|
|
185
|
+
switch (n) {
|
|
186
|
+
default:
|
|
187
|
+
case 8: dst->u16s[7] = s[7]; // fallthrough
|
|
188
|
+
case 7: dst->u16s[6] = s[6]; // fallthrough
|
|
189
|
+
case 6: dst->u16s[5] = s[5]; // fallthrough
|
|
190
|
+
case 5: dst->u16s[4] = s[4]; // fallthrough
|
|
191
|
+
case 4: dst->u16s[3] = s[3]; // fallthrough
|
|
192
|
+
case 3: dst->u16s[2] = s[2]; // fallthrough
|
|
193
|
+
case 2: dst->u16s[1] = s[1]; // fallthrough
|
|
194
|
+
case 1: dst->u16s[0] = s[0]; // fallthrough
|
|
195
|
+
case 0: break;
|
|
196
|
+
}
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
/** @brief Type-agnostic partial store for 16-bit elements (8 elements max) from 128-bit vector. */
|
|
200
|
+
NK_INTERNAL void nk_partial_store_b16x8_serial_(nk_b128_vec_t const *src, void *dst, nk_size_t n) {
|
|
201
|
+
nk_u16_t *d = (nk_u16_t *)dst;
|
|
202
|
+
switch (n) {
|
|
203
|
+
default:
|
|
204
|
+
case 8: d[7] = src->u16s[7]; // fallthrough
|
|
205
|
+
case 7: d[6] = src->u16s[6]; // fallthrough
|
|
206
|
+
case 6: d[5] = src->u16s[5]; // fallthrough
|
|
207
|
+
case 5: d[4] = src->u16s[4]; // fallthrough
|
|
208
|
+
case 4: d[3] = src->u16s[3]; // fallthrough
|
|
209
|
+
case 3: d[2] = src->u16s[2]; // fallthrough
|
|
210
|
+
case 2: d[1] = src->u16s[1]; // fallthrough
|
|
211
|
+
case 1: d[0] = src->u16s[0]; // fallthrough
|
|
212
|
+
case 0: break;
|
|
213
|
+
}
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
/** @brief Type-agnostic partial load for 16-bit elements (16 elements max) into 256-bit vector. */
|
|
217
|
+
NK_INTERNAL void nk_partial_load_b16x16_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
218
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
|
|
219
|
+
nk_u16_t const *s = (nk_u16_t const *)src;
|
|
220
|
+
switch (n) {
|
|
221
|
+
default:
|
|
222
|
+
case 16: dst->u16s[15] = s[15]; // fallthrough
|
|
223
|
+
case 15: dst->u16s[14] = s[14]; // fallthrough
|
|
224
|
+
case 14: dst->u16s[13] = s[13]; // fallthrough
|
|
225
|
+
case 13: dst->u16s[12] = s[12]; // fallthrough
|
|
226
|
+
case 12: dst->u16s[11] = s[11]; // fallthrough
|
|
227
|
+
case 11: dst->u16s[10] = s[10]; // fallthrough
|
|
228
|
+
case 10: dst->u16s[9] = s[9]; // fallthrough
|
|
229
|
+
case 9: dst->u16s[8] = s[8]; // fallthrough
|
|
230
|
+
case 8: dst->u16s[7] = s[7]; // fallthrough
|
|
231
|
+
case 7: dst->u16s[6] = s[6]; // fallthrough
|
|
232
|
+
case 6: dst->u16s[5] = s[5]; // fallthrough
|
|
233
|
+
case 5: dst->u16s[4] = s[4]; // fallthrough
|
|
234
|
+
case 4: dst->u16s[3] = s[3]; // fallthrough
|
|
235
|
+
case 3: dst->u16s[2] = s[2]; // fallthrough
|
|
236
|
+
case 2: dst->u16s[1] = s[1]; // fallthrough
|
|
237
|
+
case 1: dst->u16s[0] = s[0]; // fallthrough
|
|
238
|
+
case 0: break;
|
|
239
|
+
}
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
/** @brief Type-agnostic partial store for 16-bit elements (16 elements max) from 256-bit vector. */
|
|
243
|
+
NK_INTERNAL void nk_partial_store_b16x16_serial_(nk_b256_vec_t const *src, void *dst, nk_size_t n) {
|
|
244
|
+
nk_u16_t *d = (nk_u16_t *)dst;
|
|
245
|
+
switch (n) {
|
|
246
|
+
default:
|
|
247
|
+
case 16: d[15] = src->u16s[15]; // fallthrough
|
|
248
|
+
case 15: d[14] = src->u16s[14]; // fallthrough
|
|
249
|
+
case 14: d[13] = src->u16s[13]; // fallthrough
|
|
250
|
+
case 13: d[12] = src->u16s[12]; // fallthrough
|
|
251
|
+
case 12: d[11] = src->u16s[11]; // fallthrough
|
|
252
|
+
case 11: d[10] = src->u16s[10]; // fallthrough
|
|
253
|
+
case 10: d[9] = src->u16s[9]; // fallthrough
|
|
254
|
+
case 9: d[8] = src->u16s[8]; // fallthrough
|
|
255
|
+
case 8: d[7] = src->u16s[7]; // fallthrough
|
|
256
|
+
case 7: d[6] = src->u16s[6]; // fallthrough
|
|
257
|
+
case 6: d[5] = src->u16s[5]; // fallthrough
|
|
258
|
+
case 5: d[4] = src->u16s[4]; // fallthrough
|
|
259
|
+
case 4: d[3] = src->u16s[3]; // fallthrough
|
|
260
|
+
case 3: d[2] = src->u16s[2]; // fallthrough
|
|
261
|
+
case 2: d[1] = src->u16s[1]; // fallthrough
|
|
262
|
+
case 1: d[0] = src->u16s[0]; // fallthrough
|
|
263
|
+
case 0: break;
|
|
264
|
+
}
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
/** @brief Type-agnostic partial load for 16-bit elements (4 elements max) into 64-bit vector. */
|
|
268
|
+
NK_INTERNAL void nk_partial_load_b16x4_serial_(void const *src, nk_b64_vec_t *dst, nk_size_t n) {
|
|
269
|
+
dst->u64 = 0;
|
|
270
|
+
nk_u16_t const *s = (nk_u16_t const *)src;
|
|
271
|
+
switch (n) {
|
|
272
|
+
default:
|
|
273
|
+
case 4: dst->u16s[3] = s[3]; // fallthrough
|
|
274
|
+
case 3: dst->u16s[2] = s[2]; // fallthrough
|
|
275
|
+
case 2: dst->u16s[1] = s[1]; // fallthrough
|
|
276
|
+
case 1: dst->u16s[0] = s[0]; // fallthrough
|
|
277
|
+
case 0: break;
|
|
278
|
+
}
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
/** @brief Type-agnostic partial store for 16-bit elements (4 elements max) from 64-bit vector. */
|
|
282
|
+
NK_INTERNAL void nk_partial_store_b16x4_serial_(void *dst, nk_b64_vec_t const *src, nk_size_t n) {
|
|
283
|
+
nk_u16_t *d = (nk_u16_t *)dst;
|
|
284
|
+
switch (n) {
|
|
285
|
+
default:
|
|
286
|
+
case 4: d[3] = src->u16s[3]; // fallthrough
|
|
287
|
+
case 3: d[2] = src->u16s[2]; // fallthrough
|
|
288
|
+
case 2: d[1] = src->u16s[1]; // fallthrough
|
|
289
|
+
case 1: d[0] = src->u16s[0]; // fallthrough
|
|
290
|
+
case 0: break;
|
|
291
|
+
}
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
/** @brief Type-agnostic partial load for 8-bit elements (8 elements max) into 64-bit vector. */
|
|
295
|
+
NK_INTERNAL void nk_partial_load_b8x8_serial_(void const *src, nk_b64_vec_t *dst, nk_size_t n) {
|
|
296
|
+
dst->u64 = 0;
|
|
297
|
+
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
298
|
+
switch (n) {
|
|
299
|
+
default:
|
|
300
|
+
case 8: dst->u8s[7] = s[7]; // fallthrough
|
|
301
|
+
case 7: dst->u8s[6] = s[6]; // fallthrough
|
|
302
|
+
case 6: dst->u8s[5] = s[5]; // fallthrough
|
|
303
|
+
case 5: dst->u8s[4] = s[4]; // fallthrough
|
|
304
|
+
case 4: dst->u8s[3] = s[3]; // fallthrough
|
|
305
|
+
case 3: dst->u8s[2] = s[2]; // fallthrough
|
|
306
|
+
case 2: dst->u8s[1] = s[1]; // fallthrough
|
|
307
|
+
case 1: dst->u8s[0] = s[0]; // fallthrough
|
|
308
|
+
case 0: break;
|
|
309
|
+
}
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
/** @brief Type-agnostic partial store for 8-bit elements (8 elements max) from 64-bit vector. */
|
|
313
|
+
NK_INTERNAL void nk_partial_store_b8x8_serial_(nk_b64_vec_t const *src, void *dst, nk_size_t n) {
|
|
314
|
+
nk_u8_t *d = (nk_u8_t *)dst;
|
|
315
|
+
switch (n) {
|
|
316
|
+
default:
|
|
317
|
+
case 8: d[7] = src->u8s[7]; // fallthrough
|
|
318
|
+
case 7: d[6] = src->u8s[6]; // fallthrough
|
|
319
|
+
case 6: d[5] = src->u8s[5]; // fallthrough
|
|
320
|
+
case 5: d[4] = src->u8s[4]; // fallthrough
|
|
321
|
+
case 4: d[3] = src->u8s[3]; // fallthrough
|
|
322
|
+
case 3: d[2] = src->u8s[2]; // fallthrough
|
|
323
|
+
case 2: d[1] = src->u8s[1]; // fallthrough
|
|
324
|
+
case 1: d[0] = src->u8s[0]; // fallthrough
|
|
325
|
+
case 0: break;
|
|
326
|
+
}
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
/** @brief Type-agnostic partial store for 8-bit elements (16 elements max) from 128-bit vector. */
|
|
330
|
+
NK_INTERNAL void nk_partial_store_b8x16_serial_(nk_b128_vec_t const *src, void *dst, nk_size_t n) {
|
|
331
|
+
nk_u8_t *d = (nk_u8_t *)dst;
|
|
332
|
+
switch (n) {
|
|
333
|
+
default:
|
|
334
|
+
case 16: d[15] = src->u8s[15]; // fallthrough
|
|
335
|
+
case 15: d[14] = src->u8s[14]; // fallthrough
|
|
336
|
+
case 14: d[13] = src->u8s[13]; // fallthrough
|
|
337
|
+
case 13: d[12] = src->u8s[12]; // fallthrough
|
|
338
|
+
case 12: d[11] = src->u8s[11]; // fallthrough
|
|
339
|
+
case 11: d[10] = src->u8s[10]; // fallthrough
|
|
340
|
+
case 10: d[9] = src->u8s[9]; // fallthrough
|
|
341
|
+
case 9: d[8] = src->u8s[8]; // fallthrough
|
|
342
|
+
case 8: d[7] = src->u8s[7]; // fallthrough
|
|
343
|
+
case 7: d[6] = src->u8s[6]; // fallthrough
|
|
344
|
+
case 6: d[5] = src->u8s[5]; // fallthrough
|
|
345
|
+
case 5: d[4] = src->u8s[4]; // fallthrough
|
|
346
|
+
case 4: d[3] = src->u8s[3]; // fallthrough
|
|
347
|
+
case 3: d[2] = src->u8s[2]; // fallthrough
|
|
348
|
+
case 2: d[1] = src->u8s[1]; // fallthrough
|
|
349
|
+
case 1: d[0] = src->u8s[0]; // fallthrough
|
|
350
|
+
case 0: break;
|
|
351
|
+
}
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
/** @brief Type-agnostic partial store for 8-bit elements (32 elements max) from 256-bit vector. */
|
|
355
|
+
NK_INTERNAL void nk_partial_store_b8x32_serial_(nk_b256_vec_t const *src, void *dst, nk_size_t n) {
|
|
356
|
+
nk_u8_t *d = (nk_u8_t *)dst;
|
|
357
|
+
switch (n) {
|
|
358
|
+
default:
|
|
359
|
+
case 32: d[31] = src->u8s[31]; // fallthrough
|
|
360
|
+
case 31: d[30] = src->u8s[30]; // fallthrough
|
|
361
|
+
case 30: d[29] = src->u8s[29]; // fallthrough
|
|
362
|
+
case 29: d[28] = src->u8s[28]; // fallthrough
|
|
363
|
+
case 28: d[27] = src->u8s[27]; // fallthrough
|
|
364
|
+
case 27: d[26] = src->u8s[26]; // fallthrough
|
|
365
|
+
case 26: d[25] = src->u8s[25]; // fallthrough
|
|
366
|
+
case 25: d[24] = src->u8s[24]; // fallthrough
|
|
367
|
+
case 24: d[23] = src->u8s[23]; // fallthrough
|
|
368
|
+
case 23: d[22] = src->u8s[22]; // fallthrough
|
|
369
|
+
case 22: d[21] = src->u8s[21]; // fallthrough
|
|
370
|
+
case 21: d[20] = src->u8s[20]; // fallthrough
|
|
371
|
+
case 20: d[19] = src->u8s[19]; // fallthrough
|
|
372
|
+
case 19: d[18] = src->u8s[18]; // fallthrough
|
|
373
|
+
case 18: d[17] = src->u8s[17]; // fallthrough
|
|
374
|
+
case 17: d[16] = src->u8s[16]; // fallthrough
|
|
375
|
+
case 16: d[15] = src->u8s[15]; // fallthrough
|
|
376
|
+
case 15: d[14] = src->u8s[14]; // fallthrough
|
|
377
|
+
case 14: d[13] = src->u8s[13]; // fallthrough
|
|
378
|
+
case 13: d[12] = src->u8s[12]; // fallthrough
|
|
379
|
+
case 12: d[11] = src->u8s[11]; // fallthrough
|
|
380
|
+
case 11: d[10] = src->u8s[10]; // fallthrough
|
|
381
|
+
case 10: d[9] = src->u8s[9]; // fallthrough
|
|
382
|
+
case 9: d[8] = src->u8s[8]; // fallthrough
|
|
383
|
+
case 8: d[7] = src->u8s[7]; // fallthrough
|
|
384
|
+
case 7: d[6] = src->u8s[6]; // fallthrough
|
|
385
|
+
case 6: d[5] = src->u8s[5]; // fallthrough
|
|
386
|
+
case 5: d[4] = src->u8s[4]; // fallthrough
|
|
387
|
+
case 4: d[3] = src->u8s[3]; // fallthrough
|
|
388
|
+
case 3: d[2] = src->u8s[2]; // fallthrough
|
|
389
|
+
case 2: d[1] = src->u8s[1]; // fallthrough
|
|
390
|
+
case 1: d[0] = src->u8s[0]; // fallthrough
|
|
391
|
+
case 0: break;
|
|
392
|
+
}
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
/** @brief Type-agnostic partial load for 8-bit elements (16 elements max) into 128-bit vector. */
|
|
396
|
+
NK_INTERNAL void nk_partial_load_b8x16_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
397
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0;
|
|
398
|
+
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
399
|
+
switch (n) {
|
|
400
|
+
default:
|
|
401
|
+
case 16: dst->u8s[15] = s[15]; // fallthrough
|
|
402
|
+
case 15: dst->u8s[14] = s[14]; // fallthrough
|
|
403
|
+
case 14: dst->u8s[13] = s[13]; // fallthrough
|
|
404
|
+
case 13: dst->u8s[12] = s[12]; // fallthrough
|
|
405
|
+
case 12: dst->u8s[11] = s[11]; // fallthrough
|
|
406
|
+
case 11: dst->u8s[10] = s[10]; // fallthrough
|
|
407
|
+
case 10: dst->u8s[9] = s[9]; // fallthrough
|
|
408
|
+
case 9: dst->u8s[8] = s[8]; // fallthrough
|
|
409
|
+
case 8: dst->u8s[7] = s[7]; // fallthrough
|
|
410
|
+
case 7: dst->u8s[6] = s[6]; // fallthrough
|
|
411
|
+
case 6: dst->u8s[5] = s[5]; // fallthrough
|
|
412
|
+
case 5: dst->u8s[4] = s[4]; // fallthrough
|
|
413
|
+
case 4: dst->u8s[3] = s[3]; // fallthrough
|
|
414
|
+
case 3: dst->u8s[2] = s[2]; // fallthrough
|
|
415
|
+
case 2: dst->u8s[1] = s[1]; // fallthrough
|
|
416
|
+
case 1: dst->u8s[0] = s[0]; // fallthrough
|
|
417
|
+
case 0: break;
|
|
418
|
+
}
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
/** @brief Type-agnostic partial load for 8-bit elements (4 elements max) into 32-bit vector. */
|
|
422
|
+
NK_INTERNAL nk_b32_vec_t nk_partial_load_b8x4_serial_(void const *src, nk_size_t n) {
|
|
423
|
+
nk_b32_vec_t dst = {0};
|
|
424
|
+
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
425
|
+
switch (n) {
|
|
426
|
+
default:
|
|
427
|
+
case 4: dst.u8s[3] = s[3]; // fallthrough
|
|
428
|
+
case 3: dst.u8s[2] = s[2]; // fallthrough
|
|
429
|
+
case 2: dst.u8s[1] = s[1]; // fallthrough
|
|
430
|
+
case 1: dst.u8s[0] = s[0]; // fallthrough
|
|
431
|
+
case 0: break;
|
|
432
|
+
}
|
|
433
|
+
return dst;
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
/** @brief Partial store for 8-bit elements (up to 4) from nk_b32_vec_t. */
|
|
437
|
+
NK_INTERNAL void nk_partial_store_b8x4_serial_(nk_b32_vec_t const *src, void *dst, nk_size_t n) {
|
|
438
|
+
nk_u8_t *d = (nk_u8_t *)dst;
|
|
439
|
+
switch (n) {
|
|
440
|
+
default:
|
|
441
|
+
case 4: d[3] = src->u8s[3]; // fallthrough
|
|
442
|
+
case 3: d[2] = src->u8s[2]; // fallthrough
|
|
443
|
+
case 2: d[1] = src->u8s[1]; // fallthrough
|
|
444
|
+
case 1: d[0] = src->u8s[0]; // fallthrough
|
|
445
|
+
case 0: break;
|
|
446
|
+
}
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
/** @brief Partial load for 8-bit elements (32 max) into 256-bit vector (zeros in remaining slots). */
|
|
450
|
+
NK_INTERNAL void nk_partial_load_b8x32_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
451
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
|
|
452
|
+
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
453
|
+
switch (n) {
|
|
454
|
+
default:
|
|
455
|
+
case 32: dst->u8s[31] = s[31]; // fallthrough
|
|
456
|
+
case 31: dst->u8s[30] = s[30]; // fallthrough
|
|
457
|
+
case 30: dst->u8s[29] = s[29]; // fallthrough
|
|
458
|
+
case 29: dst->u8s[28] = s[28]; // fallthrough
|
|
459
|
+
case 28: dst->u8s[27] = s[27]; // fallthrough
|
|
460
|
+
case 27: dst->u8s[26] = s[26]; // fallthrough
|
|
461
|
+
case 26: dst->u8s[25] = s[25]; // fallthrough
|
|
462
|
+
case 25: dst->u8s[24] = s[24]; // fallthrough
|
|
463
|
+
case 24: dst->u8s[23] = s[23]; // fallthrough
|
|
464
|
+
case 23: dst->u8s[22] = s[22]; // fallthrough
|
|
465
|
+
case 22: dst->u8s[21] = s[21]; // fallthrough
|
|
466
|
+
case 21: dst->u8s[20] = s[20]; // fallthrough
|
|
467
|
+
case 20: dst->u8s[19] = s[19]; // fallthrough
|
|
468
|
+
case 19: dst->u8s[18] = s[18]; // fallthrough
|
|
469
|
+
case 18: dst->u8s[17] = s[17]; // fallthrough
|
|
470
|
+
case 17: dst->u8s[16] = s[16]; // fallthrough
|
|
471
|
+
case 16: dst->u8s[15] = s[15]; // fallthrough
|
|
472
|
+
case 15: dst->u8s[14] = s[14]; // fallthrough
|
|
473
|
+
case 14: dst->u8s[13] = s[13]; // fallthrough
|
|
474
|
+
case 13: dst->u8s[12] = s[12]; // fallthrough
|
|
475
|
+
case 12: dst->u8s[11] = s[11]; // fallthrough
|
|
476
|
+
case 11: dst->u8s[10] = s[10]; // fallthrough
|
|
477
|
+
case 10: dst->u8s[9] = s[9]; // fallthrough
|
|
478
|
+
case 9: dst->u8s[8] = s[8]; // fallthrough
|
|
479
|
+
case 8: dst->u8s[7] = s[7]; // fallthrough
|
|
480
|
+
case 7: dst->u8s[6] = s[6]; // fallthrough
|
|
481
|
+
case 6: dst->u8s[5] = s[5]; // fallthrough
|
|
482
|
+
case 5: dst->u8s[4] = s[4]; // fallthrough
|
|
483
|
+
case 4: dst->u8s[3] = s[3]; // fallthrough
|
|
484
|
+
case 3: dst->u8s[2] = s[2]; // fallthrough
|
|
485
|
+
case 2: dst->u8s[1] = s[1]; // fallthrough
|
|
486
|
+
case 1: dst->u8s[0] = s[0]; // fallthrough
|
|
487
|
+
case 0: break;
|
|
488
|
+
}
|
|
489
|
+
}
|
|
490
|
+
|
|
491
|
+
/** @brief Partial load for 4-bit nibbles (64 max = 32 bytes) into 256-bit vector (zeros in remaining slots). */
|
|
492
|
+
NK_INTERNAL void nk_partial_load_b4x64_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
493
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
|
|
494
|
+
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
495
|
+
nk_size_t n_bytes = nk_size_divide_round_up_(n, 2);
|
|
496
|
+
for (nk_size_t i = 0; i < n_bytes && i < 32; i++) dst->u8s[i] = s[i];
|
|
497
|
+
}
|
|
498
|
+
|
|
499
|
+
/** @brief Partial load for 4-bit nibbles (32 max = 16 bytes) into 128-bit vector (zeros in remaining slots). */
|
|
500
|
+
NK_INTERNAL void nk_partial_load_b4x32_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
501
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0;
|
|
502
|
+
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
503
|
+
nk_size_t n_bytes = nk_size_divide_round_up_(n, 2);
|
|
504
|
+
for (nk_size_t i = 0; i < n_bytes && i < 16; i++) dst->u8s[i] = s[i];
|
|
505
|
+
}
|
|
506
|
+
|
|
507
|
+
/** @brief Partial load for 1-bit elements (128 max = 16 bytes) into 128-bit vector (zeros in remaining slots). */
|
|
508
|
+
NK_INTERNAL void nk_partial_load_b1x128_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n_bits) {
|
|
509
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0;
|
|
510
|
+
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
511
|
+
nk_size_t n_bytes = nk_size_divide_round_up_(n_bits, 8);
|
|
512
|
+
for (nk_size_t i = 0; i < n_bytes && i < 16; i++) dst->u8s[i] = s[i];
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
/** @brief Partial load for binary (u1) data into 256-bit vector, converting n_bits → n_bytes. */
|
|
516
|
+
NK_INTERNAL void nk_partial_load_b1x256_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n_bits) {
|
|
517
|
+
nk_size_t n_bytes = nk_size_divide_round_up_(n_bits, 8);
|
|
518
|
+
nk_partial_load_b8x32_serial_(src, dst, n_bytes);
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
/** @brief Partial load for 4-bit nibbles (16 max = 8 bytes) into 64-bit vector (zeros in remaining slots). */
|
|
522
|
+
NK_INTERNAL void nk_partial_load_b4x16_serial_(void const *src, nk_b64_vec_t *dst, nk_size_t n) {
|
|
523
|
+
dst->u64 = 0;
|
|
524
|
+
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
525
|
+
nk_size_t n_bytes = nk_size_divide_round_up_(n, 2);
|
|
526
|
+
for (nk_size_t i = 0; i < n_bytes && i < 8; i++) ((nk_u8_t *)&dst->u64)[i] = s[i];
|
|
527
|
+
}
|
|
528
|
+
|
|
529
|
+
/** @brief Strided partial load for 32-bit elements (4 max) into 128-bit vector. */
|
|
530
|
+
NK_INTERNAL void nk_strided_load_b32x4_serial_(void const *src, nk_size_t stride_elements, nk_b128_vec_t *dst,
|
|
531
|
+
nk_size_t n) {
|
|
532
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0;
|
|
533
|
+
nk_u32_t const *s = (nk_u32_t const *)src;
|
|
534
|
+
for (nk_size_t i = 0; i < n && i < 4; ++i) dst->u32s[i] = s[i * stride_elements];
|
|
535
|
+
}
|
|
536
|
+
|
|
537
|
+
/** @brief Strided partial load for 16-bit elements (8 max) into 128-bit vector. */
|
|
538
|
+
NK_INTERNAL void nk_strided_load_b16x8_serial_(void const *src, nk_size_t stride_elements, nk_b128_vec_t *dst,
|
|
539
|
+
nk_size_t n) {
|
|
540
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0;
|
|
541
|
+
nk_u16_t const *s = (nk_u16_t const *)src;
|
|
542
|
+
for (nk_size_t i = 0; i < n && i < 8; ++i) dst->u16s[i] = s[i * stride_elements];
|
|
543
|
+
}
|
|
544
|
+
|
|
545
|
+
/** @brief Strided partial load for 8-bit elements (16 max) into 128-bit vector. */
|
|
546
|
+
NK_INTERNAL void nk_strided_load_b8x16_serial_(void const *src, nk_size_t stride_elements, nk_b128_vec_t *dst,
|
|
547
|
+
nk_size_t n) {
|
|
548
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0;
|
|
549
|
+
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
550
|
+
for (nk_size_t i = 0; i < n && i < 16; ++i) dst->u8s[i] = s[i * stride_elements];
|
|
551
|
+
}
|
|
552
|
+
|
|
553
|
+
#pragma endregion Type Punned Loads and Stores
|
|
554
|
+
|
|
555
|
+
/**
|
|
556
|
+
* @brief Expands an `f16` (IEEE-754 16-bit) to a `float`.
|
|
557
|
+
*
|
|
558
|
+
* Handles all IEEE-754 edge cases:
|
|
559
|
+
*
|
|
560
|
+
* Input F16 Hex F32 Hex Description
|
|
561
|
+
* +0 0x0000 0x00000000 Positive zero
|
|
562
|
+
* -0 0x8000 0x80000000 Negative zero
|
|
563
|
+
* +inf 0x7C00 0x7F800000 Positive infinity
|
|
564
|
+
* -inf 0xFC00 0xFF800000 Negative infinity
|
|
565
|
+
* NaN 0x7E00 0x7FC00000 Quiet NaN (payload preserved)
|
|
566
|
+
* Min normal 0x0400 0x38800000 2⁻¹⁴
|
|
567
|
+
* Max normal 0x7BFF 0x477FE000 65504
|
|
568
|
+
* Min denorm 0x0001 0x33800000 2⁻²⁴
|
|
569
|
+
* Max denorm 0x03FF 0x387FC000 2⁻¹⁴ - 2⁻²⁴
|
|
570
|
+
*
|
|
571
|
+
* https://stackoverflow.com/a/60047308
|
|
572
|
+
* https://gist.github.com/milhidaka/95863906fe828198f47991c813dbe233
|
|
573
|
+
* https://github.com/OpenCyphal/libcanard/blob/636795f4bc395f56af8d2c61d3757b5e762bb9e5/canard.c#L811-L834
|
|
574
|
+
*/
|
|
575
|
+
NK_PUBLIC void nk_f16_to_f32_serial(nk_f16_t const *src, nk_f32_t *dest) {
|
|
576
|
+
#if NK_NATIVE_F16
|
|
577
|
+
*dest = (nk_f32_t)(*src);
|
|
578
|
+
#else
|
|
579
|
+
unsigned short x;
|
|
580
|
+
nk_copy_bytes_(&x, src, 2);
|
|
581
|
+
|
|
582
|
+
unsigned int sign = (x >> 15) & 1;
|
|
583
|
+
unsigned int exponent = (x >> 10) & 0x1F;
|
|
584
|
+
unsigned int mantissa = x & 0x03FF;
|
|
585
|
+
|
|
586
|
+
nk_fui32_t conv;
|
|
587
|
+
|
|
588
|
+
if (exponent == 0) {
|
|
589
|
+
if (mantissa == 0) {
|
|
590
|
+
// Zero (preserve sign)
|
|
591
|
+
conv.u = sign << 31;
|
|
592
|
+
}
|
|
593
|
+
else {
|
|
594
|
+
// Denormal: value = mantissa × 2⁻²⁴
|
|
595
|
+
// Use FPU normalization, then subtract 24 from exponent
|
|
596
|
+
nk_fui32_t temp;
|
|
597
|
+
temp.f = (float)mantissa;
|
|
598
|
+
conv.u = (sign << 31) | (temp.u - 0x0C000000);
|
|
599
|
+
}
|
|
600
|
+
}
|
|
601
|
+
else if (exponent == 31) {
|
|
602
|
+
// Infinity (mantissa=0) or NaN (mantissa!=0)
|
|
603
|
+
conv.u = (sign << 31) | 0x7F800000 | (mantissa << 13);
|
|
604
|
+
}
|
|
605
|
+
else {
|
|
606
|
+
// Normal: rebias exponent (127-15=112), shift mantissa
|
|
607
|
+
conv.u = (sign << 31) | ((exponent + 112) << 23) | (mantissa << 13);
|
|
608
|
+
}
|
|
609
|
+
|
|
610
|
+
*dest = conv.f;
|
|
611
|
+
#endif
|
|
612
|
+
}
|
|
613
|
+
|
|
614
|
+
/** @brief Load 4 × f16 from memory and upcast them to 4 × f32. */
|
|
615
|
+
NK_INTERNAL void nk_load_f16x4_to_f32x4_serial_(void const *src, nk_b128_vec_t *dst) {
|
|
616
|
+
nk_f16_t const *scalars = (nk_f16_t const *)src;
|
|
617
|
+
nk_f16_to_f32_serial(scalars + 0, dst->f32s + 0);
|
|
618
|
+
nk_f16_to_f32_serial(scalars + 1, dst->f32s + 1);
|
|
619
|
+
nk_f16_to_f32_serial(scalars + 2, dst->f32s + 2);
|
|
620
|
+
nk_f16_to_f32_serial(scalars + 3, dst->f32s + 3);
|
|
621
|
+
}
|
|
622
|
+
|
|
623
|
+
/** @brief Partial load for up to 4 × f16 with upcast to 4 × f32. */
|
|
624
|
+
NK_INTERNAL void nk_partial_load_f16x4_to_f32x4_serial_(nk_f16_t const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
625
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0;
|
|
626
|
+
switch (n) {
|
|
627
|
+
default:
|
|
628
|
+
case 4: nk_f16_to_f32_serial(src + 3, dst->f32s + 3); // fallthrough
|
|
629
|
+
case 3: nk_f16_to_f32_serial(src + 2, dst->f32s + 2); // fallthrough
|
|
630
|
+
case 2: nk_f16_to_f32_serial(src + 1, dst->f32s + 1); // fallthrough
|
|
631
|
+
case 1: nk_f16_to_f32_serial(src + 0, dst->f32s + 0); // fallthrough
|
|
632
|
+
case 0: break;
|
|
633
|
+
}
|
|
634
|
+
}
|
|
635
|
+
|
|
636
|
+
/**
|
|
637
|
+
* @brief Compresses a `float` to an `f16` (IEEE-754 16-bit).
|
|
638
|
+
*
|
|
639
|
+
* Handles all IEEE-754 edge cases with round-to-nearest:
|
|
640
|
+
*
|
|
641
|
+
* Input F32 Hex F16 Hex Description
|
|
642
|
+
* +0 0x00000000 0x0000 Positive zero
|
|
643
|
+
* -0 0x80000000 0x8000 Negative zero
|
|
644
|
+
* +inf 0x7F800000 0x7C00 Positive infinity
|
|
645
|
+
* -inf 0xFF800000 0xFC00 Negative infinity
|
|
646
|
+
* NaN 0x7FC00000 0x7E00 Quiet NaN (payload truncated)
|
|
647
|
+
* 1.0 0x3F800000 0x3C00 Normal number
|
|
648
|
+
* 65504 0x477FE000 0x7BFF Max f16 normal
|
|
649
|
+
* 65520+ >0x477FE000 0x7C00 Overflow → infinity
|
|
650
|
+
* 2⁻¹⁴ 0x38800000 0x0400 Min f16 normal
|
|
651
|
+
* 2⁻²⁴ 0x33800000 0x0001 Min f16 denormal
|
|
652
|
+
* <2⁻²⁵ <0x33000000 0x0000 Underflow → zero
|
|
653
|
+
*
|
|
654
|
+
* https://stackoverflow.com/a/60047308
|
|
655
|
+
* https://gist.github.com/milhidaka/95863906fe828198f47991c813dbe233
|
|
656
|
+
* https://github.com/OpenCyphal/libcanard/blob/636795f4bc395f56af8d2c61d3757b5e762bb9e5/canard.c#L811-L834
|
|
657
|
+
*/
|
|
658
|
+
NK_PUBLIC void nk_f32_to_f16_serial(nk_f32_t const *src, nk_f16_t *dest) {
|
|
659
|
+
#if NK_NATIVE_F16
|
|
660
|
+
*dest = (nk_f16_t)(*src);
|
|
661
|
+
#else
|
|
662
|
+
nk_fui32_t conv;
|
|
663
|
+
conv.f = *src;
|
|
664
|
+
|
|
665
|
+
unsigned int sign = (conv.u >> 31) & 1;
|
|
666
|
+
unsigned int exponent = (conv.u >> 23) & 0xFF;
|
|
667
|
+
unsigned int mantissa = conv.u & 0x007FFFFF;
|
|
668
|
+
|
|
669
|
+
unsigned short result;
|
|
670
|
+
|
|
671
|
+
if (exponent == 0) {
|
|
672
|
+
// Zero or f32 denormal → f16 zero
|
|
673
|
+
result = (unsigned short)(sign << 15);
|
|
674
|
+
}
|
|
675
|
+
else if (exponent == 255) {
|
|
676
|
+
// Infinity or NaN
|
|
677
|
+
unsigned short payload = (unsigned short)(mantissa >> 13);
|
|
678
|
+
if (mantissa != 0 && payload == 0) payload = 1; // Preserve NaN-ness
|
|
679
|
+
result = (unsigned short)((sign << 15) | 0x7C00 | payload);
|
|
680
|
+
}
|
|
681
|
+
else if (exponent <= 102) {
|
|
682
|
+
// Below or at f16 denormal threshold
|
|
683
|
+
// exp=102 with mant=0 is exactly 2^-25 (tie point, rounds to 0 per round-to-even)
|
|
684
|
+
// exp=102 with mant>0 is above tie point (rounds to smallest denormal 0x0001)
|
|
685
|
+
if (exponent == 102 && mantissa > 0) result = (unsigned short)((sign << 15) | 0x0001);
|
|
686
|
+
else result = (unsigned short)(sign << 15);
|
|
687
|
+
}
|
|
688
|
+
else if (exponent < 113) {
|
|
689
|
+
// F16 denormal range (exp 103-112) with IEEE 754 round-to-nearest-even
|
|
690
|
+
unsigned int shift = 113 - exponent;
|
|
691
|
+
unsigned int shift_amount = shift + 13;
|
|
692
|
+
unsigned long long full_mant = 0x00800000ULL | mantissa;
|
|
693
|
+
|
|
694
|
+
// Extract result before rounding
|
|
695
|
+
unsigned int mant = (unsigned int)(full_mant >> shift_amount);
|
|
696
|
+
|
|
697
|
+
// IEEE 754 round-to-nearest-even: round up if round_bit is set AND
|
|
698
|
+
// (sticky_bits are nonzero OR result is odd)
|
|
699
|
+
unsigned int round_bit = (full_mant >> (shift_amount - 1)) & 1;
|
|
700
|
+
unsigned long long sticky_bits = full_mant & ((1ULL << (shift_amount - 1)) - 1);
|
|
701
|
+
|
|
702
|
+
if (round_bit && (sticky_bits || (mant & 1))) mant++;
|
|
703
|
+
|
|
704
|
+
result = (unsigned short)((sign << 15) | mant);
|
|
705
|
+
}
|
|
706
|
+
else if (exponent < 143) {
|
|
707
|
+
// Normal f16 range with IEEE 754 round-to-nearest-even
|
|
708
|
+
unsigned int f16_exp = exponent - 112;
|
|
709
|
+
unsigned int f16_mant = mantissa >> 13;
|
|
710
|
+
|
|
711
|
+
// IEEE 754 rounding: check round bit (bit 12) and sticky bits (bits 0-11)
|
|
712
|
+
unsigned int round_bit = (mantissa >> 12) & 1;
|
|
713
|
+
unsigned int sticky_bits = mantissa & 0xFFF;
|
|
714
|
+
|
|
715
|
+
if (round_bit && (sticky_bits || (f16_mant & 1))) {
|
|
716
|
+
f16_mant++;
|
|
717
|
+
if (f16_mant > 0x3FF) f16_mant = 0, f16_exp++;
|
|
718
|
+
}
|
|
719
|
+
|
|
720
|
+
if (f16_exp > 30) result = (unsigned short)((sign << 15) | 0x7C00);
|
|
721
|
+
else result = (unsigned short)((sign << 15) | (f16_exp << 10) | f16_mant);
|
|
722
|
+
}
|
|
723
|
+
else {
|
|
724
|
+
// Overflow → infinity
|
|
725
|
+
result = (unsigned short)((sign << 15) | 0x7C00);
|
|
726
|
+
}
|
|
727
|
+
|
|
728
|
+
nk_copy_bytes_(dest, &result, 2);
|
|
729
|
+
#endif
|
|
730
|
+
}
|
|
731
|
+
|
|
732
|
+
/**
|
|
198
733
|
* @brief For compilers that don't natively support the `__bf16` type,
|
|
199
734
|
* upcasts contents into a more conventional `float`.
|
|
200
735
|
*
|
|
@@ -309,8 +844,8 @@ NK_PUBLIC void nk_e4m3_to_f32_serial(nk_e4m3_t const *src, nk_f32_t *dest) {
|
|
|
309
844
|
* NaN 0x7FC00000 0x7F Quiet NaN
|
|
310
845
|
* 1.0 0x3F800000 0x38 Normal (exp=7, mant=0)
|
|
311
846
|
* 448+ >0x43E00000 0x7E Overflow → max
|
|
312
|
-
* 2⁻⁶
|
|
313
|
-
*
|
|
847
|
+
* 2⁻⁶ 0x3E800000 0x08 Min normal
|
|
848
|
+
* ≤2⁻¹⁰ ≤0x3A800000 0x00 Underflow → zero (RNE boundary)
|
|
314
849
|
*
|
|
315
850
|
* References:
|
|
316
851
|
* https://arxiv.org/pdf/2209.05433 (NVIDIA/Intel/Arm FP8 paper)
|
|
@@ -552,8 +1087,8 @@ NK_PUBLIC void nk_e5m2_to_f32_serial(nk_e5m2_t const *src, nk_f32_t *dest) {
|
|
|
552
1087
|
* NaN 0x7FC00000 0x7D Quiet NaN
|
|
553
1088
|
* 1.0 0x3F800000 0x3C Normal (exp=15, mant=0)
|
|
554
1089
|
* 57344+ >0x47600000 0x7C Overflow → infinity
|
|
555
|
-
* 2⁻¹⁴
|
|
556
|
-
*
|
|
1090
|
+
* 2⁻¹⁴ 0x38800000 0x04 Min normal
|
|
1091
|
+
* ≤2⁻¹⁷ ≤0x37000000 0x00 Underflow → zero (RNE boundary)
|
|
557
1092
|
*
|
|
558
1093
|
* References:
|
|
559
1094
|
* https://arxiv.org/pdf/2209.05433 (NVIDIA/Intel/Arm FP8 paper)
|
|
@@ -1050,565 +1585,156 @@ NK_INTERNAL nk_u64_t nk_rint_even_f64_to_u64_serial_(nk_f64_t x) {
|
|
|
1050
1585
|
}
|
|
1051
1586
|
|
|
1052
1587
|
NK_INTERNAL void nk_f32_to_i8_serial(nk_f32_t const *x, nk_i8_t *y) {
|
|
1053
|
-
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1054
|
-
else *y = (nk_i8_t)nk_rint_even_f64_to_i64_serial_(*x > 127.0f ? 127.0 : (*x < -128.0f ? -128.0 : (nk_f64_t)*x));
|
|
1055
|
-
}
|
|
1056
|
-
|
|
1057
|
-
NK_INTERNAL void nk_f32_to_u8_serial(nk_f32_t const *x, nk_u8_t *y) {
|
|
1058
|
-
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1059
|
-
else *y = (nk_u8_t)nk_rint_even_f64_to_u64_serial_(*x > 255.0f ? 255.0 : (*x < 0 ? 0.0 : (nk_f64_t)*x));
|
|
1060
|
-
}
|
|
1061
|
-
|
|
1062
|
-
NK_INTERNAL void nk_f32_to_i16_serial(nk_f32_t const *x, nk_i16_t *y) {
|
|
1063
|
-
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1064
|
-
else
|
|
1065
|
-
*y = (nk_i16_t)nk_rint_even_f64_to_i64_serial_(*x > 32767.0f ? 32767.0
|
|
1066
|
-
: (*x < -32768.0f ? -32768.0 : (nk_f64_t)*x));
|
|
1067
|
-
}
|
|
1068
|
-
|
|
1069
|
-
NK_INTERNAL void nk_f32_to_u16_serial(nk_f32_t const *x, nk_u16_t *y) {
|
|
1070
|
-
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1071
|
-
else *y = (nk_u16_t)nk_rint_even_f64_to_u64_serial_(*x > 65535.0f ? 65535.0 : (*x < 0 ? 0.0 : (nk_f64_t)*x));
|
|
1072
|
-
}
|
|
1073
|
-
|
|
1074
|
-
NK_INTERNAL void nk_f64_to_i8_serial(nk_f64_t const *x, nk_i8_t *y) {
|
|
1075
|
-
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1076
|
-
else *y = (nk_i8_t)nk_rint_even_f64_to_i64_serial_(*x > 127.0 ? 127.0 : (*x < -128.0 ? -128.0 : *x));
|
|
1077
|
-
}
|
|
1078
|
-
|
|
1079
|
-
NK_INTERNAL void nk_f64_to_u8_serial(nk_f64_t const *x, nk_u8_t *y) {
|
|
1080
|
-
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1081
|
-
else *y = (nk_u8_t)nk_rint_even_f64_to_u64_serial_(*x > 255.0 ? 255.0 : (*x < 0 ? 0.0 : *x));
|
|
1082
|
-
}
|
|
1083
|
-
|
|
1084
|
-
NK_INTERNAL void nk_f64_to_i16_serial(nk_f64_t const *x, nk_i16_t *y) {
|
|
1085
|
-
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1086
|
-
else *y = (nk_i16_t)nk_rint_even_f64_to_i64_serial_(*x > 32767.0 ? 32767.0 : (*x < -32768.0 ? -32768.0 : *x));
|
|
1087
|
-
}
|
|
1088
|
-
|
|
1089
|
-
NK_INTERNAL void nk_f64_to_u16_serial(nk_f64_t const *x, nk_u16_t *y) {
|
|
1090
|
-
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1091
|
-
else *y = (nk_u16_t)nk_rint_even_f64_to_u64_serial_(*x > 65535.0 ? 65535.0 : (*x < 0 ? 0.0 : *x));
|
|
1092
|
-
}
|
|
1093
|
-
|
|
1094
|
-
NK_INTERNAL void nk_f64_to_i32_serial(nk_f64_t const *x, nk_i32_t *y) {
|
|
1095
|
-
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1096
|
-
else
|
|
1097
|
-
*y = (nk_i32_t)nk_rint_even_f64_to_i64_serial_(*x > 2147483647.0 ? 2147483647.0
|
|
1098
|
-
: (*x < -2147483648.0 ? -2147483648.0 : *x));
|
|
1099
|
-
}
|
|
1100
|
-
|
|
1101
|
-
NK_INTERNAL void nk_f64_to_u32_serial(nk_f64_t const *x, nk_u32_t *y) {
|
|
1102
|
-
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1103
|
-
else *y = (nk_u32_t)nk_rint_even_f64_to_u64_serial_(*x > 4294967295.0 ? 4294967295.0 : (*x < 0 ? 0.0 : *x));
|
|
1104
|
-
}
|
|
1105
|
-
|
|
1106
|
-
NK_INTERNAL void nk_f64_to_i64_serial(nk_f64_t const *x, nk_i64_t *y) {
|
|
1107
|
-
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1108
|
-
else
|
|
1109
|
-
*y = nk_rint_even_f64_to_i64_serial_(*x > 9223372036854775807.0
|
|
1110
|
-
? 9223372036854775807.0
|
|
1111
|
-
: (*x < -9223372036854775808.0 ? -9223372036854775808.0 : *x));
|
|
1112
|
-
}
|
|
1113
|
-
|
|
1114
|
-
NK_INTERNAL void nk_f64_to_u64_serial(nk_f64_t const *x, nk_u64_t *y) {
|
|
1115
|
-
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1116
|
-
else
|
|
1117
|
-
*y = nk_rint_even_f64_to_u64_serial_(*x > 18446744073709551615.0 ? 18446744073709551615.0
|
|
1118
|
-
: (*x < 0 ? 0.0 : *x));
|
|
1119
|
-
}
|
|
1120
|
-
|
|
1121
|
-
NK_INTERNAL void nk_i64_to_i8_serial(nk_i64_t const *x, nk_i8_t *y) {
|
|
1122
|
-
*y = (nk_i8_t)(*x > 127ll ? 127ll : (*x < -128ll ? -128ll : *x));
|
|
1123
|
-
}
|
|
1124
|
-
|
|
1125
|
-
NK_INTERNAL void nk_i64_to_u8_serial(nk_i64_t const *x, nk_u8_t *y) {
|
|
1126
|
-
*y = (nk_u8_t)(*x > 255ll ? 255ll : (*x < 0ll ? 0ll : *x));
|
|
1127
|
-
}
|
|
1128
|
-
|
|
1129
|
-
NK_INTERNAL void nk_i64_to_i16_serial(nk_i64_t const *x, nk_i16_t *y) {
|
|
1130
|
-
*y = (nk_i16_t)(*x > 32767ll ? 32767ll : (*x < -32768ll ? -32768ll : *x));
|
|
1131
|
-
}
|
|
1132
|
-
|
|
1133
|
-
NK_INTERNAL void nk_i64_to_u16_serial(nk_i64_t const *x, nk_u16_t *y) {
|
|
1134
|
-
*y = (nk_u16_t)(*x > 65535ll ? 65535ll : (*x < 0ll ? 0ll : *x));
|
|
1135
|
-
}
|
|
1136
|
-
|
|
1137
|
-
NK_INTERNAL void nk_i64_to_i32_serial(nk_i64_t const *x, nk_i32_t *y) {
|
|
1138
|
-
*y = (nk_i32_t)(*x > 2147483647ll ? 2147483647ll : (*x < -2147483648ll ? -2147483648ll : *x));
|
|
1139
|
-
}
|
|
1140
|
-
|
|
1141
|
-
NK_INTERNAL void nk_i64_to_u32_serial(nk_i64_t const *x, nk_u32_t *y) {
|
|
1142
|
-
*y = (nk_u32_t)(*x > 4294967295ll ? 4294967295ll : (*x < 0ll ? 0ll : *x));
|
|
1143
|
-
}
|
|
1144
|
-
|
|
1145
|
-
NK_INTERNAL void nk_u64_to_i8_serial(nk_u64_t const *x, nk_i8_t *y) { *y = (nk_i8_t)(*x > 127ull ? 127ull : *x); }
|
|
1146
|
-
NK_INTERNAL void nk_u64_to_u8_serial(nk_u64_t const *x, nk_u8_t *y) { *y = (nk_u8_t)(*x > 255ull ? 255ull : *x); }
|
|
1147
|
-
NK_INTERNAL void nk_u64_to_i16_serial(nk_u64_t const *x, nk_i16_t *y) {
|
|
1148
|
-
*y = (nk_i16_t)(*x > 32767ull ? 32767ull : *x);
|
|
1149
|
-
}
|
|
1150
|
-
NK_INTERNAL void nk_u64_to_u16_serial(nk_u64_t const *x, nk_u16_t *y) {
|
|
1151
|
-
*y = (nk_u16_t)(*x > 65535ull ? 65535ull : *x);
|
|
1152
|
-
}
|
|
1153
|
-
|
|
1154
|
-
NK_INTERNAL void nk_u64_to_i32_serial(nk_u64_t const *x, nk_i32_t *y) {
|
|
1155
|
-
*y = (nk_i32_t)(*x > 2147483647ull ? 2147483647ull : *x);
|
|
1156
|
-
}
|
|
1157
|
-
|
|
1158
|
-
NK_INTERNAL void nk_u64_to_u32_serial(nk_u64_t const *x, nk_u32_t *y) {
|
|
1159
|
-
*y = (nk_u32_t)(*x > 4294967295ull ? 4294967295ull : *x);
|
|
1160
|
-
}
|
|
1161
|
-
|
|
1162
|
-
NK_PUBLIC void nk_f16_to_f64_(nk_f16_t const *src, nk_f64_t *dest) {
|
|
1163
|
-
nk_f32_t f32;
|
|
1164
|
-
nk_f16_to_f32_serial(src, &f32);
|
|
1165
|
-
*dest = f32;
|
|
1166
|
-
}
|
|
1167
|
-
NK_PUBLIC void nk_bf16_to_f64_(nk_bf16_t const *src, nk_f64_t *dest) {
|
|
1168
|
-
nk_f32_t f32;
|
|
1169
|
-
nk_bf16_to_f32_serial(src, &f32);
|
|
1170
|
-
*dest = f32;
|
|
1171
|
-
}
|
|
1172
|
-
|
|
1173
|
-
NK_INTERNAL void nk_u64_to_i64_serial(nk_u64_t const *x, nk_i64_t *y) {
|
|
1174
|
-
*y = (nk_i64_t)(*x >= 9223372036854775807ull ? 9223372036854775807ll : *x);
|
|
1175
|
-
}
|
|
1176
|
-
|
|
1177
|
-
NK_INTERNAL void nk_i8_to_u64_serial(nk_i8_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
|
|
1178
|
-
NK_INTERNAL void nk_i16_to_u64_serial(nk_i16_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
|
|
1179
|
-
NK_INTERNAL void nk_i32_to_u64_serial(nk_i32_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
|
|
1180
|
-
NK_INTERNAL void nk_i64_to_u64_serial(nk_i64_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
|
|
1181
|
-
|
|
1182
|
-
NK_INTERNAL void nk_i64_to_f16_serial(nk_i64_t const *x, nk_f16_t *y) {
|
|
1183
|
-
nk_f32_t f32 = (nk_f32_t)*x;
|
|
1184
|
-
nk_f32_to_f16_serial(&f32, y);
|
|
1185
|
-
}
|
|
1186
|
-
NK_INTERNAL void nk_i64_to_bf16_serial(nk_i64_t const *x, nk_bf16_t *y) {
|
|
1187
|
-
nk_f32_t f32 = (nk_f32_t)*x;
|
|
1188
|
-
nk_f32_to_bf16_serial(&f32, y);
|
|
1189
|
-
}
|
|
1190
|
-
NK_INTERNAL void nk_u64_to_f16_serial(nk_u64_t const *x, nk_f16_t *y) {
|
|
1191
|
-
nk_f32_t f32 = (nk_f32_t)*x;
|
|
1192
|
-
nk_f32_to_f16_serial(&f32, y);
|
|
1193
|
-
}
|
|
1194
|
-
NK_INTERNAL void nk_u64_to_bf16_serial(nk_u64_t const *x, nk_bf16_t *y) {
|
|
1195
|
-
nk_f32_t f32 = (nk_f32_t)*x;
|
|
1196
|
-
nk_f32_to_bf16_serial(&f32, y);
|
|
1197
|
-
}
|
|
1198
|
-
|
|
1199
|
-
#pragma region - Type Punned Loads and Stores
|
|
1200
|
-
|
|
1201
|
-
/** @brief Type-agnostic 256-bit full load. */
|
|
1202
|
-
NK_INTERNAL void nk_load_b256_serial_(void const *src, nk_b256_vec_t *dst) {
|
|
1203
|
-
nk_u64_t const *s = (nk_u64_t const *)src;
|
|
1204
|
-
dst->u64s[0] = s[0], dst->u64s[1] = s[1], dst->u64s[2] = s[2], dst->u64s[3] = s[3];
|
|
1205
|
-
}
|
|
1206
|
-
|
|
1207
|
-
/** @brief Type-agnostic 128-bit full load. */
|
|
1208
|
-
NK_INTERNAL void nk_load_b128_serial_(void const *src, nk_b128_vec_t *dst) {
|
|
1209
|
-
nk_u64_t const *s = (nk_u64_t const *)src;
|
|
1210
|
-
dst->u64s[0] = s[0], dst->u64s[1] = s[1];
|
|
1211
|
-
}
|
|
1212
|
-
|
|
1213
|
-
/** @brief Type-agnostic 64-bit full load. */
|
|
1214
|
-
NK_INTERNAL void nk_load_b64_serial_(void const *src, nk_b64_vec_t *dst) { dst->u64 = *(nk_u64_t const *)src; }
|
|
1215
|
-
|
|
1216
|
-
/** @brief Type-agnostic partial load for 32-bit elements (8 elements max) into 256-bit vector. */
|
|
1217
|
-
NK_INTERNAL void nk_partial_load_b32x8_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
1218
|
-
dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
|
|
1219
|
-
nk_u32_t const *s = (nk_u32_t const *)src;
|
|
1220
|
-
switch (n) {
|
|
1221
|
-
default:
|
|
1222
|
-
case 8: dst->u32s[7] = s[7]; // fallthrough
|
|
1223
|
-
case 7: dst->u32s[6] = s[6]; // fallthrough
|
|
1224
|
-
case 6: dst->u32s[5] = s[5]; // fallthrough
|
|
1225
|
-
case 5: dst->u32s[4] = s[4]; // fallthrough
|
|
1226
|
-
case 4: dst->u32s[3] = s[3]; // fallthrough
|
|
1227
|
-
case 3: dst->u32s[2] = s[2]; // fallthrough
|
|
1228
|
-
case 2: dst->u32s[1] = s[1]; // fallthrough
|
|
1229
|
-
case 1: dst->u32s[0] = s[0]; // fallthrough
|
|
1230
|
-
case 0: break;
|
|
1231
|
-
}
|
|
1232
|
-
}
|
|
1233
|
-
|
|
1234
|
-
/** @brief Type-agnostic partial load for 32-bit elements (4 elements max) into 128-bit vector. */
|
|
1235
|
-
NK_INTERNAL void nk_partial_load_b32x4_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
1236
|
-
dst->u64s[0] = 0, dst->u64s[1] = 0;
|
|
1237
|
-
nk_u32_t const *s = (nk_u32_t const *)src;
|
|
1238
|
-
switch (n) {
|
|
1239
|
-
default:
|
|
1240
|
-
case 4: dst->u32s[3] = s[3]; // fallthrough
|
|
1241
|
-
case 3: dst->u32s[2] = s[2]; // fallthrough
|
|
1242
|
-
case 2: dst->u32s[1] = s[1]; // fallthrough
|
|
1243
|
-
case 1: dst->u32s[0] = s[0]; // fallthrough
|
|
1244
|
-
case 0: break;
|
|
1245
|
-
}
|
|
1246
|
-
}
|
|
1247
|
-
|
|
1248
|
-
/** @brief Type-agnostic partial load for 8-bit elements (8 elements max) into 64-bit vector. */
|
|
1249
|
-
NK_INTERNAL void nk_partial_load_b8x8_serial_(void const *src, nk_b64_vec_t *dst, nk_size_t n) {
|
|
1250
|
-
dst->u64 = 0;
|
|
1251
|
-
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
1252
|
-
switch (n) {
|
|
1253
|
-
default:
|
|
1254
|
-
case 8: dst->u8s[7] = s[7]; // fallthrough
|
|
1255
|
-
case 7: dst->u8s[6] = s[6]; // fallthrough
|
|
1256
|
-
case 6: dst->u8s[5] = s[5]; // fallthrough
|
|
1257
|
-
case 5: dst->u8s[4] = s[4]; // fallthrough
|
|
1258
|
-
case 4: dst->u8s[3] = s[3]; // fallthrough
|
|
1259
|
-
case 3: dst->u8s[2] = s[2]; // fallthrough
|
|
1260
|
-
case 2: dst->u8s[1] = s[1]; // fallthrough
|
|
1261
|
-
case 1: dst->u8s[0] = s[0]; // fallthrough
|
|
1262
|
-
case 0: break;
|
|
1263
|
-
}
|
|
1264
|
-
}
|
|
1265
|
-
|
|
1266
|
-
/** @brief Type-agnostic partial load for 8-bit elements (4 elements max) into 32-bit vector. */
|
|
1267
|
-
NK_INTERNAL nk_b32_vec_t nk_partial_load_b8x4_serial_(void const *src, nk_size_t n) {
|
|
1268
|
-
nk_b32_vec_t dst = {0};
|
|
1269
|
-
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
1270
|
-
switch (n) {
|
|
1271
|
-
default:
|
|
1272
|
-
case 4: dst.u8s[3] = s[3]; // fallthrough
|
|
1273
|
-
case 3: dst.u8s[2] = s[2]; // fallthrough
|
|
1274
|
-
case 2: dst.u8s[1] = s[1]; // fallthrough
|
|
1275
|
-
case 1: dst.u8s[0] = s[0]; // fallthrough
|
|
1276
|
-
case 0: break;
|
|
1277
|
-
}
|
|
1278
|
-
return dst;
|
|
1279
|
-
}
|
|
1280
|
-
|
|
1281
|
-
/** @brief Partial store for 8-bit elements (up to 4) from nk_b32_vec_t. */
|
|
1282
|
-
NK_INTERNAL void nk_partial_store_b8x4_serial_(nk_b32_vec_t const *src, void *dst, nk_size_t n) {
|
|
1283
|
-
nk_u8_t *d = (nk_u8_t *)dst;
|
|
1284
|
-
switch (n) {
|
|
1285
|
-
default:
|
|
1286
|
-
case 4: d[3] = src->u8s[3]; // fallthrough
|
|
1287
|
-
case 3: d[2] = src->u8s[2]; // fallthrough
|
|
1288
|
-
case 2: d[1] = src->u8s[1]; // fallthrough
|
|
1289
|
-
case 1: d[0] = src->u8s[0]; // fallthrough
|
|
1290
|
-
case 0: break;
|
|
1291
|
-
}
|
|
1588
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1589
|
+
else *y = (nk_i8_t)nk_rint_even_f64_to_i64_serial_(*x > 127.0f ? 127.0 : (*x < -128.0f ? -128.0 : (nk_f64_t)*x));
|
|
1292
1590
|
}
|
|
1293
1591
|
|
|
1294
|
-
|
|
1295
|
-
|
|
1296
|
-
|
|
1297
|
-
nk_u16_t const *s = (nk_u16_t const *)src;
|
|
1298
|
-
switch (n) {
|
|
1299
|
-
default:
|
|
1300
|
-
case 8: dst->u16s[7] = s[7]; // fallthrough
|
|
1301
|
-
case 7: dst->u16s[6] = s[6]; // fallthrough
|
|
1302
|
-
case 6: dst->u16s[5] = s[5]; // fallthrough
|
|
1303
|
-
case 5: dst->u16s[4] = s[4]; // fallthrough
|
|
1304
|
-
case 4: dst->u16s[3] = s[3]; // fallthrough
|
|
1305
|
-
case 3: dst->u16s[2] = s[2]; // fallthrough
|
|
1306
|
-
case 2: dst->u16s[1] = s[1]; // fallthrough
|
|
1307
|
-
case 1: dst->u16s[0] = s[0]; // fallthrough
|
|
1308
|
-
case 0: break;
|
|
1309
|
-
}
|
|
1592
|
+
NK_INTERNAL void nk_f32_to_u8_serial(nk_f32_t const *x, nk_u8_t *y) {
|
|
1593
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1594
|
+
else *y = (nk_u8_t)nk_rint_even_f64_to_u64_serial_(*x > 255.0f ? 255.0 : (*x < 0 ? 0.0 : (nk_f64_t)*x));
|
|
1310
1595
|
}
|
|
1311
1596
|
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
default:
|
|
1318
|
-
case 16: dst->u8s[15] = s[15]; // fallthrough
|
|
1319
|
-
case 15: dst->u8s[14] = s[14]; // fallthrough
|
|
1320
|
-
case 14: dst->u8s[13] = s[13]; // fallthrough
|
|
1321
|
-
case 13: dst->u8s[12] = s[12]; // fallthrough
|
|
1322
|
-
case 12: dst->u8s[11] = s[11]; // fallthrough
|
|
1323
|
-
case 11: dst->u8s[10] = s[10]; // fallthrough
|
|
1324
|
-
case 10: dst->u8s[9] = s[9]; // fallthrough
|
|
1325
|
-
case 9: dst->u8s[8] = s[8]; // fallthrough
|
|
1326
|
-
case 8: dst->u8s[7] = s[7]; // fallthrough
|
|
1327
|
-
case 7: dst->u8s[6] = s[6]; // fallthrough
|
|
1328
|
-
case 6: dst->u8s[5] = s[5]; // fallthrough
|
|
1329
|
-
case 5: dst->u8s[4] = s[4]; // fallthrough
|
|
1330
|
-
case 4: dst->u8s[3] = s[3]; // fallthrough
|
|
1331
|
-
case 3: dst->u8s[2] = s[2]; // fallthrough
|
|
1332
|
-
case 2: dst->u8s[1] = s[1]; // fallthrough
|
|
1333
|
-
case 1: dst->u8s[0] = s[0]; // fallthrough
|
|
1334
|
-
case 0: break;
|
|
1335
|
-
}
|
|
1597
|
+
NK_INTERNAL void nk_f32_to_i16_serial(nk_f32_t const *x, nk_i16_t *y) {
|
|
1598
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1599
|
+
else
|
|
1600
|
+
*y = (nk_i16_t)nk_rint_even_f64_to_i64_serial_(*x > 32767.0f ? 32767.0
|
|
1601
|
+
: (*x < -32768.0f ? -32768.0 : (nk_f64_t)*x));
|
|
1336
1602
|
}
|
|
1337
1603
|
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
nk_u16_t const *s = (nk_u16_t const *)src;
|
|
1342
|
-
switch (n) {
|
|
1343
|
-
default:
|
|
1344
|
-
case 16: dst->u16s[15] = s[15]; // fallthrough
|
|
1345
|
-
case 15: dst->u16s[14] = s[14]; // fallthrough
|
|
1346
|
-
case 14: dst->u16s[13] = s[13]; // fallthrough
|
|
1347
|
-
case 13: dst->u16s[12] = s[12]; // fallthrough
|
|
1348
|
-
case 12: dst->u16s[11] = s[11]; // fallthrough
|
|
1349
|
-
case 11: dst->u16s[10] = s[10]; // fallthrough
|
|
1350
|
-
case 10: dst->u16s[9] = s[9]; // fallthrough
|
|
1351
|
-
case 9: dst->u16s[8] = s[8]; // fallthrough
|
|
1352
|
-
case 8: dst->u16s[7] = s[7]; // fallthrough
|
|
1353
|
-
case 7: dst->u16s[6] = s[6]; // fallthrough
|
|
1354
|
-
case 6: dst->u16s[5] = s[5]; // fallthrough
|
|
1355
|
-
case 5: dst->u16s[4] = s[4]; // fallthrough
|
|
1356
|
-
case 4: dst->u16s[3] = s[3]; // fallthrough
|
|
1357
|
-
case 3: dst->u16s[2] = s[2]; // fallthrough
|
|
1358
|
-
case 2: dst->u16s[1] = s[1]; // fallthrough
|
|
1359
|
-
case 1: dst->u16s[0] = s[0]; // fallthrough
|
|
1360
|
-
case 0: break;
|
|
1361
|
-
}
|
|
1604
|
+
NK_INTERNAL void nk_f32_to_u16_serial(nk_f32_t const *x, nk_u16_t *y) {
|
|
1605
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1606
|
+
else *y = (nk_u16_t)nk_rint_even_f64_to_u64_serial_(*x > 65535.0f ? 65535.0 : (*x < 0 ? 0.0 : (nk_f64_t)*x));
|
|
1362
1607
|
}
|
|
1363
1608
|
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
1368
|
-
switch (n) {
|
|
1369
|
-
default:
|
|
1370
|
-
case 32: dst->u8s[31] = s[31]; // fallthrough
|
|
1371
|
-
case 31: dst->u8s[30] = s[30]; // fallthrough
|
|
1372
|
-
case 30: dst->u8s[29] = s[29]; // fallthrough
|
|
1373
|
-
case 29: dst->u8s[28] = s[28]; // fallthrough
|
|
1374
|
-
case 28: dst->u8s[27] = s[27]; // fallthrough
|
|
1375
|
-
case 27: dst->u8s[26] = s[26]; // fallthrough
|
|
1376
|
-
case 26: dst->u8s[25] = s[25]; // fallthrough
|
|
1377
|
-
case 25: dst->u8s[24] = s[24]; // fallthrough
|
|
1378
|
-
case 24: dst->u8s[23] = s[23]; // fallthrough
|
|
1379
|
-
case 23: dst->u8s[22] = s[22]; // fallthrough
|
|
1380
|
-
case 22: dst->u8s[21] = s[21]; // fallthrough
|
|
1381
|
-
case 21: dst->u8s[20] = s[20]; // fallthrough
|
|
1382
|
-
case 20: dst->u8s[19] = s[19]; // fallthrough
|
|
1383
|
-
case 19: dst->u8s[18] = s[18]; // fallthrough
|
|
1384
|
-
case 18: dst->u8s[17] = s[17]; // fallthrough
|
|
1385
|
-
case 17: dst->u8s[16] = s[16]; // fallthrough
|
|
1386
|
-
case 16: dst->u8s[15] = s[15]; // fallthrough
|
|
1387
|
-
case 15: dst->u8s[14] = s[14]; // fallthrough
|
|
1388
|
-
case 14: dst->u8s[13] = s[13]; // fallthrough
|
|
1389
|
-
case 13: dst->u8s[12] = s[12]; // fallthrough
|
|
1390
|
-
case 12: dst->u8s[11] = s[11]; // fallthrough
|
|
1391
|
-
case 11: dst->u8s[10] = s[10]; // fallthrough
|
|
1392
|
-
case 10: dst->u8s[9] = s[9]; // fallthrough
|
|
1393
|
-
case 9: dst->u8s[8] = s[8]; // fallthrough
|
|
1394
|
-
case 8: dst->u8s[7] = s[7]; // fallthrough
|
|
1395
|
-
case 7: dst->u8s[6] = s[6]; // fallthrough
|
|
1396
|
-
case 6: dst->u8s[5] = s[5]; // fallthrough
|
|
1397
|
-
case 5: dst->u8s[4] = s[4]; // fallthrough
|
|
1398
|
-
case 4: dst->u8s[3] = s[3]; // fallthrough
|
|
1399
|
-
case 3: dst->u8s[2] = s[2]; // fallthrough
|
|
1400
|
-
case 2: dst->u8s[1] = s[1]; // fallthrough
|
|
1401
|
-
case 1: dst->u8s[0] = s[0]; // fallthrough
|
|
1402
|
-
case 0: break;
|
|
1403
|
-
}
|
|
1609
|
+
NK_INTERNAL void nk_f64_to_i8_serial(nk_f64_t const *x, nk_i8_t *y) {
|
|
1610
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1611
|
+
else *y = (nk_i8_t)nk_rint_even_f64_to_i64_serial_(*x > 127.0 ? 127.0 : (*x < -128.0 ? -128.0 : *x));
|
|
1404
1612
|
}
|
|
1405
1613
|
|
|
1406
|
-
|
|
1407
|
-
|
|
1408
|
-
|
|
1409
|
-
switch (n) {
|
|
1410
|
-
default:
|
|
1411
|
-
case 8: d[7] = src->u32s[7]; // fallthrough
|
|
1412
|
-
case 7: d[6] = src->u32s[6]; // fallthrough
|
|
1413
|
-
case 6: d[5] = src->u32s[5]; // fallthrough
|
|
1414
|
-
case 5: d[4] = src->u32s[4]; // fallthrough
|
|
1415
|
-
case 4: d[3] = src->u32s[3]; // fallthrough
|
|
1416
|
-
case 3: d[2] = src->u32s[2]; // fallthrough
|
|
1417
|
-
case 2: d[1] = src->u32s[1]; // fallthrough
|
|
1418
|
-
case 1: d[0] = src->u32s[0]; // fallthrough
|
|
1419
|
-
case 0: break;
|
|
1420
|
-
}
|
|
1614
|
+
NK_INTERNAL void nk_f64_to_u8_serial(nk_f64_t const *x, nk_u8_t *y) {
|
|
1615
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1616
|
+
else *y = (nk_u8_t)nk_rint_even_f64_to_u64_serial_(*x > 255.0 ? 255.0 : (*x < 0 ? 0.0 : *x));
|
|
1421
1617
|
}
|
|
1422
1618
|
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
switch (n) {
|
|
1427
|
-
default:
|
|
1428
|
-
case 4: d[3] = src->u32s[3]; // fallthrough
|
|
1429
|
-
case 3: d[2] = src->u32s[2]; // fallthrough
|
|
1430
|
-
case 2: d[1] = src->u32s[1]; // fallthrough
|
|
1431
|
-
case 1: d[0] = src->u32s[0]; // fallthrough
|
|
1432
|
-
case 0: break;
|
|
1433
|
-
}
|
|
1619
|
+
NK_INTERNAL void nk_f64_to_i16_serial(nk_f64_t const *x, nk_i16_t *y) {
|
|
1620
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1621
|
+
else *y = (nk_i16_t)nk_rint_even_f64_to_i64_serial_(*x > 32767.0 ? 32767.0 : (*x < -32768.0 ? -32768.0 : *x));
|
|
1434
1622
|
}
|
|
1435
1623
|
|
|
1436
|
-
|
|
1437
|
-
|
|
1438
|
-
|
|
1439
|
-
switch (n) {
|
|
1440
|
-
default:
|
|
1441
|
-
case 8: d[7] = src->u16s[7]; // fallthrough
|
|
1442
|
-
case 7: d[6] = src->u16s[6]; // fallthrough
|
|
1443
|
-
case 6: d[5] = src->u16s[5]; // fallthrough
|
|
1444
|
-
case 5: d[4] = src->u16s[4]; // fallthrough
|
|
1445
|
-
case 4: d[3] = src->u16s[3]; // fallthrough
|
|
1446
|
-
case 3: d[2] = src->u16s[2]; // fallthrough
|
|
1447
|
-
case 2: d[1] = src->u16s[1]; // fallthrough
|
|
1448
|
-
case 1: d[0] = src->u16s[0]; // fallthrough
|
|
1449
|
-
case 0: break;
|
|
1450
|
-
}
|
|
1624
|
+
NK_INTERNAL void nk_f64_to_u16_serial(nk_f64_t const *x, nk_u16_t *y) {
|
|
1625
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1626
|
+
else *y = (nk_u16_t)nk_rint_even_f64_to_u64_serial_(*x > 65535.0 ? 65535.0 : (*x < 0 ? 0.0 : *x));
|
|
1451
1627
|
}
|
|
1452
1628
|
|
|
1453
|
-
|
|
1454
|
-
|
|
1455
|
-
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
case 4: d[3] = src->u16s[3]; // fallthrough
|
|
1459
|
-
case 3: d[2] = src->u16s[2]; // fallthrough
|
|
1460
|
-
case 2: d[1] = src->u16s[1]; // fallthrough
|
|
1461
|
-
case 1: d[0] = src->u16s[0]; // fallthrough
|
|
1462
|
-
case 0: break;
|
|
1463
|
-
}
|
|
1629
|
+
NK_INTERNAL void nk_f64_to_i32_serial(nk_f64_t const *x, nk_i32_t *y) {
|
|
1630
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1631
|
+
else
|
|
1632
|
+
*y = (nk_i32_t)nk_rint_even_f64_to_i64_serial_(*x > 2147483647.0 ? 2147483647.0
|
|
1633
|
+
: (*x < -2147483648.0 ? -2147483648.0 : *x));
|
|
1464
1634
|
}
|
|
1465
1635
|
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
switch (n) {
|
|
1470
|
-
default:
|
|
1471
|
-
case 8: d[7] = src->u8s[7]; // fallthrough
|
|
1472
|
-
case 7: d[6] = src->u8s[6]; // fallthrough
|
|
1473
|
-
case 6: d[5] = src->u8s[5]; // fallthrough
|
|
1474
|
-
case 5: d[4] = src->u8s[4]; // fallthrough
|
|
1475
|
-
case 4: d[3] = src->u8s[3]; // fallthrough
|
|
1476
|
-
case 3: d[2] = src->u8s[2]; // fallthrough
|
|
1477
|
-
case 2: d[1] = src->u8s[1]; // fallthrough
|
|
1478
|
-
case 1: d[0] = src->u8s[0]; // fallthrough
|
|
1479
|
-
case 0: break;
|
|
1480
|
-
}
|
|
1636
|
+
NK_INTERNAL void nk_f64_to_u32_serial(nk_f64_t const *x, nk_u32_t *y) {
|
|
1637
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1638
|
+
else *y = (nk_u32_t)nk_rint_even_f64_to_u64_serial_(*x > 4294967295.0 ? 4294967295.0 : (*x < 0 ? 0.0 : *x));
|
|
1481
1639
|
}
|
|
1482
1640
|
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
|
|
1489
|
-
case 4: dst->u64s[3] = s[3]; // fallthrough
|
|
1490
|
-
case 3: dst->u64s[2] = s[2]; // fallthrough
|
|
1491
|
-
case 2: dst->u64s[1] = s[1]; // fallthrough
|
|
1492
|
-
case 1: dst->u64s[0] = s[0]; // fallthrough
|
|
1493
|
-
case 0: break;
|
|
1494
|
-
}
|
|
1641
|
+
NK_INTERNAL void nk_f64_to_i64_serial(nk_f64_t const *x, nk_i64_t *y) {
|
|
1642
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1643
|
+
else
|
|
1644
|
+
*y = nk_rint_even_f64_to_i64_serial_(*x > 9223372036854775807.0
|
|
1645
|
+
? 9223372036854775807.0
|
|
1646
|
+
: (*x < -9223372036854775808.0 ? -9223372036854775808.0 : *x));
|
|
1495
1647
|
}
|
|
1496
1648
|
|
|
1497
|
-
|
|
1498
|
-
|
|
1499
|
-
|
|
1500
|
-
|
|
1501
|
-
|
|
1502
|
-
case 4: d[3] = src->u64s[3]; // fallthrough
|
|
1503
|
-
case 3: d[2] = src->u64s[2]; // fallthrough
|
|
1504
|
-
case 2: d[1] = src->u64s[1]; // fallthrough
|
|
1505
|
-
case 1: d[0] = src->u64s[0]; // fallthrough
|
|
1506
|
-
case 0: break;
|
|
1507
|
-
}
|
|
1649
|
+
NK_INTERNAL void nk_f64_to_u64_serial(nk_f64_t const *x, nk_u64_t *y) {
|
|
1650
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1651
|
+
else
|
|
1652
|
+
*y = nk_rint_even_f64_to_u64_serial_(*x > 18446744073709551615.0 ? 18446744073709551615.0
|
|
1653
|
+
: (*x < 0 ? 0.0 : *x));
|
|
1508
1654
|
}
|
|
1509
1655
|
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
|
|
1515
|
-
|
|
1516
|
-
|
|
1517
|
-
|
|
1518
|
-
|
|
1519
|
-
|
|
1656
|
+
NK_INTERNAL void nk_i64_to_i8_serial(nk_i64_t const *x, nk_i8_t *y) {
|
|
1657
|
+
*y = (nk_i8_t)(*x > 127ll ? 127ll : (*x < -128ll ? -128ll : *x));
|
|
1658
|
+
}
|
|
1659
|
+
|
|
1660
|
+
NK_INTERNAL void nk_i64_to_u8_serial(nk_i64_t const *x, nk_u8_t *y) {
|
|
1661
|
+
*y = (nk_u8_t)(*x > 255ll ? 255ll : (*x < 0ll ? 0ll : *x));
|
|
1662
|
+
}
|
|
1663
|
+
|
|
1664
|
+
NK_INTERNAL void nk_i64_to_i16_serial(nk_i64_t const *x, nk_i16_t *y) {
|
|
1665
|
+
*y = (nk_i16_t)(*x > 32767ll ? 32767ll : (*x < -32768ll ? -32768ll : *x));
|
|
1520
1666
|
}
|
|
1521
1667
|
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
dst->u64 = 0;
|
|
1525
|
-
nk_u16_t const *s = (nk_u16_t const *)src;
|
|
1526
|
-
switch (n) {
|
|
1527
|
-
default:
|
|
1528
|
-
case 4: dst->u16s[3] = s[3]; // fallthrough
|
|
1529
|
-
case 3: dst->u16s[2] = s[2]; // fallthrough
|
|
1530
|
-
case 2: dst->u16s[1] = s[1]; // fallthrough
|
|
1531
|
-
case 1: dst->u16s[0] = s[0]; // fallthrough
|
|
1532
|
-
case 0: break;
|
|
1533
|
-
}
|
|
1668
|
+
NK_INTERNAL void nk_i64_to_u16_serial(nk_i64_t const *x, nk_u16_t *y) {
|
|
1669
|
+
*y = (nk_u16_t)(*x > 65535ll ? 65535ll : (*x < 0ll ? 0ll : *x));
|
|
1534
1670
|
}
|
|
1535
1671
|
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
|
|
1539
|
-
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
1540
|
-
nk_size_t n_bytes = nk_size_divide_round_up_(n, 2);
|
|
1541
|
-
for (nk_size_t i = 0; i < n_bytes && i < 32; i++) dst->u8s[i] = s[i];
|
|
1672
|
+
NK_INTERNAL void nk_i64_to_i32_serial(nk_i64_t const *x, nk_i32_t *y) {
|
|
1673
|
+
*y = (nk_i32_t)(*x > 2147483647ll ? 2147483647ll : (*x < -2147483648ll ? -2147483648ll : *x));
|
|
1542
1674
|
}
|
|
1543
1675
|
|
|
1544
|
-
|
|
1545
|
-
|
|
1546
|
-
dst->u64s[0] = 0, dst->u64s[1] = 0;
|
|
1547
|
-
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
1548
|
-
nk_size_t n_bytes = nk_size_divide_round_up_(n, 2);
|
|
1549
|
-
for (nk_size_t i = 0; i < n_bytes && i < 16; i++) dst->u8s[i] = s[i];
|
|
1676
|
+
NK_INTERNAL void nk_i64_to_u32_serial(nk_i64_t const *x, nk_u32_t *y) {
|
|
1677
|
+
*y = (nk_u32_t)(*x > 4294967295ll ? 4294967295ll : (*x < 0ll ? 0ll : *x));
|
|
1550
1678
|
}
|
|
1551
1679
|
|
|
1552
|
-
|
|
1553
|
-
NK_INTERNAL void
|
|
1554
|
-
|
|
1555
|
-
|
|
1556
|
-
|
|
1557
|
-
|
|
1680
|
+
NK_INTERNAL void nk_u64_to_i8_serial(nk_u64_t const *x, nk_i8_t *y) { *y = (nk_i8_t)(*x > 127ull ? 127ull : *x); }
|
|
1681
|
+
NK_INTERNAL void nk_u64_to_u8_serial(nk_u64_t const *x, nk_u8_t *y) { *y = (nk_u8_t)(*x > 255ull ? 255ull : *x); }
|
|
1682
|
+
NK_INTERNAL void nk_u64_to_i16_serial(nk_u64_t const *x, nk_i16_t *y) {
|
|
1683
|
+
*y = (nk_i16_t)(*x > 32767ull ? 32767ull : *x);
|
|
1684
|
+
}
|
|
1685
|
+
NK_INTERNAL void nk_u64_to_u16_serial(nk_u64_t const *x, nk_u16_t *y) {
|
|
1686
|
+
*y = (nk_u16_t)(*x > 65535ull ? 65535ull : *x);
|
|
1558
1687
|
}
|
|
1559
1688
|
|
|
1560
|
-
|
|
1561
|
-
|
|
1562
|
-
dst->u64 = 0;
|
|
1563
|
-
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
1564
|
-
nk_size_t n_bytes = nk_size_divide_round_up_(n, 2);
|
|
1565
|
-
for (nk_size_t i = 0; i < n_bytes && i < 8; i++) ((nk_u8_t *)&dst->u64)[i] = s[i];
|
|
1689
|
+
NK_INTERNAL void nk_u64_to_i32_serial(nk_u64_t const *x, nk_i32_t *y) {
|
|
1690
|
+
*y = (nk_i32_t)(*x > 2147483647ull ? 2147483647ull : *x);
|
|
1566
1691
|
}
|
|
1567
1692
|
|
|
1568
|
-
NK_INTERNAL void
|
|
1569
|
-
|
|
1570
|
-
nk_u64_t const *s = (nk_u64_t const *)src;
|
|
1571
|
-
switch (n) {
|
|
1572
|
-
default:
|
|
1573
|
-
case 2: dst->u64s[1] = s[1]; // fallthrough
|
|
1574
|
-
case 1: dst->u64s[0] = s[0]; // fallthrough
|
|
1575
|
-
case 0: break;
|
|
1576
|
-
}
|
|
1693
|
+
NK_INTERNAL void nk_u64_to_u32_serial(nk_u64_t const *x, nk_u32_t *y) {
|
|
1694
|
+
*y = (nk_u32_t)(*x > 4294967295ull ? 4294967295ull : *x);
|
|
1577
1695
|
}
|
|
1578
1696
|
|
|
1579
|
-
|
|
1580
|
-
|
|
1581
|
-
nk_u64_t *d = (nk_u64_t *)dst;
|
|
1582
|
-
switch (n) {
|
|
1583
|
-
default:
|
|
1584
|
-
case 2: d[1] = src->u64s[1]; // fallthrough
|
|
1585
|
-
case 1: d[0] = src->u64s[0]; // fallthrough
|
|
1586
|
-
case 0: break;
|
|
1587
|
-
}
|
|
1697
|
+
NK_INTERNAL void nk_u64_to_i64_serial(nk_u64_t const *x, nk_i64_t *y) {
|
|
1698
|
+
*y = (nk_i64_t)(*x >= 9223372036854775807ull ? 9223372036854775807ll : *x);
|
|
1588
1699
|
}
|
|
1589
1700
|
|
|
1590
|
-
|
|
1591
|
-
NK_INTERNAL void
|
|
1592
|
-
|
|
1593
|
-
|
|
1594
|
-
|
|
1595
|
-
|
|
1701
|
+
NK_INTERNAL void nk_i8_to_u64_serial(nk_i8_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
|
|
1702
|
+
NK_INTERNAL void nk_i16_to_u64_serial(nk_i16_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
|
|
1703
|
+
NK_INTERNAL void nk_i32_to_u64_serial(nk_i32_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
|
|
1704
|
+
NK_INTERNAL void nk_i64_to_u64_serial(nk_i64_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
|
|
1705
|
+
|
|
1706
|
+
NK_INTERNAL void nk_i64_to_f16_serial(nk_i64_t const *x, nk_f16_t *y) {
|
|
1707
|
+
nk_f32_t f32 = (nk_f32_t)*x;
|
|
1708
|
+
nk_f32_to_f16_serial(&f32, y);
|
|
1709
|
+
}
|
|
1710
|
+
NK_INTERNAL void nk_i64_to_bf16_serial(nk_i64_t const *x, nk_bf16_t *y) {
|
|
1711
|
+
nk_f32_t f32 = (nk_f32_t)*x;
|
|
1712
|
+
nk_f32_to_bf16_serial(&f32, y);
|
|
1713
|
+
}
|
|
1714
|
+
NK_INTERNAL void nk_u64_to_f16_serial(nk_u64_t const *x, nk_f16_t *y) {
|
|
1715
|
+
nk_f32_t f32 = (nk_f32_t)*x;
|
|
1716
|
+
nk_f32_to_f16_serial(&f32, y);
|
|
1717
|
+
}
|
|
1718
|
+
NK_INTERNAL void nk_u64_to_bf16_serial(nk_u64_t const *x, nk_bf16_t *y) {
|
|
1719
|
+
nk_f32_t f32 = (nk_f32_t)*x;
|
|
1720
|
+
nk_f32_to_bf16_serial(&f32, y);
|
|
1596
1721
|
}
|
|
1597
1722
|
|
|
1598
|
-
/** @brief
|
|
1599
|
-
|
|
1600
|
-
|
|
1601
|
-
|
|
1602
|
-
|
|
1603
|
-
|
|
1723
|
+
/** @brief Convert a pair of i4 (4-bit signed integer, -8 to 7) nibbles into signed integers. */
|
|
1724
|
+
NK_PUBLIC void nk_i4x2_to_i8x2_serial(nk_i4x2_t const *src, nk_i8_t *dest) {
|
|
1725
|
+
nk_u8_t byte = *(nk_u8_t const *)src;
|
|
1726
|
+
nk_u8_t high_nibble = byte >> 4;
|
|
1727
|
+
nk_u8_t low_nibble = byte & 0x0F;
|
|
1728
|
+
// Sign extend: 0-7 → 0-7, 8-15 → -8 to -1
|
|
1729
|
+
dest[0] = (nk_i8_t)((high_nibble ^ 8) - 8);
|
|
1730
|
+
dest[1] = (nk_i8_t)((low_nibble ^ 8) - 8);
|
|
1604
1731
|
}
|
|
1605
1732
|
|
|
1606
|
-
/** @brief
|
|
1607
|
-
|
|
1608
|
-
|
|
1609
|
-
|
|
1610
|
-
|
|
1611
|
-
for (nk_size_t i = 0; i < n && i < 16; ++i) dst->u8s[i] = s[i * stride_elements];
|
|
1733
|
+
/** @brief Convert a pair of u4 (4-bit unsigned integer, 0 to 15) nibbles into unsigned integers. */
|
|
1734
|
+
NK_PUBLIC void nk_u4x2_to_u8x2_serial(nk_u4x2_t const *src, nk_u8_t *dest) {
|
|
1735
|
+
nk_u8_t byte = *(nk_u8_t const *)src;
|
|
1736
|
+
dest[0] = byte >> 4;
|
|
1737
|
+
dest[1] = byte & 0x0F;
|
|
1612
1738
|
}
|
|
1613
1739
|
|
|
1614
1740
|
/**
|
|
@@ -1619,7 +1745,7 @@ NK_INTERNAL void nk_strided_load_b8x16_serial_(void const *src, nk_size_t stride
|
|
|
1619
1745
|
* The caller fills the appropriate union member based on the target dtype,
|
|
1620
1746
|
* then passes the union address as `void const *` to kernel functions.
|
|
1621
1747
|
*/
|
|
1622
|
-
typedef union nk_scalar_buffer_t {
|
|
1748
|
+
typedef union NK_MAY_ALIAS_ nk_scalar_buffer_t {
|
|
1623
1749
|
nk_u8_t bytes[16];
|
|
1624
1750
|
nk_f64_t f64;
|
|
1625
1751
|
nk_f32_t f32;
|
|
@@ -1639,115 +1765,78 @@ typedef union nk_scalar_buffer_t {
|
|
|
1639
1765
|
nk_u8_t u8;
|
|
1640
1766
|
} nk_scalar_buffer_t;
|
|
1641
1767
|
|
|
1768
|
+
/** @brief Reads a typed scalar from @p buf and writes the widened f64c into @p result.
|
|
1769
|
+
* Real types set `.imag = 0`. Safe when @p result aliases @p buf (in-place conversion).
|
|
1770
|
+
* @return 1 on success, 0 for unsupported types (sub-byte, unknown). */
|
|
1771
|
+
NK_INTERNAL int nk_scalar_buffer_to_f64c(nk_scalar_buffer_t const *buf, nk_dtype_t dtype, nk_f64c_t *result) {
|
|
1772
|
+
// Snapshot input so `result` may alias `buf` (e.g. in-place conversion within a union).
|
|
1773
|
+
nk_scalar_buffer_t local;
|
|
1774
|
+
local.f64c = buf->f64c;
|
|
1775
|
+
result->real = 0, result->imag = 0;
|
|
1776
|
+
switch (dtype) {
|
|
1777
|
+
case nk_f64_k: result->real = local.f64; break;
|
|
1778
|
+
case nk_f32_k: result->real = (nk_f64_t)local.f32; break;
|
|
1779
|
+
case nk_f16_k:
|
|
1780
|
+
nk_f16_to_f32_serial(&local.f16, &local.f32);
|
|
1781
|
+
result->real = (nk_f64_t)local.f32;
|
|
1782
|
+
break;
|
|
1783
|
+
case nk_bf16_k:
|
|
1784
|
+
nk_bf16_to_f32_serial(&local.bf16, &local.f32);
|
|
1785
|
+
result->real = (nk_f64_t)local.f32;
|
|
1786
|
+
break;
|
|
1787
|
+
case nk_f64c_k: result->real = local.f64c.real, result->imag = local.f64c.imag; break;
|
|
1788
|
+
case nk_f32c_k: result->real = (nk_f64_t)local.f32c.real, result->imag = (nk_f64_t)local.f32c.imag; break;
|
|
1789
|
+
case nk_f16c_k:
|
|
1790
|
+
nk_f16_to_f32_serial(&local.f16c.real, &local.f32);
|
|
1791
|
+
result->real = (nk_f64_t)local.f32;
|
|
1792
|
+
nk_f16_to_f32_serial(&local.f16c.imag, &local.f32);
|
|
1793
|
+
result->imag = (nk_f64_t)local.f32;
|
|
1794
|
+
break;
|
|
1795
|
+
case nk_bf16c_k:
|
|
1796
|
+
nk_bf16_to_f32_serial(&local.bf16c.real, &local.f32);
|
|
1797
|
+
result->real = (nk_f64_t)local.f32;
|
|
1798
|
+
nk_bf16_to_f32_serial(&local.bf16c.imag, &local.f32);
|
|
1799
|
+
result->imag = (nk_f64_t)local.f32;
|
|
1800
|
+
break;
|
|
1801
|
+
case nk_i64_k: result->real = (nk_f64_t)local.i64; break;
|
|
1802
|
+
case nk_u64_k: result->real = (nk_f64_t)local.u64; break;
|
|
1803
|
+
case nk_i32_k: result->real = (nk_f64_t)local.i32; break;
|
|
1804
|
+
case nk_u32_k: result->real = (nk_f64_t)local.u32; break;
|
|
1805
|
+
case nk_i16_k: result->real = (nk_f64_t)local.i16; break;
|
|
1806
|
+
case nk_u16_k: result->real = (nk_f64_t)local.u16; break;
|
|
1807
|
+
case nk_i8_k: result->real = (nk_f64_t)local.i8; break;
|
|
1808
|
+
case nk_u8_k: result->real = (nk_f64_t)local.u8; break;
|
|
1809
|
+
case nk_e4m3_k:
|
|
1810
|
+
nk_e4m3_to_f32_serial(&local.u8, &local.f32);
|
|
1811
|
+
result->real = (nk_f64_t)local.f32;
|
|
1812
|
+
break;
|
|
1813
|
+
case nk_e5m2_k:
|
|
1814
|
+
nk_e5m2_to_f32_serial(&local.u8, &local.f32);
|
|
1815
|
+
result->real = (nk_f64_t)local.f32;
|
|
1816
|
+
break;
|
|
1817
|
+
case nk_e2m3_k:
|
|
1818
|
+
nk_e2m3_to_f32_serial(&local.u8, &local.f32);
|
|
1819
|
+
result->real = (nk_f64_t)local.f32;
|
|
1820
|
+
break;
|
|
1821
|
+
case nk_e3m2_k:
|
|
1822
|
+
nk_e3m2_to_f32_serial(&local.u8, &local.f32);
|
|
1823
|
+
result->real = (nk_f64_t)local.f32;
|
|
1824
|
+
break;
|
|
1825
|
+
default: return 0;
|
|
1826
|
+
}
|
|
1827
|
+
return 1;
|
|
1828
|
+
}
|
|
1829
|
+
|
|
1642
1830
|
/**
|
|
1643
1831
|
* @brief Converts up to 8x values from `from_ptr` buffer into 8x puned buffer objects
|
|
1644
1832
|
* into a complex 64-bit floating point representation.
|
|
1645
1833
|
*/
|
|
1646
|
-
NK_INTERNAL void
|
|
1834
|
+
NK_INTERNAL void nk_scalar_buffers_to_f64c_( //
|
|
1647
1835
|
void const *from_ptr, nk_dtype_t from_dtype, nk_size_t from_count, //
|
|
1648
1836
|
nk_scalar_buffer_t to_buffers[nk_at_least_(8)]) {
|
|
1649
1837
|
|
|
1650
|
-
nk_f32_t temporary_f32;
|
|
1651
1838
|
nk_size_t i;
|
|
1652
1839
|
switch (from_dtype) {
|
|
1653
|
-
case nk_f64_k: {
|
|
1654
|
-
nk_f64_t const *p = (nk_f64_t const *)from_ptr;
|
|
1655
|
-
for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
|
|
1656
|
-
} break;
|
|
1657
|
-
case nk_f32_k: {
|
|
1658
|
-
nk_f32_t const *p = (nk_f32_t const *)from_ptr;
|
|
1659
|
-
for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
|
|
1660
|
-
} break;
|
|
1661
|
-
case nk_f16_k: {
|
|
1662
|
-
nk_f16_t const *p = (nk_f16_t const *)from_ptr;
|
|
1663
|
-
for (i = 0; i < from_count; ++i)
|
|
1664
|
-
nk_f16_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
|
|
1665
|
-
to_buffers[i].f64c.imag = 0;
|
|
1666
|
-
} break;
|
|
1667
|
-
case nk_bf16_k: {
|
|
1668
|
-
nk_bf16_t const *p = (nk_bf16_t const *)from_ptr;
|
|
1669
|
-
for (i = 0; i < from_count; ++i)
|
|
1670
|
-
nk_bf16_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
|
|
1671
|
-
to_buffers[i].f64c.imag = 0;
|
|
1672
|
-
} break;
|
|
1673
|
-
case nk_e4m3_k: {
|
|
1674
|
-
nk_u8_t const *p = (nk_u8_t const *)from_ptr;
|
|
1675
|
-
for (i = 0; i < from_count; ++i)
|
|
1676
|
-
nk_e4m3_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
|
|
1677
|
-
to_buffers[i].f64c.imag = 0;
|
|
1678
|
-
} break;
|
|
1679
|
-
case nk_e5m2_k: {
|
|
1680
|
-
nk_u8_t const *p = (nk_u8_t const *)from_ptr;
|
|
1681
|
-
for (i = 0; i < from_count; ++i)
|
|
1682
|
-
nk_e5m2_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
|
|
1683
|
-
to_buffers[i].f64c.imag = 0;
|
|
1684
|
-
} break;
|
|
1685
|
-
case nk_e2m3_k: {
|
|
1686
|
-
nk_u8_t const *p = (nk_u8_t const *)from_ptr;
|
|
1687
|
-
for (i = 0; i < from_count; ++i)
|
|
1688
|
-
nk_e2m3_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
|
|
1689
|
-
to_buffers[i].f64c.imag = 0;
|
|
1690
|
-
} break;
|
|
1691
|
-
case nk_e3m2_k: {
|
|
1692
|
-
nk_u8_t const *p = (nk_u8_t const *)from_ptr;
|
|
1693
|
-
for (i = 0; i < from_count; ++i)
|
|
1694
|
-
nk_e3m2_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
|
|
1695
|
-
to_buffers[i].f64c.imag = 0;
|
|
1696
|
-
} break;
|
|
1697
|
-
case nk_i64_k: {
|
|
1698
|
-
nk_i64_t const *p = (nk_i64_t const *)from_ptr;
|
|
1699
|
-
for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = (nk_f64_t)p[i], to_buffers[i].f64c.imag = 0;
|
|
1700
|
-
} break;
|
|
1701
|
-
case nk_i32_k: {
|
|
1702
|
-
nk_i32_t const *p = (nk_i32_t const *)from_ptr;
|
|
1703
|
-
for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
|
|
1704
|
-
} break;
|
|
1705
|
-
case nk_i16_k: {
|
|
1706
|
-
nk_i16_t const *p = (nk_i16_t const *)from_ptr;
|
|
1707
|
-
for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
|
|
1708
|
-
} break;
|
|
1709
|
-
case nk_i8_k: {
|
|
1710
|
-
nk_i8_t const *p = (nk_i8_t const *)from_ptr;
|
|
1711
|
-
for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
|
|
1712
|
-
} break;
|
|
1713
|
-
case nk_u64_k: {
|
|
1714
|
-
nk_u64_t const *p = (nk_u64_t const *)from_ptr;
|
|
1715
|
-
for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = (nk_f64_t)p[i], to_buffers[i].f64c.imag = 0;
|
|
1716
|
-
} break;
|
|
1717
|
-
case nk_u32_k: {
|
|
1718
|
-
nk_u32_t const *p = (nk_u32_t const *)from_ptr;
|
|
1719
|
-
for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
|
|
1720
|
-
} break;
|
|
1721
|
-
case nk_u16_k: {
|
|
1722
|
-
nk_u16_t const *p = (nk_u16_t const *)from_ptr;
|
|
1723
|
-
for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
|
|
1724
|
-
} break;
|
|
1725
|
-
case nk_u8_k: {
|
|
1726
|
-
nk_u8_t const *p = (nk_u8_t const *)from_ptr;
|
|
1727
|
-
for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
|
|
1728
|
-
} break;
|
|
1729
|
-
case nk_f64c_k: {
|
|
1730
|
-
nk_f64c_t const *p = (nk_f64c_t const *)from_ptr;
|
|
1731
|
-
for (i = 0; i < from_count; ++i) to_buffers[i].f64c = p[i];
|
|
1732
|
-
} break;
|
|
1733
|
-
case nk_f32c_k: {
|
|
1734
|
-
nk_f32c_t const *p = (nk_f32c_t const *)from_ptr;
|
|
1735
|
-
for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i].real, to_buffers[i].f64c.imag = p[i].imag;
|
|
1736
|
-
} break;
|
|
1737
|
-
case nk_f16c_k: {
|
|
1738
|
-
nk_f16c_t const *p = (nk_f16c_t const *)from_ptr;
|
|
1739
|
-
for (i = 0; i < from_count; ++i) {
|
|
1740
|
-
nk_f16_to_f32_serial(&p[i].real, &temporary_f32), to_buffers[i].f64c.real = temporary_f32;
|
|
1741
|
-
nk_f16_to_f32_serial(&p[i].imag, &temporary_f32), to_buffers[i].f64c.imag = temporary_f32;
|
|
1742
|
-
}
|
|
1743
|
-
} break;
|
|
1744
|
-
case nk_bf16c_k: {
|
|
1745
|
-
nk_bf16c_t const *p = (nk_bf16c_t const *)from_ptr;
|
|
1746
|
-
for (i = 0; i < from_count; ++i) {
|
|
1747
|
-
nk_bf16_to_f32_serial(&p[i].real, &temporary_f32), to_buffers[i].f64c.real = temporary_f32;
|
|
1748
|
-
nk_bf16_to_f32_serial(&p[i].imag, &temporary_f32), to_buffers[i].f64c.imag = temporary_f32;
|
|
1749
|
-
}
|
|
1750
|
-
} break;
|
|
1751
1840
|
// Sub-byte: u1 - 8 bits from 1 byte, MSB-first
|
|
1752
1841
|
case nk_u1_k: {
|
|
1753
1842
|
nk_u8_t byte = *(nk_u8_t const *)from_ptr;
|
|
@@ -1755,130 +1844,117 @@ NK_INTERNAL void nk_scalar_buffers_fill_f64c_( //
|
|
|
1755
1844
|
} break;
|
|
1756
1845
|
// Sub-byte: i4 - 8 nibbles from 4 bytes, high nibble = even index, sign-extended
|
|
1757
1846
|
case nk_i4_k: {
|
|
1758
|
-
|
|
1847
|
+
nk_i4x2_t const *pairs = (nk_i4x2_t const *)from_ptr;
|
|
1848
|
+
nk_i8_t unpacked[2];
|
|
1759
1849
|
for (i = 0; i < 4; ++i) {
|
|
1760
|
-
|
|
1761
|
-
to_buffers[i * 2].f64c.real =
|
|
1762
|
-
to_buffers[i * 2 + 1].f64c.real =
|
|
1850
|
+
nk_i4x2_to_i8x2_serial(&pairs[i], unpacked);
|
|
1851
|
+
to_buffers[i * 2].f64c.real = unpacked[0], to_buffers[i * 2].f64c.imag = 0;
|
|
1852
|
+
to_buffers[i * 2 + 1].f64c.real = unpacked[1], to_buffers[i * 2 + 1].f64c.imag = 0;
|
|
1763
1853
|
}
|
|
1764
1854
|
} break;
|
|
1765
1855
|
// Sub-byte: u4 - 8 nibbles from 4 bytes, high nibble = even index
|
|
1766
1856
|
case nk_u4_k: {
|
|
1767
|
-
|
|
1857
|
+
nk_u4x2_t const *pairs = (nk_u4x2_t const *)from_ptr;
|
|
1858
|
+
nk_u8_t unpacked[2];
|
|
1768
1859
|
for (i = 0; i < 4; ++i) {
|
|
1769
|
-
|
|
1770
|
-
to_buffers[i * 2
|
|
1860
|
+
nk_u4x2_to_u8x2_serial(&pairs[i], unpacked);
|
|
1861
|
+
to_buffers[i * 2].f64c.real = unpacked[0], to_buffers[i * 2].f64c.imag = 0;
|
|
1862
|
+
to_buffers[i * 2 + 1].f64c.real = unpacked[1], to_buffers[i * 2 + 1].f64c.imag = 0;
|
|
1771
1863
|
}
|
|
1772
1864
|
} break;
|
|
1773
|
-
|
|
1774
|
-
|
|
1865
|
+
// All byte-or-larger types: stage through a separate buffer to avoid
|
|
1866
|
+
// variable-length memcpy and type-punned read on the same union —
|
|
1867
|
+
// a pattern that triggers an ICE in MSVC's ARM64 optimizer (C1001).
|
|
1868
|
+
default: {
|
|
1869
|
+
nk_size_t stride = nk_dtype_bits(from_dtype) / NK_BITS_PER_BYTE;
|
|
1870
|
+
nk_scalar_buffer_t staged;
|
|
1871
|
+
for (i = 0; i < from_count; ++i) {
|
|
1872
|
+
staged.u64 = 0;
|
|
1873
|
+
nk_copy_bytes_(&staged, (char const *)from_ptr + i * stride, stride);
|
|
1874
|
+
nk_scalar_buffer_to_f64c(&staged, from_dtype, &to_buffers[i].f64c);
|
|
1875
|
+
}
|
|
1876
|
+
} break;
|
|
1877
|
+
}
|
|
1878
|
+
}
|
|
1879
|
+
|
|
1880
|
+
/** @brief Narrows an f64c @p value into the appropriate typed member of @p buf.
|
|
1881
|
+
* Real types use only `.real`; complex types use both components.
|
|
1882
|
+
* Safe when @p value aliases @p buf (in-place conversion).
|
|
1883
|
+
* @note Integer targets (i64, i32, ...) go through f64 rounding — values beyond 2^53 may lose precision.
|
|
1884
|
+
* @return 1 on success, 0 for unsupported types (sub-byte, unknown). */
|
|
1885
|
+
NK_INTERNAL int nk_scalar_buffer_from_f64c(nk_f64c_t const *value, nk_scalar_buffer_t *buf, nk_dtype_t dtype) {
|
|
1886
|
+
// Snapshot input so `value` may point into `buf` (e.g. in-place conversion within a union).
|
|
1887
|
+
nk_f64c_t local = *value;
|
|
1888
|
+
nk_f32_t temporary_f32;
|
|
1889
|
+
switch (dtype) {
|
|
1890
|
+
case nk_f64_k: buf->f64 = local.real; break;
|
|
1891
|
+
case nk_f32_k: buf->f32 = (nk_f32_t)local.real; break;
|
|
1892
|
+
case nk_f16_k:
|
|
1893
|
+
temporary_f32 = (nk_f32_t)local.real;
|
|
1894
|
+
nk_f32_to_f16_serial(&temporary_f32, &buf->f16);
|
|
1895
|
+
break;
|
|
1896
|
+
case nk_bf16_k:
|
|
1897
|
+
temporary_f32 = (nk_f32_t)local.real;
|
|
1898
|
+
nk_f32_to_bf16_serial(&temporary_f32, &buf->bf16);
|
|
1899
|
+
break;
|
|
1900
|
+
case nk_f64c_k:
|
|
1901
|
+
buf->f64c.real = local.real;
|
|
1902
|
+
buf->f64c.imag = local.imag;
|
|
1903
|
+
break;
|
|
1904
|
+
case nk_f32c_k:
|
|
1905
|
+
buf->f32c.real = (nk_f32_t)local.real;
|
|
1906
|
+
buf->f32c.imag = (nk_f32_t)local.imag;
|
|
1907
|
+
break;
|
|
1908
|
+
case nk_f16c_k:
|
|
1909
|
+
temporary_f32 = (nk_f32_t)local.real;
|
|
1910
|
+
nk_f32_to_f16_serial(&temporary_f32, &buf->f16c.real);
|
|
1911
|
+
temporary_f32 = (nk_f32_t)local.imag;
|
|
1912
|
+
nk_f32_to_f16_serial(&temporary_f32, &buf->f16c.imag);
|
|
1913
|
+
break;
|
|
1914
|
+
case nk_bf16c_k:
|
|
1915
|
+
temporary_f32 = (nk_f32_t)local.real;
|
|
1916
|
+
nk_f32_to_bf16_serial(&temporary_f32, &buf->bf16c.real);
|
|
1917
|
+
temporary_f32 = (nk_f32_t)local.imag;
|
|
1918
|
+
nk_f32_to_bf16_serial(&temporary_f32, &buf->bf16c.imag);
|
|
1775
1919
|
break;
|
|
1920
|
+
case nk_i64_k: nk_f64_to_i64_serial(&local.real, &buf->i64); break;
|
|
1921
|
+
case nk_u64_k: nk_f64_to_u64_serial(&local.real, &buf->u64); break;
|
|
1922
|
+
case nk_i32_k: nk_f64_to_i32_serial(&local.real, &buf->i32); break;
|
|
1923
|
+
case nk_u32_k: nk_f64_to_u32_serial(&local.real, &buf->u32); break;
|
|
1924
|
+
case nk_i16_k: nk_f64_to_i16_serial(&local.real, &buf->i16); break;
|
|
1925
|
+
case nk_u16_k: nk_f64_to_u16_serial(&local.real, &buf->u16); break;
|
|
1926
|
+
case nk_i8_k: nk_f64_to_i8_serial(&local.real, &buf->i8); break;
|
|
1927
|
+
case nk_u8_k: nk_f64_to_u8_serial(&local.real, &buf->u8); break;
|
|
1928
|
+
case nk_e4m3_k:
|
|
1929
|
+
temporary_f32 = (nk_f32_t)local.real;
|
|
1930
|
+
nk_f32_to_e4m3_serial(&temporary_f32, &buf->u8);
|
|
1931
|
+
break;
|
|
1932
|
+
case nk_e5m2_k:
|
|
1933
|
+
temporary_f32 = (nk_f32_t)local.real;
|
|
1934
|
+
nk_f32_to_e5m2_serial(&temporary_f32, &buf->u8);
|
|
1935
|
+
break;
|
|
1936
|
+
case nk_e2m3_k:
|
|
1937
|
+
temporary_f32 = (nk_f32_t)local.real;
|
|
1938
|
+
nk_f32_to_e2m3_serial(&temporary_f32, &buf->u8);
|
|
1939
|
+
break;
|
|
1940
|
+
case nk_e3m2_k:
|
|
1941
|
+
temporary_f32 = (nk_f32_t)local.real;
|
|
1942
|
+
nk_f32_to_e3m2_serial(&temporary_f32, &buf->u8);
|
|
1943
|
+
break;
|
|
1944
|
+
default: return 0;
|
|
1776
1945
|
}
|
|
1946
|
+
return 1;
|
|
1777
1947
|
}
|
|
1778
1948
|
|
|
1779
1949
|
/**
|
|
1780
1950
|
* @brief Converts up to 8x values from `from_buffers` buffer into 8x typed scalars.
|
|
1781
1951
|
*/
|
|
1782
|
-
NK_INTERNAL void
|
|
1952
|
+
NK_INTERNAL void nk_scalar_buffers_from_f64c_( //
|
|
1783
1953
|
nk_scalar_buffer_t const from_buffers[nk_at_least_(8)], //
|
|
1784
1954
|
void *to_ptr, nk_dtype_t to_dtype, nk_size_t to_count) {
|
|
1785
1955
|
|
|
1786
|
-
nk_f32_t temporary_f32;
|
|
1787
1956
|
nk_size_t i;
|
|
1788
1957
|
switch (to_dtype) {
|
|
1789
|
-
case nk_f64_k: {
|
|
1790
|
-
nk_f64_t *p = (nk_f64_t *)to_ptr;
|
|
1791
|
-
for (i = 0; i < to_count; ++i) p[i] = from_buffers[i].f64c.real;
|
|
1792
|
-
} break;
|
|
1793
|
-
case nk_f32_k: {
|
|
1794
|
-
nk_f32_t *p = (nk_f32_t *)to_ptr;
|
|
1795
|
-
for (i = 0; i < to_count; ++i) p[i] = (nk_f32_t)from_buffers[i].f64c.real;
|
|
1796
|
-
} break;
|
|
1797
|
-
case nk_f16_k: {
|
|
1798
|
-
nk_f16_t *p = (nk_f16_t *)to_ptr;
|
|
1799
|
-
for (i = 0; i < to_count; ++i)
|
|
1800
|
-
temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_f16_serial(&temporary_f32, &p[i]);
|
|
1801
|
-
} break;
|
|
1802
|
-
case nk_bf16_k: {
|
|
1803
|
-
nk_bf16_t *p = (nk_bf16_t *)to_ptr;
|
|
1804
|
-
for (i = 0; i < to_count; ++i)
|
|
1805
|
-
temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_bf16_serial(&temporary_f32, &p[i]);
|
|
1806
|
-
} break;
|
|
1807
|
-
case nk_e4m3_k: {
|
|
1808
|
-
nk_u8_t *p = (nk_u8_t *)to_ptr;
|
|
1809
|
-
for (i = 0; i < to_count; ++i)
|
|
1810
|
-
temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_e4m3_serial(&temporary_f32, &p[i]);
|
|
1811
|
-
} break;
|
|
1812
|
-
case nk_e5m2_k: {
|
|
1813
|
-
nk_u8_t *p = (nk_u8_t *)to_ptr;
|
|
1814
|
-
for (i = 0; i < to_count; ++i)
|
|
1815
|
-
temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_e5m2_serial(&temporary_f32, &p[i]);
|
|
1816
|
-
} break;
|
|
1817
|
-
case nk_e2m3_k: {
|
|
1818
|
-
nk_u8_t *p = (nk_u8_t *)to_ptr;
|
|
1819
|
-
for (i = 0; i < to_count; ++i)
|
|
1820
|
-
temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_e2m3_serial(&temporary_f32, &p[i]);
|
|
1821
|
-
} break;
|
|
1822
|
-
case nk_e3m2_k: {
|
|
1823
|
-
nk_u8_t *p = (nk_u8_t *)to_ptr;
|
|
1824
|
-
for (i = 0; i < to_count; ++i)
|
|
1825
|
-
temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_e3m2_serial(&temporary_f32, &p[i]);
|
|
1826
|
-
} break;
|
|
1827
|
-
case nk_i64_k: {
|
|
1828
|
-
nk_i64_t *p = (nk_i64_t *)to_ptr;
|
|
1829
|
-
for (i = 0; i < to_count; ++i) nk_f64_to_i64_serial(&from_buffers[i].f64c.real, &p[i]);
|
|
1830
|
-
} break;
|
|
1831
|
-
case nk_i32_k: {
|
|
1832
|
-
nk_i32_t *p = (nk_i32_t *)to_ptr;
|
|
1833
|
-
for (i = 0; i < to_count; ++i) nk_f64_to_i32_serial(&from_buffers[i].f64c.real, &p[i]);
|
|
1834
|
-
} break;
|
|
1835
|
-
case nk_i16_k: {
|
|
1836
|
-
nk_i16_t *p = (nk_i16_t *)to_ptr;
|
|
1837
|
-
for (i = 0; i < to_count; ++i) nk_f64_to_i16_serial(&from_buffers[i].f64c.real, &p[i]);
|
|
1838
|
-
} break;
|
|
1839
|
-
case nk_i8_k: {
|
|
1840
|
-
nk_i8_t *p = (nk_i8_t *)to_ptr;
|
|
1841
|
-
for (i = 0; i < to_count; ++i) nk_f64_to_i8_serial(&from_buffers[i].f64c.real, &p[i]);
|
|
1842
|
-
} break;
|
|
1843
|
-
case nk_u64_k: {
|
|
1844
|
-
nk_u64_t *p = (nk_u64_t *)to_ptr;
|
|
1845
|
-
for (i = 0; i < to_count; ++i) nk_f64_to_u64_serial(&from_buffers[i].f64c.real, &p[i]);
|
|
1846
|
-
} break;
|
|
1847
|
-
case nk_u32_k: {
|
|
1848
|
-
nk_u32_t *p = (nk_u32_t *)to_ptr;
|
|
1849
|
-
for (i = 0; i < to_count; ++i) nk_f64_to_u32_serial(&from_buffers[i].f64c.real, &p[i]);
|
|
1850
|
-
} break;
|
|
1851
|
-
case nk_u16_k: {
|
|
1852
|
-
nk_u16_t *p = (nk_u16_t *)to_ptr;
|
|
1853
|
-
for (i = 0; i < to_count; ++i) nk_f64_to_u16_serial(&from_buffers[i].f64c.real, &p[i]);
|
|
1854
|
-
} break;
|
|
1855
|
-
case nk_u8_k: {
|
|
1856
|
-
nk_u8_t *p = (nk_u8_t *)to_ptr;
|
|
1857
|
-
for (i = 0; i < to_count; ++i) nk_f64_to_u8_serial(&from_buffers[i].f64c.real, &p[i]);
|
|
1858
|
-
} break;
|
|
1859
|
-
case nk_f64c_k: {
|
|
1860
|
-
nk_f64c_t *p = (nk_f64c_t *)to_ptr;
|
|
1861
|
-
for (i = 0; i < to_count; ++i) p[i] = from_buffers[i].f64c;
|
|
1862
|
-
} break;
|
|
1863
|
-
case nk_f32c_k: {
|
|
1864
|
-
nk_f32c_t *p = (nk_f32c_t *)to_ptr;
|
|
1865
|
-
for (i = 0; i < to_count; ++i)
|
|
1866
|
-
p[i].real = (nk_f32_t)from_buffers[i].f64c.real, p[i].imag = (nk_f32_t)from_buffers[i].f64c.imag;
|
|
1867
|
-
} break;
|
|
1868
|
-
case nk_f16c_k: {
|
|
1869
|
-
nk_f16c_t *p = (nk_f16c_t *)to_ptr;
|
|
1870
|
-
for (i = 0; i < to_count; ++i) {
|
|
1871
|
-
temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_f16_serial(&temporary_f32, &p[i].real);
|
|
1872
|
-
temporary_f32 = (nk_f32_t)from_buffers[i].f64c.imag, nk_f32_to_f16_serial(&temporary_f32, &p[i].imag);
|
|
1873
|
-
}
|
|
1874
|
-
} break;
|
|
1875
|
-
case nk_bf16c_k: {
|
|
1876
|
-
nk_bf16c_t *p = (nk_bf16c_t *)to_ptr;
|
|
1877
|
-
for (i = 0; i < to_count; ++i) {
|
|
1878
|
-
temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_bf16_serial(&temporary_f32, &p[i].real);
|
|
1879
|
-
temporary_f32 = (nk_f32_t)from_buffers[i].f64c.imag, nk_f32_to_bf16_serial(&temporary_f32, &p[i].imag);
|
|
1880
|
-
}
|
|
1881
|
-
} break;
|
|
1882
1958
|
// Sub-byte: u1 - 8 bits to 1 byte, MSB-first, non-zero → 1
|
|
1883
1959
|
case nk_u1_k: {
|
|
1884
1960
|
nk_u8_t *p = (nk_u8_t *)to_ptr;
|
|
@@ -1890,32 +1966,38 @@ NK_INTERNAL void nk_scalar_buffers_export_f64c_( //
|
|
|
1890
1966
|
case nk_i4_k: {
|
|
1891
1967
|
nk_u8_t *p = (nk_u8_t *)to_ptr;
|
|
1892
1968
|
for (i = 0; i < 4; ++i) {
|
|
1893
|
-
|
|
1894
|
-
|
|
1895
|
-
|
|
1896
|
-
|
|
1897
|
-
p[i] = (nk_u8_t)(((hi & 0xF) << 4) | (lo & 0xF));
|
|
1969
|
+
nk_f64_t high = from_buffers[i * 2].f64c.real, low = from_buffers[i * 2 + 1].f64c.real;
|
|
1970
|
+
high = high > 7 ? 7 : (high < -8 ? -8 : high);
|
|
1971
|
+
low = low > 7 ? 7 : (low < -8 ? -8 : low);
|
|
1972
|
+
p[i] = (nk_u8_t)((((nk_i8_t)high & 0x0F) << 4) | ((nk_i8_t)low & 0x0F));
|
|
1898
1973
|
}
|
|
1899
1974
|
} break;
|
|
1900
1975
|
// Sub-byte: u4 - 8 nibbles to 4 bytes, high nibble = even index
|
|
1901
1976
|
case nk_u4_k: {
|
|
1902
1977
|
nk_u8_t *p = (nk_u8_t *)to_ptr;
|
|
1903
1978
|
for (i = 0; i < 4; ++i) {
|
|
1904
|
-
|
|
1905
|
-
|
|
1906
|
-
|
|
1907
|
-
|
|
1908
|
-
|
|
1979
|
+
nk_f64_t high = from_buffers[i * 2].f64c.real, low = from_buffers[i * 2 + 1].f64c.real;
|
|
1980
|
+
high = high > 15 ? 15 : (high < 0 ? 0 : high);
|
|
1981
|
+
low = low > 15 ? 15 : (low < 0 ? 0 : low);
|
|
1982
|
+
p[i] = (nk_u8_t)(((nk_u8_t)high << 4) | (nk_u8_t)low);
|
|
1983
|
+
}
|
|
1984
|
+
} break;
|
|
1985
|
+
// All byte-or-larger types: convert, then store relevant bytes
|
|
1986
|
+
default: {
|
|
1987
|
+
nk_size_t stride = nk_dtype_bits(to_dtype) / NK_BITS_PER_BYTE;
|
|
1988
|
+
nk_scalar_buffer_t tmp;
|
|
1989
|
+
for (i = 0; i < to_count; ++i) {
|
|
1990
|
+
nk_scalar_buffer_from_f64c(&from_buffers[i].f64c, &tmp, to_dtype);
|
|
1991
|
+
nk_copy_bytes_((char *)to_ptr + i * stride, &tmp, stride);
|
|
1909
1992
|
}
|
|
1910
1993
|
} break;
|
|
1911
|
-
default: break;
|
|
1912
1994
|
}
|
|
1913
1995
|
}
|
|
1914
1996
|
|
|
1915
1997
|
/**
|
|
1916
1998
|
* @brief Load 8 values from typed buffer into `buf[i].i64` (lossless widening for signed integers).
|
|
1917
1999
|
*/
|
|
1918
|
-
NK_INTERNAL void
|
|
2000
|
+
NK_INTERNAL void nk_scalar_buffers_to_i64_( //
|
|
1919
2001
|
void const *from_ptr, nk_dtype_t from_dtype, nk_size_t from_count, //
|
|
1920
2002
|
nk_scalar_buffer_t to_buffers[nk_at_least_(8)]) { //
|
|
1921
2003
|
nk_size_t i;
|
|
@@ -1938,11 +2020,12 @@ NK_INTERNAL void nk_scalar_buffers_fill_i64_( //
|
|
|
1938
2020
|
} break;
|
|
1939
2021
|
// Sub-byte: i4 - 4 bytes to 8 nibbles, sign-extend each nibble
|
|
1940
2022
|
case nk_i4_k: {
|
|
1941
|
-
|
|
2023
|
+
nk_i4x2_t const *pairs = (nk_i4x2_t const *)from_ptr;
|
|
1942
2024
|
for (i = 0; i < 4; ++i) {
|
|
1943
|
-
nk_i8_t
|
|
1944
|
-
|
|
1945
|
-
to_buffers[i * 2
|
|
2025
|
+
nk_i8_t unpacked[2];
|
|
2026
|
+
nk_i4x2_to_i8x2_serial(&pairs[i], unpacked);
|
|
2027
|
+
to_buffers[i * 2].i64 = unpacked[0];
|
|
2028
|
+
to_buffers[i * 2 + 1].i64 = unpacked[1];
|
|
1946
2029
|
}
|
|
1947
2030
|
} break;
|
|
1948
2031
|
case nk_u64_k: {
|
|
@@ -1974,8 +2057,9 @@ NK_INTERNAL void nk_scalar_buffers_fill_i64_( //
|
|
|
1974
2057
|
|
|
1975
2058
|
/**
|
|
1976
2059
|
* @brief Export 8 `buf[i].i64` values to typed buffer with saturation on downcast.
|
|
2060
|
+
* @note Only handles integer and sub-byte targets. Float/complex targets are silently skipped.
|
|
1977
2061
|
*/
|
|
1978
|
-
NK_INTERNAL void
|
|
2062
|
+
NK_INTERNAL void nk_scalar_buffers_from_i64_( //
|
|
1979
2063
|
nk_scalar_buffer_t const from_buffers[nk_at_least_(8)], //
|
|
1980
2064
|
void *to_ptr, nk_dtype_t to_dtype, nk_size_t to_count) { //
|
|
1981
2065
|
nk_size_t i;
|
|
@@ -2015,12 +2099,12 @@ NK_INTERNAL void nk_scalar_buffers_export_i64_( //
|
|
|
2015
2099
|
} break;
|
|
2016
2100
|
// Sub-byte: i4 - 8 nibbles to 4 bytes, clamp [-8,7]
|
|
2017
2101
|
case nk_i4_k: {
|
|
2018
|
-
|
|
2102
|
+
nk_i4x2_t *p = (nk_i4x2_t *)to_ptr;
|
|
2019
2103
|
for (i = 0; i < 4; ++i) {
|
|
2020
|
-
nk_i64_t
|
|
2021
|
-
|
|
2022
|
-
|
|
2023
|
-
p[i] = (nk_u8_t)(((
|
|
2104
|
+
nk_i64_t high = from_buffers[i * 2].i64, low = from_buffers[i * 2 + 1].i64;
|
|
2105
|
+
high = high > 7 ? 7 : (high < -8 ? -8 : high);
|
|
2106
|
+
low = low > 7 ? 7 : (low < -8 ? -8 : low);
|
|
2107
|
+
p[i] = (nk_u8_t)(((high & 0xF) << 4) | (low & 0xF));
|
|
2024
2108
|
}
|
|
2025
2109
|
} break;
|
|
2026
2110
|
default: break;
|
|
@@ -2030,7 +2114,7 @@ NK_INTERNAL void nk_scalar_buffers_export_i64_( //
|
|
|
2030
2114
|
/**
|
|
2031
2115
|
* @brief Load 8 values from typed buffer into `buf[i].u64` (lossless widening for unsigned integers).
|
|
2032
2116
|
*/
|
|
2033
|
-
NK_INTERNAL void
|
|
2117
|
+
NK_INTERNAL void nk_scalar_buffers_to_u64_( //
|
|
2034
2118
|
void const *from_ptr, nk_dtype_t from_dtype, nk_size_t from_count, //
|
|
2035
2119
|
nk_scalar_buffer_t to_buffers[nk_at_least_(8)]) { //
|
|
2036
2120
|
nk_size_t i;
|
|
@@ -2053,10 +2137,12 @@ NK_INTERNAL void nk_scalar_buffers_fill_u64_( //
|
|
|
2053
2137
|
} break;
|
|
2054
2138
|
// Sub-byte: u4 - 4 bytes to 8 nibbles, zero-extend
|
|
2055
2139
|
case nk_u4_k: {
|
|
2056
|
-
|
|
2140
|
+
nk_u4x2_t const *pairs = (nk_u4x2_t const *)from_ptr;
|
|
2057
2141
|
for (i = 0; i < 4; ++i) {
|
|
2058
|
-
|
|
2059
|
-
|
|
2142
|
+
nk_u8_t unpacked[2];
|
|
2143
|
+
nk_u4x2_to_u8x2_serial(&pairs[i], unpacked);
|
|
2144
|
+
to_buffers[i * 2].u64 = unpacked[0];
|
|
2145
|
+
to_buffers[i * 2 + 1].u64 = unpacked[1];
|
|
2060
2146
|
}
|
|
2061
2147
|
} break;
|
|
2062
2148
|
// Sub-byte: u1 - 1 byte to 8 bits, MSB-first
|
|
@@ -2070,8 +2156,9 @@ NK_INTERNAL void nk_scalar_buffers_fill_u64_( //
|
|
|
2070
2156
|
|
|
2071
2157
|
/**
|
|
2072
2158
|
* @brief Export 8 `buf[i].u64` values to typed buffer with saturation on downcast.
|
|
2159
|
+
* @note Only handles integer and sub-byte targets. Float/complex targets are silently skipped.
|
|
2073
2160
|
*/
|
|
2074
|
-
NK_INTERNAL void
|
|
2161
|
+
NK_INTERNAL void nk_scalar_buffers_from_u64_( //
|
|
2075
2162
|
nk_scalar_buffer_t const from_buffers[nk_at_least_(8)], //
|
|
2076
2163
|
void *to_ptr, nk_dtype_t to_dtype, nk_size_t to_count) { //
|
|
2077
2164
|
nk_size_t i;
|
|
@@ -2111,12 +2198,12 @@ NK_INTERNAL void nk_scalar_buffers_export_u64_( //
|
|
|
2111
2198
|
} break;
|
|
2112
2199
|
// Sub-byte: u4 - 8 nibbles to 4 bytes, clamp [0,15]
|
|
2113
2200
|
case nk_u4_k: {
|
|
2114
|
-
|
|
2201
|
+
nk_u4x2_t *p = (nk_u4x2_t *)to_ptr;
|
|
2115
2202
|
for (i = 0; i < 4; ++i) {
|
|
2116
|
-
nk_u64_t
|
|
2117
|
-
|
|
2118
|
-
|
|
2119
|
-
p[i] = (nk_u8_t)((
|
|
2203
|
+
nk_u64_t high = from_buffers[i * 2].u64, low = from_buffers[i * 2 + 1].u64;
|
|
2204
|
+
high = high > 15 ? 15 : high;
|
|
2205
|
+
low = low > 15 ? 15 : low;
|
|
2206
|
+
p[i] = (nk_u8_t)((high << 4) | low);
|
|
2120
2207
|
}
|
|
2121
2208
|
} break;
|
|
2122
2209
|
// Sub-byte: u1 - 8 bits to 1 byte, MSB-first, non-zero becomes 1
|
|
@@ -2130,9 +2217,24 @@ NK_INTERNAL void nk_scalar_buffers_export_u64_( //
|
|
|
2130
2217
|
}
|
|
2131
2218
|
}
|
|
2132
2219
|
|
|
2133
|
-
|
|
2220
|
+
/** @brief Widens a typed scalar from @p buf into @p result as f64 (discards imaginary part).
|
|
2221
|
+
* Safe when @p result aliases @p buf (in-place conversion). */
|
|
2222
|
+
NK_INTERNAL int nk_scalar_buffer_to_f64(nk_scalar_buffer_t const *buf, nk_dtype_t dtype, nk_f64_t *result) {
|
|
2223
|
+
nk_f64c_t temporary_f64c;
|
|
2224
|
+
int ok = nk_scalar_buffer_to_f64c(buf, dtype, &temporary_f64c);
|
|
2225
|
+
*result = temporary_f64c.real;
|
|
2226
|
+
return ok;
|
|
2227
|
+
}
|
|
2228
|
+
|
|
2229
|
+
/** @brief Narrows an f64 @p value into the appropriate typed member of @p buf.
|
|
2230
|
+
* Safe when @p value aliases @p buf (in-place: `buf->f64 = x; from_f64(&buf->f64, buf, dtype)`).
|
|
2231
|
+
* @note Integer targets go through f64 rounding — values beyond 2^53 may lose precision. */
|
|
2232
|
+
NK_INTERNAL int nk_scalar_buffer_from_f64(nk_f64_t const *value, nk_scalar_buffer_t *buf, nk_dtype_t dtype) {
|
|
2233
|
+
nk_f64c_t temporary_f64c = {*value, 0};
|
|
2234
|
+
return nk_scalar_buffer_from_f64c(&temporary_f64c, buf, dtype);
|
|
2235
|
+
}
|
|
2134
2236
|
|
|
2135
|
-
#pragma region
|
|
2237
|
+
#pragma region Public API
|
|
2136
2238
|
|
|
2137
2239
|
NK_PUBLIC void nk_cast_serial(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type) {
|
|
2138
2240
|
if (from_type == to_type) {
|
|
@@ -2162,12 +2264,12 @@ NK_PUBLIC void nk_cast_serial(void const *from, nk_dtype_t from_type, nk_size_t
|
|
|
2162
2264
|
// Both unsigned: u64 hub
|
|
2163
2265
|
if (from_family == nk_dtype_family_uint_k && to_family == nk_dtype_family_uint_k) {
|
|
2164
2266
|
for (nk_size_t b = 0; b < batches; ++b, src += from_step, dst += to_step) {
|
|
2165
|
-
|
|
2166
|
-
|
|
2267
|
+
nk_scalar_buffers_to_u64_(src, from_type, NK_BITS_PER_BYTE, bufs);
|
|
2268
|
+
nk_scalar_buffers_from_u64_(bufs, dst, to_type, NK_BITS_PER_BYTE);
|
|
2167
2269
|
}
|
|
2168
2270
|
if (tail) {
|
|
2169
|
-
|
|
2170
|
-
|
|
2271
|
+
nk_scalar_buffers_to_u64_(src, from_type, tail, bufs);
|
|
2272
|
+
nk_scalar_buffers_from_u64_(bufs, dst, to_type, tail);
|
|
2171
2273
|
}
|
|
2172
2274
|
return;
|
|
2173
2275
|
}
|
|
@@ -2176,24 +2278,24 @@ NK_PUBLIC void nk_cast_serial(void const *from, nk_dtype_t from_type, nk_size_t
|
|
|
2176
2278
|
if ((from_family == nk_dtype_family_int_k || from_family == nk_dtype_family_uint_k) &&
|
|
2177
2279
|
(to_family == nk_dtype_family_int_k || to_family == nk_dtype_family_uint_k)) {
|
|
2178
2280
|
for (nk_size_t b = 0; b < batches; ++b, src += from_step, dst += to_step) {
|
|
2179
|
-
|
|
2180
|
-
|
|
2281
|
+
nk_scalar_buffers_to_i64_(src, from_type, NK_BITS_PER_BYTE, bufs);
|
|
2282
|
+
nk_scalar_buffers_from_i64_(bufs, dst, to_type, NK_BITS_PER_BYTE);
|
|
2181
2283
|
}
|
|
2182
2284
|
if (tail) {
|
|
2183
|
-
|
|
2184
|
-
|
|
2285
|
+
nk_scalar_buffers_to_i64_(src, from_type, tail, bufs);
|
|
2286
|
+
nk_scalar_buffers_from_i64_(bufs, dst, to_type, tail);
|
|
2185
2287
|
}
|
|
2186
2288
|
return;
|
|
2187
2289
|
}
|
|
2188
2290
|
|
|
2189
2291
|
// Everything else: f64c hub (floats, complex, cross-category)
|
|
2190
2292
|
for (nk_size_t b = 0; b < batches; ++b, src += from_step, dst += to_step) {
|
|
2191
|
-
|
|
2192
|
-
|
|
2293
|
+
nk_scalar_buffers_to_f64c_(src, from_type, NK_BITS_PER_BYTE, bufs);
|
|
2294
|
+
nk_scalar_buffers_from_f64c_(bufs, dst, to_type, NK_BITS_PER_BYTE);
|
|
2193
2295
|
}
|
|
2194
2296
|
if (tail) {
|
|
2195
|
-
|
|
2196
|
-
|
|
2297
|
+
nk_scalar_buffers_to_f64c_(src, from_type, tail, bufs);
|
|
2298
|
+
nk_scalar_buffers_from_f64c_(bufs, dst, to_type, tail);
|
|
2197
2299
|
}
|
|
2198
2300
|
}
|
|
2199
2301
|
|
|
@@ -2225,35 +2327,7 @@ NK_PUBLIC void nk_e3m2_to_bf16(nk_e3m2_t const *src, nk_bf16_t *dest) {
|
|
|
2225
2327
|
nk_f32_to_bf16_serial(&temp, dest);
|
|
2226
2328
|
}
|
|
2227
2329
|
|
|
2228
|
-
|
|
2229
|
-
* @brief Convert i4 (4-bit signed integer, -8 to 7) to i8.
|
|
2230
|
-
*
|
|
2231
|
-
* Nibbles are packed: low nibble in bits [0:3], high nibble in bits [4:7].
|
|
2232
|
-
* Sign extension: XOR with 8 then subtract 8 converts unsigned nibble to signed.
|
|
2233
|
-
*/
|
|
2234
|
-
NK_PUBLIC void nk_i4_to_i8_serial_(nk_i4x2_t const *src, nk_i8_t *dest, nk_size_t count) {
|
|
2235
|
-
nk_u8_t const *bytes = (nk_u8_t const *)src;
|
|
2236
|
-
for (nk_size_t i = 0; i < count; ++i) {
|
|
2237
|
-
nk_u8_t byte = bytes[i / 2];
|
|
2238
|
-
nk_u8_t nibble = (i % 2 == 0) ? (byte & 0x0F) : (byte >> 4);
|
|
2239
|
-
dest[i] = (nk_i8_t)((nibble ^ 8) - 8); // Sign extend: 0-7 → 0-7, 8-15 → -8 to -1
|
|
2240
|
-
}
|
|
2241
|
-
}
|
|
2242
|
-
|
|
2243
|
-
/**
|
|
2244
|
-
* @brief Convert u4 (4-bit unsigned integer, 0 to 15) to u8.
|
|
2245
|
-
*
|
|
2246
|
-
* Nibbles are packed: low nibble in bits [0:3], high nibble in bits [4:7].
|
|
2247
|
-
*/
|
|
2248
|
-
NK_PUBLIC void nk_u4_to_u8_serial_(nk_u4x2_t const *src, nk_u8_t *dest, nk_size_t count) {
|
|
2249
|
-
nk_u8_t const *bytes = (nk_u8_t const *)src;
|
|
2250
|
-
for (nk_size_t i = 0; i < count; ++i) {
|
|
2251
|
-
nk_u8_t byte = bytes[i / 2];
|
|
2252
|
-
dest[i] = (i % 2 == 0) ? (byte & 0x0F) : (byte >> 4);
|
|
2253
|
-
}
|
|
2254
|
-
}
|
|
2255
|
-
|
|
2256
|
-
#pragma endregion - Public API
|
|
2330
|
+
#pragma endregion Public API
|
|
2257
2331
|
|
|
2258
2332
|
#if defined(__cplusplus)
|
|
2259
2333
|
} // extern "C"
|