numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -8,12 +8,12 @@
|
|
|
8
8
|
*
|
|
9
9
|
* @section dot_haswell_instructions Key AVX2/FMA Dot Product Instructions
|
|
10
10
|
*
|
|
11
|
-
* Intrinsic
|
|
12
|
-
* _mm256_fmadd_ps/pd
|
|
13
|
-
* _mm256_mul_ps/pd
|
|
14
|
-
* _mm256_add_ps/pd
|
|
15
|
-
* _mm256_cvtph_ps
|
|
16
|
-
* _mm256_cvtps_pd
|
|
11
|
+
* Intrinsic Instruction Haswell Genoa
|
|
12
|
+
* _mm256_fmadd_ps/pd VFMADD (YMM, YMM, YMM) 5cy @ p01 4cy @ p01
|
|
13
|
+
* _mm256_mul_ps/pd VMULPS/PD (YMM, YMM, YMM) 5cy @ p01 3cy @ p01
|
|
14
|
+
* _mm256_add_ps/pd VADDPS/PD (YMM, YMM, YMM) 3cy @ p01 3cy @ p23
|
|
15
|
+
* _mm256_cvtph_ps VCVTPH2PS (YMM, XMM) 5cy @ p01 4cy @ p12+p23
|
|
16
|
+
* _mm256_cvtps_pd VCVTPS2PD (YMM, XMM) 2cy @ p01 4cy @ p12+p23
|
|
17
17
|
*
|
|
18
18
|
* For small numeric types (F16, BF16, E4M3, E5M2) we use F32 accumulators. For F32 dot products,
|
|
19
19
|
* upcasting to F64 and downcasting back is faster than stable summation algorithms. For F64 we
|
|
@@ -141,7 +141,7 @@ NK_INTERNAL nk_f64_t nk_dot_stable_sum_f64x4_haswell_(__m256d sum_f64x4, __m256d
|
|
|
141
141
|
return tentative_sum + (lower_error + upper_error + rounding_error);
|
|
142
142
|
}
|
|
143
143
|
|
|
144
|
-
#pragma region
|
|
144
|
+
#pragma region F32 and F64 Floats
|
|
145
145
|
|
|
146
146
|
NK_PUBLIC void nk_dot_f32_haswell(nk_f32_t const *a_scalars, nk_f32_t const *b_scalars, nk_size_t count_scalars,
|
|
147
147
|
nk_f64_t *result) {
|
|
@@ -479,30 +479,35 @@ NK_INTERNAL void nk_dot_f32x4_finalize_haswell(
|
|
|
479
479
|
result->ymm_pd = sum_abcd_f64x4;
|
|
480
480
|
}
|
|
481
481
|
|
|
482
|
-
#pragma endregion
|
|
482
|
+
#pragma endregion F32 and F64 Floats
|
|
483
483
|
|
|
484
|
-
#pragma region
|
|
484
|
+
#pragma region F16 and BF16 Floats
|
|
485
485
|
|
|
486
486
|
NK_PUBLIC void nk_dot_bf16_haswell(nk_bf16_t const *a_scalars, nk_bf16_t const *b_scalars, nk_size_t count_scalars,
|
|
487
487
|
nk_f32_t *result) {
|
|
488
|
-
|
|
488
|
+
__m256i a_bf16_i16x16, b_bf16_i16x16;
|
|
489
489
|
__m256 sum_f32x8 = _mm256_setzero_ps();
|
|
490
|
+
__m256i mask_high_u32x8 = _mm256_set1_epi32((int)0xFFFF0000);
|
|
490
491
|
nk_dot_bf16_haswell_cycle:
|
|
491
|
-
if (count_scalars <
|
|
492
|
+
if (count_scalars < 16) {
|
|
492
493
|
nk_b256_vec_t a_vec, b_vec;
|
|
493
494
|
nk_partial_load_b16x16_serial_(a_scalars, &a_vec, count_scalars);
|
|
494
495
|
nk_partial_load_b16x16_serial_(b_scalars, &b_vec, count_scalars);
|
|
495
|
-
|
|
496
|
-
|
|
496
|
+
a_bf16_i16x16 = a_vec.ymm;
|
|
497
|
+
b_bf16_i16x16 = b_vec.ymm;
|
|
497
498
|
count_scalars = 0;
|
|
498
499
|
}
|
|
499
500
|
else {
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
a_scalars +=
|
|
501
|
+
a_bf16_i16x16 = _mm256_loadu_si256((__m256i const *)a_scalars);
|
|
502
|
+
b_bf16_i16x16 = _mm256_loadu_si256((__m256i const *)b_scalars);
|
|
503
|
+
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
503
504
|
}
|
|
504
|
-
|
|
505
|
-
|
|
505
|
+
__m256 a_even_f32x8 = _mm256_castsi256_ps(_mm256_slli_epi32(a_bf16_i16x16, 16));
|
|
506
|
+
__m256 b_even_f32x8 = _mm256_castsi256_ps(_mm256_slli_epi32(b_bf16_i16x16, 16));
|
|
507
|
+
sum_f32x8 = _mm256_fmadd_ps(a_even_f32x8, b_even_f32x8, sum_f32x8);
|
|
508
|
+
__m256 a_odd_f32x8 = _mm256_castsi256_ps(_mm256_and_si256(a_bf16_i16x16, mask_high_u32x8));
|
|
509
|
+
__m256 b_odd_f32x8 = _mm256_castsi256_ps(_mm256_and_si256(b_bf16_i16x16, mask_high_u32x8));
|
|
510
|
+
sum_f32x8 = _mm256_fmadd_ps(a_odd_f32x8, b_odd_f32x8, sum_f32x8);
|
|
506
511
|
if (count_scalars) goto nk_dot_bf16_haswell_cycle;
|
|
507
512
|
*result = (nk_f32_t)nk_reduce_add_f32x8_haswell_(sum_f32x8);
|
|
508
513
|
}
|
|
@@ -534,7 +539,7 @@ NK_PUBLIC void nk_dot_bf16c_haswell(nk_bf16c_t const *a_pairs, nk_bf16c_t const
|
|
|
534
539
|
nk_f32c_t *result) {
|
|
535
540
|
// Convert BF16 to F32, then use F32 complex dot product with sign-flipping optimization.
|
|
536
541
|
// Uses same XOR trick as f32c to double throughput by deferring sign flips until after loop.
|
|
537
|
-
__m128i
|
|
542
|
+
__m128i a_bf16_i16x8, b_bf16_i16x8;
|
|
538
543
|
__m256 sum_real_f32x8 = _mm256_setzero_ps();
|
|
539
544
|
__m256 sum_imag_f32x8 = _mm256_setzero_ps();
|
|
540
545
|
__m256i const sign_flip_i64x4 = _mm256_set1_epi64x(0x8000000000000000);
|
|
@@ -547,19 +552,19 @@ nk_dot_bf16c_haswell_cycle:
|
|
|
547
552
|
nk_b256_vec_t a_vec, b_vec;
|
|
548
553
|
nk_partial_load_b16x16_serial_(a_pairs, &a_vec, count_pairs * 2);
|
|
549
554
|
nk_partial_load_b16x16_serial_(b_pairs, &b_vec, count_pairs * 2);
|
|
550
|
-
|
|
551
|
-
|
|
555
|
+
a_bf16_i16x8 = a_vec.xmms[0];
|
|
556
|
+
b_bf16_i16x8 = b_vec.xmms[0];
|
|
552
557
|
count_pairs = 0;
|
|
553
558
|
}
|
|
554
559
|
else {
|
|
555
|
-
|
|
556
|
-
|
|
560
|
+
a_bf16_i16x8 = _mm_loadu_si128((__m128i const *)a_pairs);
|
|
561
|
+
b_bf16_i16x8 = _mm_loadu_si128((__m128i const *)b_pairs);
|
|
557
562
|
a_pairs += 4, b_pairs += 4, count_pairs -= 4;
|
|
558
563
|
}
|
|
559
564
|
|
|
560
565
|
// Convert BF16 to F32
|
|
561
|
-
__m256 a_f32x8 = nk_bf16x8_to_f32x8_haswell_(
|
|
562
|
-
__m256 b_f32x8 = nk_bf16x8_to_f32x8_haswell_(
|
|
566
|
+
__m256 a_f32x8 = nk_bf16x8_to_f32x8_haswell_(a_bf16_i16x8);
|
|
567
|
+
__m256 b_f32x8 = nk_bf16x8_to_f32x8_haswell_(b_bf16_i16x8);
|
|
563
568
|
|
|
564
569
|
// Complex multiply-accumulate: swap b for imaginary part
|
|
565
570
|
__m256 b_swapped_f32x8 = _mm256_castsi256_ps(
|
|
@@ -579,7 +584,7 @@ nk_dot_bf16c_haswell_cycle:
|
|
|
579
584
|
NK_PUBLIC void nk_vdot_bf16c_haswell(nk_bf16c_t const *a_pairs, nk_bf16c_t const *b_pairs, nk_size_t count_pairs,
|
|
580
585
|
nk_f32c_t *result) {
|
|
581
586
|
// Conjugate complex dot product: conj(a) * b
|
|
582
|
-
__m128i
|
|
587
|
+
__m128i a_bf16_i16x8, b_bf16_i16x8;
|
|
583
588
|
__m256 sum_real_f32x8 = _mm256_setzero_ps();
|
|
584
589
|
__m256 sum_imag_f32x8 = _mm256_setzero_ps();
|
|
585
590
|
__m256i const sign_flip_i64x4 = _mm256_set1_epi64x(0x8000000000000000);
|
|
@@ -592,19 +597,19 @@ nk_vdot_bf16c_haswell_cycle:
|
|
|
592
597
|
nk_b256_vec_t a_vec, b_vec;
|
|
593
598
|
nk_partial_load_b16x16_serial_(a_pairs, &a_vec, count_pairs * 2);
|
|
594
599
|
nk_partial_load_b16x16_serial_(b_pairs, &b_vec, count_pairs * 2);
|
|
595
|
-
|
|
596
|
-
|
|
600
|
+
a_bf16_i16x8 = a_vec.xmms[0];
|
|
601
|
+
b_bf16_i16x8 = b_vec.xmms[0];
|
|
597
602
|
count_pairs = 0;
|
|
598
603
|
}
|
|
599
604
|
else {
|
|
600
|
-
|
|
601
|
-
|
|
605
|
+
a_bf16_i16x8 = _mm_loadu_si128((__m128i const *)a_pairs);
|
|
606
|
+
b_bf16_i16x8 = _mm_loadu_si128((__m128i const *)b_pairs);
|
|
602
607
|
a_pairs += 4, b_pairs += 4, count_pairs -= 4;
|
|
603
608
|
}
|
|
604
609
|
|
|
605
610
|
// Convert BF16 to F32
|
|
606
|
-
__m256 a_f32x8 = nk_bf16x8_to_f32x8_haswell_(
|
|
607
|
-
__m256 b_f32x8 = nk_bf16x8_to_f32x8_haswell_(
|
|
611
|
+
__m256 a_f32x8 = nk_bf16x8_to_f32x8_haswell_(a_bf16_i16x8);
|
|
612
|
+
__m256 b_f32x8 = nk_bf16x8_to_f32x8_haswell_(b_bf16_i16x8);
|
|
608
613
|
|
|
609
614
|
// Conjugate complex multiply-accumulate
|
|
610
615
|
sum_real_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, sum_real_f32x8);
|
|
@@ -724,10 +729,10 @@ NK_PUBLIC void nk_dot_e2m3_haswell(nk_e2m3_t const *a_scalars, nk_e2m3_t const *
|
|
|
724
729
|
// lut_lower[0..15]: {0,2,4,6,8,10,12,14, 16,18,20,22,24,26,28,30}
|
|
725
730
|
// lut_upper[0..15]: {32,36,40,44,48,52,56,60, 64,72,80,88,96,104,112,120}
|
|
726
731
|
//
|
|
727
|
-
__m256i const
|
|
728
|
-
|
|
729
|
-
__m256i const
|
|
730
|
-
|
|
732
|
+
__m256i const lut_low_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28, 26,
|
|
733
|
+
24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
734
|
+
__m256i const lut_high_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
|
|
735
|
+
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
731
736
|
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
732
737
|
__m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
|
|
733
738
|
__m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
|
|
@@ -756,18 +761,18 @@ nk_dot_e2m3_haswell_cycle:
|
|
|
756
761
|
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
|
|
757
762
|
__m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
758
763
|
__m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
759
|
-
__m256i
|
|
760
|
-
|
|
761
|
-
__m256i
|
|
762
|
-
|
|
764
|
+
__m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
|
|
765
|
+
half_select_u8x32);
|
|
766
|
+
__m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
|
|
767
|
+
half_select_u8x32);
|
|
763
768
|
|
|
764
769
|
// Dual VPSHUFB: lookup in both halves, blend based on bit 4
|
|
765
|
-
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(
|
|
766
|
-
_mm256_shuffle_epi8(
|
|
767
|
-
|
|
768
|
-
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(
|
|
769
|
-
_mm256_shuffle_epi8(
|
|
770
|
-
|
|
770
|
+
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, a_shuffle_index_u8x32),
|
|
771
|
+
_mm256_shuffle_epi8(lut_high_u8x32, a_shuffle_index_u8x32),
|
|
772
|
+
a_high_select_u8x32);
|
|
773
|
+
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, b_shuffle_index_u8x32),
|
|
774
|
+
_mm256_shuffle_epi8(lut_high_u8x32, b_shuffle_index_u8x32),
|
|
775
|
+
b_high_select_u8x32);
|
|
771
776
|
|
|
772
777
|
// Combined sign: (a ^ b) & 0x20, negate b where signs differ
|
|
773
778
|
__m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e2m3_u8x32, b_e2m3_u8x32), sign_mask_u8x32);
|
|
@@ -795,10 +800,10 @@ NK_PUBLIC void nk_dot_e3m2_haswell(nk_e3m2_t const *a_scalars, nk_e3m2_t const *
|
|
|
795
800
|
// lut_upper[0..15]: low bytes of {32,40,48,56,64,80,96,112,128,160,192,224,256,320,384,448}
|
|
796
801
|
// High byte is 1 iff magnitude index >= 28 (values 256-448), else 0.
|
|
797
802
|
//
|
|
798
|
-
__m256i const
|
|
803
|
+
__m256i const lut_low_byte_first_u8x32 = _mm256_set_epi8( //
|
|
799
804
|
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
|
|
800
805
|
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
|
|
801
|
-
__m256i const
|
|
806
|
+
__m256i const lut_low_byte_second_u8x32 = _mm256_set_epi8( //
|
|
802
807
|
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
803
808
|
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32);
|
|
804
809
|
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
@@ -831,42 +836,44 @@ nk_dot_e3m2_haswell_cycle:
|
|
|
831
836
|
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e3m2_u8x32, magnitude_mask_u8x32);
|
|
832
837
|
__m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
833
838
|
__m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
834
|
-
__m256i
|
|
835
|
-
|
|
836
|
-
__m256i
|
|
837
|
-
|
|
839
|
+
__m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
|
|
840
|
+
half_select_u8x32);
|
|
841
|
+
__m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
|
|
842
|
+
half_select_u8x32);
|
|
838
843
|
|
|
839
844
|
// Dual VPSHUFB: lookup low bytes in both halves, blend based on bit 4
|
|
840
|
-
__m256i
|
|
841
|
-
_mm256_shuffle_epi8(
|
|
842
|
-
|
|
843
|
-
__m256i
|
|
844
|
-
_mm256_shuffle_epi8(
|
|
845
|
-
|
|
845
|
+
__m256i a_low_byte_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_byte_first_u8x32, a_shuffle_index_u8x32),
|
|
846
|
+
_mm256_shuffle_epi8(lut_low_byte_second_u8x32, a_shuffle_index_u8x32),
|
|
847
|
+
a_high_select_u8x32);
|
|
848
|
+
__m256i b_low_byte_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_byte_first_u8x32, b_shuffle_index_u8x32),
|
|
849
|
+
_mm256_shuffle_epi8(lut_low_byte_second_u8x32, b_shuffle_index_u8x32),
|
|
850
|
+
b_high_select_u8x32);
|
|
846
851
|
|
|
847
852
|
// High byte: 1 iff magnitude >= 28 (signed compare safe: 27 < 128)
|
|
848
|
-
__m256i
|
|
849
|
-
|
|
853
|
+
__m256i a_high_byte_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(a_magnitude_u8x32, high_threshold_u8x32),
|
|
854
|
+
ones_u8x32);
|
|
855
|
+
__m256i b_high_byte_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(b_magnitude_u8x32, high_threshold_u8x32),
|
|
856
|
+
ones_u8x32);
|
|
850
857
|
|
|
851
858
|
// Interleave low and high bytes into i16 (little-endian: low byte first)
|
|
852
|
-
__m256i
|
|
853
|
-
__m256i
|
|
854
|
-
__m256i
|
|
855
|
-
__m256i
|
|
859
|
+
__m256i a_low_i16x16 = _mm256_unpacklo_epi8(a_low_byte_u8x32, a_high_byte_u8x32);
|
|
860
|
+
__m256i a_high_i16x16 = _mm256_unpackhi_epi8(a_low_byte_u8x32, a_high_byte_u8x32);
|
|
861
|
+
__m256i b_low_i16x16 = _mm256_unpacklo_epi8(b_low_byte_u8x32, b_high_byte_u8x32);
|
|
862
|
+
__m256i b_high_i16x16 = _mm256_unpackhi_epi8(b_low_byte_u8x32, b_high_byte_u8x32);
|
|
856
863
|
|
|
857
864
|
// Combined sign: (a ^ b) & 0x20, widen to i16 via unpack, create +1/-1 sign vector
|
|
858
865
|
__m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e3m2_u8x32, b_e3m2_u8x32), sign_mask_u8x32);
|
|
859
866
|
__m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(sign_combined_u8x32, sign_mask_u8x32);
|
|
860
|
-
__m256i
|
|
861
|
-
__m256i
|
|
862
|
-
__m256i
|
|
863
|
-
__m256i
|
|
864
|
-
__m256i
|
|
865
|
-
__m256i
|
|
867
|
+
__m256i negate_low_i16x16 = _mm256_unpacklo_epi8(negate_mask_u8x32, negate_mask_u8x32);
|
|
868
|
+
__m256i negate_high_i16x16 = _mm256_unpackhi_epi8(negate_mask_u8x32, negate_mask_u8x32);
|
|
869
|
+
__m256i sign_low_i16x16 = _mm256_or_si256(negate_low_i16x16, ones_i16x16);
|
|
870
|
+
__m256i sign_high_i16x16 = _mm256_or_si256(negate_high_i16x16, ones_i16x16);
|
|
871
|
+
__m256i b_signed_low_i16x16 = _mm256_sign_epi16(b_low_i16x16, sign_low_i16x16);
|
|
872
|
+
__m256i b_signed_high_i16x16 = _mm256_sign_epi16(b_high_i16x16, sign_high_i16x16);
|
|
866
873
|
|
|
867
874
|
// VPMADDWD: a_unsigned_i16 × b_signed_i16 → i32 accumulator
|
|
868
|
-
sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(
|
|
869
|
-
sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(
|
|
875
|
+
sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(a_low_i16x16, b_signed_low_i16x16));
|
|
876
|
+
sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(a_high_i16x16, b_signed_high_i16x16));
|
|
870
877
|
|
|
871
878
|
if (count_scalars) goto nk_dot_e3m2_haswell_cycle;
|
|
872
879
|
*result = (nk_f32_t)nk_reduce_add_i32x8_haswell_(sum_i32x8) / 256.0f;
|
|
@@ -946,10 +953,34 @@ NK_INTERNAL void nk_dot_through_f32_finalize_haswell_(
|
|
|
946
953
|
typedef struct nk_dot_through_f32_state_haswell_t_ nk_dot_f16x8_state_haswell_t;
|
|
947
954
|
|
|
948
955
|
/**
|
|
949
|
-
* @brief Running state for
|
|
950
|
-
* @note
|
|
956
|
+
* @brief Running state for 256-bit dot accumulation over bf16 scalars on Haswell.
|
|
957
|
+
* @note Processes 16 bf16 per tile step via unpack(zero, bf16) → 2×8 f32 FMA.
|
|
951
958
|
*/
|
|
952
|
-
typedef struct nk_dot_through_f32_state_haswell_t_
|
|
959
|
+
typedef struct nk_dot_through_f32_state_haswell_t_ nk_dot_bf16x16_state_haswell_t;
|
|
960
|
+
|
|
961
|
+
NK_INTERNAL void nk_dot_bf16x16_init_haswell(nk_dot_bf16x16_state_haswell_t *state) {
|
|
962
|
+
nk_dot_through_f32_init_haswell_(state);
|
|
963
|
+
}
|
|
964
|
+
|
|
965
|
+
NK_INTERNAL void nk_dot_bf16x16_update_haswell(nk_dot_bf16x16_state_haswell_t *state, nk_b256_vec_t a, nk_b256_vec_t b,
|
|
966
|
+
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
967
|
+
nk_unused_(depth_offset);
|
|
968
|
+
nk_unused_(active_dimensions);
|
|
969
|
+
__m256i mask_high_u32x8 = _mm256_set1_epi32((int)0xFFFF0000);
|
|
970
|
+
__m256 a_even_f32x8 = _mm256_castsi256_ps(_mm256_slli_epi32(a.ymm, 16));
|
|
971
|
+
__m256 b_even_f32x8 = _mm256_castsi256_ps(_mm256_slli_epi32(b.ymm, 16));
|
|
972
|
+
state->sum_f32x8 = _mm256_fmadd_ps(a_even_f32x8, b_even_f32x8, state->sum_f32x8);
|
|
973
|
+
__m256 a_odd_f32x8 = _mm256_castsi256_ps(_mm256_and_si256(a.ymm, mask_high_u32x8));
|
|
974
|
+
__m256 b_odd_f32x8 = _mm256_castsi256_ps(_mm256_and_si256(b.ymm, mask_high_u32x8));
|
|
975
|
+
state->sum_f32x8 = _mm256_fmadd_ps(a_odd_f32x8, b_odd_f32x8, state->sum_f32x8);
|
|
976
|
+
}
|
|
977
|
+
|
|
978
|
+
NK_INTERNAL void nk_dot_bf16x16_finalize_haswell( //
|
|
979
|
+
nk_dot_bf16x16_state_haswell_t const *state_a, nk_dot_bf16x16_state_haswell_t const *state_b, //
|
|
980
|
+
nk_dot_bf16x16_state_haswell_t const *state_c, nk_dot_bf16x16_state_haswell_t const *state_d, //
|
|
981
|
+
nk_size_t total_dimensions, nk_b128_vec_t *result) {
|
|
982
|
+
nk_dot_through_f32_finalize_haswell_(state_a, state_b, state_c, state_d, total_dimensions, result);
|
|
983
|
+
}
|
|
953
984
|
|
|
954
985
|
/**
|
|
955
986
|
* @brief Running state for 128-bit dot accumulation over e4m3 scalars on Haswell.
|
|
@@ -991,10 +1022,10 @@ NK_INTERNAL void nk_dot_e2m3x32_update_haswell(nk_dot_e2m3x32_state_haswell_t *s
|
|
|
991
1022
|
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
992
1023
|
nk_unused_(depth_offset);
|
|
993
1024
|
nk_unused_(active_dimensions);
|
|
994
|
-
__m256i const
|
|
1025
|
+
__m256i const lut_low_u8x32 = _mm256_set_epi8( //
|
|
995
1026
|
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
|
|
996
1027
|
30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
|
997
|
-
__m256i const
|
|
1028
|
+
__m256i const lut_high_u8x32 = _mm256_set_epi8( //
|
|
998
1029
|
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32, //
|
|
999
1030
|
120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
|
|
1000
1031
|
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
@@ -1011,18 +1042,18 @@ NK_INTERNAL void nk_dot_e2m3x32_update_haswell(nk_dot_e2m3x32_state_haswell_t *s
|
|
|
1011
1042
|
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
|
|
1012
1043
|
__m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
1013
1044
|
__m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
1014
|
-
__m256i
|
|
1015
|
-
|
|
1016
|
-
__m256i
|
|
1017
|
-
|
|
1045
|
+
__m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
|
|
1046
|
+
half_select_u8x32);
|
|
1047
|
+
__m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
|
|
1048
|
+
half_select_u8x32);
|
|
1018
1049
|
|
|
1019
1050
|
// Dual VPSHUFB + blend
|
|
1020
|
-
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(
|
|
1021
|
-
_mm256_shuffle_epi8(
|
|
1022
|
-
|
|
1023
|
-
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(
|
|
1024
|
-
_mm256_shuffle_epi8(
|
|
1025
|
-
|
|
1051
|
+
__m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, a_shuffle_index_u8x32),
|
|
1052
|
+
_mm256_shuffle_epi8(lut_high_u8x32, a_shuffle_index_u8x32),
|
|
1053
|
+
a_high_select_u8x32);
|
|
1054
|
+
__m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, b_shuffle_index_u8x32),
|
|
1055
|
+
_mm256_shuffle_epi8(lut_high_u8x32, b_shuffle_index_u8x32),
|
|
1056
|
+
b_high_select_u8x32);
|
|
1026
1057
|
|
|
1027
1058
|
// Combined sign + conditional negate
|
|
1028
1059
|
__m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e2m3_u8x32, b_e2m3_u8x32), sign_mask_u8x32);
|
|
@@ -1086,9 +1117,9 @@ NK_INTERNAL void nk_dot_e3m2x32_update_haswell(nk_dot_e3m2x32_state_haswell_t *s
|
|
|
1086
1117
|
nk_size_t depth_offset, nk_size_t active_dimensions) {
|
|
1087
1118
|
nk_unused_(depth_offset);
|
|
1088
1119
|
nk_unused_(active_dimensions);
|
|
1089
|
-
__m256i const
|
|
1120
|
+
__m256i const lut_low_byte_first_u8x32 = _mm256_set_epi8( //
|
|
1090
1121
|
28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
|
|
1091
|
-
__m256i const
|
|
1122
|
+
__m256i const lut_low_byte_second_u8x32 = _mm256_set_epi8( //
|
|
1092
1123
|
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
|
|
1093
1124
|
(char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32);
|
|
1094
1125
|
__m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
|
|
@@ -1107,42 +1138,44 @@ NK_INTERNAL void nk_dot_e3m2x32_update_haswell(nk_dot_e3m2x32_state_haswell_t *s
|
|
|
1107
1138
|
__m256i b_magnitude_u8x32 = _mm256_and_si256(b_e3m2_u8x32, magnitude_mask_u8x32);
|
|
1108
1139
|
__m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
|
|
1109
1140
|
__m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
|
|
1110
|
-
__m256i
|
|
1111
|
-
|
|
1112
|
-
__m256i
|
|
1113
|
-
|
|
1141
|
+
__m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
|
|
1142
|
+
half_select_u8x32);
|
|
1143
|
+
__m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
|
|
1144
|
+
half_select_u8x32);
|
|
1114
1145
|
|
|
1115
1146
|
// Dual VPSHUFB for low bytes
|
|
1116
|
-
__m256i
|
|
1117
|
-
_mm256_shuffle_epi8(
|
|
1118
|
-
|
|
1119
|
-
__m256i
|
|
1120
|
-
_mm256_shuffle_epi8(
|
|
1121
|
-
|
|
1147
|
+
__m256i a_low_byte_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_byte_first_u8x32, a_shuffle_index_u8x32),
|
|
1148
|
+
_mm256_shuffle_epi8(lut_low_byte_second_u8x32, a_shuffle_index_u8x32),
|
|
1149
|
+
a_high_select_u8x32);
|
|
1150
|
+
__m256i b_low_byte_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_byte_first_u8x32, b_shuffle_index_u8x32),
|
|
1151
|
+
_mm256_shuffle_epi8(lut_low_byte_second_u8x32, b_shuffle_index_u8x32),
|
|
1152
|
+
b_high_select_u8x32);
|
|
1122
1153
|
|
|
1123
1154
|
// High byte: 1 iff magnitude >= 28
|
|
1124
|
-
__m256i
|
|
1125
|
-
|
|
1155
|
+
__m256i a_high_byte_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(a_magnitude_u8x32, high_threshold_u8x32),
|
|
1156
|
+
ones_u8x32);
|
|
1157
|
+
__m256i b_high_byte_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(b_magnitude_u8x32, high_threshold_u8x32),
|
|
1158
|
+
ones_u8x32);
|
|
1126
1159
|
|
|
1127
1160
|
// Interleave low and high bytes into i16
|
|
1128
|
-
__m256i
|
|
1129
|
-
__m256i
|
|
1130
|
-
__m256i
|
|
1131
|
-
__m256i
|
|
1161
|
+
__m256i a_low_i16x16 = _mm256_unpacklo_epi8(a_low_byte_u8x32, a_high_byte_u8x32);
|
|
1162
|
+
__m256i a_high_i16x16 = _mm256_unpackhi_epi8(a_low_byte_u8x32, a_high_byte_u8x32);
|
|
1163
|
+
__m256i b_low_i16x16 = _mm256_unpacklo_epi8(b_low_byte_u8x32, b_high_byte_u8x32);
|
|
1164
|
+
__m256i b_high_i16x16 = _mm256_unpackhi_epi8(b_low_byte_u8x32, b_high_byte_u8x32);
|
|
1132
1165
|
|
|
1133
1166
|
// Combined sign: (a ^ b) & 0x20, widen to i16, create +1/-1 sign vector via VPSIGNW
|
|
1134
1167
|
__m256i sign_combined_u8x32 = _mm256_and_si256(_mm256_xor_si256(a_e3m2_u8x32, b_e3m2_u8x32), sign_mask_u8x32);
|
|
1135
1168
|
__m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(sign_combined_u8x32, sign_mask_u8x32);
|
|
1136
|
-
__m256i
|
|
1137
|
-
__m256i
|
|
1138
|
-
__m256i
|
|
1139
|
-
__m256i
|
|
1140
|
-
__m256i
|
|
1141
|
-
__m256i
|
|
1169
|
+
__m256i negate_low_i16x16 = _mm256_unpacklo_epi8(negate_mask_u8x32, negate_mask_u8x32);
|
|
1170
|
+
__m256i negate_high_i16x16 = _mm256_unpackhi_epi8(negate_mask_u8x32, negate_mask_u8x32);
|
|
1171
|
+
__m256i sign_low_i16x16 = _mm256_or_si256(negate_low_i16x16, ones_i16x16);
|
|
1172
|
+
__m256i sign_high_i16x16 = _mm256_or_si256(negate_high_i16x16, ones_i16x16);
|
|
1173
|
+
__m256i b_signed_low_i16x16 = _mm256_sign_epi16(b_low_i16x16, sign_low_i16x16);
|
|
1174
|
+
__m256i b_signed_high_i16x16 = _mm256_sign_epi16(b_high_i16x16, sign_high_i16x16);
|
|
1142
1175
|
|
|
1143
1176
|
// VPMADDWD: a_unsigned_i16 × b_signed_i16 → i32 (two halves → two accumulators)
|
|
1144
|
-
state->sum_a_i32x8 = _mm256_add_epi32(state->sum_a_i32x8, _mm256_madd_epi16(
|
|
1145
|
-
state->sum_b_i32x8 = _mm256_add_epi32(state->sum_b_i32x8, _mm256_madd_epi16(
|
|
1177
|
+
state->sum_a_i32x8 = _mm256_add_epi32(state->sum_a_i32x8, _mm256_madd_epi16(a_low_i16x16, b_signed_low_i16x16));
|
|
1178
|
+
state->sum_b_i32x8 = _mm256_add_epi32(state->sum_b_i32x8, _mm256_madd_epi16(a_high_i16x16, b_signed_high_i16x16));
|
|
1146
1179
|
}
|
|
1147
1180
|
|
|
1148
1181
|
NK_INTERNAL void nk_dot_e3m2x32_finalize_haswell( //
|
|
@@ -1176,9 +1209,9 @@ NK_INTERNAL void nk_dot_e3m2x32_finalize_haswell(
|
|
|
1176
1209
|
results->xmm = _mm_castps_si128(sum_f32x4);
|
|
1177
1210
|
}
|
|
1178
1211
|
|
|
1179
|
-
#pragma endregion
|
|
1212
|
+
#pragma endregion F16 and BF16 Floats
|
|
1180
1213
|
|
|
1181
|
-
#pragma region
|
|
1214
|
+
#pragma region I8 and U8 Integers
|
|
1182
1215
|
|
|
1183
1216
|
NK_PUBLIC void nk_dot_i8_haswell(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
|
|
1184
1217
|
nk_i32_t *result) {
|
|
@@ -1275,33 +1308,33 @@ nk_dot_i4_haswell_cycle:
|
|
|
1275
1308
|
}
|
|
1276
1309
|
|
|
1277
1310
|
// Extract low and high nibbles
|
|
1278
|
-
__m128i
|
|
1279
|
-
__m128i
|
|
1280
|
-
__m128i
|
|
1281
|
-
__m128i
|
|
1311
|
+
__m128i a_low_u8x16 = _mm_and_si128(a_i4x32, nibble_mask_u8x16);
|
|
1312
|
+
__m128i a_high_u8x16 = _mm_and_si128(_mm_srli_epi16(a_i4x32, 4), nibble_mask_u8x16);
|
|
1313
|
+
__m128i b_low_u8x16 = _mm_and_si128(b_i4x32, nibble_mask_u8x16);
|
|
1314
|
+
__m128i b_high_u8x16 = _mm_and_si128(_mm_srli_epi16(b_i4x32, 4), nibble_mask_u8x16);
|
|
1282
1315
|
|
|
1283
1316
|
// XOR with 8 to get cx, dx values for the algebraic transformation
|
|
1284
|
-
__m128i
|
|
1285
|
-
__m128i
|
|
1286
|
-
__m128i
|
|
1287
|
-
__m128i
|
|
1317
|
+
__m128i c_low_u8x16 = _mm_xor_si128(a_low_u8x16, xor_mask_u8x16);
|
|
1318
|
+
__m128i c_high_u8x16 = _mm_xor_si128(a_high_u8x16, xor_mask_u8x16);
|
|
1319
|
+
__m128i d_low_u8x16 = _mm_xor_si128(b_low_u8x16, xor_mask_u8x16);
|
|
1320
|
+
__m128i d_high_u8x16 = _mm_xor_si128(b_high_u8x16, xor_mask_u8x16);
|
|
1288
1321
|
|
|
1289
1322
|
// Widen u8 to i16 and multiply using MADD (2× instead of 4×)
|
|
1290
|
-
__m256i
|
|
1291
|
-
__m256i
|
|
1292
|
-
__m256i
|
|
1293
|
-
__m256i
|
|
1323
|
+
__m256i c_low_i16x16 = _mm256_cvtepu8_epi16(c_low_u8x16);
|
|
1324
|
+
__m256i c_high_i16x16 = _mm256_cvtepu8_epi16(c_high_u8x16);
|
|
1325
|
+
__m256i d_low_i16x16 = _mm256_cvtepu8_epi16(d_low_u8x16);
|
|
1326
|
+
__m256i d_high_i16x16 = _mm256_cvtepu8_epi16(d_high_u8x16);
|
|
1294
1327
|
|
|
1295
1328
|
// Multiply i16×i16 and accumulate to i32 using MADD
|
|
1296
|
-
sum_cd_i32x8 = _mm256_add_epi32(sum_cd_i32x8, _mm256_madd_epi16(
|
|
1297
|
-
sum_cd_i32x8 = _mm256_add_epi32(sum_cd_i32x8, _mm256_madd_epi16(
|
|
1329
|
+
sum_cd_i32x8 = _mm256_add_epi32(sum_cd_i32x8, _mm256_madd_epi16(c_low_i16x16, d_low_i16x16));
|
|
1330
|
+
sum_cd_i32x8 = _mm256_add_epi32(sum_cd_i32x8, _mm256_madd_epi16(c_high_i16x16, d_high_i16x16));
|
|
1298
1331
|
|
|
1299
1332
|
// Optimization: Use SAD for correction sums (5cy vs 24cy for 8× widenings)
|
|
1300
1333
|
// PSADBW sums 8× u8 values to a single i64 in each 64-bit lane
|
|
1301
|
-
sum_cx_i64x2 = _mm_add_epi64(sum_cx_i64x2, _mm_sad_epu8(
|
|
1302
|
-
sum_cx_i64x2 = _mm_add_epi64(sum_cx_i64x2, _mm_sad_epu8(
|
|
1303
|
-
sum_dx_i64x2 = _mm_add_epi64(sum_dx_i64x2, _mm_sad_epu8(
|
|
1304
|
-
sum_dx_i64x2 = _mm_add_epi64(sum_dx_i64x2, _mm_sad_epu8(
|
|
1334
|
+
sum_cx_i64x2 = _mm_add_epi64(sum_cx_i64x2, _mm_sad_epu8(c_low_u8x16, zeros_u8x16));
|
|
1335
|
+
sum_cx_i64x2 = _mm_add_epi64(sum_cx_i64x2, _mm_sad_epu8(c_high_u8x16, zeros_u8x16));
|
|
1336
|
+
sum_dx_i64x2 = _mm_add_epi64(sum_dx_i64x2, _mm_sad_epu8(d_low_u8x16, zeros_u8x16));
|
|
1337
|
+
sum_dx_i64x2 = _mm_add_epi64(sum_dx_i64x2, _mm_sad_epu8(d_high_u8x16, zeros_u8x16));
|
|
1305
1338
|
|
|
1306
1339
|
if (n_bytes) goto nk_dot_i4_haswell_cycle;
|
|
1307
1340
|
|
|
@@ -1347,20 +1380,20 @@ nk_dot_u4_haswell_cycle:
|
|
|
1347
1380
|
}
|
|
1348
1381
|
|
|
1349
1382
|
// Extract low and high nibbles
|
|
1350
|
-
__m128i
|
|
1351
|
-
__m128i
|
|
1352
|
-
__m128i
|
|
1353
|
-
__m128i
|
|
1383
|
+
__m128i a_low_u8x16 = _mm_and_si128(a_u4x32, nibble_mask_u8x16);
|
|
1384
|
+
__m128i a_high_u8x16 = _mm_and_si128(_mm_srli_epi16(a_u4x32, 4), nibble_mask_u8x16);
|
|
1385
|
+
__m128i b_low_u8x16 = _mm_and_si128(b_u4x32, nibble_mask_u8x16);
|
|
1386
|
+
__m128i b_high_u8x16 = _mm_and_si128(_mm_srli_epi16(b_u4x32, 4), nibble_mask_u8x16);
|
|
1354
1387
|
|
|
1355
1388
|
// Widen u8 to i16
|
|
1356
|
-
__m256i
|
|
1357
|
-
__m256i
|
|
1358
|
-
__m256i
|
|
1359
|
-
__m256i
|
|
1389
|
+
__m256i a_low_i16x16 = _mm256_cvtepu8_epi16(a_low_u8x16);
|
|
1390
|
+
__m256i a_high_i16x16 = _mm256_cvtepu8_epi16(a_high_u8x16);
|
|
1391
|
+
__m256i b_low_i16x16 = _mm256_cvtepu8_epi16(b_low_u8x16);
|
|
1392
|
+
__m256i b_high_i16x16 = _mm256_cvtepu8_epi16(b_high_u8x16);
|
|
1360
1393
|
|
|
1361
1394
|
// Multiply i16×i16 and accumulate to i32 using MADD
|
|
1362
|
-
sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(
|
|
1363
|
-
sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(
|
|
1395
|
+
sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(a_low_i16x16, b_low_i16x16));
|
|
1396
|
+
sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(a_high_i16x16, b_high_i16x16));
|
|
1364
1397
|
|
|
1365
1398
|
if (n_bytes) goto nk_dot_u4_haswell_cycle;
|
|
1366
1399
|
|
|
@@ -1496,28 +1529,28 @@ NK_INTERNAL void nk_dot_i4x32_update_haswell(nk_dot_i4x32_state_haswell_t *state
|
|
|
1496
1529
|
__m128i b_i4x32 = b.xmm;
|
|
1497
1530
|
|
|
1498
1531
|
// Extract low and high nibbles
|
|
1499
|
-
__m128i
|
|
1500
|
-
__m128i
|
|
1501
|
-
__m128i
|
|
1502
|
-
__m128i
|
|
1532
|
+
__m128i a_low_u8x16 = _mm_and_si128(a_i4x32, nibble_mask_u8x16);
|
|
1533
|
+
__m128i a_high_u8x16 = _mm_and_si128(_mm_srli_epi16(a_i4x32, 4), nibble_mask_u8x16);
|
|
1534
|
+
__m128i b_low_u8x16 = _mm_and_si128(b_i4x32, nibble_mask_u8x16);
|
|
1535
|
+
__m128i b_high_u8x16 = _mm_and_si128(_mm_srli_epi16(b_i4x32, 4), nibble_mask_u8x16);
|
|
1503
1536
|
|
|
1504
1537
|
// XOR with 8 for algebraic transformation
|
|
1505
|
-
__m128i
|
|
1506
|
-
__m128i
|
|
1507
|
-
__m128i
|
|
1508
|
-
__m128i
|
|
1538
|
+
__m128i c_low_u8x16 = _mm_xor_si128(a_low_u8x16, xor_mask_u8x16);
|
|
1539
|
+
__m128i c_high_u8x16 = _mm_xor_si128(a_high_u8x16, xor_mask_u8x16);
|
|
1540
|
+
__m128i d_low_u8x16 = _mm_xor_si128(b_low_u8x16, xor_mask_u8x16);
|
|
1541
|
+
__m128i d_high_u8x16 = _mm_xor_si128(b_high_u8x16, xor_mask_u8x16);
|
|
1509
1542
|
|
|
1510
1543
|
// Widen u8 to i16 and multiply using MADD
|
|
1511
|
-
__m256i
|
|
1512
|
-
__m256i
|
|
1513
|
-
__m256i
|
|
1514
|
-
__m256i
|
|
1544
|
+
__m256i c_low_i16x16 = _mm256_cvtepu8_epi16(c_low_u8x16);
|
|
1545
|
+
__m256i c_high_i16x16 = _mm256_cvtepu8_epi16(c_high_u8x16);
|
|
1546
|
+
__m256i d_low_i16x16 = _mm256_cvtepu8_epi16(d_low_u8x16);
|
|
1547
|
+
__m256i d_high_i16x16 = _mm256_cvtepu8_epi16(d_high_u8x16);
|
|
1515
1548
|
|
|
1516
1549
|
// Multiply and accumulate (no SAD — correction deferred to finalize)
|
|
1517
1550
|
state->biased_product_sum_i32x8 = _mm256_add_epi32(state->biased_product_sum_i32x8,
|
|
1518
|
-
_mm256_madd_epi16(
|
|
1551
|
+
_mm256_madd_epi16(c_low_i16x16, d_low_i16x16));
|
|
1519
1552
|
state->biased_product_sum_i32x8 = _mm256_add_epi32(state->biased_product_sum_i32x8,
|
|
1520
|
-
_mm256_madd_epi16(
|
|
1553
|
+
_mm256_madd_epi16(c_high_i16x16, d_high_i16x16));
|
|
1521
1554
|
}
|
|
1522
1555
|
|
|
1523
1556
|
NK_INTERNAL void nk_dot_i4x32_finalize_haswell( //
|
|
@@ -1585,20 +1618,22 @@ NK_INTERNAL void nk_dot_u4x32_update_haswell(nk_dot_u4x32_state_haswell_t *state
|
|
|
1585
1618
|
__m128i b_u4x32 = b.xmm;
|
|
1586
1619
|
|
|
1587
1620
|
// Extract low and high nibbles
|
|
1588
|
-
__m128i
|
|
1589
|
-
__m128i
|
|
1590
|
-
__m128i
|
|
1591
|
-
__m128i
|
|
1621
|
+
__m128i a_low_u8x16 = _mm_and_si128(a_u4x32, nibble_mask_u8x16);
|
|
1622
|
+
__m128i a_high_u8x16 = _mm_and_si128(_mm_srli_epi16(a_u4x32, 4), nibble_mask_u8x16);
|
|
1623
|
+
__m128i b_low_u8x16 = _mm_and_si128(b_u4x32, nibble_mask_u8x16);
|
|
1624
|
+
__m128i b_high_u8x16 = _mm_and_si128(_mm_srli_epi16(b_u4x32, 4), nibble_mask_u8x16);
|
|
1592
1625
|
|
|
1593
1626
|
// Widen u8 to i16
|
|
1594
|
-
__m256i
|
|
1595
|
-
__m256i
|
|
1596
|
-
__m256i
|
|
1597
|
-
__m256i
|
|
1627
|
+
__m256i a_low_i16x16 = _mm256_cvtepu8_epi16(a_low_u8x16);
|
|
1628
|
+
__m256i a_high_i16x16 = _mm256_cvtepu8_epi16(a_high_u8x16);
|
|
1629
|
+
__m256i b_low_i16x16 = _mm256_cvtepu8_epi16(b_low_u8x16);
|
|
1630
|
+
__m256i b_high_i16x16 = _mm256_cvtepu8_epi16(b_high_u8x16);
|
|
1598
1631
|
|
|
1599
1632
|
// Multiply and accumulate
|
|
1600
|
-
state->product_sum_i32x8 = _mm256_add_epi32(state->product_sum_i32x8,
|
|
1601
|
-
|
|
1633
|
+
state->product_sum_i32x8 = _mm256_add_epi32(state->product_sum_i32x8,
|
|
1634
|
+
_mm256_madd_epi16(a_low_i16x16, b_low_i16x16));
|
|
1635
|
+
state->product_sum_i32x8 = _mm256_add_epi32(state->product_sum_i32x8,
|
|
1636
|
+
_mm256_madd_epi16(a_high_i16x16, b_high_i16x16));
|
|
1602
1637
|
}
|
|
1603
1638
|
|
|
1604
1639
|
NK_INTERNAL void nk_dot_u4x32_finalize_haswell( //
|
|
@@ -1619,23 +1654,23 @@ NK_INTERNAL void nk_dot_u4x32_finalize_haswell(
|
|
|
1619
1654
|
_mm256_extracti128_si256(state_d->product_sum_i32x8, 1));
|
|
1620
1655
|
|
|
1621
1656
|
// 4-way transpose to get [a,b,c,d] in lanes
|
|
1622
|
-
__m128i
|
|
1623
|
-
__m128i
|
|
1624
|
-
__m128i
|
|
1625
|
-
__m128i
|
|
1626
|
-
__m128i
|
|
1627
|
-
__m128i
|
|
1628
|
-
__m128i
|
|
1629
|
-
__m128i
|
|
1657
|
+
__m128i transpose_ab_low_i32x4 = _mm_unpacklo_epi32(product_a_i32x4, product_b_i32x4);
|
|
1658
|
+
__m128i transpose_cd_low_i32x4 = _mm_unpacklo_epi32(product_c_i32x4, product_d_i32x4);
|
|
1659
|
+
__m128i transpose_ab_high_i32x4 = _mm_unpackhi_epi32(product_a_i32x4, product_b_i32x4);
|
|
1660
|
+
__m128i transpose_cd_high_i32x4 = _mm_unpackhi_epi32(product_c_i32x4, product_d_i32x4);
|
|
1661
|
+
__m128i product_lane0_i32x4 = _mm_unpacklo_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
|
|
1662
|
+
__m128i product_lane1_i32x4 = _mm_unpackhi_epi64(transpose_ab_low_i32x4, transpose_cd_low_i32x4);
|
|
1663
|
+
__m128i product_lane2_i32x4 = _mm_unpacklo_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
|
|
1664
|
+
__m128i product_lane3_i32x4 = _mm_unpackhi_epi64(transpose_ab_high_i32x4, transpose_cd_high_i32x4);
|
|
1630
1665
|
|
|
1631
1666
|
// Sum product lanes
|
|
1632
|
-
result->xmm = _mm_add_epi32(_mm_add_epi32(
|
|
1633
|
-
_mm_add_epi32(
|
|
1667
|
+
result->xmm = _mm_add_epi32(_mm_add_epi32(product_lane0_i32x4, product_lane1_i32x4),
|
|
1668
|
+
_mm_add_epi32(product_lane2_i32x4, product_lane3_i32x4));
|
|
1634
1669
|
}
|
|
1635
1670
|
|
|
1636
|
-
#pragma endregion
|
|
1671
|
+
#pragma endregion I8 and U8 Integers
|
|
1637
1672
|
|
|
1638
|
-
#pragma region
|
|
1673
|
+
#pragma region Binary
|
|
1639
1674
|
|
|
1640
1675
|
NK_PUBLIC void nk_dot_u1_haswell(nk_u1x8_t const *a, nk_u1x8_t const *b, nk_size_t n_bits, nk_u32_t *result) {
|
|
1641
1676
|
nk_size_t n_bytes = nk_size_divide_round_up_(n_bits, NK_BITS_PER_BYTE);
|
|
@@ -1671,7 +1706,7 @@ NK_INTERNAL void nk_dot_u1x128_finalize_haswell( //
|
|
|
1671
1706
|
result->u32s[3] = state_d->dot_count;
|
|
1672
1707
|
}
|
|
1673
1708
|
|
|
1674
|
-
#pragma endregion
|
|
1709
|
+
#pragma endregion Binary
|
|
1675
1710
|
|
|
1676
1711
|
#if defined(__clang__)
|
|
1677
1712
|
#pragma clang attribute pop
|