numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,2262 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SWAR-accelerated Type Conversions for SIMD-free CPUs.
|
|
3
|
+
* @file include/numkong/cast/serial.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date January 2, 2026
|
|
6
|
+
*/
|
|
7
|
+
#ifndef NK_CAST_SERIAL_H
|
|
8
|
+
#define NK_CAST_SERIAL_H
|
|
9
|
+
|
|
10
|
+
#include "numkong/types.h"
|
|
11
|
+
|
|
12
|
+
#if defined(__cplusplus)
|
|
13
|
+
extern "C" {
|
|
14
|
+
#endif
|
|
15
|
+
|
|
16
|
+
#pragma region - Type Punned Loads and Stores
|
|
17
|
+
|
|
18
|
+
/** @brief Type-agnostic 32-bit full load (scalar). */
|
|
19
|
+
NK_INTERNAL void nk_load_b32_serial_(void const *src, nk_b32_vec_t *dst) { dst->u32 = *(nk_u32_t const *)src; }
|
|
20
|
+
|
|
21
|
+
/** @brief Type-agnostic 32-bit full store (scalar). */
|
|
22
|
+
NK_INTERNAL void nk_store_b32_serial_(nk_b32_vec_t const *src, void *dst) { *(nk_u32_t *)dst = src->u32; }
|
|
23
|
+
|
|
24
|
+
/** @brief Type-agnostic 128-bit store (serial, word-by-word). */
|
|
25
|
+
NK_INTERNAL void nk_store_b128_serial_(nk_b128_vec_t const *src, void *dst) {
|
|
26
|
+
nk_u64_t *d = (nk_u64_t *)dst;
|
|
27
|
+
d[0] = src->u64s[0];
|
|
28
|
+
d[1] = src->u64s[1];
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
/** @brief Type-agnostic 256-bit store (serial, word-by-word). */
|
|
32
|
+
NK_INTERNAL void nk_store_b256_serial_(nk_b256_vec_t const *src, void *dst) {
|
|
33
|
+
nk_u64_t *d = (nk_u64_t *)dst;
|
|
34
|
+
d[0] = src->u64s[0];
|
|
35
|
+
d[1] = src->u64s[1];
|
|
36
|
+
d[2] = src->u64s[2];
|
|
37
|
+
d[3] = src->u64s[3];
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
#pragma endregion - Type Punned Loads and Stores
|
|
41
|
+
|
|
42
|
+
/**
|
|
43
|
+
* @brief Expands an `f16` (IEEE-754 16-bit) to a `float`.
|
|
44
|
+
*
|
|
45
|
+
* Handles all IEEE-754 edge cases:
|
|
46
|
+
*
|
|
47
|
+
* Input F16 Hex F32 Hex Description
|
|
48
|
+
* +0 0x0000 0x00000000 Positive zero
|
|
49
|
+
* -0 0x8000 0x80000000 Negative zero
|
|
50
|
+
* +inf 0x7C00 0x7F800000 Positive infinity
|
|
51
|
+
* -inf 0xFC00 0xFF800000 Negative infinity
|
|
52
|
+
* NaN 0x7E00 0x7FC00000 Quiet NaN (payload preserved)
|
|
53
|
+
* Min normal 0x0400 0x38800000 2⁻¹⁴
|
|
54
|
+
* Max normal 0x7BFF 0x477FE000 65504
|
|
55
|
+
* Min denorm 0x0001 0x33800000 2⁻²⁴
|
|
56
|
+
* Max denorm 0x03FF 0x387FC000 2⁻¹⁴ - 2⁻²⁴
|
|
57
|
+
*
|
|
58
|
+
* https://stackoverflow.com/a/60047308
|
|
59
|
+
* https://gist.github.com/milhidaka/95863906fe828198f47991c813dbe233
|
|
60
|
+
* https://github.com/OpenCyphal/libcanard/blob/636795f4bc395f56af8d2c61d3757b5e762bb9e5/canard.c#L811-L834
|
|
61
|
+
*/
|
|
62
|
+
NK_PUBLIC void nk_f16_to_f32_serial(nk_f16_t const *src, nk_f32_t *dest) {
|
|
63
|
+
#if NK_NATIVE_F16
|
|
64
|
+
*dest = (nk_f32_t)(*src);
|
|
65
|
+
#else
|
|
66
|
+
unsigned short x;
|
|
67
|
+
nk_copy_bytes_(&x, src, 2);
|
|
68
|
+
|
|
69
|
+
unsigned int sign = (x >> 15) & 1;
|
|
70
|
+
unsigned int exponent = (x >> 10) & 0x1F;
|
|
71
|
+
unsigned int mantissa = x & 0x03FF;
|
|
72
|
+
|
|
73
|
+
nk_fui32_t conv;
|
|
74
|
+
|
|
75
|
+
if (exponent == 0) {
|
|
76
|
+
if (mantissa == 0) {
|
|
77
|
+
// Zero (preserve sign)
|
|
78
|
+
conv.u = sign << 31;
|
|
79
|
+
}
|
|
80
|
+
else {
|
|
81
|
+
// Denormal: value = mantissa × 2⁻²⁴
|
|
82
|
+
// Use FPU normalization, then subtract 24 from exponent
|
|
83
|
+
nk_fui32_t temp;
|
|
84
|
+
temp.f = (float)mantissa;
|
|
85
|
+
conv.u = (sign << 31) | (temp.u - 0x0C000000);
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
else if (exponent == 31) {
|
|
89
|
+
// Infinity (mantissa=0) or NaN (mantissa!=0)
|
|
90
|
+
conv.u = (sign << 31) | 0x7F800000 | (mantissa << 13);
|
|
91
|
+
}
|
|
92
|
+
else {
|
|
93
|
+
// Normal: rebias exponent (127-15=112), shift mantissa
|
|
94
|
+
conv.u = (sign << 31) | ((exponent + 112) << 23) | (mantissa << 13);
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
*dest = conv.f;
|
|
98
|
+
#endif
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
/**
|
|
102
|
+
* @brief Compresses a `float` to an `f16` (IEEE-754 16-bit).
|
|
103
|
+
*
|
|
104
|
+
* Handles all IEEE-754 edge cases with round-to-nearest:
|
|
105
|
+
*
|
|
106
|
+
* Input F32 Hex F16 Hex Description
|
|
107
|
+
* +0 0x00000000 0x0000 Positive zero
|
|
108
|
+
* -0 0x80000000 0x8000 Negative zero
|
|
109
|
+
* +inf 0x7F800000 0x7C00 Positive infinity
|
|
110
|
+
* -inf 0xFF800000 0xFC00 Negative infinity
|
|
111
|
+
* NaN 0x7FC00000 0x7E00 Quiet NaN (payload truncated)
|
|
112
|
+
* 1.0 0x3F800000 0x3C00 Normal number
|
|
113
|
+
* 65504 0x477FE000 0x7BFF Max f16 normal
|
|
114
|
+
* 65520+ >0x477FE000 0x7C00 Overflow → infinity
|
|
115
|
+
* 2⁻¹⁴ 0x38800000 0x0400 Min f16 normal
|
|
116
|
+
* 2⁻²⁴ 0x33800000 0x0001 Min f16 denormal
|
|
117
|
+
* <2⁻²⁵ <0x33000000 0x0000 Underflow → zero
|
|
118
|
+
*
|
|
119
|
+
* https://stackoverflow.com/a/60047308
|
|
120
|
+
* https://gist.github.com/milhidaka/95863906fe828198f47991c813dbe233
|
|
121
|
+
* https://github.com/OpenCyphal/libcanard/blob/636795f4bc395f56af8d2c61d3757b5e762bb9e5/canard.c#L811-L834
|
|
122
|
+
*/
|
|
123
|
+
NK_PUBLIC void nk_f32_to_f16_serial(nk_f32_t const *src, nk_f16_t *dest) {
|
|
124
|
+
#if NK_NATIVE_F16
|
|
125
|
+
*dest = (nk_f16_t)(*src);
|
|
126
|
+
#else
|
|
127
|
+
nk_fui32_t conv;
|
|
128
|
+
conv.f = *src;
|
|
129
|
+
|
|
130
|
+
unsigned int sign = (conv.u >> 31) & 1;
|
|
131
|
+
unsigned int exponent = (conv.u >> 23) & 0xFF;
|
|
132
|
+
unsigned int mantissa = conv.u & 0x007FFFFF;
|
|
133
|
+
|
|
134
|
+
unsigned short result;
|
|
135
|
+
|
|
136
|
+
if (exponent == 0) {
|
|
137
|
+
// Zero or f32 denormal → f16 zero
|
|
138
|
+
result = (unsigned short)(sign << 15);
|
|
139
|
+
}
|
|
140
|
+
else if (exponent == 255) {
|
|
141
|
+
// Infinity or NaN
|
|
142
|
+
unsigned short payload = (unsigned short)(mantissa >> 13);
|
|
143
|
+
if (mantissa != 0 && payload == 0) payload = 1; // Preserve NaN-ness
|
|
144
|
+
result = (unsigned short)((sign << 15) | 0x7C00 | payload);
|
|
145
|
+
}
|
|
146
|
+
else if (exponent <= 102) {
|
|
147
|
+
// Below or at f16 denormal threshold
|
|
148
|
+
// exp=102 with mant=0 is exactly 2^-25 (tie point, rounds to 0 per round-to-even)
|
|
149
|
+
// exp=102 with mant>0 is above tie point (rounds to smallest denormal 0x0001)
|
|
150
|
+
if (exponent == 102 && mantissa > 0) result = (unsigned short)((sign << 15) | 0x0001);
|
|
151
|
+
else result = (unsigned short)(sign << 15);
|
|
152
|
+
}
|
|
153
|
+
else if (exponent < 113) {
|
|
154
|
+
// F16 denormal range (exp 103-112) with IEEE 754 round-to-nearest-even
|
|
155
|
+
unsigned int shift = 113 - exponent;
|
|
156
|
+
unsigned int shift_amount = shift + 13;
|
|
157
|
+
unsigned long long full_mant = 0x00800000ULL | mantissa;
|
|
158
|
+
|
|
159
|
+
// Extract result before rounding
|
|
160
|
+
unsigned int mant = (unsigned int)(full_mant >> shift_amount);
|
|
161
|
+
|
|
162
|
+
// IEEE 754 round-to-nearest-even: round up if round_bit is set AND
|
|
163
|
+
// (sticky_bits are nonzero OR result is odd)
|
|
164
|
+
unsigned int round_bit = (full_mant >> (shift_amount - 1)) & 1;
|
|
165
|
+
unsigned long long sticky_bits = full_mant & ((1ULL << (shift_amount - 1)) - 1);
|
|
166
|
+
|
|
167
|
+
if (round_bit && (sticky_bits || (mant & 1))) mant++;
|
|
168
|
+
|
|
169
|
+
result = (unsigned short)((sign << 15) | mant);
|
|
170
|
+
}
|
|
171
|
+
else if (exponent < 143) {
|
|
172
|
+
// Normal f16 range with IEEE 754 round-to-nearest-even
|
|
173
|
+
unsigned int f16_exp = exponent - 112;
|
|
174
|
+
unsigned int f16_mant = mantissa >> 13;
|
|
175
|
+
|
|
176
|
+
// IEEE 754 rounding: check round bit (bit 12) and sticky bits (bits 0-11)
|
|
177
|
+
unsigned int round_bit = (mantissa >> 12) & 1;
|
|
178
|
+
unsigned int sticky_bits = mantissa & 0xFFF;
|
|
179
|
+
|
|
180
|
+
if (round_bit && (sticky_bits || (f16_mant & 1))) {
|
|
181
|
+
f16_mant++;
|
|
182
|
+
if (f16_mant > 0x3FF) f16_mant = 0, f16_exp++;
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
if (f16_exp > 30) result = (unsigned short)((sign << 15) | 0x7C00);
|
|
186
|
+
else result = (unsigned short)((sign << 15) | (f16_exp << 10) | f16_mant);
|
|
187
|
+
}
|
|
188
|
+
else {
|
|
189
|
+
// Overflow → infinity
|
|
190
|
+
result = (unsigned short)((sign << 15) | 0x7C00);
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
nk_copy_bytes_(dest, &result, 2);
|
|
194
|
+
#endif
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
/**
|
|
198
|
+
* @brief For compilers that don't natively support the `__bf16` type,
|
|
199
|
+
* upcasts contents into a more conventional `float`.
|
|
200
|
+
*
|
|
201
|
+
* https://stackoverflow.com/questions/55253233/convert-fp32-to-bfloat16-in-c/55254307#55254307
|
|
202
|
+
* https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus
|
|
203
|
+
*/
|
|
204
|
+
NK_PUBLIC void nk_bf16_to_f32_serial(nk_bf16_t const *src, nk_f32_t *dest) {
|
|
205
|
+
#if NK_NATIVE_BF16
|
|
206
|
+
*dest = (nk_f32_t)(*src);
|
|
207
|
+
#else
|
|
208
|
+
unsigned short x;
|
|
209
|
+
nk_copy_bytes_(&x, src, 2);
|
|
210
|
+
nk_fui32_t conv;
|
|
211
|
+
conv.u = x << 16; // Zero extends the mantissa
|
|
212
|
+
*dest = conv.f;
|
|
213
|
+
#endif
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
/**
|
|
217
|
+
* @brief Compresses a `float` to a `bf16` representation.
|
|
218
|
+
*
|
|
219
|
+
* https://stackoverflow.com/questions/55253233/convert-fp32-to-bfloat16-in-c/55254307#55254307
|
|
220
|
+
* https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus
|
|
221
|
+
*/
|
|
222
|
+
NK_PUBLIC void nk_f32_to_bf16_serial(nk_f32_t const *src, nk_bf16_t *dest) {
|
|
223
|
+
#if NK_NATIVE_BF16
|
|
224
|
+
*dest = (nk_bf16_t)(*src);
|
|
225
|
+
#else
|
|
226
|
+
nk_fui32_t conv;
|
|
227
|
+
conv.f = *src;
|
|
228
|
+
// IEEE 754 round-to-nearest-even: add (0x7FFF + LSB)
|
|
229
|
+
unsigned int lsb = (conv.u >> 16) & 1;
|
|
230
|
+
conv.u += 0x7FFF + lsb;
|
|
231
|
+
conv.u >>= 16;
|
|
232
|
+
// Use an intermediate variable to ensure correct behavior on big-endian systems.
|
|
233
|
+
// Copying directly from `&conv.u` would copy the wrong bytes on big-endian,
|
|
234
|
+
// since the lower 16 bits are at offset 2, not offset 0.
|
|
235
|
+
unsigned short result = (unsigned short)conv.u;
|
|
236
|
+
nk_copy_bytes_(dest, &result, 2);
|
|
237
|
+
#endif
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
/**
|
|
241
|
+
* @brief Convert FP8 E4M3 to IEEE 754 single-precision float.
|
|
242
|
+
*
|
|
243
|
+
* E4M3 (FP8) format: 1 sign bit, 4 exponent bits (bias=7), 3 mantissa bits.
|
|
244
|
+
* Range: [-448, +448], no ∞, only two NaN encodings (0x7F, 0xFF).
|
|
245
|
+
* Subnormal values: (-1)ˢ × mantissa × 2⁻⁹ = mantissa / 512.
|
|
246
|
+
*
|
|
247
|
+
* Special value mappings (E4M3 → F32):
|
|
248
|
+
* Input E4M3 Hex F32 Hex Description
|
|
249
|
+
* +0 0x00 0x00000000 Positive zero
|
|
250
|
+
* -0 0x80 0x80000000 Negative zero
|
|
251
|
+
* +NaN 0x7F 0x7FC00000 Quiet NaN (exp=15, mant!=0)
|
|
252
|
+
* -NaN 0xFF 0xFFC00000 Quiet NaN (signed)
|
|
253
|
+
* +448 (max) 0x7E 0x43E00000 Max normal = 448
|
|
254
|
+
* -448 0xFE 0xC3E00000 Min normal = -448
|
|
255
|
+
* 1.0 0x38 0x3F800000 Normal (exp=7, mant=0)
|
|
256
|
+
* Min denorm 0x01 0x3B000000 1/512 = 2⁻⁹
|
|
257
|
+
* Max denorm 0x07 0x3BE00000 7/512 = 7 × 2⁻⁹
|
|
258
|
+
*
|
|
259
|
+
* References:
|
|
260
|
+
* https://arxiv.org/pdf/2209.05433 (NVIDIA/Intel/Arm FP8 paper)
|
|
261
|
+
* https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1
|
|
262
|
+
* https://onnx.ai/onnx/technical/float8.html
|
|
263
|
+
*/
|
|
264
|
+
NK_PUBLIC void nk_e4m3_to_f32_serial(nk_e4m3_t const *src, nk_f32_t *dest) {
|
|
265
|
+
nk_u8_t raw = *src;
|
|
266
|
+
nk_u32_t sign = (nk_u32_t)(raw & 0x80) << 24;
|
|
267
|
+
nk_u32_t exponent = (raw >> 3) & 0x0Fu;
|
|
268
|
+
nk_u32_t mantissa = raw & 0x07u;
|
|
269
|
+
nk_fui32_t conv;
|
|
270
|
+
|
|
271
|
+
if (exponent == 0) {
|
|
272
|
+
if (mantissa == 0) {
|
|
273
|
+
conv.u = sign;
|
|
274
|
+
*dest = conv.f;
|
|
275
|
+
return;
|
|
276
|
+
}
|
|
277
|
+
nk_f32_t value = (nk_f32_t)mantissa * (1.0f / 512.0f);
|
|
278
|
+
*dest = sign ? -value : value;
|
|
279
|
+
return;
|
|
280
|
+
}
|
|
281
|
+
// E4M3FN has no ∞. Only exp=15 && mant=7 is NaN.
|
|
282
|
+
// exp=15 && mant=0..6 are normal values (256, 288, 320, 352, 384, 416, 448).
|
|
283
|
+
if (exponent == 0x0Fu && mantissa == 7) {
|
|
284
|
+
conv.u = sign | 0x7FC00000u; // F32 quiet NaN
|
|
285
|
+
*dest = conv.f;
|
|
286
|
+
return;
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
nk_u32_t f32_exponent = (exponent + 120u) << 23;
|
|
290
|
+
nk_u32_t f32_mantissa = mantissa << 20;
|
|
291
|
+
conv.u = sign | f32_exponent | f32_mantissa;
|
|
292
|
+
*dest = conv.f;
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
/**
|
|
296
|
+
* @brief Convert IEEE 754 single-precision float to FP8 E4M3.
|
|
297
|
+
*
|
|
298
|
+
* E4M3 (FP8) format: 1 sign bit, 4 exponent bits (bias=7), 3 mantissa bits.
|
|
299
|
+
* Range: [-448, +448], no ∞, only two NaN encodings.
|
|
300
|
+
* Rounding: RNE (Round to Nearest Even) per IEEE 754 / OCP FP8 spec.
|
|
301
|
+
* Subnormal threshold: values with |x| < 2⁻⁶ use subnormal encoding.
|
|
302
|
+
*
|
|
303
|
+
* Special value mappings (F32 → E4M3):
|
|
304
|
+
* Input F32 Hex E4M3 Hex Description
|
|
305
|
+
* +0 0x00000000 0x00 Positive zero
|
|
306
|
+
* -0 0x80000000 0x80 Negative zero
|
|
307
|
+
* +inf 0x7F800000 0x7E Saturates to max (+448)
|
|
308
|
+
* -inf 0xFF800000 0xFE Saturates to min (-448)
|
|
309
|
+
* NaN 0x7FC00000 0x7F Quiet NaN
|
|
310
|
+
* 1.0 0x3F800000 0x38 Normal (exp=7, mant=0)
|
|
311
|
+
* 448+ >0x43E00000 0x7E Overflow → max
|
|
312
|
+
* 2⁻⁶ 0x3E800000 0x08 Min normal
|
|
313
|
+
* <2⁻¹² × ⁵ <0x39800000 0x00 Underflow → zero (RNE boundary)
|
|
314
|
+
*
|
|
315
|
+
* References:
|
|
316
|
+
* https://arxiv.org/pdf/2209.05433 (NVIDIA/Intel/Arm FP8 paper)
|
|
317
|
+
* https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1
|
|
318
|
+
* https://onnx.ai/onnx/technical/float8.html
|
|
319
|
+
*/
|
|
320
|
+
NK_PUBLIC void nk_f32_to_e4m3_serial(nk_f32_t const *src, nk_e4m3_t *dest) {
|
|
321
|
+
nk_f32_t x = *src;
|
|
322
|
+
nk_fui32_t conv;
|
|
323
|
+
conv.f = x;
|
|
324
|
+
nk_u32_t sign_bit = conv.u >> 31;
|
|
325
|
+
nk_u32_t abs_bits = conv.u & 0x7FFFFFFFu;
|
|
326
|
+
nk_u8_t sign = (nk_u8_t)(sign_bit << 7);
|
|
327
|
+
|
|
328
|
+
// NaN → E4M3FN NaN (0x7F or 0xFF)
|
|
329
|
+
if (abs_bits > 0x7F800000u) {
|
|
330
|
+
*dest = (nk_e4m3_t)(sign | 0x7Fu);
|
|
331
|
+
return;
|
|
332
|
+
}
|
|
333
|
+
// Infinity → saturate to max (0x7E or 0xFE), E4M3FN has no ∞
|
|
334
|
+
if (abs_bits == 0x7F800000u) {
|
|
335
|
+
*dest = (nk_e4m3_t)(sign | 0x7Eu);
|
|
336
|
+
return;
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
if (abs_bits == 0) {
|
|
340
|
+
*dest = (nk_e4m3_t)sign;
|
|
341
|
+
return;
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
nk_f32_t abs_x = sign_bit ? -x : x;
|
|
345
|
+
|
|
346
|
+
// Subnormal range: [0, 1/64). Use RNE rounding via scaled * 512.
|
|
347
|
+
// The RNE boundary between 0 and 1/512 is at 0.5/512, not 1/512.
|
|
348
|
+
if (abs_x < (1.0f / 64.0f)) {
|
|
349
|
+
nk_f32_t scaled = abs_x * 512.0f;
|
|
350
|
+
nk_i32_t mant = (nk_i32_t)scaled;
|
|
351
|
+
nk_f32_t frac = scaled - (nk_f32_t)mant;
|
|
352
|
+
if (frac > 0.5f || (frac == 0.5f && (mant & 1))) { ++mant; }
|
|
353
|
+
// If rounds to 8, promote to first normal (exp_field=1, mantissa=0)
|
|
354
|
+
if (mant > 7) {
|
|
355
|
+
*dest = (nk_e4m3_t)(sign | 0x08u);
|
|
356
|
+
return;
|
|
357
|
+
}
|
|
358
|
+
if (mant == 0) { *dest = (nk_e4m3_t)sign; }
|
|
359
|
+
else { *dest = (nk_e4m3_t)(sign | (nk_u8_t)mant); }
|
|
360
|
+
return;
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
nk_i32_t exp = (nk_i32_t)((abs_bits >> 23) & 0xFFu) - 127;
|
|
364
|
+
nk_u32_t mantissa = abs_bits & 0x7FFFFFu;
|
|
365
|
+
nk_u32_t significand = (1u << 23) | mantissa;
|
|
366
|
+
nk_i32_t shift = 23 - 3;
|
|
367
|
+
nk_u32_t remainder_mask = (1u << shift) - 1;
|
|
368
|
+
nk_u32_t remainder = significand & remainder_mask;
|
|
369
|
+
nk_u32_t halfway = 1u << (shift - 1);
|
|
370
|
+
nk_u32_t significand_rounded = significand >> shift;
|
|
371
|
+
if (remainder > halfway || (remainder == halfway && (significand_rounded & 1))) { ++significand_rounded; }
|
|
372
|
+
if (significand_rounded == (1u << (3 + 1))) {
|
|
373
|
+
significand_rounded >>= 1;
|
|
374
|
+
++exp;
|
|
375
|
+
}
|
|
376
|
+
if (exp > 8) {
|
|
377
|
+
// Saturate to max value 448 = 0x7E (exp=15, mantissa=6). Note: 0x7F is NaN in e4m3FN.
|
|
378
|
+
*dest = (nk_e4m3_t)(sign | 0x7Eu);
|
|
379
|
+
return;
|
|
380
|
+
}
|
|
381
|
+
if (exp < -6) {
|
|
382
|
+
nk_f32_t scaled = abs_x * 512.0f;
|
|
383
|
+
nk_i32_t mant = (nk_i32_t)scaled;
|
|
384
|
+
nk_f32_t frac = scaled - (nk_f32_t)mant;
|
|
385
|
+
if (frac > 0.5f || (frac == 0.5f && (mant & 1))) { ++mant; }
|
|
386
|
+
// If rounds to 8, promote to first normal (exp_field=1, mantissa=0)
|
|
387
|
+
if (mant > 7) {
|
|
388
|
+
*dest = (nk_e4m3_t)(sign | 0x08u);
|
|
389
|
+
return;
|
|
390
|
+
}
|
|
391
|
+
if (mant == 0) { *dest = (nk_e4m3_t)sign; }
|
|
392
|
+
else { *dest = (nk_e4m3_t)(sign | (nk_u8_t)mant); }
|
|
393
|
+
return;
|
|
394
|
+
}
|
|
395
|
+
|
|
396
|
+
nk_u8_t exp_field = (nk_u8_t)(exp + 7);
|
|
397
|
+
nk_u8_t mant_field = (nk_u8_t)(significand_rounded & 0x07u);
|
|
398
|
+
// For exp_field=15, clamp mantissa to 6 to avoid NaN encoding (0x7F in e4m3FN)
|
|
399
|
+
if (exp_field == 15 && mant_field > 6) { mant_field = 6; }
|
|
400
|
+
*dest = (nk_e4m3_t)(sign | (exp_field << 3) | mant_field);
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
/**
|
|
404
|
+
* @brief Convert FP8 E4M3 to IEEE 754 half-precision float.
|
|
405
|
+
*
|
|
406
|
+
* E4M3 format: 1 sign bit, 4 exponent bits (bias=7), 3 mantissa bits.
|
|
407
|
+
* F16 format: 1 sign bit, 5 exponent bits (bias=15), 10 mantissa bits.
|
|
408
|
+
*
|
|
409
|
+
* Conversion notes:
|
|
410
|
+
* - Normal values: F16_exp = E4M3_exp + 8, mantissa shifted left by 7 bits
|
|
411
|
+
* - Subnormals: mant × 2⁻⁹ (where 2⁻⁹ = 0x1800 in F16)
|
|
412
|
+
* - NaN (0x7F): maps to F16 quiet NaN (0x7E00)
|
|
413
|
+
*/
|
|
414
|
+
NK_INTERNAL void nk_e4m3_to_f16_serial(nk_e4m3_t const *src, nk_f16_t *dest) {
|
|
415
|
+
nk_u8_t raw = *src;
|
|
416
|
+
nk_u16_t sign = ((nk_u16_t)(raw & 0x80)) << 8;
|
|
417
|
+
nk_u16_t mag = raw & 0x7F;
|
|
418
|
+
nk_u16_t mant = raw & 0x07;
|
|
419
|
+
nk_u16_t exp = (raw >> 3) & 0x0F;
|
|
420
|
+
nk_fui16_t result;
|
|
421
|
+
|
|
422
|
+
if (mag == 0x7F) {
|
|
423
|
+
result.u = sign | 0x7E00; // NaN
|
|
424
|
+
}
|
|
425
|
+
else if (exp == 0) {
|
|
426
|
+
// Subnormal: mant × 2⁻⁹, where 2⁻⁹ = 0x1800 in F16
|
|
427
|
+
nk_fui16_t scale;
|
|
428
|
+
scale.u = 0x1800;
|
|
429
|
+
nk_fui16_t mant_f16;
|
|
430
|
+
mant_f16.f = (nk_f16_t)mant;
|
|
431
|
+
result.f = mant_f16.f * scale.f;
|
|
432
|
+
result.u |= sign;
|
|
433
|
+
}
|
|
434
|
+
else {
|
|
435
|
+
// Normal: F16 = sign | ((mag << 7) + 0x2000)
|
|
436
|
+
result.u = sign | ((mag << 7) + 0x2000);
|
|
437
|
+
}
|
|
438
|
+
*dest = result.f;
|
|
439
|
+
}
|
|
440
|
+
|
|
441
|
+
/**
|
|
442
|
+
* @brief Convert FP8 E5M2 to IEEE 754 single-precision float.
|
|
443
|
+
*
|
|
444
|
+
* E5M2 (FP8) format: 1 sign bit, 5 exponent bits (bias=15), 2 mantissa bits.
|
|
445
|
+
* Range: [-57344, +57344], supports infinity and NaN (IEEE 754 compatible).
|
|
446
|
+
* Subnormal values: (-1)ˢ × mantissa × 2⁻¹⁶ = mantissa / 65536.
|
|
447
|
+
*
|
|
448
|
+
* Special value mappings (E5M2 → F32):
|
|
449
|
+
* Input E5M2 Hex F32 Hex Description
|
|
450
|
+
* +0 0x00 0x00000000 Positive zero
|
|
451
|
+
* -0 0x80 0x80000000 Negative zero
|
|
452
|
+
* +inf 0x7C 0x7F800000 Positive infinity
|
|
453
|
+
* -inf 0xFC 0xFF800000 Negative infinity
|
|
454
|
+
* +NaN 0x7D-7F 0x7FC00000 Quiet NaN (exp=31, mant!=0)
|
|
455
|
+
* -NaN 0xFD-FF 0xFFC00000 Quiet NaN (signed)
|
|
456
|
+
* +57344 (max) 0x7B 0x47600000 Max normal
|
|
457
|
+
* 1.0 0x3C 0x3F800000 Normal (exp=15, mant=0)
|
|
458
|
+
* Min denorm 0x01 0x37800000 1/65536 = 2⁻¹⁶
|
|
459
|
+
* Max denorm 0x03 0x38000000 3/65536 = 3 × 2⁻¹⁶
|
|
460
|
+
*
|
|
461
|
+
* References:
|
|
462
|
+
* https://arxiv.org/pdf/2209.05433 (NVIDIA/Intel/Arm FP8 paper)
|
|
463
|
+
* https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1
|
|
464
|
+
* https://onnx.ai/onnx/technical/float8.html
|
|
465
|
+
*/
|
|
466
|
+
NK_INTERNAL void nk_e5m2_to_f32_manual_(nk_e5m2_t const *src, nk_f32_t *dest) {
|
|
467
|
+
nk_u8_t raw = *src;
|
|
468
|
+
nk_u32_t sign = (nk_u32_t)(raw & 0x80) << 24;
|
|
469
|
+
nk_u32_t exponent = (raw >> 2) & 0x1Fu;
|
|
470
|
+
nk_u32_t mantissa = raw & 0x03u;
|
|
471
|
+
nk_fui32_t conv;
|
|
472
|
+
|
|
473
|
+
if (exponent == 0) {
|
|
474
|
+
if (mantissa == 0) {
|
|
475
|
+
conv.u = sign;
|
|
476
|
+
*dest = conv.f;
|
|
477
|
+
return;
|
|
478
|
+
}
|
|
479
|
+
nk_f32_t value = (nk_f32_t)mantissa * (1.0f / 65536.0f);
|
|
480
|
+
*dest = sign ? -value : value;
|
|
481
|
+
return;
|
|
482
|
+
}
|
|
483
|
+
if (exponent == 0x1Fu) {
|
|
484
|
+
if (mantissa == 0) { conv.u = sign | 0x7F800000u; }
|
|
485
|
+
else { conv.u = sign | 0x7FC00000u; }
|
|
486
|
+
*dest = conv.f;
|
|
487
|
+
return;
|
|
488
|
+
}
|
|
489
|
+
|
|
490
|
+
nk_u32_t f32_exponent = (exponent + 112u) << 23;
|
|
491
|
+
nk_u32_t f32_mantissa = mantissa << 21;
|
|
492
|
+
conv.u = sign | f32_exponent | f32_mantissa;
|
|
493
|
+
*dest = conv.f;
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
NK_PUBLIC void nk_e5m2_to_f32_serial(nk_e5m2_t const *src, nk_f32_t *dest) {
|
|
497
|
+
static nk_u32_t const lut[128] = {
|
|
498
|
+
0x00000000, 0x37800000, 0x38000000, 0x38400000, // exp=0 sub
|
|
499
|
+
0x38800000, 0x38A00000, 0x38C00000, 0x38E00000, // exp=1
|
|
500
|
+
0x39000000, 0x39200000, 0x39400000, 0x39600000, // exp=2
|
|
501
|
+
0x39800000, 0x39A00000, 0x39C00000, 0x39E00000, // exp=3
|
|
502
|
+
0x3A000000, 0x3A200000, 0x3A400000, 0x3A600000, // exp=4
|
|
503
|
+
0x3A800000, 0x3AA00000, 0x3AC00000, 0x3AE00000, // exp=5
|
|
504
|
+
0x3B000000, 0x3B200000, 0x3B400000, 0x3B600000, // exp=6
|
|
505
|
+
0x3B800000, 0x3BA00000, 0x3BC00000, 0x3BE00000, // exp=7
|
|
506
|
+
0x3C000000, 0x3C200000, 0x3C400000, 0x3C600000, // exp=8
|
|
507
|
+
0x3C800000, 0x3CA00000, 0x3CC00000, 0x3CE00000, // exp=9
|
|
508
|
+
0x3D000000, 0x3D200000, 0x3D400000, 0x3D600000, // exp=10
|
|
509
|
+
0x3D800000, 0x3DA00000, 0x3DC00000, 0x3DE00000, // exp=11
|
|
510
|
+
0x3E000000, 0x3E200000, 0x3E400000, 0x3E600000, // exp=12
|
|
511
|
+
0x3E800000, 0x3EA00000, 0x3EC00000, 0x3EE00000, // exp=13
|
|
512
|
+
0x3F000000, 0x3F200000, 0x3F400000, 0x3F600000, // exp=14
|
|
513
|
+
0x3F800000, 0x3FA00000, 0x3FC00000, 0x3FE00000, // exp=15
|
|
514
|
+
0x40000000, 0x40200000, 0x40400000, 0x40600000, // exp=16
|
|
515
|
+
0x40800000, 0x40A00000, 0x40C00000, 0x40E00000, // exp=17
|
|
516
|
+
0x41000000, 0x41200000, 0x41400000, 0x41600000, // exp=18
|
|
517
|
+
0x41800000, 0x41A00000, 0x41C00000, 0x41E00000, // exp=19
|
|
518
|
+
0x42000000, 0x42200000, 0x42400000, 0x42600000, // exp=20
|
|
519
|
+
0x42800000, 0x42A00000, 0x42C00000, 0x42E00000, // exp=21
|
|
520
|
+
0x43000000, 0x43200000, 0x43400000, 0x43600000, // exp=22
|
|
521
|
+
0x43800000, 0x43A00000, 0x43C00000, 0x43E00000, // exp=23
|
|
522
|
+
0x44000000, 0x44200000, 0x44400000, 0x44600000, // exp=24
|
|
523
|
+
0x44800000, 0x44A00000, 0x44C00000, 0x44E00000, // exp=25
|
|
524
|
+
0x45000000, 0x45200000, 0x45400000, 0x45600000, // exp=26
|
|
525
|
+
0x45800000, 0x45A00000, 0x45C00000, 0x45E00000, // exp=27
|
|
526
|
+
0x46000000, 0x46200000, 0x46400000, 0x46600000, // exp=28
|
|
527
|
+
0x46800000, 0x46A00000, 0x46C00000, 0x46E00000, // exp=29
|
|
528
|
+
0x47000000, 0x47200000, 0x47400000, 0x47600000, // exp=30
|
|
529
|
+
0x7F800000, 0x7FC00000, 0x7FC00000, 0x7FC00000, // inf, nan
|
|
530
|
+
};
|
|
531
|
+
nk_u8_t raw = *src;
|
|
532
|
+
nk_u32_t sign = (nk_u32_t)(raw & 0x80) << 24;
|
|
533
|
+
nk_fui32_t conv;
|
|
534
|
+
conv.u = sign | lut[raw & 0x7F];
|
|
535
|
+
*dest = conv.f;
|
|
536
|
+
}
|
|
537
|
+
|
|
538
|
+
/**
|
|
539
|
+
* @brief Convert IEEE 754 single-precision float to FP8 E5M2.
|
|
540
|
+
*
|
|
541
|
+
* E5M2 (FP8) format: 1 sign bit, 5 exponent bits (bias=15), 2 mantissa bits.
|
|
542
|
+
* Range: [-57344, +57344], supports infinity and NaN (IEEE 754 compatible).
|
|
543
|
+
* Rounding: RNE (Round to Nearest Even) per IEEE 754 / OCP FP8 spec.
|
|
544
|
+
* Subnormal threshold: values with |x| < 2⁻¹⁴ use subnormal encoding.
|
|
545
|
+
*
|
|
546
|
+
* Special value mappings (F32 → E5M2):
|
|
547
|
+
* Input F32 Hex E5M2 Hex Description
|
|
548
|
+
* +0 0x00000000 0x00 Positive zero
|
|
549
|
+
* -0 0x80000000 0x80 Negative zero
|
|
550
|
+
* +inf 0x7F800000 0x7C Positive infinity
|
|
551
|
+
* -inf 0xFF800000 0xFC Negative infinity
|
|
552
|
+
* NaN 0x7FC00000 0x7D Quiet NaN
|
|
553
|
+
* 1.0 0x3F800000 0x3C Normal (exp=15, mant=0)
|
|
554
|
+
* 57344+ >0x47600000 0x7C Overflow → infinity
|
|
555
|
+
* 2⁻¹⁴ 0x38800000 0x04 Min normal
|
|
556
|
+
* <2⁻¹⁷ × ⁵ <0x36800000 0x00 Underflow → zero (RNE boundary)
|
|
557
|
+
*
|
|
558
|
+
* References:
|
|
559
|
+
* https://arxiv.org/pdf/2209.05433 (NVIDIA/Intel/Arm FP8 paper)
|
|
560
|
+
* https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1
|
|
561
|
+
* https://onnx.ai/onnx/technical/float8.html
|
|
562
|
+
*/
|
|
563
|
+
NK_PUBLIC void nk_f32_to_e5m2_serial(nk_f32_t const *src, nk_e5m2_t *dest) {
|
|
564
|
+
nk_f32_t x = *src;
|
|
565
|
+
nk_fui32_t conv;
|
|
566
|
+
conv.f = x;
|
|
567
|
+
nk_u32_t sign_bit = conv.u >> 31;
|
|
568
|
+
nk_u32_t abs_bits = conv.u & 0x7FFFFFFFu;
|
|
569
|
+
nk_u8_t sign = (nk_u8_t)(sign_bit << 7);
|
|
570
|
+
|
|
571
|
+
if (abs_bits >= 0x7F800000u) {
|
|
572
|
+
nk_u8_t mant = (abs_bits > 0x7F800000u) ? 0x01u : 0x00u;
|
|
573
|
+
*dest = (nk_e5m2_t)(sign | 0x7Cu | mant);
|
|
574
|
+
return;
|
|
575
|
+
}
|
|
576
|
+
|
|
577
|
+
if (abs_bits == 0) {
|
|
578
|
+
*dest = (nk_e5m2_t)sign;
|
|
579
|
+
return;
|
|
580
|
+
}
|
|
581
|
+
|
|
582
|
+
nk_f32_t abs_x = sign_bit ? -x : x;
|
|
583
|
+
|
|
584
|
+
// Subnormal range: [0, 1/16384). Use RNE rounding via scaled * 65536.
|
|
585
|
+
// The RNE boundary between 0 and 1/65536 is at 0.5/65536, not 1/65536.
|
|
586
|
+
if (abs_x < (1.0f / 16384.0f)) {
|
|
587
|
+
nk_f32_t scaled = abs_x * 65536.0f;
|
|
588
|
+
nk_i32_t mant = (nk_i32_t)scaled;
|
|
589
|
+
nk_f32_t frac = scaled - (nk_f32_t)mant;
|
|
590
|
+
if (frac > 0.5f || (frac == 0.5f && (mant & 1))) { ++mant; }
|
|
591
|
+
// If rounds to 4, promote to first normal (exp_field=1, mantissa=0)
|
|
592
|
+
if (mant > 3) {
|
|
593
|
+
*dest = (nk_e5m2_t)(sign | 0x04u);
|
|
594
|
+
return;
|
|
595
|
+
}
|
|
596
|
+
if (mant == 0) { *dest = (nk_e5m2_t)sign; }
|
|
597
|
+
else { *dest = (nk_e5m2_t)(sign | (nk_u8_t)mant); }
|
|
598
|
+
return;
|
|
599
|
+
}
|
|
600
|
+
|
|
601
|
+
nk_i32_t exp = (nk_i32_t)((abs_bits >> 23) & 0xFFu) - 127;
|
|
602
|
+
nk_u32_t mantissa = abs_bits & 0x7FFFFFu;
|
|
603
|
+
nk_u32_t significand = (1u << 23) | mantissa;
|
|
604
|
+
nk_i32_t shift = 23 - 2;
|
|
605
|
+
nk_u32_t remainder_mask = (1u << shift) - 1;
|
|
606
|
+
nk_u32_t remainder = significand & remainder_mask;
|
|
607
|
+
nk_u32_t halfway = 1u << (shift - 1);
|
|
608
|
+
nk_u32_t significand_rounded = significand >> shift;
|
|
609
|
+
if (remainder > halfway || (remainder == halfway && (significand_rounded & 1))) { ++significand_rounded; }
|
|
610
|
+
if (significand_rounded == (1u << (2 + 1))) {
|
|
611
|
+
significand_rounded >>= 1;
|
|
612
|
+
++exp;
|
|
613
|
+
}
|
|
614
|
+
if (exp > 15) {
|
|
615
|
+
*dest = (nk_e5m2_t)(sign | 0x7Cu);
|
|
616
|
+
return;
|
|
617
|
+
}
|
|
618
|
+
if (exp < -14) {
|
|
619
|
+
nk_f32_t scaled = abs_x * 65536.0f;
|
|
620
|
+
nk_i32_t mant = (nk_i32_t)scaled;
|
|
621
|
+
nk_f32_t frac = scaled - (nk_f32_t)mant;
|
|
622
|
+
if (frac > 0.5f || (frac == 0.5f && (mant & 1))) { ++mant; }
|
|
623
|
+
// If rounds to 4, promote to first normal (exp_field=1, mantissa=0)
|
|
624
|
+
if (mant > 3) {
|
|
625
|
+
*dest = (nk_e5m2_t)(sign | 0x04u);
|
|
626
|
+
return;
|
|
627
|
+
}
|
|
628
|
+
if (mant == 0) { *dest = (nk_e5m2_t)sign; }
|
|
629
|
+
else { *dest = (nk_e5m2_t)(sign | (nk_u8_t)mant); }
|
|
630
|
+
return;
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
nk_u8_t exp_field = (nk_u8_t)(exp + 15);
|
|
634
|
+
nk_u8_t mant_field = (nk_u8_t)(significand_rounded & 0x03u);
|
|
635
|
+
*dest = (nk_e5m2_t)(sign | (exp_field << 2) | mant_field);
|
|
636
|
+
}
|
|
637
|
+
|
|
638
|
+
/**
|
|
639
|
+
* @brief Convert FP8 E5M2 to IEEE 754 half-precision float.
|
|
640
|
+
*
|
|
641
|
+
* E5M2 format: 1 sign bit, 5 exponent bits (bias=15), 2 mantissa bits.
|
|
642
|
+
* F16 format: 1 sign bit, 5 exponent bits (bias=15), 10 mantissa bits.
|
|
643
|
+
*
|
|
644
|
+
* Since E5M2 and F16 share the same exponent bias (15), normal values
|
|
645
|
+
* convert by simply shifting the magnitude left by 8 bits.
|
|
646
|
+
*
|
|
647
|
+
* Conversion notes:
|
|
648
|
+
* - Normal values: F16 = sign | (mag << 8)
|
|
649
|
+
* - Subnormals: mant × 2⁻¹⁶ (where 2⁻¹⁶ = 0x0100 in F16)
|
|
650
|
+
* - Infinity (0x7C): maps to F16 infinity (0x7C00)
|
|
651
|
+
* - NaN (0x7D-0x7F): maps to F16 quiet NaN (0x7E00)
|
|
652
|
+
*/
|
|
653
|
+
NK_INTERNAL void nk_e5m2_to_f16_manual_(nk_e5m2_t const *src, nk_f16_t *dest) {
|
|
654
|
+
nk_u8_t raw = *src;
|
|
655
|
+
nk_u16_t sign = ((nk_u16_t)(raw & 0x80)) << 8;
|
|
656
|
+
nk_u16_t mag = raw & 0x7F;
|
|
657
|
+
nk_u16_t mant = raw & 0x03;
|
|
658
|
+
nk_u16_t exp = (raw >> 2) & 0x1F;
|
|
659
|
+
nk_fui16_t result;
|
|
660
|
+
|
|
661
|
+
if (exp == 0) {
|
|
662
|
+
if (mant == 0) {
|
|
663
|
+
result.u = sign; // Zero
|
|
664
|
+
}
|
|
665
|
+
else {
|
|
666
|
+
// Subnormal: mant × 2⁻¹⁶, where 2⁻¹⁶ = 0x0100 in F16
|
|
667
|
+
nk_fui16_t scale;
|
|
668
|
+
scale.u = 0x0100;
|
|
669
|
+
nk_fui16_t mant_f16;
|
|
670
|
+
mant_f16.f = (nk_f16_t)mant;
|
|
671
|
+
result.f = mant_f16.f * scale.f;
|
|
672
|
+
result.u |= sign;
|
|
673
|
+
}
|
|
674
|
+
}
|
|
675
|
+
else if (mag == 0x7C) {
|
|
676
|
+
result.u = sign | 0x7C00; // Infinity
|
|
677
|
+
}
|
|
678
|
+
else if (mag > 0x7C) {
|
|
679
|
+
result.u = sign | 0x7E00; // NaN
|
|
680
|
+
}
|
|
681
|
+
else {
|
|
682
|
+
// Normal: E5M2 and F16 have same bias (15), just shift magnitude
|
|
683
|
+
result.u = sign | ((nk_u16_t)mag << 8);
|
|
684
|
+
}
|
|
685
|
+
*dest = result.f;
|
|
686
|
+
}
|
|
687
|
+
|
|
688
|
+
NK_INTERNAL void nk_e5m2_to_f16_serial(nk_e5m2_t const *src, nk_f16_t *dest) {
|
|
689
|
+
static nk_u16_t const lut[128] = {
|
|
690
|
+
0x0000, 0x0100, 0x0200, 0x0300, // exp=0 sub
|
|
691
|
+
0x0400, 0x0500, 0x0600, 0x0700, // exp=1
|
|
692
|
+
0x0800, 0x0900, 0x0A00, 0x0B00, // exp=2
|
|
693
|
+
0x0C00, 0x0D00, 0x0E00, 0x0F00, // exp=3
|
|
694
|
+
0x1000, 0x1100, 0x1200, 0x1300, // exp=4
|
|
695
|
+
0x1400, 0x1500, 0x1600, 0x1700, // exp=5
|
|
696
|
+
0x1800, 0x1900, 0x1A00, 0x1B00, // exp=6
|
|
697
|
+
0x1C00, 0x1D00, 0x1E00, 0x1F00, // exp=7
|
|
698
|
+
0x2000, 0x2100, 0x2200, 0x2300, // exp=8
|
|
699
|
+
0x2400, 0x2500, 0x2600, 0x2700, // exp=9
|
|
700
|
+
0x2800, 0x2900, 0x2A00, 0x2B00, // exp=10
|
|
701
|
+
0x2C00, 0x2D00, 0x2E00, 0x2F00, // exp=11
|
|
702
|
+
0x3000, 0x3100, 0x3200, 0x3300, // exp=12
|
|
703
|
+
0x3400, 0x3500, 0x3600, 0x3700, // exp=13
|
|
704
|
+
0x3800, 0x3900, 0x3A00, 0x3B00, // exp=14
|
|
705
|
+
0x3C00, 0x3D00, 0x3E00, 0x3F00, // exp=15
|
|
706
|
+
0x4000, 0x4100, 0x4200, 0x4300, // exp=16
|
|
707
|
+
0x4400, 0x4500, 0x4600, 0x4700, // exp=17
|
|
708
|
+
0x4800, 0x4900, 0x4A00, 0x4B00, // exp=18
|
|
709
|
+
0x4C00, 0x4D00, 0x4E00, 0x4F00, // exp=19
|
|
710
|
+
0x5000, 0x5100, 0x5200, 0x5300, // exp=20
|
|
711
|
+
0x5400, 0x5500, 0x5600, 0x5700, // exp=21
|
|
712
|
+
0x5800, 0x5900, 0x5A00, 0x5B00, // exp=22
|
|
713
|
+
0x5C00, 0x5D00, 0x5E00, 0x5F00, // exp=23
|
|
714
|
+
0x6000, 0x6100, 0x6200, 0x6300, // exp=24
|
|
715
|
+
0x6400, 0x6500, 0x6600, 0x6700, // exp=25
|
|
716
|
+
0x6800, 0x6900, 0x6A00, 0x6B00, // exp=26
|
|
717
|
+
0x6C00, 0x6D00, 0x6E00, 0x6F00, // exp=27
|
|
718
|
+
0x7000, 0x7100, 0x7200, 0x7300, // exp=28
|
|
719
|
+
0x7400, 0x7500, 0x7600, 0x7700, // exp=29
|
|
720
|
+
0x7800, 0x7900, 0x7A00, 0x7B00, // exp=30
|
|
721
|
+
0x7C00, 0x7E00, 0x7E00, 0x7E00, // inf, nan
|
|
722
|
+
};
|
|
723
|
+
nk_u8_t raw = *src;
|
|
724
|
+
nk_u16_t sign = ((nk_u16_t)(raw & 0x80)) << 8;
|
|
725
|
+
nk_fui16_t result;
|
|
726
|
+
result.u = sign | lut[raw & 0x7F];
|
|
727
|
+
*dest = result.f;
|
|
728
|
+
}
|
|
729
|
+
|
|
730
|
+
/**
|
|
731
|
+
* @brief Convert FP6 E2M3FN to IEEE 754 single-precision float.
|
|
732
|
+
*
|
|
733
|
+
* E2M3FN (FP6) format: 1 sign bit, 2 exponent bits (bias=1), 3 mantissa bits.
|
|
734
|
+
* Range: [-7.5, +7.5], no infinity or NaN (OCP Microscaling FN format).
|
|
735
|
+
* Uses precomputed lookup table for all 64 possible values.
|
|
736
|
+
*
|
|
737
|
+
* References:
|
|
738
|
+
* https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
|
739
|
+
* https://arxiv.org/abs/2401.14112 (FP6-LLM)
|
|
740
|
+
*/
|
|
741
|
+
NK_INTERNAL void nk_e2m3_to_f32_manual_(nk_e2m3_t const *src, nk_f32_t *dest) {
|
|
742
|
+
nk_u8_t raw = *src;
|
|
743
|
+
nk_u32_t sign = (nk_u32_t)((raw >> 5) & 0x01u) << 31;
|
|
744
|
+
nk_u32_t exponent = (raw >> 3) & 0x03u;
|
|
745
|
+
nk_u32_t mantissa = raw & 0x07u;
|
|
746
|
+
nk_fui32_t conv;
|
|
747
|
+
|
|
748
|
+
// Handle zero
|
|
749
|
+
if (exponent == 0 && mantissa == 0) {
|
|
750
|
+
conv.u = sign;
|
|
751
|
+
*dest = conv.f;
|
|
752
|
+
return;
|
|
753
|
+
}
|
|
754
|
+
|
|
755
|
+
// Handle subnormal (exp=0, mant!=0)
|
|
756
|
+
if (exponent == 0) {
|
|
757
|
+
// Subnormal: value = 2^(1-bias) * (mantissa / 2^p) = 2^0 * (mantissa / 8) = mantissa / 8
|
|
758
|
+
nk_f32_t value = (nk_f32_t)mantissa * (1.0f / 8.0f);
|
|
759
|
+
*dest = sign ? -value : value;
|
|
760
|
+
return;
|
|
761
|
+
}
|
|
762
|
+
|
|
763
|
+
// Normal values: rebias from E2M3 (bias=1) to F32 (bias=127)
|
|
764
|
+
// E2M3 exp range: 1-3 (unbiased: 0-2)
|
|
765
|
+
// F32 needs: (e2m3_exp - 1) + 127 = e2m3_exp + 126
|
|
766
|
+
nk_u32_t f32_exponent = (exponent + 126u) << 23;
|
|
767
|
+
nk_u32_t f32_mantissa = mantissa << 20;
|
|
768
|
+
conv.u = sign | f32_exponent | f32_mantissa;
|
|
769
|
+
*dest = conv.f;
|
|
770
|
+
}
|
|
771
|
+
|
|
772
|
+
NK_PUBLIC void nk_e2m3_to_f32_serial(nk_e2m3_t const *src, nk_f32_t *dest) {
|
|
773
|
+
static nk_u32_t const lut[32] = {
|
|
774
|
+
0x00000000, 0x3E000000, 0x3E800000, 0x3EC00000, 0x3F000000, 0x3F200000, 0x3F400000, 0x3F600000, // exp=0 sub
|
|
775
|
+
0x3F800000, 0x3F900000, 0x3FA00000, 0x3FB00000, 0x3FC00000, 0x3FD00000, 0x3FE00000, 0x3FF00000, // exp=1
|
|
776
|
+
0x40000000, 0x40100000, 0x40200000, 0x40300000, 0x40400000, 0x40500000, 0x40600000, 0x40700000, // exp=2
|
|
777
|
+
0x40800000, 0x40900000, 0x40A00000, 0x40B00000, 0x40C00000, 0x40D00000, 0x40E00000, 0x40F00000, // exp=3
|
|
778
|
+
};
|
|
779
|
+
nk_u8_t raw = *src;
|
|
780
|
+
nk_u32_t sign = (nk_u32_t)((raw >> 5) & 0x01u) << 31;
|
|
781
|
+
nk_fui32_t conv;
|
|
782
|
+
conv.u = sign | lut[raw & 0x1F];
|
|
783
|
+
*dest = conv.f;
|
|
784
|
+
}
|
|
785
|
+
|
|
786
|
+
/**
|
|
787
|
+
* @brief Convert IEEE 754 single-precision float to FP6 E2M3FN.
|
|
788
|
+
*
|
|
789
|
+
* E2M3FN (FP6) format: 1 sign bit, 2 exponent bits (bias=1), 3 mantissa bits.
|
|
790
|
+
* Range: [-7.5, +7.5], no ∞ or NaN. Saturates to max on overflow.
|
|
791
|
+
* Rounding: RNE (Round to Nearest Even) per IEEE 754.
|
|
792
|
+
* Subnormal threshold: values with |x| < 0.5 use subnormal encoding.
|
|
793
|
+
*/
|
|
794
|
+
NK_PUBLIC void nk_f32_to_e2m3_serial(nk_f32_t const *src, nk_e2m3_t *dest) {
|
|
795
|
+
nk_f32_t x = *src;
|
|
796
|
+
nk_fui32_t conv;
|
|
797
|
+
conv.f = x;
|
|
798
|
+
nk_u32_t sign_bit = conv.u >> 31;
|
|
799
|
+
nk_u32_t abs_bits = conv.u & 0x7FFFFFFFu;
|
|
800
|
+
nk_u8_t sign = (nk_u8_t)(sign_bit << 5);
|
|
801
|
+
|
|
802
|
+
// Zero
|
|
803
|
+
if (abs_bits == 0) {
|
|
804
|
+
*dest = (nk_e2m3_t)sign;
|
|
805
|
+
return;
|
|
806
|
+
}
|
|
807
|
+
|
|
808
|
+
nk_f32_t abs_x = sign_bit ? -x : x;
|
|
809
|
+
|
|
810
|
+
// Clamp to E2M3FN range [-7.5, 7.5]
|
|
811
|
+
// Max value: exp=3, mant=7 → (1 + 7/8) * 2^(3-1) = 1.875 * 4 = 7.5
|
|
812
|
+
if (abs_x >= 7.5f) {
|
|
813
|
+
*dest = (nk_e2m3_t)(sign | 0x1Fu); // Max: 0b011111
|
|
814
|
+
return;
|
|
815
|
+
}
|
|
816
|
+
|
|
817
|
+
// Subnormal range: [0, 1.0). exp=0, mant encodes value/0.125
|
|
818
|
+
if (abs_x < 1.0f) {
|
|
819
|
+
nk_f32_t scaled = abs_x * 8.0f; // Scale to mantissa range [0, 8)
|
|
820
|
+
nk_i32_t mant = (nk_i32_t)scaled;
|
|
821
|
+
nk_f32_t frac = scaled - (nk_f32_t)mant;
|
|
822
|
+
// RNE rounding
|
|
823
|
+
if (frac > 0.5f || (frac == 0.5f && (mant & 1))) { ++mant; }
|
|
824
|
+
// If rounds to 8, promote to first normal (exp=1, mant=0)
|
|
825
|
+
if (mant > 7) {
|
|
826
|
+
*dest = (nk_e2m3_t)(sign | 0x08u);
|
|
827
|
+
return;
|
|
828
|
+
}
|
|
829
|
+
*dest = (nk_e2m3_t)(sign | (nk_u8_t)mant);
|
|
830
|
+
return;
|
|
831
|
+
}
|
|
832
|
+
|
|
833
|
+
// Normal range: extract exponent and mantissa
|
|
834
|
+
nk_i32_t exp = (nk_i32_t)((abs_bits >> 23) & 0xFFu) - 127;
|
|
835
|
+
nk_u32_t mantissa = abs_bits & 0x7FFFFFu;
|
|
836
|
+
nk_u32_t significand = (1u << 23) | mantissa;
|
|
837
|
+
|
|
838
|
+
// Round mantissa from 23 to 3 bits
|
|
839
|
+
nk_i32_t shift = 23 - 3;
|
|
840
|
+
nk_u32_t remainder_mask = (1u << shift) - 1;
|
|
841
|
+
nk_u32_t remainder = significand & remainder_mask;
|
|
842
|
+
nk_u32_t halfway = 1u << (shift - 1);
|
|
843
|
+
nk_u32_t significand_rounded = significand >> shift;
|
|
844
|
+
|
|
845
|
+
// RNE rounding
|
|
846
|
+
if (remainder > halfway || (remainder == halfway && (significand_rounded & 1))) { ++significand_rounded; }
|
|
847
|
+
|
|
848
|
+
// Handle carry into exponent
|
|
849
|
+
if (significand_rounded == (1u << 4)) {
|
|
850
|
+
significand_rounded >>= 1;
|
|
851
|
+
++exp;
|
|
852
|
+
}
|
|
853
|
+
|
|
854
|
+
// Rebias exponent: e2m3_exp = f32_exp + 1
|
|
855
|
+
nk_i32_t e2m3_exp = exp + 1;
|
|
856
|
+
|
|
857
|
+
// Clamp to valid range
|
|
858
|
+
if (e2m3_exp > 3) {
|
|
859
|
+
*dest = (nk_e2m3_t)(sign | 0x1Fu); // Max value
|
|
860
|
+
return;
|
|
861
|
+
}
|
|
862
|
+
if (e2m3_exp < 0) {
|
|
863
|
+
*dest = (nk_e2m3_t)sign; // Underflow to zero
|
|
864
|
+
return;
|
|
865
|
+
}
|
|
866
|
+
|
|
867
|
+
nk_u8_t exp_field = (nk_u8_t)e2m3_exp;
|
|
868
|
+
nk_u8_t mant_field = (nk_u8_t)(significand_rounded & 0x07u);
|
|
869
|
+
*dest = (nk_e2m3_t)(sign | (exp_field << 3) | mant_field);
|
|
870
|
+
}
|
|
871
|
+
|
|
872
|
+
/**
|
|
873
|
+
* @brief Convert FP6 E3M2FN to IEEE 754 single-precision float.
|
|
874
|
+
*
|
|
875
|
+
* E3M2FN (FP6) format: 1 sign bit, 3 exponent bits (bias=3), 2 mantissa bits.
|
|
876
|
+
* Range: [-28, +28], no infinity or NaN (OCP Microscaling FN format).
|
|
877
|
+
*/
|
|
878
|
+
NK_INTERNAL void nk_e3m2_to_f32_manual_(nk_e3m2_t const *src, nk_f32_t *dest) {
|
|
879
|
+
nk_u8_t raw = *src;
|
|
880
|
+
nk_u32_t sign = (nk_u32_t)((raw >> 5) & 0x01u) << 31;
|
|
881
|
+
nk_u32_t exponent = (raw >> 2) & 0x07u;
|
|
882
|
+
nk_u32_t mantissa = raw & 0x03u;
|
|
883
|
+
nk_fui32_t conv;
|
|
884
|
+
|
|
885
|
+
// Handle zero
|
|
886
|
+
if (exponent == 0 && mantissa == 0) {
|
|
887
|
+
conv.u = sign;
|
|
888
|
+
*dest = conv.f;
|
|
889
|
+
return;
|
|
890
|
+
}
|
|
891
|
+
|
|
892
|
+
// Handle subnormal (exp=0, mant!=0)
|
|
893
|
+
if (exponent == 0) {
|
|
894
|
+
// Subnormal: value = 2^(-2) * (mantissa / 4)
|
|
895
|
+
nk_f32_t value = (nk_f32_t)mantissa * (1.0f / 16.0f); // 2^(-2) * (1/4) = 1/16
|
|
896
|
+
*dest = sign ? -value : value;
|
|
897
|
+
return;
|
|
898
|
+
}
|
|
899
|
+
|
|
900
|
+
// Normal values: rebias from E3M2 (bias=3) to F32 (bias=127)
|
|
901
|
+
// E3M2 exp range: 1-7 (unbiased: -2 to 4)
|
|
902
|
+
// F32 needs: (e3m2_exp - 3) + 127 = e3m2_exp + 124
|
|
903
|
+
nk_u32_t f32_exponent = (exponent + 124u) << 23;
|
|
904
|
+
nk_u32_t f32_mantissa = mantissa << 21;
|
|
905
|
+
conv.u = sign | f32_exponent | f32_mantissa;
|
|
906
|
+
*dest = conv.f;
|
|
907
|
+
}
|
|
908
|
+
|
|
909
|
+
NK_PUBLIC void nk_e3m2_to_f32_serial(nk_e3m2_t const *src, nk_f32_t *dest) {
|
|
910
|
+
static nk_u32_t const lut[32] = {
|
|
911
|
+
0x00000000, 0x3D800000, 0x3E000000, 0x3E400000, // exp=0 sub
|
|
912
|
+
0x3E800000, 0x3EA00000, 0x3EC00000, 0x3EE00000, // exp=1
|
|
913
|
+
0x3F000000, 0x3F200000, 0x3F400000, 0x3F600000, // exp=2
|
|
914
|
+
0x3F800000, 0x3FA00000, 0x3FC00000, 0x3FE00000, // exp=3
|
|
915
|
+
0x40000000, 0x40200000, 0x40400000, 0x40600000, // exp=4
|
|
916
|
+
0x40800000, 0x40A00000, 0x40C00000, 0x40E00000, // exp=5
|
|
917
|
+
0x41000000, 0x41200000, 0x41400000, 0x41600000, // exp=6
|
|
918
|
+
0x41800000, 0x41A00000, 0x41C00000, 0x41E00000, // exp=7
|
|
919
|
+
};
|
|
920
|
+
nk_u8_t raw = *src;
|
|
921
|
+
nk_u32_t sign = (nk_u32_t)((raw >> 5) & 0x01u) << 31;
|
|
922
|
+
nk_fui32_t conv;
|
|
923
|
+
conv.u = sign | lut[raw & 0x1F];
|
|
924
|
+
*dest = conv.f;
|
|
925
|
+
}
|
|
926
|
+
|
|
927
|
+
/**
|
|
928
|
+
* @brief Convert IEEE 754 single-precision float to FP6 E3M2FN.
|
|
929
|
+
*
|
|
930
|
+
* E3M2FN (FP6) format: 1 sign bit, 3 exponent bits (bias=3), 2 mantissa bits.
|
|
931
|
+
* Range: [-28, +28], no ∞ or NaN. Saturates to max on overflow.
|
|
932
|
+
* Rounding: RNE (Round to Nearest Even) per IEEE 754.
|
|
933
|
+
* Subnormal threshold: values with |x| < 0.25 use subnormal encoding.
|
|
934
|
+
*/
|
|
935
|
+
NK_PUBLIC void nk_f32_to_e3m2_serial(nk_f32_t const *src, nk_e3m2_t *dest) {
|
|
936
|
+
nk_f32_t x = *src;
|
|
937
|
+
nk_fui32_t conv;
|
|
938
|
+
conv.f = x;
|
|
939
|
+
nk_u32_t sign_bit = conv.u >> 31;
|
|
940
|
+
nk_u32_t abs_bits = conv.u & 0x7FFFFFFFu;
|
|
941
|
+
nk_u8_t sign = (nk_u8_t)(sign_bit << 5);
|
|
942
|
+
|
|
943
|
+
// Zero
|
|
944
|
+
if (abs_bits == 0) {
|
|
945
|
+
*dest = (nk_e3m2_t)sign;
|
|
946
|
+
return;
|
|
947
|
+
}
|
|
948
|
+
|
|
949
|
+
nk_f32_t abs_x = sign_bit ? -x : x;
|
|
950
|
+
|
|
951
|
+
// Clamp to E3M2FN range [-28, 28]
|
|
952
|
+
// Max value: exp=7, mant=2 → (1 + 2/4) * 2^(7-3) = 1.5 * 16 = 24
|
|
953
|
+
// Actually max is exp=7, mant=3 → (1 + 3/4) * 2⁴ = 1.75 * 16 = 28
|
|
954
|
+
if (abs_x >= 28.0f) {
|
|
955
|
+
*dest = (nk_e3m2_t)(sign | 0x1Fu); // Max: 0b011111 (exp=7, mant=3)
|
|
956
|
+
return;
|
|
957
|
+
}
|
|
958
|
+
|
|
959
|
+
// Subnormal range: [0, 0.25). exp=0, mant encodes value/0.0625
|
|
960
|
+
if (abs_x < 0.25f) {
|
|
961
|
+
nk_f32_t scaled = abs_x * 16.0f; // Scale to mantissa range [0, 4)
|
|
962
|
+
nk_i32_t mant = (nk_i32_t)scaled;
|
|
963
|
+
nk_f32_t frac = scaled - (nk_f32_t)mant;
|
|
964
|
+
// RNE rounding
|
|
965
|
+
if (frac > 0.5f || (frac == 0.5f && (mant & 1))) { ++mant; }
|
|
966
|
+
// If rounds to 4, promote to first normal (exp=1, mant=0)
|
|
967
|
+
if (mant > 3) {
|
|
968
|
+
*dest = (nk_e3m2_t)(sign | 0x04u);
|
|
969
|
+
return;
|
|
970
|
+
}
|
|
971
|
+
*dest = (nk_e3m2_t)(sign | (nk_u8_t)mant);
|
|
972
|
+
return;
|
|
973
|
+
}
|
|
974
|
+
|
|
975
|
+
// Normal range: extract exponent and mantissa
|
|
976
|
+
nk_i32_t exp = (nk_i32_t)((abs_bits >> 23) & 0xFFu) - 127;
|
|
977
|
+
nk_u32_t mantissa = abs_bits & 0x7FFFFFu;
|
|
978
|
+
nk_u32_t significand = (1u << 23) | mantissa;
|
|
979
|
+
|
|
980
|
+
// Round mantissa from 23 to 2 bits
|
|
981
|
+
nk_i32_t shift = 23 - 2;
|
|
982
|
+
nk_u32_t remainder_mask = (1u << shift) - 1;
|
|
983
|
+
nk_u32_t remainder = significand & remainder_mask;
|
|
984
|
+
nk_u32_t halfway = 1u << (shift - 1);
|
|
985
|
+
nk_u32_t significand_rounded = significand >> shift;
|
|
986
|
+
|
|
987
|
+
// RNE rounding
|
|
988
|
+
if (remainder > halfway || (remainder == halfway && (significand_rounded & 1))) { ++significand_rounded; }
|
|
989
|
+
|
|
990
|
+
// Handle carry into exponent
|
|
991
|
+
if (significand_rounded == (1u << 3)) {
|
|
992
|
+
significand_rounded >>= 1;
|
|
993
|
+
++exp;
|
|
994
|
+
}
|
|
995
|
+
|
|
996
|
+
// Rebias exponent: e3m2_exp = f32_exp + 3
|
|
997
|
+
nk_i32_t e3m2_exp = exp + 3;
|
|
998
|
+
|
|
999
|
+
// Clamp to valid range
|
|
1000
|
+
if (e3m2_exp > 7) {
|
|
1001
|
+
*dest = (nk_e3m2_t)(sign | 0x1Fu); // Max value
|
|
1002
|
+
return;
|
|
1003
|
+
}
|
|
1004
|
+
if (e3m2_exp < 0) {
|
|
1005
|
+
*dest = (nk_e3m2_t)sign; // Underflow to zero
|
|
1006
|
+
return;
|
|
1007
|
+
}
|
|
1008
|
+
|
|
1009
|
+
nk_u8_t exp_field = (nk_u8_t)e3m2_exp;
|
|
1010
|
+
nk_u8_t mant_field = (nk_u8_t)(significand_rounded & 0x03u);
|
|
1011
|
+
*dest = (nk_e3m2_t)(sign | (exp_field << 2) | mant_field);
|
|
1012
|
+
}
|
|
1013
|
+
|
|
1014
|
+
NK_INTERNAL void nk_f16_to_f64_serial(nk_f16_t const *x, nk_f64_t *y) {
|
|
1015
|
+
nk_f32_t f32;
|
|
1016
|
+
nk_f16_to_f32_serial(x, &f32);
|
|
1017
|
+
*y = (nk_f64_t)f32;
|
|
1018
|
+
}
|
|
1019
|
+
NK_INTERNAL void nk_f64_to_f16_serial(nk_f64_t const *x, nk_f16_t *y) {
|
|
1020
|
+
nk_f32_t f32 = (nk_f32_t)*x;
|
|
1021
|
+
nk_f32_to_f16_serial(&f32, y);
|
|
1022
|
+
}
|
|
1023
|
+
NK_INTERNAL void nk_bf16_to_f64_serial(nk_bf16_t const *x, nk_f64_t *y) {
|
|
1024
|
+
nk_f32_t f32;
|
|
1025
|
+
nk_bf16_to_f32_serial(x, &f32);
|
|
1026
|
+
*y = (nk_f64_t)f32;
|
|
1027
|
+
}
|
|
1028
|
+
NK_INTERNAL void nk_f64_to_bf16_serial(nk_f64_t const *x, nk_bf16_t *y) {
|
|
1029
|
+
nk_f32_t f32 = (nk_f32_t)*x;
|
|
1030
|
+
nk_f32_to_bf16_serial(&f32, y);
|
|
1031
|
+
}
|
|
1032
|
+
|
|
1033
|
+
/* Convert floating-point numbers to integers with the project-wide narrowing policy:
|
|
1034
|
+
* finite values are clamped and rounded to nearest, ties to even, infinities saturate,
|
|
1035
|
+
* and NaNs map to zero.
|
|
1036
|
+
*/
|
|
1037
|
+
NK_INTERNAL nk_i64_t nk_rint_even_f64_to_i64_serial_(nk_f64_t x) {
|
|
1038
|
+
nk_i64_t integer = (nk_i64_t)x;
|
|
1039
|
+
nk_f64_t fraction = x - (nk_f64_t)integer;
|
|
1040
|
+
if (fraction > 0.5 || (fraction == 0.5 && (integer & 1))) ++integer;
|
|
1041
|
+
else if (fraction < -0.5 || (fraction == -0.5 && (integer & 1))) --integer;
|
|
1042
|
+
return integer;
|
|
1043
|
+
}
|
|
1044
|
+
|
|
1045
|
+
NK_INTERNAL nk_u64_t nk_rint_even_f64_to_u64_serial_(nk_f64_t x) {
|
|
1046
|
+
nk_u64_t integer = (nk_u64_t)x;
|
|
1047
|
+
nk_f64_t fraction = x - (nk_f64_t)integer;
|
|
1048
|
+
if (fraction > 0.5 || (fraction == 0.5 && (integer & 1))) ++integer;
|
|
1049
|
+
return integer;
|
|
1050
|
+
}
|
|
1051
|
+
|
|
1052
|
+
NK_INTERNAL void nk_f32_to_i8_serial(nk_f32_t const *x, nk_i8_t *y) {
|
|
1053
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1054
|
+
else *y = (nk_i8_t)nk_rint_even_f64_to_i64_serial_(*x > 127.0f ? 127.0 : (*x < -128.0f ? -128.0 : (nk_f64_t)*x));
|
|
1055
|
+
}
|
|
1056
|
+
|
|
1057
|
+
NK_INTERNAL void nk_f32_to_u8_serial(nk_f32_t const *x, nk_u8_t *y) {
|
|
1058
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1059
|
+
else *y = (nk_u8_t)nk_rint_even_f64_to_u64_serial_(*x > 255.0f ? 255.0 : (*x < 0 ? 0.0 : (nk_f64_t)*x));
|
|
1060
|
+
}
|
|
1061
|
+
|
|
1062
|
+
NK_INTERNAL void nk_f32_to_i16_serial(nk_f32_t const *x, nk_i16_t *y) {
|
|
1063
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1064
|
+
else
|
|
1065
|
+
*y = (nk_i16_t)nk_rint_even_f64_to_i64_serial_(*x > 32767.0f ? 32767.0
|
|
1066
|
+
: (*x < -32768.0f ? -32768.0 : (nk_f64_t)*x));
|
|
1067
|
+
}
|
|
1068
|
+
|
|
1069
|
+
NK_INTERNAL void nk_f32_to_u16_serial(nk_f32_t const *x, nk_u16_t *y) {
|
|
1070
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1071
|
+
else *y = (nk_u16_t)nk_rint_even_f64_to_u64_serial_(*x > 65535.0f ? 65535.0 : (*x < 0 ? 0.0 : (nk_f64_t)*x));
|
|
1072
|
+
}
|
|
1073
|
+
|
|
1074
|
+
NK_INTERNAL void nk_f64_to_i8_serial(nk_f64_t const *x, nk_i8_t *y) {
|
|
1075
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1076
|
+
else *y = (nk_i8_t)nk_rint_even_f64_to_i64_serial_(*x > 127.0 ? 127.0 : (*x < -128.0 ? -128.0 : *x));
|
|
1077
|
+
}
|
|
1078
|
+
|
|
1079
|
+
NK_INTERNAL void nk_f64_to_u8_serial(nk_f64_t const *x, nk_u8_t *y) {
|
|
1080
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1081
|
+
else *y = (nk_u8_t)nk_rint_even_f64_to_u64_serial_(*x > 255.0 ? 255.0 : (*x < 0 ? 0.0 : *x));
|
|
1082
|
+
}
|
|
1083
|
+
|
|
1084
|
+
NK_INTERNAL void nk_f64_to_i16_serial(nk_f64_t const *x, nk_i16_t *y) {
|
|
1085
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1086
|
+
else *y = (nk_i16_t)nk_rint_even_f64_to_i64_serial_(*x > 32767.0 ? 32767.0 : (*x < -32768.0 ? -32768.0 : *x));
|
|
1087
|
+
}
|
|
1088
|
+
|
|
1089
|
+
NK_INTERNAL void nk_f64_to_u16_serial(nk_f64_t const *x, nk_u16_t *y) {
|
|
1090
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1091
|
+
else *y = (nk_u16_t)nk_rint_even_f64_to_u64_serial_(*x > 65535.0 ? 65535.0 : (*x < 0 ? 0.0 : *x));
|
|
1092
|
+
}
|
|
1093
|
+
|
|
1094
|
+
NK_INTERNAL void nk_f64_to_i32_serial(nk_f64_t const *x, nk_i32_t *y) {
|
|
1095
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1096
|
+
else
|
|
1097
|
+
*y = (nk_i32_t)nk_rint_even_f64_to_i64_serial_(*x > 2147483647.0 ? 2147483647.0
|
|
1098
|
+
: (*x < -2147483648.0 ? -2147483648.0 : *x));
|
|
1099
|
+
}
|
|
1100
|
+
|
|
1101
|
+
NK_INTERNAL void nk_f64_to_u32_serial(nk_f64_t const *x, nk_u32_t *y) {
|
|
1102
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1103
|
+
else *y = (nk_u32_t)nk_rint_even_f64_to_u64_serial_(*x > 4294967295.0 ? 4294967295.0 : (*x < 0 ? 0.0 : *x));
|
|
1104
|
+
}
|
|
1105
|
+
|
|
1106
|
+
NK_INTERNAL void nk_f64_to_i64_serial(nk_f64_t const *x, nk_i64_t *y) {
|
|
1107
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1108
|
+
else
|
|
1109
|
+
*y = nk_rint_even_f64_to_i64_serial_(*x > 9223372036854775807.0
|
|
1110
|
+
? 9223372036854775807.0
|
|
1111
|
+
: (*x < -9223372036854775808.0 ? -9223372036854775808.0 : *x));
|
|
1112
|
+
}
|
|
1113
|
+
|
|
1114
|
+
NK_INTERNAL void nk_f64_to_u64_serial(nk_f64_t const *x, nk_u64_t *y) {
|
|
1115
|
+
if (*x != *x) *y = 0; // For IEEE floating-point, NaN is the one value that is not equal to itself
|
|
1116
|
+
else
|
|
1117
|
+
*y = nk_rint_even_f64_to_u64_serial_(*x > 18446744073709551615.0 ? 18446744073709551615.0
|
|
1118
|
+
: (*x < 0 ? 0.0 : *x));
|
|
1119
|
+
}
|
|
1120
|
+
|
|
1121
|
+
NK_INTERNAL void nk_i64_to_i8_serial(nk_i64_t const *x, nk_i8_t *y) {
|
|
1122
|
+
*y = (nk_i8_t)(*x > 127ll ? 127ll : (*x < -128ll ? -128ll : *x));
|
|
1123
|
+
}
|
|
1124
|
+
|
|
1125
|
+
NK_INTERNAL void nk_i64_to_u8_serial(nk_i64_t const *x, nk_u8_t *y) {
|
|
1126
|
+
*y = (nk_u8_t)(*x > 255ll ? 255ll : (*x < 0ll ? 0ll : *x));
|
|
1127
|
+
}
|
|
1128
|
+
|
|
1129
|
+
NK_INTERNAL void nk_i64_to_i16_serial(nk_i64_t const *x, nk_i16_t *y) {
|
|
1130
|
+
*y = (nk_i16_t)(*x > 32767ll ? 32767ll : (*x < -32768ll ? -32768ll : *x));
|
|
1131
|
+
}
|
|
1132
|
+
|
|
1133
|
+
NK_INTERNAL void nk_i64_to_u16_serial(nk_i64_t const *x, nk_u16_t *y) {
|
|
1134
|
+
*y = (nk_u16_t)(*x > 65535ll ? 65535ll : (*x < 0ll ? 0ll : *x));
|
|
1135
|
+
}
|
|
1136
|
+
|
|
1137
|
+
NK_INTERNAL void nk_i64_to_i32_serial(nk_i64_t const *x, nk_i32_t *y) {
|
|
1138
|
+
*y = (nk_i32_t)(*x > 2147483647ll ? 2147483647ll : (*x < -2147483648ll ? -2147483648ll : *x));
|
|
1139
|
+
}
|
|
1140
|
+
|
|
1141
|
+
NK_INTERNAL void nk_i64_to_u32_serial(nk_i64_t const *x, nk_u32_t *y) {
|
|
1142
|
+
*y = (nk_u32_t)(*x > 4294967295ll ? 4294967295ll : (*x < 0ll ? 0ll : *x));
|
|
1143
|
+
}
|
|
1144
|
+
|
|
1145
|
+
NK_INTERNAL void nk_u64_to_i8_serial(nk_u64_t const *x, nk_i8_t *y) { *y = (nk_i8_t)(*x > 127ull ? 127ull : *x); }
|
|
1146
|
+
NK_INTERNAL void nk_u64_to_u8_serial(nk_u64_t const *x, nk_u8_t *y) { *y = (nk_u8_t)(*x > 255ull ? 255ull : *x); }
|
|
1147
|
+
NK_INTERNAL void nk_u64_to_i16_serial(nk_u64_t const *x, nk_i16_t *y) {
|
|
1148
|
+
*y = (nk_i16_t)(*x > 32767ull ? 32767ull : *x);
|
|
1149
|
+
}
|
|
1150
|
+
NK_INTERNAL void nk_u64_to_u16_serial(nk_u64_t const *x, nk_u16_t *y) {
|
|
1151
|
+
*y = (nk_u16_t)(*x > 65535ull ? 65535ull : *x);
|
|
1152
|
+
}
|
|
1153
|
+
|
|
1154
|
+
NK_INTERNAL void nk_u64_to_i32_serial(nk_u64_t const *x, nk_i32_t *y) {
|
|
1155
|
+
*y = (nk_i32_t)(*x > 2147483647ull ? 2147483647ull : *x);
|
|
1156
|
+
}
|
|
1157
|
+
|
|
1158
|
+
NK_INTERNAL void nk_u64_to_u32_serial(nk_u64_t const *x, nk_u32_t *y) {
|
|
1159
|
+
*y = (nk_u32_t)(*x > 4294967295ull ? 4294967295ull : *x);
|
|
1160
|
+
}
|
|
1161
|
+
|
|
1162
|
+
NK_PUBLIC void nk_f16_to_f64_(nk_f16_t const *src, nk_f64_t *dest) {
|
|
1163
|
+
nk_f32_t f32;
|
|
1164
|
+
nk_f16_to_f32_serial(src, &f32);
|
|
1165
|
+
*dest = f32;
|
|
1166
|
+
}
|
|
1167
|
+
NK_PUBLIC void nk_bf16_to_f64_(nk_bf16_t const *src, nk_f64_t *dest) {
|
|
1168
|
+
nk_f32_t f32;
|
|
1169
|
+
nk_bf16_to_f32_serial(src, &f32);
|
|
1170
|
+
*dest = f32;
|
|
1171
|
+
}
|
|
1172
|
+
|
|
1173
|
+
NK_INTERNAL void nk_u64_to_i64_serial(nk_u64_t const *x, nk_i64_t *y) {
|
|
1174
|
+
*y = (nk_i64_t)(*x >= 9223372036854775807ull ? 9223372036854775807ll : *x);
|
|
1175
|
+
}
|
|
1176
|
+
|
|
1177
|
+
NK_INTERNAL void nk_i8_to_u64_serial(nk_i8_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
|
|
1178
|
+
NK_INTERNAL void nk_i16_to_u64_serial(nk_i16_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
|
|
1179
|
+
NK_INTERNAL void nk_i32_to_u64_serial(nk_i32_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
|
|
1180
|
+
NK_INTERNAL void nk_i64_to_u64_serial(nk_i64_t const *x, nk_u64_t *y) { *y = (nk_u64_t)(*x < 0 ? 0 : *x); }
|
|
1181
|
+
|
|
1182
|
+
NK_INTERNAL void nk_i64_to_f16_serial(nk_i64_t const *x, nk_f16_t *y) {
|
|
1183
|
+
nk_f32_t f32 = (nk_f32_t)*x;
|
|
1184
|
+
nk_f32_to_f16_serial(&f32, y);
|
|
1185
|
+
}
|
|
1186
|
+
NK_INTERNAL void nk_i64_to_bf16_serial(nk_i64_t const *x, nk_bf16_t *y) {
|
|
1187
|
+
nk_f32_t f32 = (nk_f32_t)*x;
|
|
1188
|
+
nk_f32_to_bf16_serial(&f32, y);
|
|
1189
|
+
}
|
|
1190
|
+
NK_INTERNAL void nk_u64_to_f16_serial(nk_u64_t const *x, nk_f16_t *y) {
|
|
1191
|
+
nk_f32_t f32 = (nk_f32_t)*x;
|
|
1192
|
+
nk_f32_to_f16_serial(&f32, y);
|
|
1193
|
+
}
|
|
1194
|
+
NK_INTERNAL void nk_u64_to_bf16_serial(nk_u64_t const *x, nk_bf16_t *y) {
|
|
1195
|
+
nk_f32_t f32 = (nk_f32_t)*x;
|
|
1196
|
+
nk_f32_to_bf16_serial(&f32, y);
|
|
1197
|
+
}
|
|
1198
|
+
|
|
1199
|
+
#pragma region - Type Punned Loads and Stores
|
|
1200
|
+
|
|
1201
|
+
/** @brief Type-agnostic 256-bit full load. */
|
|
1202
|
+
NK_INTERNAL void nk_load_b256_serial_(void const *src, nk_b256_vec_t *dst) {
|
|
1203
|
+
nk_u64_t const *s = (nk_u64_t const *)src;
|
|
1204
|
+
dst->u64s[0] = s[0], dst->u64s[1] = s[1], dst->u64s[2] = s[2], dst->u64s[3] = s[3];
|
|
1205
|
+
}
|
|
1206
|
+
|
|
1207
|
+
/** @brief Type-agnostic 128-bit full load. */
|
|
1208
|
+
NK_INTERNAL void nk_load_b128_serial_(void const *src, nk_b128_vec_t *dst) {
|
|
1209
|
+
nk_u64_t const *s = (nk_u64_t const *)src;
|
|
1210
|
+
dst->u64s[0] = s[0], dst->u64s[1] = s[1];
|
|
1211
|
+
}
|
|
1212
|
+
|
|
1213
|
+
/** @brief Type-agnostic 64-bit full load. */
|
|
1214
|
+
NK_INTERNAL void nk_load_b64_serial_(void const *src, nk_b64_vec_t *dst) { dst->u64 = *(nk_u64_t const *)src; }
|
|
1215
|
+
|
|
1216
|
+
/** @brief Type-agnostic partial load for 32-bit elements (8 elements max) into 256-bit vector. */
|
|
1217
|
+
NK_INTERNAL void nk_partial_load_b32x8_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
1218
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
|
|
1219
|
+
nk_u32_t const *s = (nk_u32_t const *)src;
|
|
1220
|
+
switch (n) {
|
|
1221
|
+
default:
|
|
1222
|
+
case 8: dst->u32s[7] = s[7]; // fallthrough
|
|
1223
|
+
case 7: dst->u32s[6] = s[6]; // fallthrough
|
|
1224
|
+
case 6: dst->u32s[5] = s[5]; // fallthrough
|
|
1225
|
+
case 5: dst->u32s[4] = s[4]; // fallthrough
|
|
1226
|
+
case 4: dst->u32s[3] = s[3]; // fallthrough
|
|
1227
|
+
case 3: dst->u32s[2] = s[2]; // fallthrough
|
|
1228
|
+
case 2: dst->u32s[1] = s[1]; // fallthrough
|
|
1229
|
+
case 1: dst->u32s[0] = s[0]; // fallthrough
|
|
1230
|
+
case 0: break;
|
|
1231
|
+
}
|
|
1232
|
+
}
|
|
1233
|
+
|
|
1234
|
+
/** @brief Type-agnostic partial load for 32-bit elements (4 elements max) into 128-bit vector. */
|
|
1235
|
+
NK_INTERNAL void nk_partial_load_b32x4_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
1236
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0;
|
|
1237
|
+
nk_u32_t const *s = (nk_u32_t const *)src;
|
|
1238
|
+
switch (n) {
|
|
1239
|
+
default:
|
|
1240
|
+
case 4: dst->u32s[3] = s[3]; // fallthrough
|
|
1241
|
+
case 3: dst->u32s[2] = s[2]; // fallthrough
|
|
1242
|
+
case 2: dst->u32s[1] = s[1]; // fallthrough
|
|
1243
|
+
case 1: dst->u32s[0] = s[0]; // fallthrough
|
|
1244
|
+
case 0: break;
|
|
1245
|
+
}
|
|
1246
|
+
}
|
|
1247
|
+
|
|
1248
|
+
/** @brief Type-agnostic partial load for 8-bit elements (8 elements max) into 64-bit vector. */
|
|
1249
|
+
NK_INTERNAL void nk_partial_load_b8x8_serial_(void const *src, nk_b64_vec_t *dst, nk_size_t n) {
|
|
1250
|
+
dst->u64 = 0;
|
|
1251
|
+
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
1252
|
+
switch (n) {
|
|
1253
|
+
default:
|
|
1254
|
+
case 8: dst->u8s[7] = s[7]; // fallthrough
|
|
1255
|
+
case 7: dst->u8s[6] = s[6]; // fallthrough
|
|
1256
|
+
case 6: dst->u8s[5] = s[5]; // fallthrough
|
|
1257
|
+
case 5: dst->u8s[4] = s[4]; // fallthrough
|
|
1258
|
+
case 4: dst->u8s[3] = s[3]; // fallthrough
|
|
1259
|
+
case 3: dst->u8s[2] = s[2]; // fallthrough
|
|
1260
|
+
case 2: dst->u8s[1] = s[1]; // fallthrough
|
|
1261
|
+
case 1: dst->u8s[0] = s[0]; // fallthrough
|
|
1262
|
+
case 0: break;
|
|
1263
|
+
}
|
|
1264
|
+
}
|
|
1265
|
+
|
|
1266
|
+
/** @brief Type-agnostic partial load for 8-bit elements (4 elements max) into 32-bit vector. */
|
|
1267
|
+
NK_INTERNAL nk_b32_vec_t nk_partial_load_b8x4_serial_(void const *src, nk_size_t n) {
|
|
1268
|
+
nk_b32_vec_t dst = {0};
|
|
1269
|
+
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
1270
|
+
switch (n) {
|
|
1271
|
+
default:
|
|
1272
|
+
case 4: dst.u8s[3] = s[3]; // fallthrough
|
|
1273
|
+
case 3: dst.u8s[2] = s[2]; // fallthrough
|
|
1274
|
+
case 2: dst.u8s[1] = s[1]; // fallthrough
|
|
1275
|
+
case 1: dst.u8s[0] = s[0]; // fallthrough
|
|
1276
|
+
case 0: break;
|
|
1277
|
+
}
|
|
1278
|
+
return dst;
|
|
1279
|
+
}
|
|
1280
|
+
|
|
1281
|
+
/** @brief Partial store for 8-bit elements (up to 4) from nk_b32_vec_t. */
|
|
1282
|
+
NK_INTERNAL void nk_partial_store_b8x4_serial_(nk_b32_vec_t const *src, void *dst, nk_size_t n) {
|
|
1283
|
+
nk_u8_t *d = (nk_u8_t *)dst;
|
|
1284
|
+
switch (n) {
|
|
1285
|
+
default:
|
|
1286
|
+
case 4: d[3] = src->u8s[3]; // fallthrough
|
|
1287
|
+
case 3: d[2] = src->u8s[2]; // fallthrough
|
|
1288
|
+
case 2: d[1] = src->u8s[1]; // fallthrough
|
|
1289
|
+
case 1: d[0] = src->u8s[0]; // fallthrough
|
|
1290
|
+
case 0: break;
|
|
1291
|
+
}
|
|
1292
|
+
}
|
|
1293
|
+
|
|
1294
|
+
/** @brief Type-agnostic partial load for 16-bit elements (8 elements max) into 128-bit vector. */
|
|
1295
|
+
NK_INTERNAL void nk_partial_load_b16x8_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
1296
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0;
|
|
1297
|
+
nk_u16_t const *s = (nk_u16_t const *)src;
|
|
1298
|
+
switch (n) {
|
|
1299
|
+
default:
|
|
1300
|
+
case 8: dst->u16s[7] = s[7]; // fallthrough
|
|
1301
|
+
case 7: dst->u16s[6] = s[6]; // fallthrough
|
|
1302
|
+
case 6: dst->u16s[5] = s[5]; // fallthrough
|
|
1303
|
+
case 5: dst->u16s[4] = s[4]; // fallthrough
|
|
1304
|
+
case 4: dst->u16s[3] = s[3]; // fallthrough
|
|
1305
|
+
case 3: dst->u16s[2] = s[2]; // fallthrough
|
|
1306
|
+
case 2: dst->u16s[1] = s[1]; // fallthrough
|
|
1307
|
+
case 1: dst->u16s[0] = s[0]; // fallthrough
|
|
1308
|
+
case 0: break;
|
|
1309
|
+
}
|
|
1310
|
+
}
|
|
1311
|
+
|
|
1312
|
+
/** @brief Type-agnostic partial load for 8-bit elements (16 elements max) into 128-bit vector. */
|
|
1313
|
+
NK_INTERNAL void nk_partial_load_b8x16_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
1314
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0;
|
|
1315
|
+
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
1316
|
+
switch (n) {
|
|
1317
|
+
default:
|
|
1318
|
+
case 16: dst->u8s[15] = s[15]; // fallthrough
|
|
1319
|
+
case 15: dst->u8s[14] = s[14]; // fallthrough
|
|
1320
|
+
case 14: dst->u8s[13] = s[13]; // fallthrough
|
|
1321
|
+
case 13: dst->u8s[12] = s[12]; // fallthrough
|
|
1322
|
+
case 12: dst->u8s[11] = s[11]; // fallthrough
|
|
1323
|
+
case 11: dst->u8s[10] = s[10]; // fallthrough
|
|
1324
|
+
case 10: dst->u8s[9] = s[9]; // fallthrough
|
|
1325
|
+
case 9: dst->u8s[8] = s[8]; // fallthrough
|
|
1326
|
+
case 8: dst->u8s[7] = s[7]; // fallthrough
|
|
1327
|
+
case 7: dst->u8s[6] = s[6]; // fallthrough
|
|
1328
|
+
case 6: dst->u8s[5] = s[5]; // fallthrough
|
|
1329
|
+
case 5: dst->u8s[4] = s[4]; // fallthrough
|
|
1330
|
+
case 4: dst->u8s[3] = s[3]; // fallthrough
|
|
1331
|
+
case 3: dst->u8s[2] = s[2]; // fallthrough
|
|
1332
|
+
case 2: dst->u8s[1] = s[1]; // fallthrough
|
|
1333
|
+
case 1: dst->u8s[0] = s[0]; // fallthrough
|
|
1334
|
+
case 0: break;
|
|
1335
|
+
}
|
|
1336
|
+
}
|
|
1337
|
+
|
|
1338
|
+
/** @brief Type-agnostic partial load for 16-bit elements (16 elements max) into 256-bit vector. */
|
|
1339
|
+
NK_INTERNAL void nk_partial_load_b16x16_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
1340
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
|
|
1341
|
+
nk_u16_t const *s = (nk_u16_t const *)src;
|
|
1342
|
+
switch (n) {
|
|
1343
|
+
default:
|
|
1344
|
+
case 16: dst->u16s[15] = s[15]; // fallthrough
|
|
1345
|
+
case 15: dst->u16s[14] = s[14]; // fallthrough
|
|
1346
|
+
case 14: dst->u16s[13] = s[13]; // fallthrough
|
|
1347
|
+
case 13: dst->u16s[12] = s[12]; // fallthrough
|
|
1348
|
+
case 12: dst->u16s[11] = s[11]; // fallthrough
|
|
1349
|
+
case 11: dst->u16s[10] = s[10]; // fallthrough
|
|
1350
|
+
case 10: dst->u16s[9] = s[9]; // fallthrough
|
|
1351
|
+
case 9: dst->u16s[8] = s[8]; // fallthrough
|
|
1352
|
+
case 8: dst->u16s[7] = s[7]; // fallthrough
|
|
1353
|
+
case 7: dst->u16s[6] = s[6]; // fallthrough
|
|
1354
|
+
case 6: dst->u16s[5] = s[5]; // fallthrough
|
|
1355
|
+
case 5: dst->u16s[4] = s[4]; // fallthrough
|
|
1356
|
+
case 4: dst->u16s[3] = s[3]; // fallthrough
|
|
1357
|
+
case 3: dst->u16s[2] = s[2]; // fallthrough
|
|
1358
|
+
case 2: dst->u16s[1] = s[1]; // fallthrough
|
|
1359
|
+
case 1: dst->u16s[0] = s[0]; // fallthrough
|
|
1360
|
+
case 0: break;
|
|
1361
|
+
}
|
|
1362
|
+
}
|
|
1363
|
+
|
|
1364
|
+
/** @brief Partial load for 8-bit elements (32 max) into 256-bit vector (zeros in remaining slots). */
|
|
1365
|
+
NK_INTERNAL void nk_partial_load_b8x32_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
1366
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
|
|
1367
|
+
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
1368
|
+
switch (n) {
|
|
1369
|
+
default:
|
|
1370
|
+
case 32: dst->u8s[31] = s[31]; // fallthrough
|
|
1371
|
+
case 31: dst->u8s[30] = s[30]; // fallthrough
|
|
1372
|
+
case 30: dst->u8s[29] = s[29]; // fallthrough
|
|
1373
|
+
case 29: dst->u8s[28] = s[28]; // fallthrough
|
|
1374
|
+
case 28: dst->u8s[27] = s[27]; // fallthrough
|
|
1375
|
+
case 27: dst->u8s[26] = s[26]; // fallthrough
|
|
1376
|
+
case 26: dst->u8s[25] = s[25]; // fallthrough
|
|
1377
|
+
case 25: dst->u8s[24] = s[24]; // fallthrough
|
|
1378
|
+
case 24: dst->u8s[23] = s[23]; // fallthrough
|
|
1379
|
+
case 23: dst->u8s[22] = s[22]; // fallthrough
|
|
1380
|
+
case 22: dst->u8s[21] = s[21]; // fallthrough
|
|
1381
|
+
case 21: dst->u8s[20] = s[20]; // fallthrough
|
|
1382
|
+
case 20: dst->u8s[19] = s[19]; // fallthrough
|
|
1383
|
+
case 19: dst->u8s[18] = s[18]; // fallthrough
|
|
1384
|
+
case 18: dst->u8s[17] = s[17]; // fallthrough
|
|
1385
|
+
case 17: dst->u8s[16] = s[16]; // fallthrough
|
|
1386
|
+
case 16: dst->u8s[15] = s[15]; // fallthrough
|
|
1387
|
+
case 15: dst->u8s[14] = s[14]; // fallthrough
|
|
1388
|
+
case 14: dst->u8s[13] = s[13]; // fallthrough
|
|
1389
|
+
case 13: dst->u8s[12] = s[12]; // fallthrough
|
|
1390
|
+
case 12: dst->u8s[11] = s[11]; // fallthrough
|
|
1391
|
+
case 11: dst->u8s[10] = s[10]; // fallthrough
|
|
1392
|
+
case 10: dst->u8s[9] = s[9]; // fallthrough
|
|
1393
|
+
case 9: dst->u8s[8] = s[8]; // fallthrough
|
|
1394
|
+
case 8: dst->u8s[7] = s[7]; // fallthrough
|
|
1395
|
+
case 7: dst->u8s[6] = s[6]; // fallthrough
|
|
1396
|
+
case 6: dst->u8s[5] = s[5]; // fallthrough
|
|
1397
|
+
case 5: dst->u8s[4] = s[4]; // fallthrough
|
|
1398
|
+
case 4: dst->u8s[3] = s[3]; // fallthrough
|
|
1399
|
+
case 3: dst->u8s[2] = s[2]; // fallthrough
|
|
1400
|
+
case 2: dst->u8s[1] = s[1]; // fallthrough
|
|
1401
|
+
case 1: dst->u8s[0] = s[0]; // fallthrough
|
|
1402
|
+
case 0: break;
|
|
1403
|
+
}
|
|
1404
|
+
}
|
|
1405
|
+
|
|
1406
|
+
/** @brief Type-agnostic partial store for 32-bit elements (8 elements max) from 256-bit vector. */
|
|
1407
|
+
NK_INTERNAL void nk_partial_store_b32x8_serial_(nk_b256_vec_t const *src, void *dst, nk_size_t n) {
|
|
1408
|
+
nk_u32_t *d = (nk_u32_t *)dst;
|
|
1409
|
+
switch (n) {
|
|
1410
|
+
default:
|
|
1411
|
+
case 8: d[7] = src->u32s[7]; // fallthrough
|
|
1412
|
+
case 7: d[6] = src->u32s[6]; // fallthrough
|
|
1413
|
+
case 6: d[5] = src->u32s[5]; // fallthrough
|
|
1414
|
+
case 5: d[4] = src->u32s[4]; // fallthrough
|
|
1415
|
+
case 4: d[3] = src->u32s[3]; // fallthrough
|
|
1416
|
+
case 3: d[2] = src->u32s[2]; // fallthrough
|
|
1417
|
+
case 2: d[1] = src->u32s[1]; // fallthrough
|
|
1418
|
+
case 1: d[0] = src->u32s[0]; // fallthrough
|
|
1419
|
+
case 0: break;
|
|
1420
|
+
}
|
|
1421
|
+
}
|
|
1422
|
+
|
|
1423
|
+
/** @brief Type-agnostic partial store for 32-bit elements (4 elements max) from 128-bit vector. */
|
|
1424
|
+
NK_INTERNAL void nk_partial_store_b32x4_serial_(nk_b128_vec_t const *src, void *dst, nk_size_t n) {
|
|
1425
|
+
nk_u32_t *d = (nk_u32_t *)dst;
|
|
1426
|
+
switch (n) {
|
|
1427
|
+
default:
|
|
1428
|
+
case 4: d[3] = src->u32s[3]; // fallthrough
|
|
1429
|
+
case 3: d[2] = src->u32s[2]; // fallthrough
|
|
1430
|
+
case 2: d[1] = src->u32s[1]; // fallthrough
|
|
1431
|
+
case 1: d[0] = src->u32s[0]; // fallthrough
|
|
1432
|
+
case 0: break;
|
|
1433
|
+
}
|
|
1434
|
+
}
|
|
1435
|
+
|
|
1436
|
+
/** @brief Type-agnostic partial store for 16-bit elements (8 elements max) from 128-bit vector. */
|
|
1437
|
+
NK_INTERNAL void nk_partial_store_b16x8_serial_(nk_b128_vec_t const *src, void *dst, nk_size_t n) {
|
|
1438
|
+
nk_u16_t *d = (nk_u16_t *)dst;
|
|
1439
|
+
switch (n) {
|
|
1440
|
+
default:
|
|
1441
|
+
case 8: d[7] = src->u16s[7]; // fallthrough
|
|
1442
|
+
case 7: d[6] = src->u16s[6]; // fallthrough
|
|
1443
|
+
case 6: d[5] = src->u16s[5]; // fallthrough
|
|
1444
|
+
case 5: d[4] = src->u16s[4]; // fallthrough
|
|
1445
|
+
case 4: d[3] = src->u16s[3]; // fallthrough
|
|
1446
|
+
case 3: d[2] = src->u16s[2]; // fallthrough
|
|
1447
|
+
case 2: d[1] = src->u16s[1]; // fallthrough
|
|
1448
|
+
case 1: d[0] = src->u16s[0]; // fallthrough
|
|
1449
|
+
case 0: break;
|
|
1450
|
+
}
|
|
1451
|
+
}
|
|
1452
|
+
|
|
1453
|
+
/** @brief Type-agnostic partial store for 16-bit elements (4 elements max) from 64-bit vector. */
|
|
1454
|
+
NK_INTERNAL void nk_partial_store_b16x4_serial_(void *dst, nk_b64_vec_t const *src, nk_size_t n) {
|
|
1455
|
+
nk_u16_t *d = (nk_u16_t *)dst;
|
|
1456
|
+
switch (n) {
|
|
1457
|
+
default:
|
|
1458
|
+
case 4: d[3] = src->u16s[3]; // fallthrough
|
|
1459
|
+
case 3: d[2] = src->u16s[2]; // fallthrough
|
|
1460
|
+
case 2: d[1] = src->u16s[1]; // fallthrough
|
|
1461
|
+
case 1: d[0] = src->u16s[0]; // fallthrough
|
|
1462
|
+
case 0: break;
|
|
1463
|
+
}
|
|
1464
|
+
}
|
|
1465
|
+
|
|
1466
|
+
/** @brief Type-agnostic partial store for 8-bit elements (8 elements max) from 64-bit vector. */
|
|
1467
|
+
NK_INTERNAL void nk_partial_store_b8x8_serial_(nk_b64_vec_t const *src, void *dst, nk_size_t n) {
|
|
1468
|
+
nk_u8_t *d = (nk_u8_t *)dst;
|
|
1469
|
+
switch (n) {
|
|
1470
|
+
default:
|
|
1471
|
+
case 8: d[7] = src->u8s[7]; // fallthrough
|
|
1472
|
+
case 7: d[6] = src->u8s[6]; // fallthrough
|
|
1473
|
+
case 6: d[5] = src->u8s[5]; // fallthrough
|
|
1474
|
+
case 5: d[4] = src->u8s[4]; // fallthrough
|
|
1475
|
+
case 4: d[3] = src->u8s[3]; // fallthrough
|
|
1476
|
+
case 3: d[2] = src->u8s[2]; // fallthrough
|
|
1477
|
+
case 2: d[1] = src->u8s[1]; // fallthrough
|
|
1478
|
+
case 1: d[0] = src->u8s[0]; // fallthrough
|
|
1479
|
+
case 0: break;
|
|
1480
|
+
}
|
|
1481
|
+
}
|
|
1482
|
+
|
|
1483
|
+
/** @brief Type-agnostic partial load for 64-bit elements (4 elements max) into 256-bit vector. */
|
|
1484
|
+
NK_INTERNAL void nk_partial_load_b64x4_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
1485
|
+
nk_u64_t const *s = (nk_u64_t const *)src;
|
|
1486
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
|
|
1487
|
+
switch (n) {
|
|
1488
|
+
default:
|
|
1489
|
+
case 4: dst->u64s[3] = s[3]; // fallthrough
|
|
1490
|
+
case 3: dst->u64s[2] = s[2]; // fallthrough
|
|
1491
|
+
case 2: dst->u64s[1] = s[1]; // fallthrough
|
|
1492
|
+
case 1: dst->u64s[0] = s[0]; // fallthrough
|
|
1493
|
+
case 0: break;
|
|
1494
|
+
}
|
|
1495
|
+
}
|
|
1496
|
+
|
|
1497
|
+
/** @brief Type-agnostic partial store for 64-bit elements (4 elements max) from 256-bit vector. */
|
|
1498
|
+
NK_INTERNAL void nk_partial_store_b64x4_serial_(nk_b256_vec_t const *src, void *dst, nk_size_t n) {
|
|
1499
|
+
nk_u64_t *d = (nk_u64_t *)dst;
|
|
1500
|
+
switch (n) {
|
|
1501
|
+
default:
|
|
1502
|
+
case 4: d[3] = src->u64s[3]; // fallthrough
|
|
1503
|
+
case 3: d[2] = src->u64s[2]; // fallthrough
|
|
1504
|
+
case 2: d[1] = src->u64s[1]; // fallthrough
|
|
1505
|
+
case 1: d[0] = src->u64s[0]; // fallthrough
|
|
1506
|
+
case 0: break;
|
|
1507
|
+
}
|
|
1508
|
+
}
|
|
1509
|
+
|
|
1510
|
+
/** @brief Type-agnostic partial load for 32-bit elements (2 elements max) into 64-bit vector. */
|
|
1511
|
+
NK_INTERNAL void nk_partial_load_b32x2_serial_(void const *src, nk_b64_vec_t *dst, nk_size_t n) {
|
|
1512
|
+
dst->u64 = 0;
|
|
1513
|
+
nk_u32_t const *s = (nk_u32_t const *)src;
|
|
1514
|
+
switch (n) {
|
|
1515
|
+
default:
|
|
1516
|
+
case 2: dst->u32s[1] = s[1]; // fallthrough
|
|
1517
|
+
case 1: dst->u32s[0] = s[0]; // fallthrough
|
|
1518
|
+
case 0: break;
|
|
1519
|
+
}
|
|
1520
|
+
}
|
|
1521
|
+
|
|
1522
|
+
/** @brief Type-agnostic partial load for 16-bit elements (4 elements max) into 64-bit vector. */
|
|
1523
|
+
NK_INTERNAL void nk_partial_load_b16x4_serial_(void const *src, nk_b64_vec_t *dst, nk_size_t n) {
|
|
1524
|
+
dst->u64 = 0;
|
|
1525
|
+
nk_u16_t const *s = (nk_u16_t const *)src;
|
|
1526
|
+
switch (n) {
|
|
1527
|
+
default:
|
|
1528
|
+
case 4: dst->u16s[3] = s[3]; // fallthrough
|
|
1529
|
+
case 3: dst->u16s[2] = s[2]; // fallthrough
|
|
1530
|
+
case 2: dst->u16s[1] = s[1]; // fallthrough
|
|
1531
|
+
case 1: dst->u16s[0] = s[0]; // fallthrough
|
|
1532
|
+
case 0: break;
|
|
1533
|
+
}
|
|
1534
|
+
}
|
|
1535
|
+
|
|
1536
|
+
/** @brief Partial load for 4-bit nibbles (64 max = 32 bytes) into 256-bit vector (zeros in remaining slots). */
|
|
1537
|
+
NK_INTERNAL void nk_partial_load_b4x64_serial_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
1538
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0, dst->u64s[2] = 0, dst->u64s[3] = 0;
|
|
1539
|
+
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
1540
|
+
nk_size_t n_bytes = nk_size_divide_round_up_(n, 2);
|
|
1541
|
+
for (nk_size_t i = 0; i < n_bytes && i < 32; i++) dst->u8s[i] = s[i];
|
|
1542
|
+
}
|
|
1543
|
+
|
|
1544
|
+
/** @brief Partial load for 4-bit nibbles (32 max = 16 bytes) into 128-bit vector (zeros in remaining slots). */
|
|
1545
|
+
NK_INTERNAL void nk_partial_load_b4x32_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
1546
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0;
|
|
1547
|
+
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
1548
|
+
nk_size_t n_bytes = nk_size_divide_round_up_(n, 2);
|
|
1549
|
+
for (nk_size_t i = 0; i < n_bytes && i < 16; i++) dst->u8s[i] = s[i];
|
|
1550
|
+
}
|
|
1551
|
+
|
|
1552
|
+
/** @brief Partial load for 1-bit elements (128 max = 16 bytes) into 128-bit vector (zeros in remaining slots). */
|
|
1553
|
+
NK_INTERNAL void nk_partial_load_b1x128_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n_bits) {
|
|
1554
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0;
|
|
1555
|
+
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
1556
|
+
nk_size_t n_bytes = nk_size_divide_round_up_(n_bits, 8);
|
|
1557
|
+
for (nk_size_t i = 0; i < n_bytes && i < 16; i++) dst->u8s[i] = s[i];
|
|
1558
|
+
}
|
|
1559
|
+
|
|
1560
|
+
/** @brief Partial load for 4-bit nibbles (16 max = 8 bytes) into 64-bit vector (zeros in remaining slots). */
|
|
1561
|
+
NK_INTERNAL void nk_partial_load_b4x16_serial_(void const *src, nk_b64_vec_t *dst, nk_size_t n) {
|
|
1562
|
+
dst->u64 = 0;
|
|
1563
|
+
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
1564
|
+
nk_size_t n_bytes = nk_size_divide_round_up_(n, 2);
|
|
1565
|
+
for (nk_size_t i = 0; i < n_bytes && i < 8; i++) ((nk_u8_t *)&dst->u64)[i] = s[i];
|
|
1566
|
+
}
|
|
1567
|
+
|
|
1568
|
+
NK_INTERNAL void nk_partial_load_b64x2_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
1569
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0;
|
|
1570
|
+
nk_u64_t const *s = (nk_u64_t const *)src;
|
|
1571
|
+
switch (n) {
|
|
1572
|
+
default:
|
|
1573
|
+
case 2: dst->u64s[1] = s[1]; // fallthrough
|
|
1574
|
+
case 1: dst->u64s[0] = s[0]; // fallthrough
|
|
1575
|
+
case 0: break;
|
|
1576
|
+
}
|
|
1577
|
+
}
|
|
1578
|
+
|
|
1579
|
+
/** @brief Type-agnostic partial store for 64-bit elements (2 elements max) from 128-bit vector. */
|
|
1580
|
+
NK_INTERNAL void nk_partial_store_b64x2_serial_(nk_b128_vec_t const *src, void *dst, nk_size_t n) {
|
|
1581
|
+
nk_u64_t *d = (nk_u64_t *)dst;
|
|
1582
|
+
switch (n) {
|
|
1583
|
+
default:
|
|
1584
|
+
case 2: d[1] = src->u64s[1]; // fallthrough
|
|
1585
|
+
case 1: d[0] = src->u64s[0]; // fallthrough
|
|
1586
|
+
case 0: break;
|
|
1587
|
+
}
|
|
1588
|
+
}
|
|
1589
|
+
|
|
1590
|
+
/** @brief Strided partial load for 32-bit elements (4 max) into 128-bit vector. */
|
|
1591
|
+
NK_INTERNAL void nk_strided_load_b32x4_serial_(void const *src, nk_size_t stride_elements, nk_b128_vec_t *dst,
|
|
1592
|
+
nk_size_t n) {
|
|
1593
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0;
|
|
1594
|
+
nk_u32_t const *s = (nk_u32_t const *)src;
|
|
1595
|
+
for (nk_size_t i = 0; i < n && i < 4; ++i) dst->u32s[i] = s[i * stride_elements];
|
|
1596
|
+
}
|
|
1597
|
+
|
|
1598
|
+
/** @brief Strided partial load for 16-bit elements (8 max) into 128-bit vector. */
|
|
1599
|
+
NK_INTERNAL void nk_strided_load_b16x8_serial_(void const *src, nk_size_t stride_elements, nk_b128_vec_t *dst,
|
|
1600
|
+
nk_size_t n) {
|
|
1601
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0;
|
|
1602
|
+
nk_u16_t const *s = (nk_u16_t const *)src;
|
|
1603
|
+
for (nk_size_t i = 0; i < n && i < 8; ++i) dst->u16s[i] = s[i * stride_elements];
|
|
1604
|
+
}
|
|
1605
|
+
|
|
1606
|
+
/** @brief Strided partial load for 8-bit elements (16 max) into 128-bit vector. */
|
|
1607
|
+
NK_INTERNAL void nk_strided_load_b8x16_serial_(void const *src, nk_size_t stride_elements, nk_b128_vec_t *dst,
|
|
1608
|
+
nk_size_t n) {
|
|
1609
|
+
dst->u64s[0] = 0, dst->u64s[1] = 0;
|
|
1610
|
+
nk_u8_t const *s = (nk_u8_t const *)src;
|
|
1611
|
+
for (nk_size_t i = 0; i < n && i < 16; ++i) dst->u8s[i] = s[i * stride_elements];
|
|
1612
|
+
}
|
|
1613
|
+
|
|
1614
|
+
/**
|
|
1615
|
+
* @brief Union for type-punned scalar values at language binding boundaries.
|
|
1616
|
+
*
|
|
1617
|
+
* Used to bridge different type systems (Python, JavaScript, etc.) where
|
|
1618
|
+
* scalars arrive as f64 but need to be passed to kernels as typed pointers.
|
|
1619
|
+
* The caller fills the appropriate union member based on the target dtype,
|
|
1620
|
+
* then passes the union address as `void const *` to kernel functions.
|
|
1621
|
+
*/
|
|
1622
|
+
typedef union nk_scalar_buffer_t {
|
|
1623
|
+
nk_u8_t bytes[16];
|
|
1624
|
+
nk_f64_t f64;
|
|
1625
|
+
nk_f32_t f32;
|
|
1626
|
+
nk_f16_t f16;
|
|
1627
|
+
nk_bf16_t bf16;
|
|
1628
|
+
nk_f64c_t f64c;
|
|
1629
|
+
nk_f32c_t f32c;
|
|
1630
|
+
nk_f16c_t f16c;
|
|
1631
|
+
nk_bf16c_t bf16c;
|
|
1632
|
+
nk_i64_t i64;
|
|
1633
|
+
nk_u64_t u64;
|
|
1634
|
+
nk_i32_t i32;
|
|
1635
|
+
nk_u32_t u32;
|
|
1636
|
+
nk_i16_t i16;
|
|
1637
|
+
nk_u16_t u16;
|
|
1638
|
+
nk_i8_t i8;
|
|
1639
|
+
nk_u8_t u8;
|
|
1640
|
+
} nk_scalar_buffer_t;
|
|
1641
|
+
|
|
1642
|
+
/**
|
|
1643
|
+
* @brief Converts up to 8x values from `from_ptr` buffer into 8x puned buffer objects
|
|
1644
|
+
* into a complex 64-bit floating point representation.
|
|
1645
|
+
*/
|
|
1646
|
+
NK_INTERNAL void nk_scalar_buffers_fill_f64c_( //
|
|
1647
|
+
void const *from_ptr, nk_dtype_t from_dtype, nk_size_t from_count, //
|
|
1648
|
+
nk_scalar_buffer_t to_buffers[nk_at_least_(8)]) {
|
|
1649
|
+
|
|
1650
|
+
nk_f32_t temporary_f32;
|
|
1651
|
+
nk_size_t i;
|
|
1652
|
+
switch (from_dtype) {
|
|
1653
|
+
case nk_f64_k: {
|
|
1654
|
+
nk_f64_t const *p = (nk_f64_t const *)from_ptr;
|
|
1655
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
|
|
1656
|
+
} break;
|
|
1657
|
+
case nk_f32_k: {
|
|
1658
|
+
nk_f32_t const *p = (nk_f32_t const *)from_ptr;
|
|
1659
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
|
|
1660
|
+
} break;
|
|
1661
|
+
case nk_f16_k: {
|
|
1662
|
+
nk_f16_t const *p = (nk_f16_t const *)from_ptr;
|
|
1663
|
+
for (i = 0; i < from_count; ++i)
|
|
1664
|
+
nk_f16_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
|
|
1665
|
+
to_buffers[i].f64c.imag = 0;
|
|
1666
|
+
} break;
|
|
1667
|
+
case nk_bf16_k: {
|
|
1668
|
+
nk_bf16_t const *p = (nk_bf16_t const *)from_ptr;
|
|
1669
|
+
for (i = 0; i < from_count; ++i)
|
|
1670
|
+
nk_bf16_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
|
|
1671
|
+
to_buffers[i].f64c.imag = 0;
|
|
1672
|
+
} break;
|
|
1673
|
+
case nk_e4m3_k: {
|
|
1674
|
+
nk_u8_t const *p = (nk_u8_t const *)from_ptr;
|
|
1675
|
+
for (i = 0; i < from_count; ++i)
|
|
1676
|
+
nk_e4m3_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
|
|
1677
|
+
to_buffers[i].f64c.imag = 0;
|
|
1678
|
+
} break;
|
|
1679
|
+
case nk_e5m2_k: {
|
|
1680
|
+
nk_u8_t const *p = (nk_u8_t const *)from_ptr;
|
|
1681
|
+
for (i = 0; i < from_count; ++i)
|
|
1682
|
+
nk_e5m2_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
|
|
1683
|
+
to_buffers[i].f64c.imag = 0;
|
|
1684
|
+
} break;
|
|
1685
|
+
case nk_e2m3_k: {
|
|
1686
|
+
nk_u8_t const *p = (nk_u8_t const *)from_ptr;
|
|
1687
|
+
for (i = 0; i < from_count; ++i)
|
|
1688
|
+
nk_e2m3_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
|
|
1689
|
+
to_buffers[i].f64c.imag = 0;
|
|
1690
|
+
} break;
|
|
1691
|
+
case nk_e3m2_k: {
|
|
1692
|
+
nk_u8_t const *p = (nk_u8_t const *)from_ptr;
|
|
1693
|
+
for (i = 0; i < from_count; ++i)
|
|
1694
|
+
nk_e3m2_to_f32_serial(&p[i], &temporary_f32), to_buffers[i].f64c.real = temporary_f32,
|
|
1695
|
+
to_buffers[i].f64c.imag = 0;
|
|
1696
|
+
} break;
|
|
1697
|
+
case nk_i64_k: {
|
|
1698
|
+
nk_i64_t const *p = (nk_i64_t const *)from_ptr;
|
|
1699
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = (nk_f64_t)p[i], to_buffers[i].f64c.imag = 0;
|
|
1700
|
+
} break;
|
|
1701
|
+
case nk_i32_k: {
|
|
1702
|
+
nk_i32_t const *p = (nk_i32_t const *)from_ptr;
|
|
1703
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
|
|
1704
|
+
} break;
|
|
1705
|
+
case nk_i16_k: {
|
|
1706
|
+
nk_i16_t const *p = (nk_i16_t const *)from_ptr;
|
|
1707
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
|
|
1708
|
+
} break;
|
|
1709
|
+
case nk_i8_k: {
|
|
1710
|
+
nk_i8_t const *p = (nk_i8_t const *)from_ptr;
|
|
1711
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
|
|
1712
|
+
} break;
|
|
1713
|
+
case nk_u64_k: {
|
|
1714
|
+
nk_u64_t const *p = (nk_u64_t const *)from_ptr;
|
|
1715
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = (nk_f64_t)p[i], to_buffers[i].f64c.imag = 0;
|
|
1716
|
+
} break;
|
|
1717
|
+
case nk_u32_k: {
|
|
1718
|
+
nk_u32_t const *p = (nk_u32_t const *)from_ptr;
|
|
1719
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
|
|
1720
|
+
} break;
|
|
1721
|
+
case nk_u16_k: {
|
|
1722
|
+
nk_u16_t const *p = (nk_u16_t const *)from_ptr;
|
|
1723
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
|
|
1724
|
+
} break;
|
|
1725
|
+
case nk_u8_k: {
|
|
1726
|
+
nk_u8_t const *p = (nk_u8_t const *)from_ptr;
|
|
1727
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i], to_buffers[i].f64c.imag = 0;
|
|
1728
|
+
} break;
|
|
1729
|
+
case nk_f64c_k: {
|
|
1730
|
+
nk_f64c_t const *p = (nk_f64c_t const *)from_ptr;
|
|
1731
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].f64c = p[i];
|
|
1732
|
+
} break;
|
|
1733
|
+
case nk_f32c_k: {
|
|
1734
|
+
nk_f32c_t const *p = (nk_f32c_t const *)from_ptr;
|
|
1735
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].f64c.real = p[i].real, to_buffers[i].f64c.imag = p[i].imag;
|
|
1736
|
+
} break;
|
|
1737
|
+
case nk_f16c_k: {
|
|
1738
|
+
nk_f16c_t const *p = (nk_f16c_t const *)from_ptr;
|
|
1739
|
+
for (i = 0; i < from_count; ++i) {
|
|
1740
|
+
nk_f16_to_f32_serial(&p[i].real, &temporary_f32), to_buffers[i].f64c.real = temporary_f32;
|
|
1741
|
+
nk_f16_to_f32_serial(&p[i].imag, &temporary_f32), to_buffers[i].f64c.imag = temporary_f32;
|
|
1742
|
+
}
|
|
1743
|
+
} break;
|
|
1744
|
+
case nk_bf16c_k: {
|
|
1745
|
+
nk_bf16c_t const *p = (nk_bf16c_t const *)from_ptr;
|
|
1746
|
+
for (i = 0; i < from_count; ++i) {
|
|
1747
|
+
nk_bf16_to_f32_serial(&p[i].real, &temporary_f32), to_buffers[i].f64c.real = temporary_f32;
|
|
1748
|
+
nk_bf16_to_f32_serial(&p[i].imag, &temporary_f32), to_buffers[i].f64c.imag = temporary_f32;
|
|
1749
|
+
}
|
|
1750
|
+
} break;
|
|
1751
|
+
// Sub-byte: u1 - 8 bits from 1 byte, MSB-first
|
|
1752
|
+
case nk_u1_k: {
|
|
1753
|
+
nk_u8_t byte = *(nk_u8_t const *)from_ptr;
|
|
1754
|
+
for (i = 0; i < 8; ++i) to_buffers[i].f64c.real = (byte >> (7 - i)) & 1, to_buffers[i].f64c.imag = 0;
|
|
1755
|
+
} break;
|
|
1756
|
+
// Sub-byte: i4 - 8 nibbles from 4 bytes, high nibble = even index, sign-extended
|
|
1757
|
+
case nk_i4_k: {
|
|
1758
|
+
nk_u8_t const *p = (nk_u8_t const *)from_ptr;
|
|
1759
|
+
for (i = 0; i < 4; ++i) {
|
|
1760
|
+
nk_i8_t hi = (nk_i8_t)(p[i] >> 4), lo = (nk_i8_t)(p[i] & 0xF);
|
|
1761
|
+
to_buffers[i * 2].f64c.real = (hi ^ 8) - 8, to_buffers[i * 2].f64c.imag = 0;
|
|
1762
|
+
to_buffers[i * 2 + 1].f64c.real = (lo ^ 8) - 8, to_buffers[i * 2 + 1].f64c.imag = 0;
|
|
1763
|
+
}
|
|
1764
|
+
} break;
|
|
1765
|
+
// Sub-byte: u4 - 8 nibbles from 4 bytes, high nibble = even index
|
|
1766
|
+
case nk_u4_k: {
|
|
1767
|
+
nk_u8_t const *p = (nk_u8_t const *)from_ptr;
|
|
1768
|
+
for (i = 0; i < 4; ++i) {
|
|
1769
|
+
to_buffers[i * 2].f64c.real = p[i] >> 4, to_buffers[i * 2].f64c.imag = 0;
|
|
1770
|
+
to_buffers[i * 2 + 1].f64c.real = p[i] & 0xF, to_buffers[i * 2 + 1].f64c.imag = 0;
|
|
1771
|
+
}
|
|
1772
|
+
} break;
|
|
1773
|
+
default:
|
|
1774
|
+
for (i = 0; i < 8; ++i) to_buffers[i].f64c.real = 0, to_buffers[i].f64c.imag = 0;
|
|
1775
|
+
break;
|
|
1776
|
+
}
|
|
1777
|
+
}
|
|
1778
|
+
|
|
1779
|
+
/**
|
|
1780
|
+
* @brief Converts up to 8x values from `from_buffers` buffer into 8x typed scalars.
|
|
1781
|
+
*/
|
|
1782
|
+
NK_INTERNAL void nk_scalar_buffers_export_f64c_( //
|
|
1783
|
+
nk_scalar_buffer_t const from_buffers[nk_at_least_(8)], //
|
|
1784
|
+
void *to_ptr, nk_dtype_t to_dtype, nk_size_t to_count) {
|
|
1785
|
+
|
|
1786
|
+
nk_f32_t temporary_f32;
|
|
1787
|
+
nk_size_t i;
|
|
1788
|
+
switch (to_dtype) {
|
|
1789
|
+
case nk_f64_k: {
|
|
1790
|
+
nk_f64_t *p = (nk_f64_t *)to_ptr;
|
|
1791
|
+
for (i = 0; i < to_count; ++i) p[i] = from_buffers[i].f64c.real;
|
|
1792
|
+
} break;
|
|
1793
|
+
case nk_f32_k: {
|
|
1794
|
+
nk_f32_t *p = (nk_f32_t *)to_ptr;
|
|
1795
|
+
for (i = 0; i < to_count; ++i) p[i] = (nk_f32_t)from_buffers[i].f64c.real;
|
|
1796
|
+
} break;
|
|
1797
|
+
case nk_f16_k: {
|
|
1798
|
+
nk_f16_t *p = (nk_f16_t *)to_ptr;
|
|
1799
|
+
for (i = 0; i < to_count; ++i)
|
|
1800
|
+
temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_f16_serial(&temporary_f32, &p[i]);
|
|
1801
|
+
} break;
|
|
1802
|
+
case nk_bf16_k: {
|
|
1803
|
+
nk_bf16_t *p = (nk_bf16_t *)to_ptr;
|
|
1804
|
+
for (i = 0; i < to_count; ++i)
|
|
1805
|
+
temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_bf16_serial(&temporary_f32, &p[i]);
|
|
1806
|
+
} break;
|
|
1807
|
+
case nk_e4m3_k: {
|
|
1808
|
+
nk_u8_t *p = (nk_u8_t *)to_ptr;
|
|
1809
|
+
for (i = 0; i < to_count; ++i)
|
|
1810
|
+
temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_e4m3_serial(&temporary_f32, &p[i]);
|
|
1811
|
+
} break;
|
|
1812
|
+
case nk_e5m2_k: {
|
|
1813
|
+
nk_u8_t *p = (nk_u8_t *)to_ptr;
|
|
1814
|
+
for (i = 0; i < to_count; ++i)
|
|
1815
|
+
temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_e5m2_serial(&temporary_f32, &p[i]);
|
|
1816
|
+
} break;
|
|
1817
|
+
case nk_e2m3_k: {
|
|
1818
|
+
nk_u8_t *p = (nk_u8_t *)to_ptr;
|
|
1819
|
+
for (i = 0; i < to_count; ++i)
|
|
1820
|
+
temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_e2m3_serial(&temporary_f32, &p[i]);
|
|
1821
|
+
} break;
|
|
1822
|
+
case nk_e3m2_k: {
|
|
1823
|
+
nk_u8_t *p = (nk_u8_t *)to_ptr;
|
|
1824
|
+
for (i = 0; i < to_count; ++i)
|
|
1825
|
+
temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_e3m2_serial(&temporary_f32, &p[i]);
|
|
1826
|
+
} break;
|
|
1827
|
+
case nk_i64_k: {
|
|
1828
|
+
nk_i64_t *p = (nk_i64_t *)to_ptr;
|
|
1829
|
+
for (i = 0; i < to_count; ++i) nk_f64_to_i64_serial(&from_buffers[i].f64c.real, &p[i]);
|
|
1830
|
+
} break;
|
|
1831
|
+
case nk_i32_k: {
|
|
1832
|
+
nk_i32_t *p = (nk_i32_t *)to_ptr;
|
|
1833
|
+
for (i = 0; i < to_count; ++i) nk_f64_to_i32_serial(&from_buffers[i].f64c.real, &p[i]);
|
|
1834
|
+
} break;
|
|
1835
|
+
case nk_i16_k: {
|
|
1836
|
+
nk_i16_t *p = (nk_i16_t *)to_ptr;
|
|
1837
|
+
for (i = 0; i < to_count; ++i) nk_f64_to_i16_serial(&from_buffers[i].f64c.real, &p[i]);
|
|
1838
|
+
} break;
|
|
1839
|
+
case nk_i8_k: {
|
|
1840
|
+
nk_i8_t *p = (nk_i8_t *)to_ptr;
|
|
1841
|
+
for (i = 0; i < to_count; ++i) nk_f64_to_i8_serial(&from_buffers[i].f64c.real, &p[i]);
|
|
1842
|
+
} break;
|
|
1843
|
+
case nk_u64_k: {
|
|
1844
|
+
nk_u64_t *p = (nk_u64_t *)to_ptr;
|
|
1845
|
+
for (i = 0; i < to_count; ++i) nk_f64_to_u64_serial(&from_buffers[i].f64c.real, &p[i]);
|
|
1846
|
+
} break;
|
|
1847
|
+
case nk_u32_k: {
|
|
1848
|
+
nk_u32_t *p = (nk_u32_t *)to_ptr;
|
|
1849
|
+
for (i = 0; i < to_count; ++i) nk_f64_to_u32_serial(&from_buffers[i].f64c.real, &p[i]);
|
|
1850
|
+
} break;
|
|
1851
|
+
case nk_u16_k: {
|
|
1852
|
+
nk_u16_t *p = (nk_u16_t *)to_ptr;
|
|
1853
|
+
for (i = 0; i < to_count; ++i) nk_f64_to_u16_serial(&from_buffers[i].f64c.real, &p[i]);
|
|
1854
|
+
} break;
|
|
1855
|
+
case nk_u8_k: {
|
|
1856
|
+
nk_u8_t *p = (nk_u8_t *)to_ptr;
|
|
1857
|
+
for (i = 0; i < to_count; ++i) nk_f64_to_u8_serial(&from_buffers[i].f64c.real, &p[i]);
|
|
1858
|
+
} break;
|
|
1859
|
+
case nk_f64c_k: {
|
|
1860
|
+
nk_f64c_t *p = (nk_f64c_t *)to_ptr;
|
|
1861
|
+
for (i = 0; i < to_count; ++i) p[i] = from_buffers[i].f64c;
|
|
1862
|
+
} break;
|
|
1863
|
+
case nk_f32c_k: {
|
|
1864
|
+
nk_f32c_t *p = (nk_f32c_t *)to_ptr;
|
|
1865
|
+
for (i = 0; i < to_count; ++i)
|
|
1866
|
+
p[i].real = (nk_f32_t)from_buffers[i].f64c.real, p[i].imag = (nk_f32_t)from_buffers[i].f64c.imag;
|
|
1867
|
+
} break;
|
|
1868
|
+
case nk_f16c_k: {
|
|
1869
|
+
nk_f16c_t *p = (nk_f16c_t *)to_ptr;
|
|
1870
|
+
for (i = 0; i < to_count; ++i) {
|
|
1871
|
+
temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_f16_serial(&temporary_f32, &p[i].real);
|
|
1872
|
+
temporary_f32 = (nk_f32_t)from_buffers[i].f64c.imag, nk_f32_to_f16_serial(&temporary_f32, &p[i].imag);
|
|
1873
|
+
}
|
|
1874
|
+
} break;
|
|
1875
|
+
case nk_bf16c_k: {
|
|
1876
|
+
nk_bf16c_t *p = (nk_bf16c_t *)to_ptr;
|
|
1877
|
+
for (i = 0; i < to_count; ++i) {
|
|
1878
|
+
temporary_f32 = (nk_f32_t)from_buffers[i].f64c.real, nk_f32_to_bf16_serial(&temporary_f32, &p[i].real);
|
|
1879
|
+
temporary_f32 = (nk_f32_t)from_buffers[i].f64c.imag, nk_f32_to_bf16_serial(&temporary_f32, &p[i].imag);
|
|
1880
|
+
}
|
|
1881
|
+
} break;
|
|
1882
|
+
// Sub-byte: u1 - 8 bits to 1 byte, MSB-first, non-zero → 1
|
|
1883
|
+
case nk_u1_k: {
|
|
1884
|
+
nk_u8_t *p = (nk_u8_t *)to_ptr;
|
|
1885
|
+
nk_u8_t byte = 0;
|
|
1886
|
+
for (i = 0; i < 8; ++i) byte |= (from_buffers[i].f64c.real != 0) << (7 - i);
|
|
1887
|
+
*p = byte;
|
|
1888
|
+
} break;
|
|
1889
|
+
// Sub-byte: i4 - 8 nibbles to 4 bytes, high nibble = even index
|
|
1890
|
+
case nk_i4_k: {
|
|
1891
|
+
nk_u8_t *p = (nk_u8_t *)to_ptr;
|
|
1892
|
+
for (i = 0; i < 4; ++i) {
|
|
1893
|
+
nk_i64_t hi = (nk_i64_t)from_buffers[i * 2].f64c.real;
|
|
1894
|
+
nk_i64_t lo = (nk_i64_t)from_buffers[i * 2 + 1].f64c.real;
|
|
1895
|
+
hi = hi > 7 ? 7 : (hi < -8 ? -8 : hi);
|
|
1896
|
+
lo = lo > 7 ? 7 : (lo < -8 ? -8 : lo);
|
|
1897
|
+
p[i] = (nk_u8_t)(((hi & 0xF) << 4) | (lo & 0xF));
|
|
1898
|
+
}
|
|
1899
|
+
} break;
|
|
1900
|
+
// Sub-byte: u4 - 8 nibbles to 4 bytes, high nibble = even index
|
|
1901
|
+
case nk_u4_k: {
|
|
1902
|
+
nk_u8_t *p = (nk_u8_t *)to_ptr;
|
|
1903
|
+
for (i = 0; i < 4; ++i) {
|
|
1904
|
+
nk_u64_t hi = (nk_u64_t)from_buffers[i * 2].f64c.real;
|
|
1905
|
+
nk_u64_t lo = (nk_u64_t)from_buffers[i * 2 + 1].f64c.real;
|
|
1906
|
+
hi = hi > 15 ? 15 : hi;
|
|
1907
|
+
lo = lo > 15 ? 15 : lo;
|
|
1908
|
+
p[i] = (nk_u8_t)((hi << 4) | lo);
|
|
1909
|
+
}
|
|
1910
|
+
} break;
|
|
1911
|
+
default: break;
|
|
1912
|
+
}
|
|
1913
|
+
}
|
|
1914
|
+
|
|
1915
|
+
/**
|
|
1916
|
+
* @brief Load 8 values from typed buffer into `buf[i].i64` (lossless widening for signed integers).
|
|
1917
|
+
*/
|
|
1918
|
+
NK_INTERNAL void nk_scalar_buffers_fill_i64_( //
|
|
1919
|
+
void const *from_ptr, nk_dtype_t from_dtype, nk_size_t from_count, //
|
|
1920
|
+
nk_scalar_buffer_t to_buffers[nk_at_least_(8)]) { //
|
|
1921
|
+
nk_size_t i;
|
|
1922
|
+
switch (from_dtype) {
|
|
1923
|
+
case nk_i64_k: {
|
|
1924
|
+
nk_i64_t const *p = (nk_i64_t const *)from_ptr;
|
|
1925
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].i64 = p[i];
|
|
1926
|
+
} break;
|
|
1927
|
+
case nk_i32_k: {
|
|
1928
|
+
nk_i32_t const *p = (nk_i32_t const *)from_ptr;
|
|
1929
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].i64 = p[i];
|
|
1930
|
+
} break;
|
|
1931
|
+
case nk_i16_k: {
|
|
1932
|
+
nk_i16_t const *p = (nk_i16_t const *)from_ptr;
|
|
1933
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].i64 = p[i];
|
|
1934
|
+
} break;
|
|
1935
|
+
case nk_i8_k: {
|
|
1936
|
+
nk_i8_t const *p = (nk_i8_t const *)from_ptr;
|
|
1937
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].i64 = p[i];
|
|
1938
|
+
} break;
|
|
1939
|
+
// Sub-byte: i4 - 4 bytes to 8 nibbles, sign-extend each nibble
|
|
1940
|
+
case nk_i4_k: {
|
|
1941
|
+
nk_u8_t const *p = (nk_u8_t const *)from_ptr;
|
|
1942
|
+
for (i = 0; i < 4; ++i) {
|
|
1943
|
+
nk_i8_t hi = (nk_i8_t)(p[i] >> 4), lo = (nk_i8_t)(p[i] & 0xF);
|
|
1944
|
+
to_buffers[i * 2].i64 = (hi ^ 8) - 8;
|
|
1945
|
+
to_buffers[i * 2 + 1].i64 = (lo ^ 8) - 8;
|
|
1946
|
+
}
|
|
1947
|
+
} break;
|
|
1948
|
+
case nk_u64_k: {
|
|
1949
|
+
nk_u64_t const *p = (nk_u64_t const *)from_ptr;
|
|
1950
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].i64 = (nk_i64_t)p[i];
|
|
1951
|
+
} break;
|
|
1952
|
+
case nk_u32_k: {
|
|
1953
|
+
nk_u32_t const *p = (nk_u32_t const *)from_ptr;
|
|
1954
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].i64 = (nk_i64_t)p[i];
|
|
1955
|
+
} break;
|
|
1956
|
+
case nk_u16_k: {
|
|
1957
|
+
nk_u16_t const *p = (nk_u16_t const *)from_ptr;
|
|
1958
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].i64 = (nk_i64_t)p[i];
|
|
1959
|
+
} break;
|
|
1960
|
+
case nk_u8_k: {
|
|
1961
|
+
nk_u8_t const *p = (nk_u8_t const *)from_ptr;
|
|
1962
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].i64 = (nk_i64_t)p[i];
|
|
1963
|
+
} break;
|
|
1964
|
+
case nk_u4_k: {
|
|
1965
|
+
nk_u8_t const *p = (nk_u8_t const *)from_ptr;
|
|
1966
|
+
for (i = 0; i < 4; ++i) {
|
|
1967
|
+
to_buffers[i * 2].i64 = (nk_i64_t)(p[i] >> 4);
|
|
1968
|
+
to_buffers[i * 2 + 1].i64 = (nk_i64_t)(p[i] & 0xF);
|
|
1969
|
+
}
|
|
1970
|
+
} break;
|
|
1971
|
+
default: break;
|
|
1972
|
+
}
|
|
1973
|
+
}
|
|
1974
|
+
|
|
1975
|
+
/**
|
|
1976
|
+
* @brief Export 8 `buf[i].i64` values to typed buffer with saturation on downcast.
|
|
1977
|
+
*/
|
|
1978
|
+
NK_INTERNAL void nk_scalar_buffers_export_i64_( //
|
|
1979
|
+
nk_scalar_buffer_t const from_buffers[nk_at_least_(8)], //
|
|
1980
|
+
void *to_ptr, nk_dtype_t to_dtype, nk_size_t to_count) { //
|
|
1981
|
+
nk_size_t i;
|
|
1982
|
+
switch (to_dtype) {
|
|
1983
|
+
case nk_i64_k: {
|
|
1984
|
+
nk_i64_t *p = (nk_i64_t *)to_ptr;
|
|
1985
|
+
for (i = 0; i < to_count; ++i) p[i] = from_buffers[i].i64;
|
|
1986
|
+
} break;
|
|
1987
|
+
case nk_i32_k: {
|
|
1988
|
+
nk_i32_t *p = (nk_i32_t *)to_ptr;
|
|
1989
|
+
for (i = 0; i < to_count; ++i) nk_i64_to_i32_serial(&from_buffers[i].i64, &p[i]);
|
|
1990
|
+
} break;
|
|
1991
|
+
case nk_i16_k: {
|
|
1992
|
+
nk_i16_t *p = (nk_i16_t *)to_ptr;
|
|
1993
|
+
for (i = 0; i < to_count; ++i) nk_i64_to_i16_serial(&from_buffers[i].i64, &p[i]);
|
|
1994
|
+
} break;
|
|
1995
|
+
case nk_i8_k: {
|
|
1996
|
+
nk_i8_t *p = (nk_i8_t *)to_ptr;
|
|
1997
|
+
for (i = 0; i < to_count; ++i) nk_i64_to_i8_serial(&from_buffers[i].i64, &p[i]);
|
|
1998
|
+
} break;
|
|
1999
|
+
// Unsigned targets: clamp negatives to 0
|
|
2000
|
+
case nk_u64_k: {
|
|
2001
|
+
nk_u64_t *p = (nk_u64_t *)to_ptr;
|
|
2002
|
+
for (i = 0; i < to_count; ++i) nk_i64_to_u64_serial(&from_buffers[i].i64, &p[i]);
|
|
2003
|
+
} break;
|
|
2004
|
+
case nk_u32_k: {
|
|
2005
|
+
nk_u32_t *p = (nk_u32_t *)to_ptr;
|
|
2006
|
+
for (i = 0; i < to_count; ++i) nk_i64_to_u32_serial(&from_buffers[i].i64, &p[i]);
|
|
2007
|
+
} break;
|
|
2008
|
+
case nk_u16_k: {
|
|
2009
|
+
nk_u16_t *p = (nk_u16_t *)to_ptr;
|
|
2010
|
+
for (i = 0; i < to_count; ++i) nk_i64_to_u16_serial(&from_buffers[i].i64, &p[i]);
|
|
2011
|
+
} break;
|
|
2012
|
+
case nk_u8_k: {
|
|
2013
|
+
nk_u8_t *p = (nk_u8_t *)to_ptr;
|
|
2014
|
+
for (i = 0; i < to_count; ++i) nk_i64_to_u8_serial(&from_buffers[i].i64, &p[i]);
|
|
2015
|
+
} break;
|
|
2016
|
+
// Sub-byte: i4 - 8 nibbles to 4 bytes, clamp [-8,7]
|
|
2017
|
+
case nk_i4_k: {
|
|
2018
|
+
nk_u8_t *p = (nk_u8_t *)to_ptr;
|
|
2019
|
+
for (i = 0; i < 4; ++i) {
|
|
2020
|
+
nk_i64_t hi = from_buffers[i * 2].i64, lo = from_buffers[i * 2 + 1].i64;
|
|
2021
|
+
hi = hi > 7 ? 7 : (hi < -8 ? -8 : hi);
|
|
2022
|
+
lo = lo > 7 ? 7 : (lo < -8 ? -8 : lo);
|
|
2023
|
+
p[i] = (nk_u8_t)(((hi & 0xF) << 4) | (lo & 0xF));
|
|
2024
|
+
}
|
|
2025
|
+
} break;
|
|
2026
|
+
default: break;
|
|
2027
|
+
}
|
|
2028
|
+
}
|
|
2029
|
+
|
|
2030
|
+
/**
|
|
2031
|
+
* @brief Load 8 values from typed buffer into `buf[i].u64` (lossless widening for unsigned integers).
|
|
2032
|
+
*/
|
|
2033
|
+
NK_INTERNAL void nk_scalar_buffers_fill_u64_( //
|
|
2034
|
+
void const *from_ptr, nk_dtype_t from_dtype, nk_size_t from_count, //
|
|
2035
|
+
nk_scalar_buffer_t to_buffers[nk_at_least_(8)]) { //
|
|
2036
|
+
nk_size_t i;
|
|
2037
|
+
switch (from_dtype) {
|
|
2038
|
+
case nk_u64_k: {
|
|
2039
|
+
nk_u64_t const *p = (nk_u64_t const *)from_ptr;
|
|
2040
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].u64 = p[i];
|
|
2041
|
+
} break;
|
|
2042
|
+
case nk_u32_k: {
|
|
2043
|
+
nk_u32_t const *p = (nk_u32_t const *)from_ptr;
|
|
2044
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].u64 = p[i];
|
|
2045
|
+
} break;
|
|
2046
|
+
case nk_u16_k: {
|
|
2047
|
+
nk_u16_t const *p = (nk_u16_t const *)from_ptr;
|
|
2048
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].u64 = p[i];
|
|
2049
|
+
} break;
|
|
2050
|
+
case nk_u8_k: {
|
|
2051
|
+
nk_u8_t const *p = (nk_u8_t const *)from_ptr;
|
|
2052
|
+
for (i = 0; i < from_count; ++i) to_buffers[i].u64 = p[i];
|
|
2053
|
+
} break;
|
|
2054
|
+
// Sub-byte: u4 - 4 bytes to 8 nibbles, zero-extend
|
|
2055
|
+
case nk_u4_k: {
|
|
2056
|
+
nk_u8_t const *p = (nk_u8_t const *)from_ptr;
|
|
2057
|
+
for (i = 0; i < 4; ++i) {
|
|
2058
|
+
to_buffers[i * 2].u64 = p[i] >> 4;
|
|
2059
|
+
to_buffers[i * 2 + 1].u64 = p[i] & 0xF;
|
|
2060
|
+
}
|
|
2061
|
+
} break;
|
|
2062
|
+
// Sub-byte: u1 - 1 byte to 8 bits, MSB-first
|
|
2063
|
+
case nk_u1_k: {
|
|
2064
|
+
nk_u8_t byte = *(nk_u8_t const *)from_ptr;
|
|
2065
|
+
for (i = 0; i < 8; ++i) to_buffers[i].u64 = (byte >> (7 - i)) & 1;
|
|
2066
|
+
} break;
|
|
2067
|
+
default: break;
|
|
2068
|
+
}
|
|
2069
|
+
}
|
|
2070
|
+
|
|
2071
|
+
/**
|
|
2072
|
+
* @brief Export 8 `buf[i].u64` values to typed buffer with saturation on downcast.
|
|
2073
|
+
*/
|
|
2074
|
+
NK_INTERNAL void nk_scalar_buffers_export_u64_( //
|
|
2075
|
+
nk_scalar_buffer_t const from_buffers[nk_at_least_(8)], //
|
|
2076
|
+
void *to_ptr, nk_dtype_t to_dtype, nk_size_t to_count) { //
|
|
2077
|
+
nk_size_t i;
|
|
2078
|
+
switch (to_dtype) {
|
|
2079
|
+
case nk_u64_k: {
|
|
2080
|
+
nk_u64_t *p = (nk_u64_t *)to_ptr;
|
|
2081
|
+
for (i = 0; i < to_count; ++i) p[i] = from_buffers[i].u64;
|
|
2082
|
+
} break;
|
|
2083
|
+
case nk_u32_k: {
|
|
2084
|
+
nk_u32_t *p = (nk_u32_t *)to_ptr;
|
|
2085
|
+
for (i = 0; i < to_count; ++i) nk_u64_to_u32_serial(&from_buffers[i].u64, &p[i]);
|
|
2086
|
+
} break;
|
|
2087
|
+
case nk_u16_k: {
|
|
2088
|
+
nk_u16_t *p = (nk_u16_t *)to_ptr;
|
|
2089
|
+
for (i = 0; i < to_count; ++i) nk_u64_to_u16_serial(&from_buffers[i].u64, &p[i]);
|
|
2090
|
+
} break;
|
|
2091
|
+
case nk_u8_k: {
|
|
2092
|
+
nk_u8_t *p = (nk_u8_t *)to_ptr;
|
|
2093
|
+
for (i = 0; i < to_count; ++i) nk_u64_to_u8_serial(&from_buffers[i].u64, &p[i]);
|
|
2094
|
+
} break;
|
|
2095
|
+
// Signed targets: clamp to i64_max
|
|
2096
|
+
case nk_i64_k: {
|
|
2097
|
+
nk_i64_t *p = (nk_i64_t *)to_ptr;
|
|
2098
|
+
for (i = 0; i < to_count; ++i) nk_u64_to_i64_serial(&from_buffers[i].u64, &p[i]);
|
|
2099
|
+
} break;
|
|
2100
|
+
case nk_i32_k: {
|
|
2101
|
+
nk_i32_t *p = (nk_i32_t *)to_ptr;
|
|
2102
|
+
for (i = 0; i < to_count; ++i) nk_u64_to_i32_serial(&from_buffers[i].u64, &p[i]);
|
|
2103
|
+
} break;
|
|
2104
|
+
case nk_i16_k: {
|
|
2105
|
+
nk_i16_t *p = (nk_i16_t *)to_ptr;
|
|
2106
|
+
for (i = 0; i < to_count; ++i) nk_u64_to_i16_serial(&from_buffers[i].u64, &p[i]);
|
|
2107
|
+
} break;
|
|
2108
|
+
case nk_i8_k: {
|
|
2109
|
+
nk_i8_t *p = (nk_i8_t *)to_ptr;
|
|
2110
|
+
for (i = 0; i < to_count; ++i) nk_u64_to_i8_serial(&from_buffers[i].u64, &p[i]);
|
|
2111
|
+
} break;
|
|
2112
|
+
// Sub-byte: u4 - 8 nibbles to 4 bytes, clamp [0,15]
|
|
2113
|
+
case nk_u4_k: {
|
|
2114
|
+
nk_u8_t *p = (nk_u8_t *)to_ptr;
|
|
2115
|
+
for (i = 0; i < 4; ++i) {
|
|
2116
|
+
nk_u64_t hi = from_buffers[i * 2].u64, lo = from_buffers[i * 2 + 1].u64;
|
|
2117
|
+
hi = hi > 15 ? 15 : hi;
|
|
2118
|
+
lo = lo > 15 ? 15 : lo;
|
|
2119
|
+
p[i] = (nk_u8_t)((hi << 4) | lo);
|
|
2120
|
+
}
|
|
2121
|
+
} break;
|
|
2122
|
+
// Sub-byte: u1 - 8 bits to 1 byte, MSB-first, non-zero becomes 1
|
|
2123
|
+
case nk_u1_k: {
|
|
2124
|
+
nk_u8_t *p = (nk_u8_t *)to_ptr;
|
|
2125
|
+
nk_u8_t byte = 0;
|
|
2126
|
+
for (i = 0; i < 8; ++i) byte |= (from_buffers[i].u64 != 0) << (7 - i);
|
|
2127
|
+
*p = byte;
|
|
2128
|
+
} break;
|
|
2129
|
+
default: break;
|
|
2130
|
+
}
|
|
2131
|
+
}
|
|
2132
|
+
|
|
2133
|
+
#pragma endregion - Type Punned Loads and Stores
|
|
2134
|
+
|
|
2135
|
+
#pragma region - Public API
|
|
2136
|
+
|
|
2137
|
+
NK_PUBLIC void nk_cast_serial(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type) {
|
|
2138
|
+
if (from_type == to_type) {
|
|
2139
|
+
nk_size_t size_bits = nk_dtype_bits(from_type);
|
|
2140
|
+
nk_size_t size_bytes = nk_size_divide_round_up_(n * size_bits, NK_BITS_PER_BYTE);
|
|
2141
|
+
if (size_bytes > 0) nk_copy_bytes_(to, from, size_bytes);
|
|
2142
|
+
return;
|
|
2143
|
+
}
|
|
2144
|
+
|
|
2145
|
+
nk_size_t from_bits = nk_dtype_bits(from_type);
|
|
2146
|
+
nk_size_t to_bits = nk_dtype_bits(to_type);
|
|
2147
|
+
if (from_bits == 0 || to_bits == 0) return;
|
|
2148
|
+
|
|
2149
|
+
// Byte steps per batch of NK_BITS_PER_BYTE elements
|
|
2150
|
+
nk_size_t from_step = from_bits;
|
|
2151
|
+
nk_size_t to_step = to_bits;
|
|
2152
|
+
|
|
2153
|
+
nk_u8_t const *src = (nk_u8_t const *)from;
|
|
2154
|
+
nk_u8_t *dst = (nk_u8_t *)to;
|
|
2155
|
+
nk_dtype_family_t from_family = nk_dtype_family(from_type);
|
|
2156
|
+
nk_dtype_family_t to_family = nk_dtype_family(to_type);
|
|
2157
|
+
|
|
2158
|
+
nk_size_t batches = n / NK_BITS_PER_BYTE;
|
|
2159
|
+
nk_size_t tail = n % NK_BITS_PER_BYTE;
|
|
2160
|
+
nk_scalar_buffer_t bufs[NK_BITS_PER_BYTE];
|
|
2161
|
+
|
|
2162
|
+
// Both unsigned: u64 hub
|
|
2163
|
+
if (from_family == nk_dtype_family_uint_k && to_family == nk_dtype_family_uint_k) {
|
|
2164
|
+
for (nk_size_t b = 0; b < batches; ++b, src += from_step, dst += to_step) {
|
|
2165
|
+
nk_scalar_buffers_fill_u64_(src, from_type, NK_BITS_PER_BYTE, bufs);
|
|
2166
|
+
nk_scalar_buffers_export_u64_(bufs, dst, to_type, NK_BITS_PER_BYTE);
|
|
2167
|
+
}
|
|
2168
|
+
if (tail) {
|
|
2169
|
+
nk_scalar_buffers_fill_u64_(src, from_type, tail, bufs);
|
|
2170
|
+
nk_scalar_buffers_export_u64_(bufs, dst, to_type, tail);
|
|
2171
|
+
}
|
|
2172
|
+
return;
|
|
2173
|
+
}
|
|
2174
|
+
|
|
2175
|
+
// Both integers, at least one signed: i64 hub
|
|
2176
|
+
if ((from_family == nk_dtype_family_int_k || from_family == nk_dtype_family_uint_k) &&
|
|
2177
|
+
(to_family == nk_dtype_family_int_k || to_family == nk_dtype_family_uint_k)) {
|
|
2178
|
+
for (nk_size_t b = 0; b < batches; ++b, src += from_step, dst += to_step) {
|
|
2179
|
+
nk_scalar_buffers_fill_i64_(src, from_type, NK_BITS_PER_BYTE, bufs);
|
|
2180
|
+
nk_scalar_buffers_export_i64_(bufs, dst, to_type, NK_BITS_PER_BYTE);
|
|
2181
|
+
}
|
|
2182
|
+
if (tail) {
|
|
2183
|
+
nk_scalar_buffers_fill_i64_(src, from_type, tail, bufs);
|
|
2184
|
+
nk_scalar_buffers_export_i64_(bufs, dst, to_type, tail);
|
|
2185
|
+
}
|
|
2186
|
+
return;
|
|
2187
|
+
}
|
|
2188
|
+
|
|
2189
|
+
// Everything else: f64c hub (floats, complex, cross-category)
|
|
2190
|
+
for (nk_size_t b = 0; b < batches; ++b, src += from_step, dst += to_step) {
|
|
2191
|
+
nk_scalar_buffers_fill_f64c_(src, from_type, NK_BITS_PER_BYTE, bufs);
|
|
2192
|
+
nk_scalar_buffers_export_f64c_(bufs, dst, to_type, NK_BITS_PER_BYTE);
|
|
2193
|
+
}
|
|
2194
|
+
if (tail) {
|
|
2195
|
+
nk_scalar_buffers_fill_f64c_(src, from_type, tail, bufs);
|
|
2196
|
+
nk_scalar_buffers_export_f64c_(bufs, dst, to_type, tail);
|
|
2197
|
+
}
|
|
2198
|
+
}
|
|
2199
|
+
|
|
2200
|
+
/** @brief Convert E4M3 to BF16 via F32 intermediate. */
|
|
2201
|
+
NK_PUBLIC void nk_e4m3_to_bf16(nk_e4m3_t const *src, nk_bf16_t *dest) {
|
|
2202
|
+
nk_f32_t temp;
|
|
2203
|
+
nk_e4m3_to_f32_serial(src, &temp);
|
|
2204
|
+
nk_f32_to_bf16_serial(&temp, dest);
|
|
2205
|
+
}
|
|
2206
|
+
|
|
2207
|
+
/** @brief Convert E5M2 to BF16 via F32 intermediate. */
|
|
2208
|
+
NK_PUBLIC void nk_e5m2_to_bf16(nk_e5m2_t const *src, nk_bf16_t *dest) {
|
|
2209
|
+
nk_f32_t temp;
|
|
2210
|
+
nk_e5m2_to_f32_serial(src, &temp);
|
|
2211
|
+
nk_f32_to_bf16_serial(&temp, dest);
|
|
2212
|
+
}
|
|
2213
|
+
|
|
2214
|
+
/** @brief Convert E2M3 to BF16 via F32 intermediate. */
|
|
2215
|
+
NK_PUBLIC void nk_e2m3_to_bf16(nk_e2m3_t const *src, nk_bf16_t *dest) {
|
|
2216
|
+
nk_f32_t temp;
|
|
2217
|
+
nk_e2m3_to_f32_serial(src, &temp);
|
|
2218
|
+
nk_f32_to_bf16_serial(&temp, dest);
|
|
2219
|
+
}
|
|
2220
|
+
|
|
2221
|
+
/** @brief Convert E3M2 to BF16 via F32 intermediate. */
|
|
2222
|
+
NK_PUBLIC void nk_e3m2_to_bf16(nk_e3m2_t const *src, nk_bf16_t *dest) {
|
|
2223
|
+
nk_f32_t temp;
|
|
2224
|
+
nk_e3m2_to_f32_serial(src, &temp);
|
|
2225
|
+
nk_f32_to_bf16_serial(&temp, dest);
|
|
2226
|
+
}
|
|
2227
|
+
|
|
2228
|
+
/**
|
|
2229
|
+
* @brief Convert i4 (4-bit signed integer, -8 to 7) to i8.
|
|
2230
|
+
*
|
|
2231
|
+
* Nibbles are packed: low nibble in bits [0:3], high nibble in bits [4:7].
|
|
2232
|
+
* Sign extension: XOR with 8 then subtract 8 converts unsigned nibble to signed.
|
|
2233
|
+
*/
|
|
2234
|
+
NK_PUBLIC void nk_i4_to_i8_serial_(nk_i4x2_t const *src, nk_i8_t *dest, nk_size_t count) {
|
|
2235
|
+
nk_u8_t const *bytes = (nk_u8_t const *)src;
|
|
2236
|
+
for (nk_size_t i = 0; i < count; ++i) {
|
|
2237
|
+
nk_u8_t byte = bytes[i / 2];
|
|
2238
|
+
nk_u8_t nibble = (i % 2 == 0) ? (byte & 0x0F) : (byte >> 4);
|
|
2239
|
+
dest[i] = (nk_i8_t)((nibble ^ 8) - 8); // Sign extend: 0-7 → 0-7, 8-15 → -8 to -1
|
|
2240
|
+
}
|
|
2241
|
+
}
|
|
2242
|
+
|
|
2243
|
+
/**
|
|
2244
|
+
* @brief Convert u4 (4-bit unsigned integer, 0 to 15) to u8.
|
|
2245
|
+
*
|
|
2246
|
+
* Nibbles are packed: low nibble in bits [0:3], high nibble in bits [4:7].
|
|
2247
|
+
*/
|
|
2248
|
+
NK_PUBLIC void nk_u4_to_u8_serial_(nk_u4x2_t const *src, nk_u8_t *dest, nk_size_t count) {
|
|
2249
|
+
nk_u8_t const *bytes = (nk_u8_t const *)src;
|
|
2250
|
+
for (nk_size_t i = 0; i < count; ++i) {
|
|
2251
|
+
nk_u8_t byte = bytes[i / 2];
|
|
2252
|
+
dest[i] = (i % 2 == 0) ? (byte & 0x0F) : (byte >> 4);
|
|
2253
|
+
}
|
|
2254
|
+
}
|
|
2255
|
+
|
|
2256
|
+
#pragma endregion - Public API
|
|
2257
|
+
|
|
2258
|
+
#if defined(__cplusplus)
|
|
2259
|
+
} // extern "C"
|
|
2260
|
+
#endif
|
|
2261
|
+
|
|
2262
|
+
#endif // NK_CAST_SERIAL_H
|