numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -6,17 +6,17 @@ All conversions use round-to-nearest-even (RNE) for narrowing and exact widening
|
|
|
6
6
|
|
|
7
7
|
BFloat16 relates to Float32 by truncation with rounding:
|
|
8
8
|
|
|
9
|
-
|
|
9
|
+
$$
|
|
10
10
|
\text{bf16} \approx \text{f32} \gg 16
|
|
11
|
-
|
|
11
|
+
$$
|
|
12
12
|
|
|
13
13
|
With RNE tie-breaking to preserve the least significant bit of the truncated result.
|
|
14
14
|
|
|
15
15
|
Float16 range and precision:
|
|
16
16
|
|
|
17
|
-
|
|
17
|
+
$$
|
|
18
18
|
\text{f16} \in [-65504, 65504], \quad \text{min positive normal} = 2^{-14}
|
|
19
|
-
|
|
19
|
+
$$
|
|
20
20
|
|
|
21
21
|
Reformulating as Python pseudocode:
|
|
22
22
|
|
|
@@ -194,69 +194,69 @@ Measured with Wasmtime v42 (Cranelift backend).
|
|
|
194
194
|
| __f32 ↔ e4m3__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░ |
|
|
195
195
|
| `nk_cast_serial` | ? gb/s | ? gb/s | 0.239 gb/s | ? gb/s | ? gb/s | 0.746 gb/s |
|
|
196
196
|
|
|
197
|
-
### Apple
|
|
197
|
+
### Apple M5
|
|
198
198
|
|
|
199
199
|
#### Native
|
|
200
200
|
|
|
201
201
|
| Kernel | ↓ 256 | ↓ 1K | ↓ 4K | ↑ 256 | ↑ 1K | ↑ 4K |
|
|
202
202
|
| :--------------- | -----------: | -----------: | -----------: | -----------: | -----------: | -----------: |
|
|
203
203
|
| __f32 ↔ bf16__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
|
|
204
|
-
| `nk_cast_serial` |
|
|
205
|
-
| `nk_cast_neon` |
|
|
204
|
+
| `nk_cast_serial` | 1.37 gb/s | 1.35 gb/s | 1.41 gb/s | 1.37 gb/s | 1.34 gb/s | 1.38 gb/s |
|
|
205
|
+
| `nk_cast_neon` | 19.3 gb/s | 23.7 gb/s | 23.2 gb/s | 59.4 gb/s | 58.9 gb/s | 57.3 gb/s |
|
|
206
206
|
| __f32 ↔ f16__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
|
|
207
|
-
| `nk_cast_serial` |
|
|
208
|
-
| `nk_cast_neon` |
|
|
207
|
+
| `nk_cast_serial` | 1.37 gb/s | 1.31 gb/s | 1.32 gb/s | 1.37 gb/s | 1.31 gb/s | 1.40 gb/s |
|
|
208
|
+
| `nk_cast_neon` | 20.1 gb/s | 21.9 gb/s | 25.0 gb/s | 52.1 gb/s | 60.2 gb/s | 70.2 gb/s |
|
|
209
209
|
| __f32 ↔ e5m2__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
|
|
210
|
-
| `nk_cast_serial` |
|
|
211
|
-
| `nk_cast_neon` |
|
|
210
|
+
| `nk_cast_serial` | 0.681 gb/s | 0.621 gb/s | 0.600 gb/s | 1.17 gb/s | 1.17 gb/s | 1.23 gb/s |
|
|
211
|
+
| `nk_cast_neon` | 8.50 gb/s | 8.45 gb/s | 8.35 gb/s | 40.6 gb/s | 46.5 gb/s | 46.5 gb/s |
|
|
212
212
|
| __f32 ↔ e4m3__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
|
|
213
|
-
| `nk_cast_serial` |
|
|
214
|
-
| `nk_cast_neon` |
|
|
213
|
+
| `nk_cast_serial` | 0.683 gb/s | 0.618 gb/s | 0.586 gb/s | 1.02 gb/s | 1.01 gb/s | 1.02 gb/s |
|
|
214
|
+
| `nk_cast_neon` | 7.85 gb/s | 7.91 gb/s | 7.66 gb/s | 18.9 gb/s | 19.2 gb/s | 18.3 gb/s |
|
|
215
215
|
| __f32 ↔ e3m2__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
|
|
216
|
-
| `nk_cast_serial` |
|
|
217
|
-
| `nk_cast_neon` |
|
|
216
|
+
| `nk_cast_serial` | 0.702 gb/s | 0.632 gb/s | 0.596 gb/s | 1.17 gb/s | 1.13 gb/s | 1.15 gb/s |
|
|
217
|
+
| `nk_cast_neon` | 8.94 gb/s | 9.02 gb/s | 8.91 gb/s | 24.9 gb/s | 25.0 gb/s | 24.4 gb/s |
|
|
218
218
|
| __f32 ↔ e2m3__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
|
|
219
|
-
| `nk_cast_serial` |
|
|
220
|
-
| `nk_cast_neon` |
|
|
219
|
+
| `nk_cast_serial` | 0.921 gb/s | 0.843 gb/s | 0.715 gb/s | 1.21 gb/s | 1.21 gb/s | 1.26 gb/s |
|
|
220
|
+
| `nk_cast_neon` | 8.89 gb/s | 9.03 gb/s | 8.82 gb/s | 24.9 gb/s | 25.1 gb/s | 24.6 gb/s |
|
|
221
221
|
| __f32 ↔ i16__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
|
|
222
|
-
| `nk_cast_serial` |
|
|
223
|
-
| `nk_cast_neon` |
|
|
222
|
+
| `nk_cast_serial` | 0.785 gb/s | 0.679 gb/s | 0.678 gb/s | 1.44 gb/s | 1.39 gb/s | 1.49 gb/s |
|
|
223
|
+
| `nk_cast_neon` | 19.4 gb/s | 22.6 gb/s | 23.9 gb/s | 19.9 gb/s | 23.2 gb/s | 25.9 gb/s |
|
|
224
224
|
| __f32 ↔ u16__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
|
|
225
|
-
| `nk_cast_serial` |
|
|
226
|
-
| `nk_cast_neon` |
|
|
225
|
+
| `nk_cast_serial` | 0.916 gb/s | 0.822 gb/s | 0.726 gb/s | 1.37 gb/s | 1.36 gb/s | 1.48 gb/s |
|
|
226
|
+
| `nk_cast_neon` | 20.3 gb/s | 20.6 gb/s | 22.1 gb/s | 15.6 gb/s | 18.5 gb/s | 17.4 gb/s |
|
|
227
227
|
| __f32 ↔ i8__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
|
|
228
|
-
| `nk_cast_serial` |
|
|
229
|
-
| `nk_cast_neon` |
|
|
228
|
+
| `nk_cast_serial` | 0.725 gb/s | 0.616 gb/s | 0.578 gb/s | 1.21 gb/s | 1.21 gb/s | 1.28 gb/s |
|
|
229
|
+
| `nk_cast_neon` | 18.2 gb/s | 24.5 gb/s | 21.7 gb/s | 16.3 gb/s | 18.9 gb/s | 19.8 gb/s |
|
|
230
230
|
| __f32 ↔ u8__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
|
|
231
|
-
| `nk_cast_serial` |
|
|
232
|
-
| `nk_cast_neon` |
|
|
231
|
+
| `nk_cast_serial` | 0.967 gb/s | 0.795 gb/s | 0.723 gb/s | 1.29 gb/s | 1.25 gb/s | 1.40 gb/s |
|
|
232
|
+
| `nk_cast_neon` | 17.5 gb/s | 19.8 gb/s | 19.4 gb/s | 13.8 gb/s | 17.8 gb/s | 15.1 gb/s |
|
|
233
233
|
| __f64 ↔ f32__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
|
|
234
|
-
| `nk_cast_serial` |
|
|
235
|
-
| `nk_cast_neon` |
|
|
234
|
+
| `nk_cast_serial` | 2.65 gb/s | 2.60 gb/s | 2.70 gb/s | 2.59 gb/s | 2.55 gb/s | 2.65 gb/s |
|
|
235
|
+
| `nk_cast_neon` | 2.87 gb/s | 2.60 gb/s | 2.73 gb/s | 2.64 gb/s | 2.63 gb/s | 2.57 gb/s |
|
|
236
236
|
| __f64 ↔ i64__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
|
|
237
|
-
| `nk_cast_serial` |
|
|
238
|
-
| `nk_cast_neon` |
|
|
237
|
+
| `nk_cast_serial` | 2.42 gb/s | 2.00 gb/s | 1.86 gb/s | 3.79 gb/s | 3.61 gb/s | 4.03 gb/s |
|
|
238
|
+
| `nk_cast_neon` | 2.51 gb/s | 1.94 gb/s | 1.78 gb/s | 3.83 gb/s | 3.68 gb/s | 3.79 gb/s |
|
|
239
239
|
| __f64 ↔ u64__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
|
|
240
|
-
| `nk_cast_serial` |
|
|
241
|
-
| `nk_cast_neon` |
|
|
240
|
+
| `nk_cast_serial` | 2.56 gb/s | 2.19 gb/s | 2.06 gb/s | 3.71 gb/s | 3.50 gb/s | 3.87 gb/s |
|
|
241
|
+
| `nk_cast_neon` | 2.68 gb/s | 2.10 gb/s | 1.97 gb/s | 3.68 gb/s | 3.61 gb/s | 3.58 gb/s |
|
|
242
242
|
| __f64 ↔ i32__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
|
|
243
|
-
| `nk_cast_serial` |
|
|
244
|
-
| `nk_cast_neon` |
|
|
243
|
+
| `nk_cast_serial` | 1.58 gb/s | 1.32 gb/s | 1.29 gb/s | 2.65 gb/s | 2.58 gb/s | 2.84 gb/s |
|
|
244
|
+
| `nk_cast_neon` | 1.61 gb/s | 1.33 gb/s | 1.24 gb/s | 2.73 gb/s | 2.63 gb/s | 2.66 gb/s |
|
|
245
245
|
| __f64 ↔ u32__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
|
|
246
|
-
| `nk_cast_serial` |
|
|
247
|
-
| `nk_cast_neon` |
|
|
246
|
+
| `nk_cast_serial` | 1.83 gb/s | 1.53 gb/s | 1.47 gb/s | 2.55 gb/s | 2.48 gb/s | 2.69 gb/s |
|
|
247
|
+
| `nk_cast_neon` | 1.89 gb/s | 1.53 gb/s | 1.38 gb/s | 2.56 gb/s | 2.54 gb/s | 2.59 gb/s |
|
|
248
248
|
|
|
249
249
|
#### WASM
|
|
250
250
|
|
|
251
|
-
Measured with Wasmtime
|
|
251
|
+
Measured with Wasmtime v43 (Cranelift backend).
|
|
252
252
|
|
|
253
253
|
| Kernel | ↓ 256 | ↓ 1K | ↓ 4K | ↑ 256 | ↑ 1K | ↑ 4K |
|
|
254
254
|
| :--------------- | -----------: | -----------: | -----------: | -----------: | -----------: | -----------: |
|
|
255
255
|
| __f32 ↔ bf16__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
|
|
256
|
-
| `nk_cast_serial` |
|
|
256
|
+
| `nk_cast_serial` | 0.514 gb/s | 0.522 gb/s | 0.538 gb/s | 0.511 gb/s | 0.526 gb/s | 0.519 gb/s |
|
|
257
257
|
| __f32 ↔ f16__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
|
|
258
|
-
| `nk_cast_serial` | 0.
|
|
258
|
+
| `nk_cast_serial` | 0.368 gb/s | 0.363 gb/s | 0.360 gb/s | 0.490 gb/s | 0.480 gb/s | 0.489 gb/s |
|
|
259
259
|
| __f32 ↔ e5m2__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
|
|
260
|
-
| `nk_cast_serial` | 0.
|
|
260
|
+
| `nk_cast_serial` | 0.323 gb/s | 0.312 gb/s | 0.304 gb/s | 0.423 gb/s | 0.425 gb/s | 0.425 gb/s |
|
|
261
261
|
| __f32 ↔ e4m3__ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ | ░░░░░░░░░░░░ |
|
|
262
|
-
| `nk_cast_serial` | 0.
|
|
262
|
+
| `nk_cast_serial` | 0.315 gb/s | 0.304 gb/s | 0.295 gb/s | 0.396 gb/s | 0.396 gb/s | 0.397 gb/s |
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Type Conversions for Diamond Rapids.
|
|
3
|
+
* @file include/numkong/cast/diamond.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 23, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/cast/icelake.h
|
|
8
|
+
*
|
|
9
|
+
* Uses VCVTHF82PH (E4M3→FP16) and VCVTBF82PH (E5M2→FP16) for native 1-instruction
|
|
10
|
+
* FP8→FP16 conversion. Both conversions are exact (no rounding needed).
|
|
11
|
+
*/
|
|
12
|
+
#ifndef NK_CAST_DIAMOND_H
|
|
13
|
+
#define NK_CAST_DIAMOND_H
|
|
14
|
+
|
|
15
|
+
#if NK_TARGET_X86_
|
|
16
|
+
#if NK_TARGET_DIAMOND
|
|
17
|
+
|
|
18
|
+
#include "numkong/types.h"
|
|
19
|
+
|
|
20
|
+
#if defined(__cplusplus)
|
|
21
|
+
extern "C" {
|
|
22
|
+
#endif
|
|
23
|
+
|
|
24
|
+
#if defined(__clang__)
|
|
25
|
+
#pragma clang attribute push( \
|
|
26
|
+
__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,avx10.2-512,f16c,fma,bmi,bmi2"))), \
|
|
27
|
+
apply_to = function)
|
|
28
|
+
#elif defined(__GNUC__)
|
|
29
|
+
#pragma GCC push_options
|
|
30
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "avx10.2-512", "f16c", "fma", \
|
|
31
|
+
"bmi", "bmi2")
|
|
32
|
+
#endif
|
|
33
|
+
|
|
34
|
+
NK_INTERNAL void nk_load_e4m3x32_to_f16x32_diamond_(nk_e4m3_t const *src, nk_b512_vec_t *dst) {
|
|
35
|
+
dst->zmm_ph = _mm512_cvthf8_ph(_mm256_loadu_epi8(src));
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
NK_INTERNAL void nk_partial_load_e4m3x32_to_f16x32_diamond_(nk_e4m3_t const *src, nk_b512_vec_t *dst, nk_size_t count) {
|
|
39
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count);
|
|
40
|
+
dst->zmm_ph = _mm512_cvthf8_ph(_mm256_maskz_loadu_epi8(mask, src));
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
NK_INTERNAL void nk_load_e5m2x32_to_f16x32_diamond_(nk_e5m2_t const *src, nk_b512_vec_t *dst) {
|
|
44
|
+
dst->zmm_ph = _mm512_cvtbf8_ph(_mm256_loadu_epi8(src));
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
NK_INTERNAL void nk_partial_load_e5m2x32_to_f16x32_diamond_(nk_e5m2_t const *src, nk_b512_vec_t *dst, nk_size_t count) {
|
|
48
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, count);
|
|
49
|
+
dst->zmm_ph = _mm512_cvtbf8_ph(_mm256_maskz_loadu_epi8(mask, src));
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
#if defined(__clang__)
|
|
53
|
+
#pragma clang attribute pop
|
|
54
|
+
#elif defined(__GNUC__)
|
|
55
|
+
#pragma GCC pop_options
|
|
56
|
+
#endif
|
|
57
|
+
|
|
58
|
+
#if defined(__cplusplus)
|
|
59
|
+
} // extern "C"
|
|
60
|
+
#endif
|
|
61
|
+
|
|
62
|
+
#endif // NK_TARGET_DIAMOND
|
|
63
|
+
#endif // NK_TARGET_X86_
|
|
64
|
+
#endif // NK_CAST_DIAMOND_H
|
|
@@ -6,12 +6,12 @@
|
|
|
6
6
|
*
|
|
7
7
|
* @section haswell_cast_instructions Key F16C/AVX2 Conversion Instructions
|
|
8
8
|
*
|
|
9
|
-
* Intrinsic
|
|
10
|
-
* _mm256_cvtph_ps
|
|
11
|
-
* _mm256_cvtps_ph
|
|
12
|
-
* _mm256_cvtepi16_epi32
|
|
13
|
-
* _mm256_slli_epi32
|
|
14
|
-
* _mm256_blendv_ps
|
|
9
|
+
* Intrinsic Instruction Haswell Genoa
|
|
10
|
+
* _mm256_cvtph_ps VCVTPH2PS (YMM, XMM) 5cy @ p01 4cy @ p12+p23
|
|
11
|
+
* _mm256_cvtps_ph VCVTPS2PH (XMM, YMM, I8) 5cy @ p01 4cy @ p12+p23
|
|
12
|
+
* _mm256_cvtepi16_epi32 VPMOVSXWD (YMM, XMM) 1cy @ p5 2cy @ p12
|
|
13
|
+
* _mm256_slli_epi32 VPSLLD (YMM, YMM, I8) 1cy @ p0 1cy @ p23
|
|
14
|
+
* _mm256_blendv_ps VBLENDVPS (YMM, YMM, YMM, YMM) 2cy @ p015 1cy @ p01
|
|
15
15
|
*
|
|
16
16
|
* F16C provides hardware F16<->F32 conversion. BF16 lacks hardware support and is emulated via
|
|
17
17
|
* bit manipulation (shift upper 16 bits). FP8 formats (E4M3/E5M2) use lookup tables for subnormal
|
|
@@ -38,14 +38,14 @@ extern "C" {
|
|
|
38
38
|
#endif
|
|
39
39
|
|
|
40
40
|
NK_PUBLIC void nk_f32_to_f16_haswell(nk_f32_t const *from, nk_f16_t *to) {
|
|
41
|
-
*to = _mm_cvtsi128_si32(_mm_cvtps_ph(_mm_set_ss(*from), _MM_FROUND_TO_NEAREST_INT));
|
|
41
|
+
*(nk_u16_t *)to = (nk_u16_t)_mm_cvtsi128_si32(_mm_cvtps_ph(_mm_set_ss(*from), _MM_FROUND_TO_NEAREST_INT));
|
|
42
42
|
}
|
|
43
43
|
|
|
44
44
|
NK_PUBLIC void nk_f16_to_f32_haswell(nk_f16_t const *from, nk_f32_t *to) {
|
|
45
|
-
*to = _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(*from)));
|
|
45
|
+
*to = _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(*(nk_u16_t const *)from)));
|
|
46
46
|
}
|
|
47
47
|
|
|
48
|
-
#pragma region
|
|
48
|
+
#pragma region Type Punned Loads and Stores
|
|
49
49
|
|
|
50
50
|
/** @brief Type-agnostic 256-bit full load (Haswell AVX2). */
|
|
51
51
|
NK_INTERNAL void nk_load_b256_haswell_(void const *src, nk_b256_vec_t *dst) {
|
|
@@ -99,9 +99,9 @@ NK_INTERNAL void nk_partial_store_b64x4_haswell_(nk_b256_vec_t const *src, void
|
|
|
99
99
|
_mm256_maskstore_pd((double *)dst, mask_i64x4, _mm256_castsi256_pd(src->ymm));
|
|
100
100
|
}
|
|
101
101
|
|
|
102
|
-
#pragma endregion
|
|
102
|
+
#pragma endregion Type Punned Loads and Stores
|
|
103
103
|
|
|
104
|
-
#pragma region
|
|
104
|
+
#pragma region Vectorized Conversions
|
|
105
105
|
|
|
106
106
|
/** @brief Convert 8x bf16 → 8x f32 by shifting left 16 bits (AVX2). */
|
|
107
107
|
NK_INTERNAL __m256 nk_bf16x8_to_f32x8_haswell_(__m128i bf16_i16x8) {
|
|
@@ -116,9 +116,9 @@ NK_INTERNAL __m128i nk_f32x8_to_bf16x8_haswell_(__m256 f32x8) {
|
|
|
116
116
|
__m256i rounded_i32x8 = _mm256_add_epi32(bits_i32x8, _mm256_add_epi32(_mm256_set1_epi32(0x7FFF), lsb_i32x8));
|
|
117
117
|
__m256i bf16_i32x8 = _mm256_srli_epi32(rounded_i32x8, 16);
|
|
118
118
|
// Pack 8x i32 to 8x i16
|
|
119
|
-
__m128i
|
|
120
|
-
__m128i
|
|
121
|
-
return _mm_packus_epi32(
|
|
119
|
+
__m128i low_i32x4 = _mm256_castsi256_si128(bf16_i32x8);
|
|
120
|
+
__m128i high_i32x4 = _mm256_extracti128_si256(bf16_i32x8, 1);
|
|
121
|
+
return _mm_packus_epi32(low_i32x4, high_i32x4);
|
|
122
122
|
}
|
|
123
123
|
|
|
124
124
|
/** @brief Integer upcasts to f32x8 (AVX2). */
|
|
@@ -132,10 +132,10 @@ NK_INTERNAL __m256 nk_u16x8_to_f32x8_haswell_(__m128i u16x8) {
|
|
|
132
132
|
}
|
|
133
133
|
NK_INTERNAL __m256 nk_i32x8_to_f32x8_haswell_(__m256i i32x8) { return _mm256_cvtepi32_ps(i32x8); }
|
|
134
134
|
NK_INTERNAL __m256 nk_u32x8_to_f32x8_haswell_(__m256i u32x8) {
|
|
135
|
-
__m256i
|
|
136
|
-
__m256i
|
|
137
|
-
return _mm256_add_ps(_mm256_cvtepi32_ps(
|
|
138
|
-
_mm256_mul_ps(_mm256_cvtepi32_ps(
|
|
135
|
+
__m256i low_i32x8 = _mm256_and_si256(u32x8, _mm256_set1_epi32(0xFFFF));
|
|
136
|
+
__m256i high_i32x8 = _mm256_srli_epi32(u32x8, 16);
|
|
137
|
+
return _mm256_add_ps(_mm256_cvtepi32_ps(low_i32x8),
|
|
138
|
+
_mm256_mul_ps(_mm256_cvtepi32_ps(high_i32x8), _mm256_set1_ps(65536.0f)));
|
|
139
139
|
}
|
|
140
140
|
|
|
141
141
|
/** @brief Saturating f32x8 downcasts to integers (AVX2). */
|
|
@@ -172,167 +172,10 @@ NK_INTERNAL __m128i nk_f32x8_to_u8x8_haswell_(__m256 f32x8) {
|
|
|
172
172
|
return _mm_packus_epi16(u16x8, _mm_setzero_si128());
|
|
173
173
|
}
|
|
174
174
|
|
|
175
|
-
/** @brief Convert 16x e4m3 → 16x bf16 via arithmetic + small LUT for subnormals (AVX2).
|
|
176
|
-
* E4M3 format: S EEEE MMM (bias=7). BF16: S EEEEEEEE MMMMMMM (bias=127).
|
|
177
|
-
* Normal values: BF16 = sign | ((lower7 << 4) + 0x3C00).
|
|
178
|
-
* Subnormals (8 values): looked up via vpshufb from an 8-entry LUT.
|
|
179
|
-
* Handles all corner cases: zero, subnormals, normals, and NaN. */
|
|
180
|
-
NK_INTERNAL __m256i nk_e4m3x16_to_bf16x16_haswell_(__m128i e4m3x16) {
|
|
181
|
-
__m256i e4m3_i16x16 = _mm256_cvtepu8_epi16(e4m3x16);
|
|
182
|
-
__m256i sign_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16((short)0x80));
|
|
183
|
-
__m256i lower7_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x7F));
|
|
184
|
-
|
|
185
|
-
// Normal path: BF16 = ((lower7 << 4) + 0x3C00) | (sign << 8)
|
|
186
|
-
__m256i normal_abs_i16x16 = _mm256_add_epi16(_mm256_slli_epi16(lower7_i16x16, 4), _mm256_set1_epi16(0x3C00));
|
|
187
|
-
sign_i16x16 = _mm256_slli_epi16(sign_i16x16, 8);
|
|
188
|
-
__m256i normal_i16x16 = _mm256_or_si256(sign_i16x16, normal_abs_i16x16);
|
|
189
|
-
|
|
190
|
-
// Subnormal LUT via shuffle_epi8 (8 entries: mantissa 0-7 → BF16)
|
|
191
|
-
// E4M3 subnormal BF16 values: 0x0000, 0x3B00, 0x3B80, 0x3BC0, 0x3C00, 0x3C20, 0x3C40, 0x3C60
|
|
192
|
-
// Split into low bytes and high bytes for reconstruction
|
|
193
|
-
__m256i const lo_lut_i8x32 = _mm256_broadcastsi128_si256(_mm_set_epi8( //
|
|
194
|
-
0x60, 0x40, 0x20, 0x00, (char)0xC0, (char)0x80, 0x00, 0x00, //
|
|
195
|
-
0x60, 0x40, 0x20, 0x00, (char)0xC0, (char)0x80, 0x00, 0x00)); //
|
|
196
|
-
__m256i const hi_lut_i8x32 = _mm256_broadcastsi128_si256(_mm_set_epi8( //
|
|
197
|
-
0x3C, 0x3C, 0x3C, 0x3C, 0x3B, 0x3B, 0x3B, 0x00, //
|
|
198
|
-
0x3C, 0x3C, 0x3C, 0x3C, 0x3B, 0x3B, 0x3B, 0x00)); //
|
|
199
|
-
|
|
200
|
-
// Extract mantissa (bits 0-2) as byte indices for shuffle
|
|
201
|
-
__m256i byte_idx_i8x32 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi8(0x07));
|
|
202
|
-
__m256i lo_bytes_i8x32 = _mm256_shuffle_epi8(lo_lut_i8x32, byte_idx_i8x32);
|
|
203
|
-
__m256i hi_bytes_i8x32 = _mm256_shuffle_epi8(hi_lut_i8x32, byte_idx_i8x32);
|
|
204
|
-
|
|
205
|
-
// Combine low and high bytes into 16-bit values
|
|
206
|
-
__m256i subnorm_abs_i16x16 = _mm256_or_si256( //
|
|
207
|
-
_mm256_and_si256(lo_bytes_i8x32, _mm256_set1_epi16(0x00FF)), //
|
|
208
|
-
_mm256_slli_epi16(hi_bytes_i8x32, 8)); //
|
|
209
|
-
__m256i subnorm_i16x16 = _mm256_or_si256(subnorm_abs_i16x16, sign_i16x16);
|
|
210
|
-
|
|
211
|
-
// Blend: if exponent == 0, use subnormal result; else use normal result
|
|
212
|
-
__m256i exp_bits_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x78));
|
|
213
|
-
__m256i is_subnormal_i16x16 = _mm256_cmpeq_epi16(exp_bits_i16x16, _mm256_setzero_si256());
|
|
214
|
-
__m256i result_i16x16 = _mm256_blendv_epi8(normal_i16x16, subnorm_i16x16, is_subnormal_i16x16);
|
|
215
|
-
|
|
216
|
-
// Handle NaN: E4M3 index 127 (0x7F) → BF16 NaN (0x7FC0)
|
|
217
|
-
__m256i is_nan_i16x16 = _mm256_cmpeq_epi16(lower7_i16x16, _mm256_set1_epi16(0x7F));
|
|
218
|
-
__m256i nan_i16x16 = _mm256_or_si256(sign_i16x16, _mm256_set1_epi16(0x7FC0));
|
|
219
|
-
return _mm256_blendv_epi8(result_i16x16, nan_i16x16, is_nan_i16x16);
|
|
220
|
-
}
|
|
221
|
-
|
|
222
|
-
/** @brief Convert 16x e5m2 → 16x bf16 via arithmetic + small LUT for subnormals (AVX2).
|
|
223
|
-
* E5M2 format: S EEEEE MM (bias=15). BF16: S EEEEEEEE MMMMMMM (bias=127).
|
|
224
|
-
* Normal values: BF16 = sign | ((lower7 << 5) + 0x3800).
|
|
225
|
-
* Subnormals (4 values): looked up via vpshufb from a 4-entry LUT.
|
|
226
|
-
* Handles all corner cases: zero, subnormals, normals, infinity, and NaN. */
|
|
227
|
-
NK_INTERNAL __m256i nk_e5m2x16_to_bf16x16_haswell_(__m128i e5m2x16) {
|
|
228
|
-
__m256i e5m2_i16x16 = _mm256_cvtepu8_epi16(e5m2x16);
|
|
229
|
-
__m256i sign_i16x16 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16((short)0x80));
|
|
230
|
-
__m256i lower7_i16x16 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16(0x7F));
|
|
231
|
-
|
|
232
|
-
// Normal path: BF16 = ((lower7 << 5) + 0x3800) | (sign << 8)
|
|
233
|
-
__m256i normal_abs_i16x16 = _mm256_add_epi16(_mm256_slli_epi16(lower7_i16x16, 5), _mm256_set1_epi16(0x3800));
|
|
234
|
-
sign_i16x16 = _mm256_slli_epi16(sign_i16x16, 8);
|
|
235
|
-
__m256i normal_i16x16 = _mm256_or_si256(sign_i16x16, normal_abs_i16x16);
|
|
236
|
-
|
|
237
|
-
// Subnormal LUT via shuffle_epi8 (4 entries: mantissa 0-3 → BF16)
|
|
238
|
-
// E5M2 subnormal BF16 values: 0x0000, 0x3780, 0x3800, 0x3840
|
|
239
|
-
__m256i const lo_lut_i8x32 = _mm256_broadcastsi128_si256(_mm_set_epi8( //
|
|
240
|
-
0x00, 0x00, 0x00, 0x00, 0x40, 0x00, (char)0x80, 0x00, //
|
|
241
|
-
0x00, 0x00, 0x00, 0x00, 0x40, 0x00, (char)0x80, 0x00)); //
|
|
242
|
-
__m256i const hi_lut_i8x32 = _mm256_broadcastsi128_si256(_mm_set_epi8( //
|
|
243
|
-
0x00, 0x00, 0x00, 0x00, 0x38, 0x38, 0x37, 0x00, //
|
|
244
|
-
0x00, 0x00, 0x00, 0x00, 0x38, 0x38, 0x37, 0x00)); //
|
|
245
|
-
|
|
246
|
-
// Extract mantissa (bits 0-1) as byte indices for shuffle
|
|
247
|
-
__m256i byte_idx_i8x32 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi8(0x03));
|
|
248
|
-
__m256i lo_bytes_i8x32 = _mm256_shuffle_epi8(lo_lut_i8x32, byte_idx_i8x32);
|
|
249
|
-
__m256i hi_bytes_i8x32 = _mm256_shuffle_epi8(hi_lut_i8x32, byte_idx_i8x32);
|
|
250
|
-
|
|
251
|
-
// Combine low and high bytes into 16-bit values
|
|
252
|
-
__m256i subnorm_abs_i16x16 = _mm256_or_si256( //
|
|
253
|
-
_mm256_and_si256(lo_bytes_i8x32, _mm256_set1_epi16(0x00FF)), //
|
|
254
|
-
_mm256_slli_epi16(hi_bytes_i8x32, 8)); //
|
|
255
|
-
__m256i subnorm_i16x16 = _mm256_or_si256(subnorm_abs_i16x16, sign_i16x16);
|
|
256
|
-
|
|
257
|
-
// Blend: if exponent == 0, use subnormal result; else use normal result
|
|
258
|
-
__m256i exp_bits_i16x16 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16(0x7C));
|
|
259
|
-
__m256i is_subnormal_i16x16 = _mm256_cmpeq_epi16(exp_bits_i16x16, _mm256_setzero_si256());
|
|
260
|
-
__m256i result_i16x16 = _mm256_blendv_epi8(normal_i16x16, subnorm_i16x16, is_subnormal_i16x16);
|
|
261
|
-
|
|
262
|
-
// Handle Inf (0x7C) and NaN (0x7D-0x7F)
|
|
263
|
-
__m256i is_inf_i16x16 = _mm256_cmpeq_epi16(lower7_i16x16, _mm256_set1_epi16(0x7C));
|
|
264
|
-
__m256i is_nan_i16x16 = _mm256_cmpgt_epi16(lower7_i16x16, _mm256_set1_epi16(0x7C));
|
|
265
|
-
__m256i inf_i16x16 = _mm256_or_si256(sign_i16x16, _mm256_set1_epi16(0x7F80));
|
|
266
|
-
__m256i nan_i16x16 = _mm256_or_si256(sign_i16x16, _mm256_set1_epi16(0x7FC0));
|
|
267
|
-
result_i16x16 = _mm256_blendv_epi8(result_i16x16, inf_i16x16, is_inf_i16x16);
|
|
268
|
-
return _mm256_blendv_epi8(result_i16x16, nan_i16x16, is_nan_i16x16);
|
|
269
|
-
}
|
|
270
|
-
|
|
271
|
-
/** @brief Convert 16x e4m3 → 16x f16 via arithmetic + small LUT for subnormals (AVX2).
|
|
272
|
-
* E4M3 format: S EEEE MMM (bias=7). F16: S EEEEE MMMMMMMMMM (bias=15).
|
|
273
|
-
* Normal values: F16 = sign | ((lower7 << 7) + 0x2000).
|
|
274
|
-
* Subnormals (8 values): looked up via vpshufb from an 8-entry LUT.
|
|
275
|
-
* Handles all corner cases: zero, subnormals, normals, and NaN. */
|
|
276
|
-
NK_INTERNAL __m256i nk_e4m3x16_to_f16x16_haswell_(__m128i e4m3x16) {
|
|
277
|
-
__m256i e4m3_i16x16 = _mm256_cvtepu8_epi16(e4m3x16);
|
|
278
|
-
__m256i sign_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16((short)0x80));
|
|
279
|
-
__m256i lower7_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x7F));
|
|
280
|
-
|
|
281
|
-
// Normal path: F16 = ((lower7 << 7) + 0x2000) | (sign << 8)
|
|
282
|
-
__m256i normal_abs_i16x16 = _mm256_add_epi16(_mm256_slli_epi16(lower7_i16x16, 7), _mm256_set1_epi16(0x2000));
|
|
283
|
-
sign_i16x16 = _mm256_slli_epi16(sign_i16x16, 8);
|
|
284
|
-
__m256i normal_i16x16 = _mm256_or_si256(sign_i16x16, normal_abs_i16x16);
|
|
285
|
-
|
|
286
|
-
// Subnormal LUT via shuffle_epi8 (8 entries: mantissa 0-7 → F16)
|
|
287
|
-
// E4M3 subnormal F16 values: 0x0000, 0x1800, 0x1C00, 0x1E00, 0x2000, 0x2100, 0x2200, 0x2300
|
|
288
|
-
// All low bytes are 0x00, high bytes: 0x00, 0x18, 0x1C, 0x1E, 0x20, 0x21, 0x22, 0x23
|
|
289
|
-
// _mm_set_epi8 order: b15..u1 (unused), b7=idx7, b6=idx6, ..., b0=idx0
|
|
290
|
-
__m256i const lo_lut_i8x32 = _mm256_setzero_si256();
|
|
291
|
-
__m256i const hi_lut_i8x32 = _mm256_broadcastsi128_si256(_mm_set_epi8( //
|
|
292
|
-
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, //
|
|
293
|
-
0x23, 0x22, 0x21, 0x20, 0x1E, 0x1C, 0x18, 0x00)); //
|
|
294
|
-
|
|
295
|
-
// Extract mantissa (bits 0-2) as byte indices for shuffle
|
|
296
|
-
__m256i byte_idx_i8x32 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi8(0x07));
|
|
297
|
-
__m256i lo_bytes_i8x32 = _mm256_shuffle_epi8(lo_lut_i8x32, byte_idx_i8x32);
|
|
298
|
-
__m256i hi_bytes_i8x32 = _mm256_shuffle_epi8(hi_lut_i8x32, byte_idx_i8x32);
|
|
299
|
-
|
|
300
|
-
// Combine low and high bytes into 16-bit values
|
|
301
|
-
__m256i subnorm_abs_i16x16 = _mm256_or_si256( //
|
|
302
|
-
_mm256_and_si256(lo_bytes_i8x32, _mm256_set1_epi16(0x00FF)), //
|
|
303
|
-
_mm256_slli_epi16(hi_bytes_i8x32, 8)); //
|
|
304
|
-
__m256i subnorm_i16x16 = _mm256_or_si256(subnorm_abs_i16x16, sign_i16x16);
|
|
305
|
-
|
|
306
|
-
// Blend: if exponent == 0, use subnormal result; else use normal result
|
|
307
|
-
__m256i exp_bits_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x78));
|
|
308
|
-
__m256i is_subnormal_i16x16 = _mm256_cmpeq_epi16(exp_bits_i16x16, _mm256_setzero_si256());
|
|
309
|
-
__m256i result_i16x16 = _mm256_blendv_epi8(normal_i16x16, subnorm_i16x16, is_subnormal_i16x16);
|
|
310
|
-
|
|
311
|
-
// Handle NaN: E4M3 index 127 (0x7F) → F16 NaN (0x7E00)
|
|
312
|
-
__m256i is_nan_i16x16 = _mm256_cmpeq_epi16(lower7_i16x16, _mm256_set1_epi16(0x7F));
|
|
313
|
-
__m256i nan_i16x16 = _mm256_or_si256(sign_i16x16, _mm256_set1_epi16(0x7E00));
|
|
314
|
-
return _mm256_blendv_epi8(result_i16x16, nan_i16x16, is_nan_i16x16);
|
|
315
|
-
}
|
|
316
|
-
|
|
317
|
-
/** @brief Convert 16x e5m2 → 16x f16 via simple bit shift (AVX2).
|
|
318
|
-
* E5M2 format: S EEEEE MM (bias=15). F16: S EEEEE MMMMMMMMMM (bias=15).
|
|
319
|
-
* Same exponent bias means F16 = (lower7 << 8) | (sign << 15).
|
|
320
|
-
* Handles all corner cases: zero, subnormals, normals, infinity, and NaN. */
|
|
321
|
-
NK_INTERNAL __m256i nk_e5m2x16_to_f16x16_haswell_(__m128i e5m2x16) {
|
|
322
|
-
__m256i e5m2_i16x16 = _mm256_cvtepu8_epi16(e5m2x16);
|
|
323
|
-
__m256i sign_i16x16 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16((short)0x80));
|
|
324
|
-
__m256i lower7_i16x16 = _mm256_and_si256(e5m2_i16x16, _mm256_set1_epi16(0x7F));
|
|
325
|
-
|
|
326
|
-
// F16 = (lower7 << 8) | (sign << 15)
|
|
327
|
-
// Works for all cases: subnormals, normals, infinity, and NaN
|
|
328
|
-
__m256i result_i16x16 = _mm256_slli_epi16(lower7_i16x16, 8);
|
|
329
|
-
sign_i16x16 = _mm256_slli_epi16(sign_i16x16, 8);
|
|
330
|
-
return _mm256_or_si256(result_i16x16, sign_i16x16);
|
|
331
|
-
}
|
|
332
|
-
|
|
333
175
|
/** @brief Convert 8x e4m3 → 8x f32 via bit manipulation (AVX2).
|
|
334
176
|
* E4M3 format: S EEEE MMM (bias=7). F32: sign<<31, (exp+120)<<23, mant<<20.
|
|
335
|
-
* Subnormals (exp=0):
|
|
177
|
+
* Subnormals (exp=0): looked up via vpermps from an 8-entry register LUT.
|
|
178
|
+
* NaN detection uses a single comparison on the 7-bit magnitude (0x7F). */
|
|
336
179
|
NK_INTERNAL __m256 nk_e4m3x8_to_f32x8_haswell_(__m128i e4m3_i8x8) {
|
|
337
180
|
__m256i e4m3_i32x8 = _mm256_cvtepu8_epi32(e4m3_i8x8);
|
|
338
181
|
|
|
@@ -348,21 +191,26 @@ NK_INTERNAL __m256 nk_e4m3x8_to_f32x8_haswell_(__m128i e4m3_i8x8) {
|
|
|
348
191
|
__m256i f32_mant_i32x8 = _mm256_slli_epi32(mant_i32x8, 20);
|
|
349
192
|
__m256i normal_bits_i32x8 = _mm256_or_si256(f32_sign_i32x8, _mm256_or_si256(f32_exp_i32x8, f32_mant_i32x8));
|
|
350
193
|
|
|
351
|
-
// Subnormal path:
|
|
352
|
-
__m256
|
|
353
|
-
|
|
194
|
+
// Subnormal path: vpermps from 8-entry register LUT (3 cy latency, no memory access)
|
|
195
|
+
__m256 subnorm_lut_f32x8 = _mm256_setr_ps(0, 1.0f / 512, 2.0f / 512, 3.0f / 512, //
|
|
196
|
+
4.0f / 512, 5.0f / 512, 6.0f / 512, 7.0f / 512);
|
|
197
|
+
__m256i subnorm_bits_i32x8 = _mm256_or_si256( //
|
|
198
|
+
_mm256_castps_si256(_mm256_permutevar8x32_ps(subnorm_lut_f32x8, mant_i32x8)), f32_sign_i32x8);
|
|
354
199
|
|
|
355
|
-
//
|
|
200
|
+
// Bitwise select: if exp==0, use subnormal; otherwise use normal
|
|
356
201
|
__m256i exp_zero_mask = _mm256_cmpeq_epi32(exp_i32x8, _mm256_setzero_si256());
|
|
357
|
-
|
|
358
|
-
|
|
202
|
+
__m256i result_i32x8 = _mm256_or_si256( //
|
|
203
|
+
_mm256_and_si256(exp_zero_mask, subnorm_bits_i32x8), //
|
|
204
|
+
_mm256_andnot_si256(exp_zero_mask, normal_bits_i32x8));
|
|
359
205
|
|
|
360
|
-
// NaN
|
|
361
|
-
__m256i
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
206
|
+
// NaN: E4M3FN has NaN only at magnitude 0x7F (exp=15, mant=7)
|
|
207
|
+
__m256i lower7_i32x8 = _mm256_and_si256(e4m3_i32x8, _mm256_set1_epi32(0x7F));
|
|
208
|
+
__m256i is_nan_mask = _mm256_cmpeq_epi32(lower7_i32x8, _mm256_set1_epi32(0x7F));
|
|
209
|
+
__m256i nan_i32x8 = _mm256_or_si256(f32_sign_i32x8, _mm256_set1_epi32(0x7FC00000));
|
|
210
|
+
result_i32x8 = _mm256_or_si256( //
|
|
211
|
+
_mm256_and_si256(is_nan_mask, nan_i32x8), //
|
|
212
|
+
_mm256_andnot_si256(is_nan_mask, result_i32x8));
|
|
213
|
+
return _mm256_castsi256_ps(result_i32x8);
|
|
366
214
|
}
|
|
367
215
|
|
|
368
216
|
/** @brief Convert 8x e5m2 → 8x f32 via bit manipulation (AVX2).
|
|
@@ -676,9 +524,9 @@ NK_INTERNAL __m128i nk_f32x8_to_e3m2x8_haswell_(__m256 f32x8) {
|
|
|
676
524
|
return packed_i8x8;
|
|
677
525
|
}
|
|
678
526
|
|
|
679
|
-
#pragma endregion
|
|
527
|
+
#pragma endregion Vectorized Conversions
|
|
680
528
|
|
|
681
|
-
#pragma region
|
|
529
|
+
#pragma region Converting Loads and Stores
|
|
682
530
|
|
|
683
531
|
/** @brief Full load for f16 elements (8) with conversion to f32 via F16C. */
|
|
684
532
|
NK_INTERNAL void nk_load_f16x8_to_f32x8_haswell_(void const *src, nk_b256_vec_t *dst) {
|
|
@@ -794,9 +642,9 @@ NK_INTERNAL void nk_partial_load_u32x8_to_f32x8_haswell_(nk_u32_t const *src, nk
|
|
|
794
642
|
dst->ymm_ps = nk_u32x8_to_f32x8_haswell_(vec.ymm);
|
|
795
643
|
}
|
|
796
644
|
|
|
797
|
-
#pragma endregion
|
|
645
|
+
#pragma endregion Converting Loads and Stores
|
|
798
646
|
|
|
799
|
-
#pragma region
|
|
647
|
+
#pragma region Public API
|
|
800
648
|
|
|
801
649
|
NK_PUBLIC void nk_cast_haswell(void const *from, nk_dtype_t from_type, nk_size_t n, void *to, nk_dtype_t to_type) {
|
|
802
650
|
// Same-type fast path
|
|
@@ -958,7 +806,7 @@ NK_PUBLIC void nk_cast_haswell(void const *from, nk_dtype_t from_type, nk_size_t
|
|
|
958
806
|
}
|
|
959
807
|
}
|
|
960
808
|
|
|
961
|
-
#pragma endregion
|
|
809
|
+
#pragma endregion Public API
|
|
962
810
|
|
|
963
811
|
#if defined(__clang__)
|
|
964
812
|
#pragma clang attribute pop
|