numkong 7.0.0 → 7.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +239 -122
- package/binding.gyp +25 -491
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -94,11 +94,11 @@ extern "C" {
|
|
|
94
94
|
#define nk_define_dot_(input_type, accumulator_type, output_type, load_and_convert) \
|
|
95
95
|
NK_PUBLIC void nk_dot_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
|
|
96
96
|
nk_size_t n, nk_##output_type##_t *result) { \
|
|
97
|
-
nk_##accumulator_type##_t sum = 0,
|
|
97
|
+
nk_##accumulator_type##_t sum = 0, a_value, b_value; \
|
|
98
98
|
for (nk_size_t i = 0; i != n; ++i) { \
|
|
99
|
-
load_and_convert(a + i, &
|
|
100
|
-
load_and_convert(b + i, &
|
|
101
|
-
sum +=
|
|
99
|
+
load_and_convert(a + i, &a_value); \
|
|
100
|
+
load_and_convert(b + i, &b_value); \
|
|
101
|
+
sum += a_value * b_value; \
|
|
102
102
|
} \
|
|
103
103
|
*result = (nk_##output_type##_t)sum; \
|
|
104
104
|
}
|
|
@@ -139,15 +139,15 @@ extern "C" {
|
|
|
139
139
|
result->imag = sum_imag; \
|
|
140
140
|
}
|
|
141
141
|
|
|
142
|
-
#pragma region
|
|
142
|
+
#pragma region F32 and F64 Floats
|
|
143
143
|
|
|
144
144
|
nk_define_dot_(f32, f64, f64, nk_assign_from_to_) // nk_dot_f32_serial
|
|
145
145
|
nk_define_dot_complex_(f32c, f64, f64c, nk_assign_from_to_) // nk_dot_f32c_serial
|
|
146
146
|
nk_define_vdot_complex_(f32c, f64, f64c, nk_assign_from_to_) // nk_vdot_f32c_serial
|
|
147
147
|
|
|
148
|
-
#pragma endregion
|
|
148
|
+
#pragma endregion F32 and F64 Floats
|
|
149
149
|
|
|
150
|
-
#pragma region
|
|
150
|
+
#pragma region F16 and BF16 Floats
|
|
151
151
|
|
|
152
152
|
nk_define_dot_(f16, f32, f32, nk_f16_to_f32_serial) // nk_dot_f16_serial
|
|
153
153
|
nk_define_dot_complex_(f16c, f32, f32c, nk_f16_to_f32_serial) // nk_dot_f16c_serial
|
|
@@ -162,9 +162,9 @@ nk_define_dot_(e5m2, f32, f32, nk_e5m2_to_f32_serial) // nk_dot_e5m2_serial
|
|
|
162
162
|
nk_define_dot_(e2m3, f32, f32, nk_e2m3_to_f32_serial) // nk_dot_e2m3_serial
|
|
163
163
|
nk_define_dot_(e3m2, f32, f32, nk_e3m2_to_f32_serial) // nk_dot_e3m2_serial
|
|
164
164
|
|
|
165
|
-
#pragma endregion
|
|
165
|
+
#pragma endregion F16 and BF16 Floats
|
|
166
166
|
|
|
167
|
-
#pragma region
|
|
167
|
+
#pragma region I8 and U8 Integers
|
|
168
168
|
|
|
169
169
|
nk_define_dot_(i8, i32, i32, nk_assign_from_to_) // nk_dot_i8_serial
|
|
170
170
|
nk_define_dot_(u8, u32, u32, nk_assign_from_to_) // nk_dot_u8_serial
|
|
@@ -207,9 +207,9 @@ NK_PUBLIC void nk_dot_u4_serial(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_
|
|
|
207
207
|
*result = sum;
|
|
208
208
|
}
|
|
209
209
|
|
|
210
|
-
#pragma endregion
|
|
210
|
+
#pragma endregion I8 and U8 Integers
|
|
211
211
|
|
|
212
|
-
#pragma region
|
|
212
|
+
#pragma region F32 and F64 Floats
|
|
213
213
|
|
|
214
214
|
/* Double-precision dot-produce variants
|
|
215
215
|
*
|
|
@@ -325,9 +325,9 @@ NK_INTERNAL void nk_dot_f32x4_finalize_serial(
|
|
|
325
325
|
result->f64s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
|
|
326
326
|
}
|
|
327
327
|
|
|
328
|
-
#pragma endregion
|
|
328
|
+
#pragma endregion F32 and F64 Floats
|
|
329
329
|
|
|
330
|
-
#pragma region
|
|
330
|
+
#pragma region F16 and BF16 Floats
|
|
331
331
|
|
|
332
332
|
typedef struct nk_dot_f16x8_state_serial_t {
|
|
333
333
|
nk_f32_t sums[4];
|
|
@@ -364,6 +364,36 @@ NK_INTERNAL void nk_dot_f16x8_finalize_serial(
|
|
|
364
364
|
result->f32s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
|
|
365
365
|
}
|
|
366
366
|
|
|
367
|
+
typedef struct nk_dot_through_f32x4_state_serial_t {
|
|
368
|
+
nk_f32_t sums[4];
|
|
369
|
+
} nk_dot_through_f32x4_state_serial_t;
|
|
370
|
+
|
|
371
|
+
NK_INTERNAL void nk_dot_through_f32x4_init_serial(nk_dot_through_f32x4_state_serial_t *state) {
|
|
372
|
+
state->sums[0] = 0, state->sums[1] = 0, state->sums[2] = 0, state->sums[3] = 0;
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
NK_INTERNAL void nk_dot_through_f32x4_update_serial(nk_dot_through_f32x4_state_serial_t *state, nk_b128_vec_t a,
|
|
376
|
+
nk_b128_vec_t b, nk_size_t depth_offset,
|
|
377
|
+
nk_size_t active_dimensions) {
|
|
378
|
+
nk_unused_(depth_offset);
|
|
379
|
+
nk_unused_(active_dimensions);
|
|
380
|
+
state->sums[0] += a.f32s[0] * b.f32s[0];
|
|
381
|
+
state->sums[1] += a.f32s[1] * b.f32s[1];
|
|
382
|
+
state->sums[2] += a.f32s[2] * b.f32s[2];
|
|
383
|
+
state->sums[3] += a.f32s[3] * b.f32s[3];
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
NK_INTERNAL void nk_dot_through_f32x4_finalize_serial( //
|
|
387
|
+
nk_dot_through_f32x4_state_serial_t const *state_a, nk_dot_through_f32x4_state_serial_t const *state_b, //
|
|
388
|
+
nk_dot_through_f32x4_state_serial_t const *state_c, nk_dot_through_f32x4_state_serial_t const *state_d, //
|
|
389
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
390
|
+
nk_unused_(total_dimensions);
|
|
391
|
+
result->f32s[0] = state_a->sums[0] + state_a->sums[1] + state_a->sums[2] + state_a->sums[3];
|
|
392
|
+
result->f32s[1] = state_b->sums[0] + state_b->sums[1] + state_b->sums[2] + state_b->sums[3];
|
|
393
|
+
result->f32s[2] = state_c->sums[0] + state_c->sums[1] + state_c->sums[2] + state_c->sums[3];
|
|
394
|
+
result->f32s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
|
|
395
|
+
}
|
|
396
|
+
|
|
367
397
|
typedef struct nk_dot_bf16x8_state_serial_t {
|
|
368
398
|
nk_f32_t sums[4];
|
|
369
399
|
} nk_dot_bf16x8_state_serial_t;
|
|
@@ -399,9 +429,9 @@ NK_INTERNAL void nk_dot_bf16x8_finalize_serial(
|
|
|
399
429
|
result->f32s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
|
|
400
430
|
}
|
|
401
431
|
|
|
402
|
-
#pragma endregion
|
|
432
|
+
#pragma endregion F16 and BF16 Floats
|
|
403
433
|
|
|
404
|
-
#pragma region
|
|
434
|
+
#pragma region I8 and U8 Integers
|
|
405
435
|
|
|
406
436
|
typedef struct nk_dot_i8x16_state_serial_t {
|
|
407
437
|
nk_i64_t sums[2];
|
|
@@ -476,9 +506,9 @@ NK_INTERNAL void nk_dot_u8x16_finalize_serial(
|
|
|
476
506
|
result->u32s[3] = (nk_u32_t)(state_d->sums[0] + state_d->sums[1]);
|
|
477
507
|
}
|
|
478
508
|
|
|
479
|
-
#pragma endregion
|
|
509
|
+
#pragma endregion I8 and U8 Integers
|
|
480
510
|
|
|
481
|
-
#pragma region
|
|
511
|
+
#pragma region F16 and BF16 Floats
|
|
482
512
|
|
|
483
513
|
typedef struct nk_dot_e4m3x16_state_serial_t {
|
|
484
514
|
nk_f32_t sums[4];
|
|
@@ -640,9 +670,9 @@ NK_INTERNAL void nk_dot_e3m2x16_finalize_serial(
|
|
|
640
670
|
result->f32s[3] = state_d->sums[0] + state_d->sums[1] + state_d->sums[2] + state_d->sums[3];
|
|
641
671
|
}
|
|
642
672
|
|
|
643
|
-
#pragma endregion
|
|
673
|
+
#pragma endregion F16 and BF16 Floats
|
|
644
674
|
|
|
645
|
-
#pragma region
|
|
675
|
+
#pragma region I8 and U8 Integers
|
|
646
676
|
|
|
647
677
|
// U4x2 state: processes 16 nibbles (8 bytes = 64 bits) per update
|
|
648
678
|
typedef struct nk_dot_u4x16_state_serial_t {
|
|
@@ -694,20 +724,26 @@ NK_INTERNAL void nk_dot_u4x16_finalize_serial(nk_dot_u4x16_state_serial_t const
|
|
|
694
724
|
}
|
|
695
725
|
|
|
696
726
|
NK_INTERNAL void nk_load_i4x16_to_i8x16_serial_(void const *src, nk_b128_vec_t *dst) {
|
|
697
|
-
|
|
727
|
+
nk_i4x2_t const *pairs = (nk_i4x2_t const *)src;
|
|
728
|
+
for (nk_size_t i = 0; i < 8; ++i) nk_i4x2_to_i8x2_serial(&pairs[i], &dst->i8s[i * 2]);
|
|
698
729
|
}
|
|
699
730
|
|
|
700
731
|
NK_INTERNAL void nk_partial_load_i4x16_to_i8x16_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
701
|
-
|
|
732
|
+
nk_i4x2_t const *pairs = (nk_i4x2_t const *)src;
|
|
733
|
+
nk_size_t count_pairs = n / 2;
|
|
734
|
+
for (nk_size_t i = 0; i < count_pairs; ++i) nk_i4x2_to_i8x2_serial(&pairs[i], &dst->i8s[i * 2]);
|
|
702
735
|
for (nk_size_t i = n; i < 16; ++i) dst->i8s[i] = 0;
|
|
703
736
|
}
|
|
704
737
|
|
|
705
738
|
NK_INTERNAL void nk_load_u4x16_to_u8x16_serial_(void const *src, nk_b128_vec_t *dst) {
|
|
706
|
-
|
|
739
|
+
nk_u4x2_t const *pairs = (nk_u4x2_t const *)src;
|
|
740
|
+
for (nk_size_t i = 0; i < 8; ++i) nk_u4x2_to_u8x2_serial(&pairs[i], &dst->u8s[i * 2]);
|
|
707
741
|
}
|
|
708
742
|
|
|
709
743
|
NK_INTERNAL void nk_partial_load_u4x16_to_u8x16_serial_(void const *src, nk_b128_vec_t *dst, nk_size_t n) {
|
|
710
|
-
|
|
744
|
+
nk_u4x2_t const *pairs = (nk_u4x2_t const *)src;
|
|
745
|
+
nk_size_t count_pairs = n / 2;
|
|
746
|
+
for (nk_size_t i = 0; i < count_pairs; ++i) nk_u4x2_to_u8x2_serial(&pairs[i], &dst->u8s[i * 2]);
|
|
711
747
|
for (nk_size_t i = n; i < 16; ++i) dst->u8s[i] = 0;
|
|
712
748
|
}
|
|
713
749
|
|
|
@@ -759,9 +795,9 @@ NK_INTERNAL void nk_dot_i4x16_finalize_serial(nk_dot_i4x16_state_serial_t const
|
|
|
759
795
|
result->i32s[3] = (nk_i32_t)(state_d->sums[0] + state_d->sums[1]);
|
|
760
796
|
}
|
|
761
797
|
|
|
762
|
-
#pragma endregion
|
|
798
|
+
#pragma endregion I8 and U8 Integers
|
|
763
799
|
|
|
764
|
-
#pragma region
|
|
800
|
+
#pragma region Binary
|
|
765
801
|
|
|
766
802
|
NK_PUBLIC void nk_dot_u1_serial(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
|
|
767
803
|
nk_u32_t dot = 0;
|
|
@@ -798,7 +834,7 @@ NK_INTERNAL void nk_dot_u1x128_finalize_serial(nk_dot_u1x128_state_serial_t cons
|
|
|
798
834
|
result->u32s[3] = state_d->dot_count;
|
|
799
835
|
}
|
|
800
836
|
|
|
801
|
-
#pragma endregion
|
|
837
|
+
#pragma endregion Binary
|
|
802
838
|
|
|
803
839
|
/**
|
|
804
840
|
* Serial fallback sum helpers for progressive element-sum accumulation.
|
|
@@ -806,7 +842,7 @@ NK_INTERNAL void nk_dot_u1x128_finalize_serial(nk_dot_u1x128_state_serial_t cons
|
|
|
806
842
|
* on the depth loop's already-loaded vectors, avoiding a separate sum pass.
|
|
807
843
|
*/
|
|
808
844
|
|
|
809
|
-
#pragma region
|
|
845
|
+
#pragma region Stateful Element Sum Helpers (for compensated GEMM)
|
|
810
846
|
|
|
811
847
|
/* i4x32: Haswell i4 (nk_b128_vec_t containing 32 nibbles in 16 bytes) */
|
|
812
848
|
typedef struct nk_sum_i4x32_state_serial_t {
|
|
@@ -818,8 +854,8 @@ NK_INTERNAL void nk_sum_i4x32_init_serial(nk_sum_i4x32_state_serial_t *state) {
|
|
|
818
854
|
NK_INTERNAL void nk_sum_i4x32_update_serial(nk_sum_i4x32_state_serial_t *state, nk_b128_vec_t v) {
|
|
819
855
|
nk_u8_t const *d = (nk_u8_t const *)&v;
|
|
820
856
|
for (int i = 0; i < 16; i++) {
|
|
821
|
-
nk_i8_t low = (nk_i8_t)((d[i] & 0x0F) ^ 0x08) - 8;
|
|
822
|
-
nk_i8_t high = (nk_i8_t)((d[i] >> 4) ^ 0x08) - 8;
|
|
857
|
+
nk_i8_t low = (nk_i8_t)((d[i] & 0x0F) ^ 0x08) - 8; // sign-extend low nibble
|
|
858
|
+
nk_i8_t high = (nk_i8_t)((d[i] >> 4) ^ 0x08) - 8; // sign-extend high nibble
|
|
823
859
|
state->sum += low + high;
|
|
824
860
|
}
|
|
825
861
|
}
|
|
@@ -829,7 +865,7 @@ NK_INTERNAL nk_i32_t nk_sum_i4x32_finalize_serial(nk_sum_i4x32_state_serial_t co
|
|
|
829
865
|
return (nk_i32_t)state->sum;
|
|
830
866
|
}
|
|
831
867
|
|
|
832
|
-
#pragma endregion
|
|
868
|
+
#pragma endregion Stateful Element Sum Helpers
|
|
833
869
|
|
|
834
870
|
#if defined(__cplusplus)
|
|
835
871
|
} // extern "C"
|
|
@@ -8,9 +8,9 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section dot_sierra_instructions AVX-VNNI-INT8 Instructions
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* _mm256_dpbssd_epi32
|
|
13
|
-
* _mm256_dpbuud_epi32
|
|
11
|
+
* Intrinsic Instruction
|
|
12
|
+
* _mm256_dpbssd_epi32 VPDPBSSD (YMM, YMM, YMM) i8 × i8 → i32
|
|
13
|
+
* _mm256_dpbuud_epi32 VPDPBUUD (YMM, YMM, YMM) u8 × u8 → u32
|
|
14
14
|
*
|
|
15
15
|
* Sierra Forest CPUs support AVX-VNNI-INT8, adding native signed*signed and
|
|
16
16
|
* unsigned*unsigned 8-bit dot products. This eliminates the algebraic sign
|
|
@@ -248,10 +248,10 @@ NK_PUBLIC void nk_dot_e2m3_sierra(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b
|
|
|
248
248
|
// Uses dpbssd instead of dpbusd — both operands are already signed i8 after
|
|
249
249
|
// LUT + sign application, so no unsigned conversion is needed.
|
|
250
250
|
//
|
|
251
|
-
__m256i const
|
|
252
|
-
|
|
253
|
-
__m256i const
|
|
254
|
-
|
|
251
|
+
__m256i const lut_low_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28, 26,
|
|
252
|
+
24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
253
|
+
__m256i const lut_high_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
254
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
255
255
|
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
256
256
|
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
257
257
|
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
@@ -277,11 +277,11 @@ nk_dot_e2m3_sierra_cycle:
|
|
|
277
277
|
// Decode a: extract magnitude, dual-VPSHUFB LUT, apply sign
|
|
278
278
|
__m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
|
|
279
279
|
__m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
280
|
-
__m256i
|
|
281
|
-
|
|
282
|
-
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(
|
|
283
|
-
_mm256_shuffle_epi8(
|
|
284
|
-
|
|
280
|
+
__m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
|
|
281
|
+
half_select_u8x32);
|
|
282
|
+
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, a_shuffle_index_u8x32),
|
|
283
|
+
_mm256_shuffle_epi8(lut_high_u8x32, a_shuffle_index_u8x32),
|
|
284
|
+
a_high_select_u8x32);
|
|
285
285
|
__m256i a_negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
286
286
|
__m256i a_signed_i8x32 = _mm256_blendv_epi8(
|
|
287
287
|
a_unsigned_u8x32, _mm256_sub_epi8(_mm256_setzero_si256(), a_unsigned_u8x32), a_negate_mask_u8x32);
|
|
@@ -289,11 +289,11 @@ nk_dot_e2m3_sierra_cycle:
|
|
|
289
289
|
// Decode b: same LUT decode + sign
|
|
290
290
|
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
|
|
291
291
|
__m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
292
|
-
__m256i
|
|
293
|
-
|
|
294
|
-
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(
|
|
295
|
-
_mm256_shuffle_epi8(
|
|
296
|
-
|
|
292
|
+
__m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
|
|
293
|
+
half_select_u8x32);
|
|
294
|
+
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, b_shuffle_index_u8x32),
|
|
295
|
+
_mm256_shuffle_epi8(lut_high_u8x32, b_shuffle_index_u8x32),
|
|
296
|
+
b_high_select_u8x32);
|
|
297
297
|
__m256i b_negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
298
298
|
__m256i b_signed_i8x32 = _mm256_blendv_epi8(
|
|
299
299
|
b_unsigned_u8x32, _mm256_sub_epi8(_mm256_setzero_si256(), b_unsigned_u8x32), b_negate_mask_u8x32);
|
|
@@ -318,10 +318,10 @@ NK_INTERNAL void nk_dot_e2m3x32_update_sierra(nk_dot_e2m3x32_state_sierra_t *sta
|
|
|
318
318
|
nk_unused_(depth_offset);
|
|
319
319
|
nk_unused_(active_dimensions);
|
|
320
320
|
// Same LUT constants...
|
|
321
|
-
__m256i const
|
|
322
|
-
|
|
323
|
-
__m256i const
|
|
324
|
-
|
|
321
|
+
__m256i const lut_low_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28, 26,
|
|
322
|
+
24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
323
|
+
__m256i const lut_high_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
324
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
325
325
|
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
326
326
|
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
327
327
|
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
@@ -333,11 +333,11 @@ NK_INTERNAL void nk_dot_e2m3x32_update_sierra(nk_dot_e2m3x32_state_sierra_t *sta
|
|
|
333
333
|
// Decode a
|
|
334
334
|
__m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
|
|
335
335
|
__m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
336
|
-
__m256i
|
|
337
|
-
|
|
338
|
-
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(
|
|
339
|
-
_mm256_shuffle_epi8(
|
|
340
|
-
|
|
336
|
+
__m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
|
|
337
|
+
half_select_u8x32);
|
|
338
|
+
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, a_shuffle_index_u8x32),
|
|
339
|
+
_mm256_shuffle_epi8(lut_high_u8x32, a_shuffle_index_u8x32),
|
|
340
|
+
a_high_select_u8x32);
|
|
341
341
|
__m256i a_negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
342
342
|
__m256i a_signed_i8x32 = _mm256_blendv_epi8(
|
|
343
343
|
a_unsigned_u8x32, _mm256_sub_epi8(_mm256_setzero_si256(), a_unsigned_u8x32), a_negate_mask_u8x32);
|
|
@@ -345,11 +345,11 @@ NK_INTERNAL void nk_dot_e2m3x32_update_sierra(nk_dot_e2m3x32_state_sierra_t *sta
|
|
|
345
345
|
// Decode b
|
|
346
346
|
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
|
|
347
347
|
__m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
348
|
-
__m256i
|
|
349
|
-
|
|
350
|
-
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(
|
|
351
|
-
_mm256_shuffle_epi8(
|
|
352
|
-
|
|
348
|
+
__m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
|
|
349
|
+
half_select_u8x32);
|
|
350
|
+
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, b_shuffle_index_u8x32),
|
|
351
|
+
_mm256_shuffle_epi8(lut_high_u8x32, b_shuffle_index_u8x32),
|
|
352
|
+
b_high_select_u8x32);
|
|
353
353
|
__m256i b_negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
|
|
354
354
|
__m256i b_signed_i8x32 = _mm256_blendv_epi8(
|
|
355
355
|
b_unsigned_u8x32, _mm256_sub_epi8(_mm256_setzero_si256(), b_unsigned_u8x32), b_negate_mask_u8x32);
|