numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -45,6 +45,49 @@
|
|
|
45
45
|
*
|
|
46
46
|
* @sa `dimensions_per_value<T>()` to convert dimension counts to value counts.
|
|
47
47
|
* @sa `bits_per_value<T>()` to infer the size of each value.
|
|
48
|
+
*
|
|
49
|
+
* @section fp8_types FP8 Numeric Types
|
|
50
|
+
*
|
|
51
|
+
* There are several variants of 8-bit floating point types supported by different industry memebers
|
|
52
|
+
* with different hardware support. None are part of the IEEE 754 standard, but some are part of the
|
|
53
|
+
* Open Compute Project (OCP) 8-bit Floating Point Specification (OFP8):
|
|
54
|
+
*
|
|
55
|
+
* Format Bias Sign Exp Mant Range Infinity NaN Standard
|
|
56
|
+
* E4M3FN 7 1 4 3 ±448 ❌ No Only 0x7F/0xFF OCP, NVIDIA, ONNX
|
|
57
|
+
* E5M2 15 1 5 2 ±57344 ✅ Yes (0x7C/0xFC) 0x7D-7F, 0xFD-FF OCP, IEEE-like
|
|
58
|
+
* E4M3FNUZ 8 1 4 3 ±240 ❌ No 0x80 only GraphCore, ONNX
|
|
59
|
+
* E5M2FNUZ 16 1 5 2 ±57344 ❌ No 0x80 only GraphCore, ONNX
|
|
60
|
+
*
|
|
61
|
+
* In currently available and soon incoming harware, only two series of models prioritze FNUZ over OCP:
|
|
62
|
+
*
|
|
63
|
+
* - GraphCore IPUs were the original platform proposing FNUZ
|
|
64
|
+
* - AMD MI300 series based on CDNA3 implements FNUZ, but not OCP
|
|
65
|
+
* - AMD MI350+ series based on CDNA4 switch to OCP and remove FNUZ
|
|
66
|
+
* - NVIDIA Hopper and Blackwell only support E4M3FN, E5M2
|
|
67
|
+
* - Intel AVX10.2 defines HF8 (E4M3FN) and BF8 (E5M2) - OCP-aligned
|
|
68
|
+
* - Arm implements E4M3 (meaning E4M3FN) and E5M2 with a shared `__mfp8` type and a `FPMR` format selector
|
|
69
|
+
*
|
|
70
|
+
* For brevety, across NumKong, "E4M3" implies "E4M3FN".
|
|
71
|
+
*
|
|
72
|
+
* @section fp6_types FP6 Numeric Types
|
|
73
|
+
*
|
|
74
|
+
* The OCP Microscaling (MX) v1.0 specification defines two 6-bit floating-point formats
|
|
75
|
+
* for block-scaled quantization. Both are "FN" (finite-numeric): all bit patterns map
|
|
76
|
+
* to real numbers with no Inf or NaN codes. Stored byte-aligned with 2 bits of padding.
|
|
77
|
+
*
|
|
78
|
+
* Format Bias Sign Exp Mant Range Subnormals Infinity NaN Standard
|
|
79
|
+
* E2M3 1 1 2 3 ±7.5 14 of 64 ❌ No ❌ OCP MX v1.0
|
|
80
|
+
* E3M2 3 1 3 2 ±28 6 of 64 ❌ No ❌ OCP MX v1.0
|
|
81
|
+
*
|
|
82
|
+
* E2M3 favors mantissa precision (3 bits) for narrow dynamic range — ideal for activations.
|
|
83
|
+
* E3M2 favors exponent range (3 bits) for wider dynamic range — suited for weights.
|
|
84
|
+
* Both follow IEEE 754 subnormal rules: when exp=0, the implicit leading bit is 0,
|
|
85
|
+
* giving value = (-1)^s × 0.mmm × 2^(1-bias). This provides gradual underflow to zero.
|
|
86
|
+
*
|
|
87
|
+
* No hardware directly computes on FP6. On Arm with FEAT_FP8DOT4, E2M3 values can be
|
|
88
|
+
* losslessly promoted to E4M3 (same mantissa width, rebias exponent by +6) and E3M2 to
|
|
89
|
+
* E5M2 (same mantissa width, rebias exponent by +12), then fed to FDOT instructions.
|
|
90
|
+
* Subnormal values (exp=0) require normalization during this promotion.
|
|
48
91
|
*/
|
|
49
92
|
|
|
50
93
|
#ifndef NK_TYPES_HPP
|
|
@@ -2352,10 +2395,13 @@ struct e5m2_t {
|
|
|
2352
2395
|
* @brief Float6 E2M3FN: 1 sign + 2 exponent (bias=1) + 3 mantissa bits, with 2 bits of padding.
|
|
2353
2396
|
*
|
|
2354
2397
|
* Range: [-7.5, +7.5], stored byte-aligned (0b00SEEMMM, upper 2 bits padding).
|
|
2355
|
-
* No Inf/NaN (OCP Microscaling FN format). All 64 bit patterns are valid.
|
|
2398
|
+
* No Inf/NaN (OCP Microscaling FN format). All 64 bit patterns are valid numbers.
|
|
2399
|
+
* 64 total codes: 48 normal, 14 subnormal (exp=0, mant!=0), 2 zeros (+/-0).
|
|
2356
2400
|
* Only 18 of 64 values (28.1%) fall in [-1, +1] — 72% of codes represent |x| > 1.
|
|
2401
|
+
* Subnormal values span [+/-0.125, +/-0.875] using formula 0.mmm x 2^(1-bias).
|
|
2357
2402
|
* Dot products are exact via integer accumulation: every value x 16 is an integer
|
|
2358
2403
|
* in [-120, +120], so products fit in i16 and sums fit in i32 without rounding.
|
|
2404
|
+
* Losslessly promotable to E4M3 by rebiasing exponent +6 (normals) or normalizing (subnormals).
|
|
2359
2405
|
*
|
|
2360
2406
|
* @see https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
|
2361
2407
|
* @see https://arxiv.org/abs/2401.14112 (FP6-LLM paper)
|
|
@@ -2533,10 +2579,13 @@ struct e2m3_t {
|
|
|
2533
2579
|
* @brief Float6 E3M2FN: 1 sign + 3 exponent (bias=3) + 2 mantissa bits, with 2 bits of padding.
|
|
2534
2580
|
*
|
|
2535
2581
|
* Range: [-28, +28], stored byte-aligned (0b00SEEEMM, upper 2 bits padding).
|
|
2536
|
-
* No Inf/NaN (OCP Microscaling FN format). All 64 bit patterns are valid.
|
|
2582
|
+
* No Inf/NaN (OCP Microscaling FN format). All 64 bit patterns are valid numbers.
|
|
2583
|
+
* 64 total codes: 56 normal, 6 subnormal (exp=0, mant!=0), 2 zeros (+/-0).
|
|
2537
2584
|
* 26 of 64 values (40.6%) fall in [-1, +1].
|
|
2585
|
+
* Subnormal values span [+/-0.0625, +/-0.1875] using formula 0.mm x 2^(1-bias).
|
|
2538
2586
|
* Dot products are exact via integer accumulation: every value x 4 is an integer
|
|
2539
2587
|
* in [-28, +28], so products fit in i16 and sums fit in i32 without rounding.
|
|
2588
|
+
* Losslessly promotable to E5M2 by rebiasing exponent +12 (normals) or normalizing (subnormals).
|
|
2540
2589
|
*
|
|
2541
2590
|
* @see https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
|
2542
2591
|
* @see https://arxiv.org/abs/2401.14112 (FP6-LLM paper)
|
|
@@ -5266,7 +5315,7 @@ struct u4x2_t {
|
|
|
5266
5315
|
constexpr std::strong_ordering operator<=>(u4x2_t const &o) const noexcept = default;
|
|
5267
5316
|
};
|
|
5268
5317
|
|
|
5269
|
-
#pragma region
|
|
5318
|
+
#pragma region Enum Conversion
|
|
5270
5319
|
|
|
5271
5320
|
/**
|
|
5272
5321
|
* @brief Maps `nk_dtype_t` enum values to their corresponding C++ wrapper types.
|
|
@@ -5301,9 +5350,9 @@ template <> struct type_for<nk_i4_k> { using type = i4x2_t; };
|
|
|
5301
5350
|
template <> struct type_for<nk_u4_k> { using type = u4x2_t; };
|
|
5302
5351
|
// clang-format on
|
|
5303
5352
|
|
|
5304
|
-
#pragma endregion
|
|
5353
|
+
#pragma endregion Enum Conversion
|
|
5305
5354
|
|
|
5306
|
-
#pragma region
|
|
5355
|
+
#pragma region Numeric Limits
|
|
5307
5356
|
|
|
5308
5357
|
/** @brief Get the maximum representable value for a type. */
|
|
5309
5358
|
template <typename scalar_type_>
|
|
@@ -5339,6 +5388,16 @@ constexpr unsigned dimensions_per_value() noexcept {
|
|
|
5339
5388
|
return bits_per_value<scalar_type_>() / bits_per_dimension<scalar_type_>();
|
|
5340
5389
|
}
|
|
5341
5390
|
|
|
5391
|
+
/**
|
|
5392
|
+
* @brief The mutable reference type for one logical dimension of a value.
|
|
5393
|
+
*
|
|
5394
|
+
* For normal types (1 dim per value): a plain `value_type_ &`.
|
|
5395
|
+
* For sub-byte packed types: a `sub_byte_ref<value_type_>` proxy.
|
|
5396
|
+
*/
|
|
5397
|
+
template <typename value_type_>
|
|
5398
|
+
using value_ref =
|
|
5399
|
+
std::conditional_t<dimensions_per_value<value_type_>() == 1, value_type_ &, sub_byte_ref<value_type_>>;
|
|
5400
|
+
|
|
5342
5401
|
/**
|
|
5343
5402
|
* @brief Extract the word type from a value type.
|
|
5344
5403
|
*
|
|
@@ -5412,9 +5471,9 @@ constexpr std::size_t round_up_to_multiple(std::size_t n) {
|
|
|
5412
5471
|
return divide_round_up<multiple_>(n) * multiple_;
|
|
5413
5472
|
}
|
|
5414
5473
|
|
|
5415
|
-
#pragma endregion
|
|
5474
|
+
#pragma endregion Numeric Limits
|
|
5416
5475
|
|
|
5417
|
-
#pragma region
|
|
5476
|
+
#pragma region SIMD Dispatch Helpers
|
|
5418
5477
|
|
|
5419
5478
|
/** @brief Controls whether template wrappers dispatch to SIMD C kernels. */
|
|
5420
5479
|
enum allow_simd_t {
|
|
@@ -5553,9 +5612,9 @@ constexpr unsigned count_intersection(u1x8_t a, u1x8_t b) noexcept { return a.in
|
|
|
5553
5612
|
/** @brief Count bit-level union for u1x8_t (8 packed bits). Returns popcount of OR. */
|
|
5554
5613
|
constexpr unsigned count_union(u1x8_t a, u1x8_t b) noexcept { return a.union_size(b); }
|
|
5555
5614
|
|
|
5556
|
-
#pragma endregion
|
|
5615
|
+
#pragma endregion SIMD Dispatch Helpers
|
|
5557
5616
|
|
|
5558
|
-
#pragma region
|
|
5617
|
+
#pragma region F118 Mixed Operators
|
|
5559
5618
|
|
|
5560
5619
|
constexpr f118_t operator+(double a, f118_t b) noexcept { return f118_t(a) + b; }
|
|
5561
5620
|
constexpr f118_t operator-(double a, f118_t b) noexcept { return f118_t(a) - b; }
|
|
@@ -5569,9 +5628,9 @@ constexpr bool operator>(double a, f118_t b) noexcept { return f118_t(a) > b; }
|
|
|
5569
5628
|
constexpr bool operator<=(double a, f118_t b) noexcept { return f118_t(a) <= b; }
|
|
5570
5629
|
constexpr bool operator>=(double a, f118_t b) noexcept { return f118_t(a) >= b; }
|
|
5571
5630
|
|
|
5572
|
-
#pragma endregion
|
|
5631
|
+
#pragma endregion F118 Mixed Operators
|
|
5573
5632
|
|
|
5574
|
-
#pragma region
|
|
5633
|
+
#pragma region Concepts
|
|
5575
5634
|
|
|
5576
5635
|
template <typename matrix_type_, typename element_type_>
|
|
5577
5636
|
concept const_matrix_of = requires(matrix_type_ const &m) {
|
|
@@ -5596,8 +5655,373 @@ concept packed_matrix_like = requires(packed_type_ const &p) {
|
|
|
5596
5655
|
{ p.depth() } -> std::convertible_to<std::size_t>;
|
|
5597
5656
|
};
|
|
5598
5657
|
|
|
5599
|
-
#pragma endregion
|
|
5658
|
+
#pragma endregion Concepts
|
|
5659
|
+
|
|
5660
|
+
} // namespace ashvardanian::numkong
|
|
5661
|
+
|
|
5662
|
+
#if __has_include(<format>)
|
|
5663
|
+
#include <format>
|
|
5664
|
+
#if defined(__cpp_lib_format) && __cpp_lib_format >= 202110L
|
|
5665
|
+
|
|
5666
|
+
namespace ashvardanian::numkong {
|
|
5667
|
+
|
|
5668
|
+
/**
|
|
5669
|
+
* Parsed format spec for NumKong scalar types.
|
|
5670
|
+
*
|
|
5671
|
+
* Supports the standard format spec grammar `[[fill]align][sign][#][0][width][.precision][type]`:
|
|
5672
|
+
* - `{}` — clean float value, composes with C++23 range formatting
|
|
5673
|
+
* - `{:#}` — annotated: `3.14 [0x4248]`
|
|
5674
|
+
* - `{:.2f}` — precision/type forwarded to `std::formatter<float>`
|
|
5675
|
+
* - `{:x}` / `{:#x}` / `{:X}` — hex bits (with optional `0x` prefix)
|
|
5676
|
+
* - `{:b}` / `{:#b}` — binary bits (with optional `0b` prefix)
|
|
5677
|
+
*/
|
|
5678
|
+
struct scalar_format_spec_t {
|
|
5679
|
+
enum class mode_t : unsigned char { float_val_k, hex_k, binary_k };
|
|
5680
|
+
mode_t mode_ = mode_t::float_val_k;
|
|
5681
|
+
bool annotate_ = false; // '#' in float mode → append [0xHHHH]
|
|
5682
|
+
bool prefix_ = false; // '#' in hex/binary mode → 0x/0b prefix
|
|
5683
|
+
bool upper_ = false; // 'X' vs 'x'
|
|
5684
|
+
std::formatter<float> float_fmt_;
|
|
5685
|
+
|
|
5686
|
+
constexpr std::format_parse_context::iterator parse(std::format_parse_context &ctx) {
|
|
5687
|
+
auto it = ctx.begin();
|
|
5688
|
+
auto end = ctx.end();
|
|
5689
|
+
if (it == end || *it == '}') return it;
|
|
5690
|
+
|
|
5691
|
+
// Find last character of the format spec (the type position).
|
|
5692
|
+
auto spec_last = it;
|
|
5693
|
+
for (auto s = it; s != end && *s != '}'; ++s) spec_last = s;
|
|
5694
|
+
|
|
5695
|
+
char type_char = *spec_last;
|
|
5696
|
+
if (type_char == 'x' || type_char == 'X') {
|
|
5697
|
+
mode_ = mode_t::hex_k;
|
|
5698
|
+
upper_ = (type_char == 'X');
|
|
5699
|
+
for (auto s = it; s != spec_last; ++s)
|
|
5700
|
+
if (*s == '#') {
|
|
5701
|
+
prefix_ = true;
|
|
5702
|
+
break;
|
|
5703
|
+
}
|
|
5704
|
+
return spec_last + 1;
|
|
5705
|
+
}
|
|
5706
|
+
if (type_char == 'b' || type_char == 'B') {
|
|
5707
|
+
mode_ = mode_t::binary_k;
|
|
5708
|
+
for (auto s = it; s != spec_last; ++s)
|
|
5709
|
+
if (*s == '#') {
|
|
5710
|
+
prefix_ = true;
|
|
5711
|
+
break;
|
|
5712
|
+
}
|
|
5713
|
+
return spec_last + 1;
|
|
5714
|
+
}
|
|
5715
|
+
|
|
5716
|
+
// Float mode — '#' means annotate with hex bits.
|
|
5717
|
+
for (auto s = it; s != end && *s != '}'; ++s)
|
|
5718
|
+
if (*s == '#') {
|
|
5719
|
+
annotate_ = true;
|
|
5720
|
+
break;
|
|
5721
|
+
}
|
|
5722
|
+
return float_fmt_.parse(ctx);
|
|
5723
|
+
}
|
|
5724
|
+
};
|
|
5725
|
+
|
|
5726
|
+
/** Write zero-padded hex to an output iterator. `width` is 1, 2, or 4. Each branch uses a literal format string. */
|
|
5727
|
+
inline std::format_context::iterator format_hex_(std::format_context::iterator out, unsigned bits, unsigned width,
|
|
5728
|
+
bool prefix, bool upper) {
|
|
5729
|
+
if (width == 4) {
|
|
5730
|
+
if (prefix && upper) return std::format_to(out, "0X{:04X}", bits);
|
|
5731
|
+
if (prefix) return std::format_to(out, "0x{:04x}", bits);
|
|
5732
|
+
if (upper) return std::format_to(out, "{:04X}", bits);
|
|
5733
|
+
return std::format_to(out, "{:04x}", bits);
|
|
5734
|
+
}
|
|
5735
|
+
if (width == 1) {
|
|
5736
|
+
if (prefix && upper) return std::format_to(out, "0X{:01X}", bits);
|
|
5737
|
+
if (prefix) return std::format_to(out, "0x{:01x}", bits);
|
|
5738
|
+
if (upper) return std::format_to(out, "{:01X}", bits);
|
|
5739
|
+
return std::format_to(out, "{:01x}", bits);
|
|
5740
|
+
}
|
|
5741
|
+
if (prefix && upper) return std::format_to(out, "0X{:02X}", bits);
|
|
5742
|
+
if (prefix) return std::format_to(out, "0x{:02x}", bits);
|
|
5743
|
+
if (upper) return std::format_to(out, "{:02X}", bits);
|
|
5744
|
+
return std::format_to(out, "{:02x}", bits);
|
|
5745
|
+
}
|
|
5746
|
+
|
|
5747
|
+
/** Write zero-padded binary to an output iterator. `width` is 4, 8, or 16. */
|
|
5748
|
+
inline std::format_context::iterator format_bin_(std::format_context::iterator out, unsigned bits, unsigned width,
|
|
5749
|
+
bool prefix) {
|
|
5750
|
+
if (width == 16) {
|
|
5751
|
+
if (prefix) return std::format_to(out, "0b{:016b}", bits);
|
|
5752
|
+
return std::format_to(out, "{:016b}", bits);
|
|
5753
|
+
}
|
|
5754
|
+
if (width == 4) {
|
|
5755
|
+
if (prefix) return std::format_to(out, "0b{:04b}", bits);
|
|
5756
|
+
return std::format_to(out, "{:04b}", bits);
|
|
5757
|
+
}
|
|
5758
|
+
if (prefix) return std::format_to(out, "0b{:08b}", bits);
|
|
5759
|
+
return std::format_to(out, "{:08b}", bits);
|
|
5760
|
+
}
|
|
5761
|
+
|
|
5762
|
+
/** Write hex annotation suffix: ` [0x{hex}]`. */
|
|
5763
|
+
inline std::format_context::iterator format_hex_suffix_(std::format_context::iterator out, unsigned bits,
|
|
5764
|
+
unsigned width) {
|
|
5765
|
+
if (width == 4) return std::format_to(out, " [0x{:04x}]", bits);
|
|
5766
|
+
if (width == 1) return std::format_to(out, " [0x{:01x}]", bits);
|
|
5767
|
+
return std::format_to(out, " [0x{:02x}]", bits);
|
|
5768
|
+
}
|
|
5769
|
+
|
|
5770
|
+
/**
|
|
5771
|
+
* Formatter implementation for float-like scalar types (f16, bf16, e4m3, e5m2, e2m3, e3m2).
|
|
5772
|
+
* @tparam value_type_ The NumKong scalar type.
|
|
5773
|
+
* @tparam hex_width_ Number of hex digits (4 for 16-bit, 2 for 8-bit).
|
|
5774
|
+
* @tparam bin_width_ Number of binary digits (16 for 16-bit, 8 for 8-bit).
|
|
5775
|
+
*/
|
|
5776
|
+
template <typename value_type_, unsigned hex_width_, unsigned bin_width_>
|
|
5777
|
+
struct float_scalar_formatter_ {
|
|
5778
|
+
scalar_format_spec_t spec_;
|
|
5779
|
+
|
|
5780
|
+
constexpr std::format_parse_context::iterator parse(std::format_parse_context &ctx) { return spec_.parse(ctx); }
|
|
5781
|
+
|
|
5782
|
+
std::format_context::iterator format(value_type_ v, std::format_context &ctx) const {
|
|
5783
|
+
using mode_t = scalar_format_spec_t::mode_t;
|
|
5784
|
+
unsigned bits;
|
|
5785
|
+
if constexpr (requires { v.to_bits(); }) bits = static_cast<unsigned>(v.to_bits());
|
|
5786
|
+
else bits = static_cast<unsigned>(v.raw_);
|
|
5787
|
+
|
|
5788
|
+
switch (spec_.mode_) {
|
|
5789
|
+
case mode_t::hex_k: return format_hex_(ctx.out(), bits, hex_width_, spec_.prefix_, spec_.upper_);
|
|
5790
|
+
case mode_t::binary_k: return format_bin_(ctx.out(), bits, bin_width_, spec_.prefix_);
|
|
5791
|
+
default: {
|
|
5792
|
+
auto out = spec_.float_fmt_.format(v.to_f32(), ctx);
|
|
5793
|
+
if (spec_.annotate_) out = format_hex_suffix_(out, bits, hex_width_);
|
|
5794
|
+
return out;
|
|
5795
|
+
}
|
|
5796
|
+
}
|
|
5797
|
+
}
|
|
5798
|
+
};
|
|
5600
5799
|
|
|
5601
5800
|
} // namespace ashvardanian::numkong
|
|
5602
5801
|
|
|
5802
|
+
template <>
|
|
5803
|
+
struct std::formatter<ashvardanian::numkong::f16_t>
|
|
5804
|
+
: ashvardanian::numkong::float_scalar_formatter_<ashvardanian::numkong::f16_t, 4, 16> {};
|
|
5805
|
+
template <>
|
|
5806
|
+
struct std::formatter<ashvardanian::numkong::bf16_t>
|
|
5807
|
+
: ashvardanian::numkong::float_scalar_formatter_<ashvardanian::numkong::bf16_t, 4, 16> {};
|
|
5808
|
+
template <>
|
|
5809
|
+
struct std::formatter<ashvardanian::numkong::e4m3_t>
|
|
5810
|
+
: ashvardanian::numkong::float_scalar_formatter_<ashvardanian::numkong::e4m3_t, 2, 8> {};
|
|
5811
|
+
template <>
|
|
5812
|
+
struct std::formatter<ashvardanian::numkong::e5m2_t>
|
|
5813
|
+
: ashvardanian::numkong::float_scalar_formatter_<ashvardanian::numkong::e5m2_t, 2, 8> {};
|
|
5814
|
+
template <>
|
|
5815
|
+
struct std::formatter<ashvardanian::numkong::e2m3_t>
|
|
5816
|
+
: ashvardanian::numkong::float_scalar_formatter_<ashvardanian::numkong::e2m3_t, 2, 8> {};
|
|
5817
|
+
template <>
|
|
5818
|
+
struct std::formatter<ashvardanian::numkong::e3m2_t>
|
|
5819
|
+
: ashvardanian::numkong::float_scalar_formatter_<ashvardanian::numkong::e3m2_t, 2, 8> {};
|
|
5820
|
+
|
|
5821
|
+
template <>
|
|
5822
|
+
struct std::formatter<ashvardanian::numkong::i4x2_t> {
|
|
5823
|
+
ashvardanian::numkong::scalar_format_spec_t spec_;
|
|
5824
|
+
|
|
5825
|
+
constexpr std::format_parse_context::iterator parse(std::format_parse_context &ctx) { return spec_.parse(ctx); }
|
|
5826
|
+
|
|
5827
|
+
std::format_context::iterator format(ashvardanian::numkong::i4x2_t v, std::format_context &ctx) const {
|
|
5828
|
+
namespace nk = ashvardanian::numkong;
|
|
5829
|
+
using mode_t = nk::scalar_format_spec_t::mode_t;
|
|
5830
|
+
unsigned bits = static_cast<unsigned>(v.raw_);
|
|
5831
|
+
switch (spec_.mode_) {
|
|
5832
|
+
case mode_t::hex_k: return nk::format_hex_(ctx.out(), bits, 2, spec_.prefix_, spec_.upper_);
|
|
5833
|
+
case mode_t::binary_k: return nk::format_bin_(ctx.out(), bits, 8, spec_.prefix_);
|
|
5834
|
+
default: {
|
|
5835
|
+
auto out = std::format_to(ctx.out(), "({}, {})", static_cast<int>(v.low().raw()),
|
|
5836
|
+
static_cast<int>(v.high().raw()));
|
|
5837
|
+
if (spec_.annotate_) out = nk::format_hex_suffix_(out, bits, 2);
|
|
5838
|
+
return out;
|
|
5839
|
+
}
|
|
5840
|
+
}
|
|
5841
|
+
}
|
|
5842
|
+
};
|
|
5843
|
+
|
|
5844
|
+
template <>
|
|
5845
|
+
struct std::formatter<ashvardanian::numkong::u4x2_t> {
|
|
5846
|
+
ashvardanian::numkong::scalar_format_spec_t spec_;
|
|
5847
|
+
|
|
5848
|
+
constexpr std::format_parse_context::iterator parse(std::format_parse_context &ctx) { return spec_.parse(ctx); }
|
|
5849
|
+
|
|
5850
|
+
std::format_context::iterator format(ashvardanian::numkong::u4x2_t v, std::format_context &ctx) const {
|
|
5851
|
+
namespace nk = ashvardanian::numkong;
|
|
5852
|
+
using mode_t = nk::scalar_format_spec_t::mode_t;
|
|
5853
|
+
unsigned bits = static_cast<unsigned>(v.raw_);
|
|
5854
|
+
switch (spec_.mode_) {
|
|
5855
|
+
case mode_t::hex_k: return nk::format_hex_(ctx.out(), bits, 2, spec_.prefix_, spec_.upper_);
|
|
5856
|
+
case mode_t::binary_k: return nk::format_bin_(ctx.out(), bits, 8, spec_.prefix_);
|
|
5857
|
+
default: {
|
|
5858
|
+
auto out = std::format_to(ctx.out(), "({}, {})", static_cast<unsigned>(v.low().raw()),
|
|
5859
|
+
static_cast<unsigned>(v.high().raw()));
|
|
5860
|
+
if (spec_.annotate_) out = nk::format_hex_suffix_(out, bits, 2);
|
|
5861
|
+
return out;
|
|
5862
|
+
}
|
|
5863
|
+
}
|
|
5864
|
+
}
|
|
5865
|
+
};
|
|
5866
|
+
|
|
5867
|
+
template <>
|
|
5868
|
+
struct std::formatter<ashvardanian::numkong::u1x8_t> {
|
|
5869
|
+
ashvardanian::numkong::scalar_format_spec_t spec_;
|
|
5870
|
+
|
|
5871
|
+
constexpr std::format_parse_context::iterator parse(std::format_parse_context &ctx) { return spec_.parse(ctx); }
|
|
5872
|
+
|
|
5873
|
+
std::format_context::iterator format(ashvardanian::numkong::u1x8_t v, std::format_context &ctx) const {
|
|
5874
|
+
namespace nk = ashvardanian::numkong;
|
|
5875
|
+
using mode_t = nk::scalar_format_spec_t::mode_t;
|
|
5876
|
+
unsigned bits = static_cast<unsigned>(v.raw_);
|
|
5877
|
+
switch (spec_.mode_) {
|
|
5878
|
+
case mode_t::hex_k: return nk::format_hex_(ctx.out(), bits, 2, spec_.prefix_, spec_.upper_);
|
|
5879
|
+
case mode_t::binary_k: return nk::format_bin_(ctx.out(), bits, 8, spec_.prefix_);
|
|
5880
|
+
default: {
|
|
5881
|
+
auto out = std::format_to(ctx.out(), "0b{:08b}", bits);
|
|
5882
|
+
if (spec_.annotate_) out = nk::format_hex_suffix_(out, bits, 2);
|
|
5883
|
+
return out;
|
|
5884
|
+
}
|
|
5885
|
+
}
|
|
5886
|
+
}
|
|
5887
|
+
};
|
|
5888
|
+
|
|
5889
|
+
/**
|
|
5890
|
+
* @brief Formatter for a single signed nibble (-8..7). Supports `{}`, `{:#}`, `{:x}`, `{:b}`.
|
|
5891
|
+
* Float-precision specs (e.g. `{:.2f}`) are not meaningful and ignored.
|
|
5892
|
+
*/
|
|
5893
|
+
template <>
|
|
5894
|
+
struct std::formatter<ashvardanian::numkong::sub_byte_ref<ashvardanian::numkong::i4x2_t>> {
|
|
5895
|
+
ashvardanian::numkong::scalar_format_spec_t spec_;
|
|
5896
|
+
|
|
5897
|
+
constexpr std::format_parse_context::iterator parse(std::format_parse_context &ctx) { return spec_.parse(ctx); }
|
|
5898
|
+
|
|
5899
|
+
std::format_context::iterator format(ashvardanian::numkong::sub_byte_ref<ashvardanian::numkong::i4x2_t> v,
|
|
5900
|
+
std::format_context &ctx) const {
|
|
5901
|
+
namespace nk = ashvardanian::numkong;
|
|
5902
|
+
using mode_t = nk::scalar_format_spec_t::mode_t;
|
|
5903
|
+
unsigned bits = static_cast<unsigned>(v.get()) & 0x0Fu;
|
|
5904
|
+
switch (spec_.mode_) {
|
|
5905
|
+
case mode_t::hex_k: return nk::format_hex_(ctx.out(), bits, 1, spec_.prefix_, spec_.upper_);
|
|
5906
|
+
case mode_t::binary_k: return nk::format_bin_(ctx.out(), bits, 4, spec_.prefix_);
|
|
5907
|
+
default: {
|
|
5908
|
+
auto out = std::format_to(ctx.out(), "{}", static_cast<int>(v.get()));
|
|
5909
|
+
if (spec_.annotate_) out = nk::format_hex_suffix_(out, bits, 1);
|
|
5910
|
+
return out;
|
|
5911
|
+
}
|
|
5912
|
+
}
|
|
5913
|
+
}
|
|
5914
|
+
};
|
|
5915
|
+
|
|
5916
|
+
/**
|
|
5917
|
+
* @brief Formatter for a single unsigned nibble (0..15). Supports `{}`, `{:#}`, `{:x}`, `{:b}`.
|
|
5918
|
+
* Float-precision specs (e.g. `{:.2f}`) are not meaningful and ignored.
|
|
5919
|
+
*/
|
|
5920
|
+
template <>
|
|
5921
|
+
struct std::formatter<ashvardanian::numkong::sub_byte_ref<ashvardanian::numkong::u4x2_t>> {
|
|
5922
|
+
ashvardanian::numkong::scalar_format_spec_t spec_;
|
|
5923
|
+
|
|
5924
|
+
constexpr std::format_parse_context::iterator parse(std::format_parse_context &ctx) { return spec_.parse(ctx); }
|
|
5925
|
+
|
|
5926
|
+
std::format_context::iterator format(ashvardanian::numkong::sub_byte_ref<ashvardanian::numkong::u4x2_t> v,
|
|
5927
|
+
std::format_context &ctx) const {
|
|
5928
|
+
namespace nk = ashvardanian::numkong;
|
|
5929
|
+
using mode_t = nk::scalar_format_spec_t::mode_t;
|
|
5930
|
+
unsigned bits = static_cast<unsigned>(v.get());
|
|
5931
|
+
switch (spec_.mode_) {
|
|
5932
|
+
case mode_t::hex_k: return nk::format_hex_(ctx.out(), bits, 1, spec_.prefix_, spec_.upper_);
|
|
5933
|
+
case mode_t::binary_k: return nk::format_bin_(ctx.out(), bits, 4, spec_.prefix_);
|
|
5934
|
+
default: {
|
|
5935
|
+
auto out = std::format_to(ctx.out(), "{}", static_cast<unsigned>(v.get()));
|
|
5936
|
+
if (spec_.annotate_) out = nk::format_hex_suffix_(out, bits, 1);
|
|
5937
|
+
return out;
|
|
5938
|
+
}
|
|
5939
|
+
}
|
|
5940
|
+
}
|
|
5941
|
+
};
|
|
5942
|
+
|
|
5943
|
+
/** @brief Formatter for a single bit. Only `{}` is supported — hex and binary are not meaningful. */
|
|
5944
|
+
template <>
|
|
5945
|
+
struct std::formatter<ashvardanian::numkong::sub_byte_ref<ashvardanian::numkong::u1x8_t>> {
|
|
5946
|
+
constexpr std::format_parse_context::iterator parse(std::format_parse_context &ctx) { return ctx.begin(); }
|
|
5947
|
+
|
|
5948
|
+
std::format_context::iterator format(ashvardanian::numkong::sub_byte_ref<ashvardanian::numkong::u1x8_t> v,
|
|
5949
|
+
std::format_context &ctx) const {
|
|
5950
|
+
return std::format_to(ctx.out(), "{}", v.get() ? 1u : 0u);
|
|
5951
|
+
}
|
|
5952
|
+
};
|
|
5953
|
+
|
|
5954
|
+
template <>
|
|
5955
|
+
struct std::formatter<ashvardanian::numkong::f16c_t> {
|
|
5956
|
+
ashvardanian::numkong::scalar_format_spec_t spec_;
|
|
5957
|
+
|
|
5958
|
+
constexpr std::format_parse_context::iterator parse(std::format_parse_context &ctx) { return spec_.parse(ctx); }
|
|
5959
|
+
|
|
5960
|
+
std::format_context::iterator format(ashvardanian::numkong::f16c_t v, std::format_context &ctx) const {
|
|
5961
|
+
namespace nk = ashvardanian::numkong;
|
|
5962
|
+
using mode_t = nk::scalar_format_spec_t::mode_t;
|
|
5963
|
+
unsigned re_bits = static_cast<unsigned>(v.real().to_bits());
|
|
5964
|
+
unsigned im_bits = static_cast<unsigned>(v.imag().to_bits());
|
|
5965
|
+
switch (spec_.mode_) {
|
|
5966
|
+
case mode_t::hex_k: {
|
|
5967
|
+
auto out = std::format_to(ctx.out(), "(");
|
|
5968
|
+
out = nk::format_hex_(out, re_bits, 4, spec_.prefix_, spec_.upper_);
|
|
5969
|
+
out = std::format_to(out, ", ");
|
|
5970
|
+
out = nk::format_hex_(out, im_bits, 4, spec_.prefix_, spec_.upper_);
|
|
5971
|
+
return std::format_to(out, ")");
|
|
5972
|
+
}
|
|
5973
|
+
case mode_t::binary_k: {
|
|
5974
|
+
auto out = std::format_to(ctx.out(), "(");
|
|
5975
|
+
out = nk::format_bin_(out, re_bits, 16, spec_.prefix_);
|
|
5976
|
+
out = std::format_to(out, ", ");
|
|
5977
|
+
out = nk::format_bin_(out, im_bits, 16, spec_.prefix_);
|
|
5978
|
+
return std::format_to(out, ")");
|
|
5979
|
+
}
|
|
5980
|
+
default:
|
|
5981
|
+
if (spec_.annotate_)
|
|
5982
|
+
return std::format_to(ctx.out(), "({} [0x{:04x}], {} [0x{:04x}])", v.real().to_f32(), re_bits,
|
|
5983
|
+
v.imag().to_f32(), im_bits);
|
|
5984
|
+
return std::format_to(ctx.out(), "({}, {})", v.real().to_f32(), v.imag().to_f32());
|
|
5985
|
+
}
|
|
5986
|
+
}
|
|
5987
|
+
};
|
|
5988
|
+
|
|
5989
|
+
template <>
|
|
5990
|
+
struct std::formatter<ashvardanian::numkong::bf16c_t> {
|
|
5991
|
+
ashvardanian::numkong::scalar_format_spec_t spec_;
|
|
5992
|
+
|
|
5993
|
+
constexpr std::format_parse_context::iterator parse(std::format_parse_context &ctx) { return spec_.parse(ctx); }
|
|
5994
|
+
|
|
5995
|
+
std::format_context::iterator format(ashvardanian::numkong::bf16c_t v, std::format_context &ctx) const {
|
|
5996
|
+
namespace nk = ashvardanian::numkong;
|
|
5997
|
+
using mode_t = nk::scalar_format_spec_t::mode_t;
|
|
5998
|
+
unsigned re_bits = static_cast<unsigned>(v.real().to_bits());
|
|
5999
|
+
unsigned im_bits = static_cast<unsigned>(v.imag().to_bits());
|
|
6000
|
+
switch (spec_.mode_) {
|
|
6001
|
+
case mode_t::hex_k: {
|
|
6002
|
+
auto out = std::format_to(ctx.out(), "(");
|
|
6003
|
+
out = nk::format_hex_(out, re_bits, 4, spec_.prefix_, spec_.upper_);
|
|
6004
|
+
out = std::format_to(out, ", ");
|
|
6005
|
+
out = nk::format_hex_(out, im_bits, 4, spec_.prefix_, spec_.upper_);
|
|
6006
|
+
return std::format_to(out, ")");
|
|
6007
|
+
}
|
|
6008
|
+
case mode_t::binary_k: {
|
|
6009
|
+
auto out = std::format_to(ctx.out(), "(");
|
|
6010
|
+
out = nk::format_bin_(out, re_bits, 16, spec_.prefix_);
|
|
6011
|
+
out = std::format_to(out, ", ");
|
|
6012
|
+
out = nk::format_bin_(out, im_bits, 16, spec_.prefix_);
|
|
6013
|
+
return std::format_to(out, ")");
|
|
6014
|
+
}
|
|
6015
|
+
default:
|
|
6016
|
+
if (spec_.annotate_)
|
|
6017
|
+
return std::format_to(ctx.out(), "({} [0x{:04x}], {} [0x{:04x}])", v.real().to_f32(), re_bits,
|
|
6018
|
+
v.imag().to_f32(), im_bits);
|
|
6019
|
+
return std::format_to(ctx.out(), "({}, {})", v.real().to_f32(), v.imag().to_f32());
|
|
6020
|
+
}
|
|
6021
|
+
}
|
|
6022
|
+
};
|
|
6023
|
+
|
|
6024
|
+
#endif // __cpp_lib_format
|
|
6025
|
+
#endif // __has_include(<format>)
|
|
6026
|
+
|
|
5603
6027
|
#endif // NK_TYPES_HPP
|