numkong 7.4.5 → 7.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -0
- package/binding.gyp +99 -5
- package/c/dispatch_e5m2.c +23 -3
- package/c/dispatch_f16.c +23 -0
- package/c/numkong.c +0 -13
- package/include/numkong/attention/sme.h +34 -31
- package/include/numkong/capabilities.h +2 -15
- package/include/numkong/cast/README.md +3 -0
- package/include/numkong/cast/haswell.h +28 -64
- package/include/numkong/cast/neon.h +15 -0
- package/include/numkong/cast/serial.h +17 -0
- package/include/numkong/cast/skylake.h +67 -52
- package/include/numkong/cast.h +1 -0
- package/include/numkong/curved/smef64.h +82 -62
- package/include/numkong/dot/README.md +1 -0
- package/include/numkong/dot/haswell.h +92 -13
- package/include/numkong/dot/rvvbf16.h +1 -1
- package/include/numkong/dot/rvvhalf.h +1 -1
- package/include/numkong/dot/serial.h +15 -0
- package/include/numkong/dot/skylake.h +61 -14
- package/include/numkong/dot/sve.h +6 -5
- package/include/numkong/dot/svebfdot.h +2 -1
- package/include/numkong/dot/svehalf.h +6 -5
- package/include/numkong/dot/svesdot.h +3 -2
- package/include/numkong/dots/README.md +2 -0
- package/include/numkong/dots/graniteamx.h +1167 -0
- package/include/numkong/dots/haswell.h +28 -28
- package/include/numkong/dots/sapphireamx.h +1 -1
- package/include/numkong/dots/serial.h +33 -11
- package/include/numkong/dots/skylake.h +28 -23
- package/include/numkong/dots/sme.h +172 -140
- package/include/numkong/dots/smebi32.h +14 -11
- package/include/numkong/dots/smef64.h +31 -26
- package/include/numkong/dots.h +41 -3
- package/include/numkong/each/serial.h +39 -0
- package/include/numkong/geospatial/haswell.h +1 -1
- package/include/numkong/geospatial/neon.h +1 -1
- package/include/numkong/geospatial/serial.h +15 -4
- package/include/numkong/geospatial/skylake.h +1 -1
- package/include/numkong/maxsim/serial.h +15 -0
- package/include/numkong/maxsim/sme.h +34 -33
- package/include/numkong/mesh/README.md +50 -44
- package/include/numkong/mesh/genoa.h +462 -0
- package/include/numkong/mesh/haswell.h +806 -933
- package/include/numkong/mesh/neon.h +871 -943
- package/include/numkong/mesh/neonbfdot.h +382 -522
- package/include/numkong/mesh/neonfhm.h +676 -0
- package/include/numkong/mesh/rvv.h +404 -319
- package/include/numkong/mesh/serial.h +225 -161
- package/include/numkong/mesh/skylake.h +1029 -1585
- package/include/numkong/mesh/v128relaxed.h +403 -377
- package/include/numkong/mesh.h +38 -0
- package/include/numkong/reduce/neon.h +29 -0
- package/include/numkong/reduce/neonbfdot.h +2 -2
- package/include/numkong/reduce/neonfhm.h +4 -4
- package/include/numkong/reduce/serial.h +15 -1
- package/include/numkong/reduce/sve.h +52 -0
- package/include/numkong/reduce.h +4 -0
- package/include/numkong/set/sve.h +6 -5
- package/include/numkong/sets/smebi32.h +35 -30
- package/include/numkong/sparse/serial.h +17 -2
- package/include/numkong/sparse/sve2.h +3 -2
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +98 -56
- package/include/numkong/spatial/serial.h +15 -0
- package/include/numkong/spatial/skylake.h +114 -54
- package/include/numkong/spatial/sve.h +7 -6
- package/include/numkong/spatial/svebfdot.h +7 -4
- package/include/numkong/spatial/svehalf.h +5 -4
- package/include/numkong/spatial/svesdot.h +9 -8
- package/include/numkong/spatial.h +0 -12
- package/include/numkong/spatials/graniteamx.h +301 -0
- package/include/numkong/spatials/serial.h +39 -0
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +391 -350
- package/include/numkong/spatials/smef64.h +79 -70
- package/include/numkong/spatials.h +54 -4
- package/include/numkong/tensor.hpp +107 -23
- package/include/numkong/types.h +59 -0
- package/javascript/dist/cjs/numkong.js +13 -0
- package/javascript/dist/esm/numkong.js +13 -0
- package/javascript/numkong.c +59 -14
- package/javascript/numkong.ts +13 -0
- package/package.json +7 -7
- package/probes/probe.js +2 -2
- package/wasm/numkong.wasm +0 -0
|
@@ -175,60 +175,63 @@ NK_INTERNAL __m256i nk_f32x16_to_bf16x16_skylake_(__m512 a) {
|
|
|
175
175
|
return _mm512_cvtepi32_epi16(x);
|
|
176
176
|
}
|
|
177
177
|
|
|
178
|
-
/** @brief Convert 16x e4m3 → 16x f32 via
|
|
179
|
-
* E4M3
|
|
180
|
-
*
|
|
178
|
+
/** @brief Convert 16x e4m3 → 16x f32 via Giesen-style fake-F16 cast (AVX-512 + F16C).
|
|
179
|
+
* E4M3 `byte = S EEEE MMM` (bias 7). Shifting the magnitude into F16 positions
|
|
180
|
+
* `((byte & 0x7F) << 7) | ((byte & 0x80) << 8)` yields a fake F16 whose F16 value
|
|
181
|
+
* differs from the true E4M3 magnitude by exactly 2⁸ (bias delta 15 − 7). The
|
|
182
|
+
* fake F16 is widened via `vcvtph2ps` and corrected by ×256 in F32. Subnormal
|
|
183
|
+
* handling falls out for free via F16 subnormal semantics. NaN (|byte|==0x7F) is
|
|
184
|
+
* the sole E4M3 special value that would misinterpret as finite; blended
|
|
185
|
+
* explicitly with F32 quiet NaN bits. */
|
|
181
186
|
NK_INTERNAL __m512 nk_e4m3x16_to_f32x16_skylake_(__m128i e4m3_i8x16) {
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
__m512
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
187
|
+
__m256i const magnitude_mask_u16x16 = _mm256_set1_epi16(0x7F);
|
|
188
|
+
__m256i const sign_mask_u16x16 = _mm256_set1_epi16((short)0x80);
|
|
189
|
+
__m256i const f16_nan_u16x16 = _mm256_set1_epi16(0x7E00);
|
|
190
|
+
__m256i word_u16x16 = _mm256_cvtepu8_epi16(e4m3_i8x16);
|
|
191
|
+
__m256i magnitude_u16x16 = _mm256_and_si256(word_u16x16, magnitude_mask_u16x16);
|
|
192
|
+
__mmask16 is_nan = _mm256_cmpeq_epi16_mask(magnitude_u16x16, magnitude_mask_u16x16);
|
|
193
|
+
__m256i shifted_magnitude_u16x16 = _mm256_slli_epi16(magnitude_u16x16, 7);
|
|
194
|
+
__m256i shifted_sign_u16x16 = _mm256_slli_epi16(_mm256_and_si256(word_u16x16, sign_mask_u16x16), 8);
|
|
195
|
+
__m256i f16_bits_u16x16 = _mm256_or_si256(shifted_magnitude_u16x16, shifted_sign_u16x16);
|
|
196
|
+
f16_bits_u16x16 = _mm256_mask_mov_epi16(f16_bits_u16x16, is_nan, f16_nan_u16x16);
|
|
197
|
+
__m512 fake_f32x16 = _mm512_cvtph_ps(f16_bits_u16x16);
|
|
198
|
+
return _mm512_mul_ps(fake_f32x16, _mm512_set1_ps(256.0f));
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
/** @brief Convert 16x e4m3 → 16x f16 via arithmetic + 8-entry subnormal LUT (AVX-512BW + AVX-512VL).
|
|
202
|
+
* E4M3: S EEEE MMM (bias=7). F16: S EEEEE MMMMMMMMMM (bias=15).
|
|
203
|
+
* Normal (exp != 0): F16 = ((lower7 << 7) + 0x2000) | (sign << 8) — bias delta 8 added at the
|
|
204
|
+
* exp-position (8 << 10 = 0x2000) after placing magnitude bits at F16 positions 13..7.
|
|
205
|
+
* Subnormal (exp == 0): looked up from 8-entry F16 LUT — values 0, 1/512, 2/512, …, 7/512 encoded as
|
|
206
|
+
* F16 normals (the smallest E4M3 subnormal 1/512 = 2⁻⁹ is well within F16 normal range).
|
|
207
|
+
* NaN (|byte| == 0x7F): blended in as F16 quiet NaN with original sign. */
|
|
208
|
+
NK_INTERNAL __m256i nk_e4m3x16_to_f16x16_skylake_(__m128i e4m3_u8x16) {
|
|
209
|
+
__m256i e4m3_i16x16 = _mm256_cvtepu8_epi16(e4m3_u8x16);
|
|
210
|
+
__m256i sign_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16((short)0x80));
|
|
211
|
+
__m256i lower7_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x7F));
|
|
212
|
+
__m256i normal_abs_i16x16 = _mm256_add_epi16(_mm256_slli_epi16(lower7_i16x16, 7), _mm256_set1_epi16(0x2000));
|
|
213
|
+
__m256i subn_lut_i16x16 = _mm256_set_epi16( //
|
|
214
|
+
0x2300, 0x2200, 0x2100, 0x2000, 0x1E00, 0x1C00, 0x1800, 0x0000, 0x2300, 0x2200, 0x2100, 0x2000, 0x1E00, 0x1C00,
|
|
215
|
+
0x1800, 0x0000);
|
|
216
|
+
__m256i mant_idx_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x07));
|
|
217
|
+
__m256i subn_abs_i16x16 = _mm256_permutexvar_epi16(mant_idx_i16x16, subn_lut_i16x16);
|
|
218
|
+
__mmask16 is_subnormal = _mm256_testn_epi16_mask(e4m3_i16x16, _mm256_set1_epi16(0x78));
|
|
219
|
+
__m256i abs_i16x16 = _mm256_mask_blend_epi16(is_subnormal, normal_abs_i16x16, subn_abs_i16x16);
|
|
220
|
+
__m256i shifted_sign_i16x16 = _mm256_slli_epi16(sign_i16x16, 8);
|
|
221
|
+
__m256i result_i16x16 = _mm256_or_si256(abs_i16x16, shifted_sign_i16x16);
|
|
222
|
+
__mmask16 is_nan = _mm256_cmpeq_epi16_mask(lower7_i16x16, _mm256_set1_epi16(0x7F));
|
|
223
|
+
__m256i nan_i16x16 = _mm256_or_si256(shifted_sign_i16x16, _mm256_set1_epi16(0x7E00));
|
|
224
|
+
return _mm256_mask_blend_epi16(is_nan, result_i16x16, nan_i16x16);
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
/** @brief Convert 16x e5m2 → 16x f32 via free-shift widen (AVX-512 + F16C).
|
|
228
|
+
* E5M2 shares F16's exponent bias (15): `(byte << 8)` is the matching F16 bit
|
|
229
|
+
* pattern for every E5M2 value (normals, subnormals, zero, ±Inf, NaN — all
|
|
230
|
+
* bit-exact). Widen u8 → u16, shift, then VCVTPH2PS to F32. Three ops total. */
|
|
214
231
|
NK_INTERNAL __m512 nk_e5m2x16_to_f32x16_skylake_(__m128i e5m2_i8x16) {
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
__m512i exp_i32x16 = _mm512_and_si512(_mm512_srli_epi32(e5m2_i32x16, 2), _mm512_set1_epi32(0x1F));
|
|
219
|
-
__m512i mantissa_i32x16 = _mm512_and_si512(e5m2_i32x16, _mm512_set1_epi32(0x03));
|
|
220
|
-
__m512i sign_i32x16 = _mm512_slli_epi32(_mm512_srli_epi32(e5m2_i32x16, 7), 31);
|
|
221
|
-
|
|
222
|
-
// Normal path: sign | ((exp+112)<<23) | (mantissa<<21)
|
|
223
|
-
__m512i f32_exp_i32x16 = _mm512_slli_epi32(_mm512_add_epi32(exp_i32x16, _mm512_set1_epi32(112)), 23);
|
|
224
|
-
__m512i f32_mantissa_i32x16 = _mm512_slli_epi32(mantissa_i32x16, 21);
|
|
225
|
-
__m512 result_f32x16 = _mm512_castsi512_ps(
|
|
226
|
-
_mm512_ternarylogic_epi32(sign_i32x16, f32_exp_i32x16, f32_mantissa_i32x16, 0xFE));
|
|
227
|
-
|
|
228
|
-
// Subnormal fix: for exp==0 lanes, replace with (mantissa / 65536) | sign using masked OR
|
|
229
|
-
__mmask16 is_subnormal = _mm512_testn_epi32_mask(e5m2_i32x16, _mm512_set1_epi32(0x7C));
|
|
230
|
-
__m512 subnorm_abs_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(mantissa_i32x16), _mm512_set1_ps(1.0f / 65536.0f));
|
|
231
|
-
return _mm512_mask_or_ps(result_f32x16, is_subnormal, subnorm_abs_f32x16, _mm512_castsi512_ps(sign_i32x16));
|
|
232
|
+
__m256i e5m2_u16x16 = _mm256_cvtepu8_epi16(e5m2_i8x16);
|
|
233
|
+
__m256i f16_bits_u16x16 = _mm256_slli_epi16(e5m2_u16x16, 8);
|
|
234
|
+
return _mm512_cvtph_ps(f16_bits_u16x16);
|
|
232
235
|
}
|
|
233
236
|
|
|
234
237
|
/** @brief Convert 16x e2m3 → 16x f32 via bit manipulation (AVX-512).
|
|
@@ -660,6 +663,18 @@ NK_INTERNAL void nk_partial_load_e4m3x16_to_f32x16_skylake_(void const *src, nk_
|
|
|
660
663
|
dst->zmm_ps = nk_e4m3x16_to_f32x16_skylake_(e4m3_partial.xmm);
|
|
661
664
|
}
|
|
662
665
|
|
|
666
|
+
/** @brief Load 16 e4m3 values and convert to 16 f16 (Skylake AVX-512BW). */
|
|
667
|
+
NK_INTERNAL void nk_load_e4m3x16_to_f16x16_skylake_(void const *src, nk_b256_vec_t *dst) {
|
|
668
|
+
dst->ymm = nk_e4m3x16_to_f16x16_skylake_(_mm_loadu_si128((__m128i const *)src));
|
|
669
|
+
}
|
|
670
|
+
|
|
671
|
+
/** @brief Partial load of up to 16 e4m3 values with conversion to f16 (Skylake AVX-512BW). */
|
|
672
|
+
NK_INTERNAL void nk_partial_load_e4m3x16_to_f16x16_skylake_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
|
|
673
|
+
nk_b128_vec_t e4m3_partial;
|
|
674
|
+
nk_partial_load_b8x16_skylake_(src, &e4m3_partial, n);
|
|
675
|
+
dst->ymm = nk_e4m3x16_to_f16x16_skylake_(e4m3_partial.xmm);
|
|
676
|
+
}
|
|
677
|
+
|
|
663
678
|
/** @brief Load 16 e5m2 values and convert to 16 f32 (Skylake AVX-512). */
|
|
664
679
|
NK_INTERNAL void nk_load_e5m2x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst) {
|
|
665
680
|
dst->zmm_ps = nk_e5m2x16_to_f32x16_skylake_(_mm_loadu_si128((__m128i const *)src));
|
package/include/numkong/cast.h
CHANGED
|
@@ -175,6 +175,7 @@ NK_PUBLIC void nk_cast_v128relaxed(void const *from, nk_dtype_t from_type, nk_si
|
|
|
175
175
|
#include "numkong/cast/icelake.h"
|
|
176
176
|
#include "numkong/cast/sapphire.h"
|
|
177
177
|
#include "numkong/cast/rvv.h"
|
|
178
|
+
#include "numkong/cast/v128relaxed.h"
|
|
178
179
|
#include "numkong/cast/powervsx.h"
|
|
179
180
|
#include "numkong/cast/loongsonasx.h"
|
|
180
181
|
|
|
@@ -52,9 +52,10 @@
|
|
|
52
52
|
#if NK_TARGET_SMEF64
|
|
53
53
|
|
|
54
54
|
#include "numkong/types.h"
|
|
55
|
+
#include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
|
|
55
56
|
#include "numkong/spatial/neon.h" // `nk_f64_sqrt_neon`
|
|
56
|
-
#include "numkong/dots/sme.h" // nk_sme_zero_za64_tile_0_
|
|
57
|
-
#include "numkong/curved/serial.h" // `nk_bilinear_f64_serial
|
|
57
|
+
#include "numkong/dots/sme.h" // `nk_sme_zero_za64_tile_0_`
|
|
58
|
+
#include "numkong/curved/serial.h" // `nk_bilinear_f64_serial`
|
|
58
59
|
|
|
59
60
|
#if defined(__cplusplus)
|
|
60
61
|
extern "C" {
|
|
@@ -90,8 +91,8 @@ NK_PUBLIC void nk_dot2_f64_sve_accumulate_(svbool_t predicate_b64x, svfloat64_t
|
|
|
90
91
|
* @brief f32 bilinear: GEMV via FMOPA (widening f32→f64, exact accumulation).
|
|
91
92
|
* ZA0.D = C staging, ZA1.D = GEMV accumulator.
|
|
92
93
|
*/
|
|
93
|
-
|
|
94
|
-
nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t dimensions, nk_f64_t *result) {
|
|
94
|
+
__arm_new("za") static void nk_bilinear_f32_smef64_streaming_( //
|
|
95
|
+
nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t dimensions, nk_f64_t *result) NK_STREAMING_ {
|
|
95
96
|
svbool_t predicate_body_b64x = svptrue_b64();
|
|
96
97
|
nk_size_t tile_dimension = svcntd();
|
|
97
98
|
nk_f64_t outer_sum_f64 = 0.0;
|
|
@@ -124,24 +125,25 @@ __arm_locally_streaming __arm_new("za") static void nk_bilinear_f32_smef64_strea
|
|
|
124
125
|
svfloat64_t v_f64x = svread_ver_za64_f64_m(svdup_f64(0.0), row_predicate_b64x, 1, 0);
|
|
125
126
|
svfloat64_t a_f64x = svcvt_f64_f32_x(
|
|
126
127
|
row_predicate_b64x, svreinterpret_f32_u64(svld1uw_u64(row_predicate_b64x, (nk_u32_t const *)(a + row))));
|
|
127
|
-
outer_sum_f64 +=
|
|
128
|
+
outer_sum_f64 += nk_svaddv_f64_(predicate_body_b64x, svmul_f64_x(row_predicate_b64x, a_f64x, v_f64x));
|
|
128
129
|
}
|
|
129
130
|
|
|
130
131
|
*result = outer_sum_f64;
|
|
131
132
|
}
|
|
132
133
|
|
|
133
|
-
NK_PUBLIC void nk_bilinear_f32_smef64(
|
|
134
|
-
|
|
134
|
+
NK_PUBLIC void nk_bilinear_f32_smef64( //
|
|
135
|
+
nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t dimensions, nk_f64_t *result) {
|
|
136
|
+
nk_sme_start_streaming_();
|
|
135
137
|
nk_bilinear_f32_smef64_streaming_(a, b, c, dimensions, result);
|
|
138
|
+
nk_sme_stop_streaming_();
|
|
136
139
|
}
|
|
137
140
|
|
|
138
141
|
/**
|
|
139
142
|
* @brief f32 Mahalanobis: GEMV v = C×d via FMOPA, where d = a − b (exact in f64).
|
|
140
143
|
* ZA0.D = C staging, ZA1.D = GEMV accumulator.
|
|
141
144
|
*/
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
nk_size_t dimensions) {
|
|
145
|
+
__arm_new("za") static nk_f64_t nk_mahalanobis_f32_smef64_streaming_( //
|
|
146
|
+
nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t dimensions) NK_STREAMING_ {
|
|
145
147
|
|
|
146
148
|
svbool_t predicate_body_b64x = svptrue_b64();
|
|
147
149
|
nk_size_t tile_dimension = svcntd();
|
|
@@ -179,15 +181,17 @@ __arm_locally_streaming __arm_new("za") static nk_f64_t
|
|
|
179
181
|
svfloat64_t b_f64x = svcvt_f64_f32_x(
|
|
180
182
|
row_predicate_b64x, svreinterpret_f32_u64(svld1uw_u64(row_predicate_b64x, (nk_u32_t const *)(b + row))));
|
|
181
183
|
svfloat64_t d_f64x = svsub_f64_x(row_predicate_b64x, a_f64x, b_f64x);
|
|
182
|
-
outer_sum_f64 +=
|
|
184
|
+
outer_sum_f64 += nk_svaddv_f64_(predicate_body_b64x, svmul_f64_x(row_predicate_b64x, d_f64x, v_f64x));
|
|
183
185
|
}
|
|
184
186
|
|
|
185
187
|
return outer_sum_f64;
|
|
186
188
|
}
|
|
187
189
|
|
|
188
|
-
NK_PUBLIC void nk_mahalanobis_f32_smef64(
|
|
189
|
-
|
|
190
|
+
NK_PUBLIC void nk_mahalanobis_f32_smef64( //
|
|
191
|
+
nk_f32_t const *a, nk_f32_t const *b, nk_f32_t const *c, nk_size_t dimensions, nk_f64_t *result) {
|
|
192
|
+
nk_sme_start_streaming_();
|
|
190
193
|
nk_f64_t quadratic = nk_mahalanobis_f32_smef64_streaming_(a, b, c, dimensions);
|
|
194
|
+
nk_sme_stop_streaming_();
|
|
191
195
|
*result = nk_f64_sqrt_neon(quadratic > 0 ? quadratic : 0);
|
|
192
196
|
}
|
|
193
197
|
|
|
@@ -195,9 +199,8 @@ NK_PUBLIC void nk_mahalanobis_f32_smef64(nk_f32_t const *a, nk_f32_t const *b, n
|
|
|
195
199
|
* @brief f64 bilinear: row-by-row streaming SVE with Dot2 compensation.
|
|
196
200
|
* 4-row fast path shares b_f64x loads; 1-row tail for remainder.
|
|
197
201
|
*/
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
nk_f64_t *result) {
|
|
202
|
+
static void nk_bilinear_f64_smef64_ssve_( //
|
|
203
|
+
nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t dimensions, nk_f64_t *result) NK_STREAMING_ {
|
|
201
204
|
svbool_t predicate_all_b64x = svptrue_b64();
|
|
202
205
|
nk_f64_t outer_sum = 0.0, outer_comp = 0.0;
|
|
203
206
|
nk_size_t row = 0;
|
|
@@ -226,14 +229,18 @@ __arm_locally_streaming static void nk_bilinear_f64_smef64_streaming_(nk_f64_t c
|
|
|
226
229
|
predicate_b64x = svwhilelt_b64(j, dimensions);
|
|
227
230
|
}
|
|
228
231
|
|
|
229
|
-
nk_f64_dot2_(
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
nk_f64_dot2_(
|
|
236
|
-
|
|
232
|
+
nk_f64_dot2_(
|
|
233
|
+
&outer_sum, &outer_comp, a0,
|
|
234
|
+
nk_svaddv_f64_(predicate_all_b64x, sum_0_f64x) + nk_svaddv_f64_(predicate_all_b64x, compensation_0_f64x));
|
|
235
|
+
nk_f64_dot2_(
|
|
236
|
+
&outer_sum, &outer_comp, a1,
|
|
237
|
+
nk_svaddv_f64_(predicate_all_b64x, sum_1_f64x) + nk_svaddv_f64_(predicate_all_b64x, compensation_1_f64x));
|
|
238
|
+
nk_f64_dot2_(
|
|
239
|
+
&outer_sum, &outer_comp, a2,
|
|
240
|
+
nk_svaddv_f64_(predicate_all_b64x, sum_2_f64x) + nk_svaddv_f64_(predicate_all_b64x, compensation_2_f64x));
|
|
241
|
+
nk_f64_dot2_(
|
|
242
|
+
&outer_sum, &outer_comp, a3,
|
|
243
|
+
nk_svaddv_f64_(predicate_all_b64x, sum_3_f64x) + nk_svaddv_f64_(predicate_all_b64x, compensation_3_f64x));
|
|
237
244
|
}
|
|
238
245
|
|
|
239
246
|
// 1-row tail
|
|
@@ -250,24 +257,27 @@ __arm_locally_streaming static void nk_bilinear_f64_smef64_streaming_(nk_f64_t c
|
|
|
250
257
|
predicate_b64x = svwhilelt_b64(j, dimensions);
|
|
251
258
|
}
|
|
252
259
|
|
|
253
|
-
nk_f64_t cb_j =
|
|
260
|
+
nk_f64_t cb_j = nk_svaddv_f64_(predicate_all_b64x, sum_f64x) +
|
|
261
|
+
nk_svaddv_f64_(predicate_all_b64x, compensation_f64x);
|
|
254
262
|
nk_f64_dot2_(&outer_sum, &outer_comp, a[row], cb_j);
|
|
255
263
|
}
|
|
256
264
|
|
|
257
265
|
*result = outer_sum + outer_comp;
|
|
258
266
|
}
|
|
259
267
|
|
|
260
|
-
NK_PUBLIC void nk_bilinear_f64_smef64(
|
|
261
|
-
|
|
262
|
-
|
|
268
|
+
NK_PUBLIC void nk_bilinear_f64_smef64( //
|
|
269
|
+
nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t dimensions, nk_f64_t *result) {
|
|
270
|
+
nk_sme_start_streaming_();
|
|
271
|
+
nk_bilinear_f64_smef64_ssve_(a, b, c, dimensions, result);
|
|
272
|
+
nk_sme_stop_streaming_();
|
|
263
273
|
}
|
|
264
274
|
|
|
265
275
|
/**
|
|
266
276
|
* @brief f64 Mahalanobis: row-by-row streaming SVE with Dot2 compensation.
|
|
267
277
|
* 4-row fast path shares (a−b) column vector; 1-row tail for remainder.
|
|
268
278
|
*/
|
|
269
|
-
|
|
270
|
-
|
|
279
|
+
static nk_f64_t nk_mahalanobis_f64_smef64_ssve_( //
|
|
280
|
+
nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t dimensions) NK_STREAMING_ {
|
|
271
281
|
svbool_t predicate_all_b64x = svptrue_b64();
|
|
272
282
|
nk_f64_t outer_sum = 0.0, outer_comp = 0.0;
|
|
273
283
|
nk_size_t row = 0;
|
|
@@ -298,14 +308,18 @@ __arm_locally_streaming static nk_f64_t nk_mahalanobis_f64_smef64_streaming_(nk_
|
|
|
298
308
|
predicate_b64x = svwhilelt_b64(j, dimensions);
|
|
299
309
|
}
|
|
300
310
|
|
|
301
|
-
nk_f64_dot2_(
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
nk_f64_dot2_(
|
|
308
|
-
|
|
311
|
+
nk_f64_dot2_(
|
|
312
|
+
&outer_sum, &outer_comp, d0,
|
|
313
|
+
nk_svaddv_f64_(predicate_all_b64x, sum_0_f64x) + nk_svaddv_f64_(predicate_all_b64x, compensation_0_f64x));
|
|
314
|
+
nk_f64_dot2_(
|
|
315
|
+
&outer_sum, &outer_comp, d1,
|
|
316
|
+
nk_svaddv_f64_(predicate_all_b64x, sum_1_f64x) + nk_svaddv_f64_(predicate_all_b64x, compensation_1_f64x));
|
|
317
|
+
nk_f64_dot2_(
|
|
318
|
+
&outer_sum, &outer_comp, d2,
|
|
319
|
+
nk_svaddv_f64_(predicate_all_b64x, sum_2_f64x) + nk_svaddv_f64_(predicate_all_b64x, compensation_2_f64x));
|
|
320
|
+
nk_f64_dot2_(
|
|
321
|
+
&outer_sum, &outer_comp, d3,
|
|
322
|
+
nk_svaddv_f64_(predicate_all_b64x, sum_3_f64x) + nk_svaddv_f64_(predicate_all_b64x, compensation_3_f64x));
|
|
309
323
|
}
|
|
310
324
|
|
|
311
325
|
// 1-row tail
|
|
@@ -324,16 +338,19 @@ __arm_locally_streaming static nk_f64_t nk_mahalanobis_f64_smef64_streaming_(nk_
|
|
|
324
338
|
predicate_b64x = svwhilelt_b64(j, dimensions);
|
|
325
339
|
}
|
|
326
340
|
|
|
327
|
-
nk_f64_t cb_j =
|
|
341
|
+
nk_f64_t cb_j = nk_svaddv_f64_(predicate_all_b64x, sum_f64x) +
|
|
342
|
+
nk_svaddv_f64_(predicate_all_b64x, compensation_f64x);
|
|
328
343
|
nk_f64_dot2_(&outer_sum, &outer_comp, diff_row, cb_j);
|
|
329
344
|
}
|
|
330
345
|
|
|
331
346
|
return outer_sum + outer_comp;
|
|
332
347
|
}
|
|
333
348
|
|
|
334
|
-
NK_PUBLIC void nk_mahalanobis_f64_smef64(
|
|
335
|
-
|
|
336
|
-
|
|
349
|
+
NK_PUBLIC void nk_mahalanobis_f64_smef64( //
|
|
350
|
+
nk_f64_t const *a, nk_f64_t const *b, nk_f64_t const *c, nk_size_t dimensions, nk_f64_t *result) {
|
|
351
|
+
nk_sme_start_streaming_();
|
|
352
|
+
nk_f64_t quadratic = nk_mahalanobis_f64_smef64_ssve_(a, b, c, dimensions);
|
|
353
|
+
nk_sme_stop_streaming_();
|
|
337
354
|
*result = nk_f64_sqrt_neon(quadratic > 0 ? quadratic : 0);
|
|
338
355
|
}
|
|
339
356
|
|
|
@@ -341,11 +358,9 @@ NK_PUBLIC void nk_mahalanobis_f64_smef64(nk_f64_t const *a, nk_f64_t const *b, n
|
|
|
341
358
|
* @brief f32c bilinear: complex GEMV via FMOPA (widening f32→f64).
|
|
342
359
|
* ZA0.D = C staging, ZA1.D = v_real accumulator, ZA2.D = v_imag accumulator.
|
|
343
360
|
*/
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
nk_size_t dimensions,
|
|
348
|
-
nk_f64c_t *results) {
|
|
361
|
+
__arm_new("za") static void nk_bilinear_f32c_smef64_streaming_( //
|
|
362
|
+
nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_f32c_t const *c_pairs, nk_size_t dimensions,
|
|
363
|
+
nk_f64c_t *results) NK_STREAMING_ {
|
|
349
364
|
svbool_t predicate_body_b64x = svptrue_b64();
|
|
350
365
|
nk_size_t tile_dimension = svcntd();
|
|
351
366
|
nk_f64_t outer_sum_real_f64 = 0.0, outer_sum_imag_f64 = 0.0;
|
|
@@ -407,10 +422,10 @@ __arm_locally_streaming __arm_new("za") static void nk_bilinear_f32c_smef64_stre
|
|
|
407
422
|
svfloat64_t a_im_f64x = svcvt_f64_f32_x(row_predicate_b64x, svtrn2_f32(a_f32x, a_f32x));
|
|
408
423
|
|
|
409
424
|
// Complex dot: a × v
|
|
410
|
-
outer_sum_real_f64 +=
|
|
425
|
+
outer_sum_real_f64 += nk_svaddv_f64_(
|
|
411
426
|
predicate_body_b64x, svsub_f64_x(row_predicate_b64x, svmul_f64_x(row_predicate_b64x, a_re_f64x, v_re_f64x),
|
|
412
427
|
svmul_f64_x(row_predicate_b64x, a_im_f64x, v_im_f64x)));
|
|
413
|
-
outer_sum_imag_f64 +=
|
|
428
|
+
outer_sum_imag_f64 += nk_svaddv_f64_(
|
|
414
429
|
predicate_body_b64x, svadd_f64_x(row_predicate_b64x, svmul_f64_x(row_predicate_b64x, a_re_f64x, v_im_f64x),
|
|
415
430
|
svmul_f64_x(row_predicate_b64x, a_im_f64x, v_re_f64x)));
|
|
416
431
|
}
|
|
@@ -419,19 +434,21 @@ __arm_locally_streaming __arm_new("za") static void nk_bilinear_f32c_smef64_stre
|
|
|
419
434
|
results->imag = outer_sum_imag_f64;
|
|
420
435
|
}
|
|
421
436
|
|
|
422
|
-
NK_PUBLIC void nk_bilinear_f32c_smef64(
|
|
423
|
-
|
|
437
|
+
NK_PUBLIC void nk_bilinear_f32c_smef64( //
|
|
438
|
+
nk_f32c_t const *a_pairs, nk_f32c_t const *b_pairs, nk_f32c_t const *c_pairs, nk_size_t dimensions,
|
|
439
|
+
nk_f64c_t *results) {
|
|
440
|
+
nk_sme_start_streaming_();
|
|
424
441
|
nk_bilinear_f32c_smef64_streaming_(a_pairs, b_pairs, c_pairs, dimensions, results);
|
|
442
|
+
nk_sme_stop_streaming_();
|
|
425
443
|
}
|
|
426
444
|
|
|
427
445
|
/**
|
|
428
446
|
* @brief f64c bilinear: interleaved Dot2 with permute + deferred XOR sign-flip.
|
|
429
447
|
* 2 accumulators instead of 4, halving inner loop work (~15 vs ~28 SVE ops).
|
|
430
448
|
*/
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
nk_f64c_t *results) {
|
|
449
|
+
static void nk_bilinear_f64c_smef64_ssve_( //
|
|
450
|
+
nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_f64c_t const *c_pairs, nk_size_t dimensions,
|
|
451
|
+
nk_f64c_t *results) NK_STREAMING_ {
|
|
435
452
|
svbool_t predicate_all_b64x = svptrue_b64();
|
|
436
453
|
nk_f64_t outer_sum_real = 0.0, outer_comp_real = 0.0;
|
|
437
454
|
nk_f64_t outer_sum_imag = 0.0, outer_comp_imag = 0.0;
|
|
@@ -474,10 +491,10 @@ __arm_locally_streaming static void nk_bilinear_f64c_smef64_streaming_(nk_f64c_t
|
|
|
474
491
|
sveor_u64_x(predicate_all_b64x, svreinterpret_u64_f64(sum_real_f64x), sign_mask_u64x));
|
|
475
492
|
comp_real_f64x = svreinterpret_f64_u64(
|
|
476
493
|
sveor_u64_x(predicate_all_b64x, svreinterpret_u64_f64(comp_real_f64x), sign_mask_u64x));
|
|
477
|
-
nk_f64_t inner_real =
|
|
478
|
-
|
|
479
|
-
nk_f64_t inner_imag =
|
|
480
|
-
|
|
494
|
+
nk_f64_t inner_real = nk_svaddv_f64_(predicate_all_b64x,
|
|
495
|
+
svadd_f64_x(predicate_all_b64x, sum_real_f64x, comp_real_f64x));
|
|
496
|
+
nk_f64_t inner_imag = nk_svaddv_f64_(predicate_all_b64x,
|
|
497
|
+
svadd_f64_x(predicate_all_b64x, sum_imag_f64x, comp_imag_f64x));
|
|
481
498
|
|
|
482
499
|
// Outer Dot2 complex multiply: a × inner
|
|
483
500
|
nk_f64_dot2_(&outer_sum_real, &outer_comp_real, a_real, inner_real);
|
|
@@ -490,9 +507,12 @@ __arm_locally_streaming static void nk_bilinear_f64c_smef64_streaming_(nk_f64c_t
|
|
|
490
507
|
results->imag = outer_sum_imag + outer_comp_imag;
|
|
491
508
|
}
|
|
492
509
|
|
|
493
|
-
NK_PUBLIC void nk_bilinear_f64c_smef64(
|
|
494
|
-
|
|
495
|
-
|
|
510
|
+
NK_PUBLIC void nk_bilinear_f64c_smef64( //
|
|
511
|
+
nk_f64c_t const *a_pairs, nk_f64c_t const *b_pairs, nk_f64c_t const *c_pairs, nk_size_t dimensions,
|
|
512
|
+
nk_f64c_t *results) {
|
|
513
|
+
nk_sme_start_streaming_();
|
|
514
|
+
nk_bilinear_f64c_smef64_ssve_(a_pairs, b_pairs, c_pairs, dimensions, results);
|
|
515
|
+
nk_sme_stop_streaming_();
|
|
496
516
|
}
|
|
497
517
|
|
|
498
518
|
#if defined(__clang__)
|
|
@@ -111,6 +111,7 @@ This processes 64 E4M3 bytes per iteration in u8, doubling the element density o
|
|
|
111
111
|
|
|
112
112
|
`nk_dot_e5m2_genoa` converts FP8 values to BF16, then accumulates via `VDPBF16PS`, reusing Genoa's BF16 dot-product instruction for FP8 types.
|
|
113
113
|
Each `VDPBF16PS` fuses two BF16 multiply-adds per 32-bit lane at 6-cycle throughput.
|
|
114
|
+
On Skylake-X–class CPUs without BF16 dot-product hardware, `nk_dot_e4m3_skylake` / `nk_dot_e5m2_skylake` (and their Haswell twins `nk_dot_e4m3_haswell` / `nk_dot_e5m2_haswell`) instead route through the Giesen-style FP8 → F16 fake-bit-pattern cast, widen via `VCVTPH2PS`, and accumulate in F32 with two independent FMA chains reducing into a single register — avoiding the 3-chain scheduler-stall of the BF16 algebraic form on kernels without native BF16 FMA.
|
|
114
115
|
`nk_dot_bf16c_genoa` uses the same instruction for complex BF16, preparing operands with `VPSHUFB` for lane swapping and `VPXORD` with `0x80000000` for sign flips before feeding into `VDPBF16PS`.
|
|
115
116
|
|
|
116
117
|
### Deferred Sign-Flip in Complex Dot Products
|
|
@@ -698,25 +698,39 @@ nk_dot_e4m3_haswell_cycle:
|
|
|
698
698
|
|
|
699
699
|
NK_PUBLIC void nk_dot_e5m2_haswell(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
|
|
700
700
|
nk_f32_t *result) {
|
|
701
|
-
|
|
702
|
-
__m256
|
|
701
|
+
// E5M2 shares F16 bias; inline the free-shift unpack for the two 8-lane halves.
|
|
702
|
+
__m256 first_chain_f32x8 = _mm256_setzero_ps();
|
|
703
|
+
__m256 second_chain_f32x8 = _mm256_setzero_ps();
|
|
704
|
+
__m128i const zero_u8x16 = _mm_setzero_si128();
|
|
705
|
+
__m128i a_u8x16, b_u8x16;
|
|
706
|
+
|
|
703
707
|
nk_dot_e5m2_haswell_cycle:
|
|
704
|
-
if (count_scalars <
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
708
|
+
if (count_scalars < 16) {
|
|
709
|
+
nk_b128_vec_t a_vec, b_vec;
|
|
710
|
+
nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
|
|
711
|
+
nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
|
|
712
|
+
a_u8x16 = a_vec.xmm;
|
|
713
|
+
b_u8x16 = b_vec.xmm;
|
|
710
714
|
count_scalars = 0;
|
|
711
715
|
}
|
|
712
716
|
else {
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
a_scalars +=
|
|
717
|
+
a_u8x16 = _mm_loadu_si128((__m128i const *)a_scalars);
|
|
718
|
+
b_u8x16 = _mm_loadu_si128((__m128i const *)b_scalars);
|
|
719
|
+
a_scalars += 16, b_scalars += 16, count_scalars -= 16;
|
|
716
720
|
}
|
|
717
|
-
|
|
721
|
+
__m128i a_even_f16x8 = _mm_unpacklo_epi8(zero_u8x16, a_u8x16);
|
|
722
|
+
__m128i a_odd_f16x8 = _mm_unpackhi_epi8(zero_u8x16, a_u8x16);
|
|
723
|
+
__m128i b_even_f16x8 = _mm_unpacklo_epi8(zero_u8x16, b_u8x16);
|
|
724
|
+
__m128i b_odd_f16x8 = _mm_unpackhi_epi8(zero_u8x16, b_u8x16);
|
|
725
|
+
__m256 a_first_f32x8 = _mm256_cvtph_ps(a_even_f16x8);
|
|
726
|
+
__m256 a_second_f32x8 = _mm256_cvtph_ps(a_odd_f16x8);
|
|
727
|
+
__m256 b_first_f32x8 = _mm256_cvtph_ps(b_even_f16x8);
|
|
728
|
+
__m256 b_second_f32x8 = _mm256_cvtph_ps(b_odd_f16x8);
|
|
729
|
+
first_chain_f32x8 = _mm256_fmadd_ps(a_first_f32x8, b_first_f32x8, first_chain_f32x8);
|
|
730
|
+
second_chain_f32x8 = _mm256_fmadd_ps(a_second_f32x8, b_second_f32x8, second_chain_f32x8);
|
|
718
731
|
if (count_scalars) goto nk_dot_e5m2_haswell_cycle;
|
|
719
|
-
|
|
732
|
+
|
|
733
|
+
*result = (nk_f32_t)nk_reduce_add_f32x8_haswell_(_mm256_add_ps(first_chain_f32x8, second_chain_f32x8));
|
|
720
734
|
}
|
|
721
735
|
|
|
722
736
|
NK_PUBLIC void nk_dot_e2m3_haswell(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
|
|
@@ -910,6 +924,71 @@ NK_INTERNAL void nk_dot_through_f32_update_haswell_(nk_dot_through_f32_state_has
|
|
|
910
924
|
state->sum_f32x8 = _mm256_fmadd_ps(a.ymm_ps, b.ymm_ps, state->sum_f32x8);
|
|
911
925
|
}
|
|
912
926
|
|
|
927
|
+
/**
|
|
928
|
+
* @brief E5M2 byte-batched update: consumes 32 raw E5M2 bytes per call and widens inline.
|
|
929
|
+
* Two independent FMA chains (each 2-deep) merge into the single __m256 state accumulator.
|
|
930
|
+
*/
|
|
931
|
+
NK_INTERNAL void nk_dot_e5m2x32_update_haswell_(nk_dot_through_f32_state_haswell_t_ *state, nk_b256_vec_t a_bytes,
|
|
932
|
+
nk_b256_vec_t b_bytes, nk_size_t depth_offset,
|
|
933
|
+
nk_size_t active_dimensions) {
|
|
934
|
+
nk_unused_(depth_offset);
|
|
935
|
+
nk_unused_(active_dimensions);
|
|
936
|
+
__m128i const zero_u8x16 = _mm_setzero_si128();
|
|
937
|
+
__m128i a_low_u8x16 = _mm256_castsi256_si128(a_bytes.ymm);
|
|
938
|
+
__m128i a_high_u8x16 = _mm256_extracti128_si256(a_bytes.ymm, 1);
|
|
939
|
+
__m128i b_low_u8x16 = _mm256_castsi256_si128(b_bytes.ymm);
|
|
940
|
+
__m128i b_high_u8x16 = _mm256_extracti128_si256(b_bytes.ymm, 1);
|
|
941
|
+
__m128i a_first_f16x8 = _mm_unpacklo_epi8(zero_u8x16, a_low_u8x16);
|
|
942
|
+
__m128i a_second_f16x8 = _mm_unpackhi_epi8(zero_u8x16, a_low_u8x16);
|
|
943
|
+
__m128i a_third_f16x8 = _mm_unpacklo_epi8(zero_u8x16, a_high_u8x16);
|
|
944
|
+
__m128i a_fourth_f16x8 = _mm_unpackhi_epi8(zero_u8x16, a_high_u8x16);
|
|
945
|
+
__m128i b_first_f16x8 = _mm_unpacklo_epi8(zero_u8x16, b_low_u8x16);
|
|
946
|
+
__m128i b_second_f16x8 = _mm_unpackhi_epi8(zero_u8x16, b_low_u8x16);
|
|
947
|
+
__m128i b_third_f16x8 = _mm_unpacklo_epi8(zero_u8x16, b_high_u8x16);
|
|
948
|
+
__m128i b_fourth_f16x8 = _mm_unpackhi_epi8(zero_u8x16, b_high_u8x16);
|
|
949
|
+
__m256 a_first_f32x8 = _mm256_cvtph_ps(a_first_f16x8);
|
|
950
|
+
__m256 a_second_f32x8 = _mm256_cvtph_ps(a_second_f16x8);
|
|
951
|
+
__m256 a_third_f32x8 = _mm256_cvtph_ps(a_third_f16x8);
|
|
952
|
+
__m256 a_fourth_f32x8 = _mm256_cvtph_ps(a_fourth_f16x8);
|
|
953
|
+
__m256 b_first_f32x8 = _mm256_cvtph_ps(b_first_f16x8);
|
|
954
|
+
__m256 b_second_f32x8 = _mm256_cvtph_ps(b_second_f16x8);
|
|
955
|
+
__m256 b_third_f32x8 = _mm256_cvtph_ps(b_third_f16x8);
|
|
956
|
+
__m256 b_fourth_f32x8 = _mm256_cvtph_ps(b_fourth_f16x8);
|
|
957
|
+
__m256 first_chain_f32x8 = _mm256_mul_ps(a_first_f32x8, b_first_f32x8);
|
|
958
|
+
__m256 second_chain_f32x8 = _mm256_mul_ps(a_second_f32x8, b_second_f32x8);
|
|
959
|
+
first_chain_f32x8 = _mm256_fmadd_ps(a_third_f32x8, b_third_f32x8, first_chain_f32x8);
|
|
960
|
+
second_chain_f32x8 = _mm256_fmadd_ps(a_fourth_f32x8, b_fourth_f32x8, second_chain_f32x8);
|
|
961
|
+
state->sum_f32x8 = _mm256_add_ps(state->sum_f32x8, _mm256_add_ps(first_chain_f32x8, second_chain_f32x8));
|
|
962
|
+
}
|
|
963
|
+
|
|
964
|
+
/**
|
|
965
|
+
* @brief E4M3 byte-batched update: consumes 32 raw E4M3 bytes per call. Widens through the
|
|
966
|
+
* Giesen cast helper. Two independent FMA chains merge into the single state accumulator.
|
|
967
|
+
*/
|
|
968
|
+
NK_INTERNAL void nk_dot_e4m3x32_update_haswell_(nk_dot_through_f32_state_haswell_t_ *state, nk_b256_vec_t a_bytes,
|
|
969
|
+
nk_b256_vec_t b_bytes, nk_size_t depth_offset,
|
|
970
|
+
nk_size_t active_dimensions) {
|
|
971
|
+
nk_unused_(depth_offset);
|
|
972
|
+
nk_unused_(active_dimensions);
|
|
973
|
+
__m128i a_low_u8x16 = _mm256_castsi256_si128(a_bytes.ymm);
|
|
974
|
+
__m128i a_high_u8x16 = _mm256_extracti128_si256(a_bytes.ymm, 1);
|
|
975
|
+
__m128i b_low_u8x16 = _mm256_castsi256_si128(b_bytes.ymm);
|
|
976
|
+
__m128i b_high_u8x16 = _mm256_extracti128_si256(b_bytes.ymm, 1);
|
|
977
|
+
__m256 a_first_f32x8 = nk_e4m3x8_to_f32x8_haswell_(a_low_u8x16);
|
|
978
|
+
__m256 a_second_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_unpackhi_epi64(a_low_u8x16, a_low_u8x16));
|
|
979
|
+
__m256 a_third_f32x8 = nk_e4m3x8_to_f32x8_haswell_(a_high_u8x16);
|
|
980
|
+
__m256 a_fourth_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_unpackhi_epi64(a_high_u8x16, a_high_u8x16));
|
|
981
|
+
__m256 b_first_f32x8 = nk_e4m3x8_to_f32x8_haswell_(b_low_u8x16);
|
|
982
|
+
__m256 b_second_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_unpackhi_epi64(b_low_u8x16, b_low_u8x16));
|
|
983
|
+
__m256 b_third_f32x8 = nk_e4m3x8_to_f32x8_haswell_(b_high_u8x16);
|
|
984
|
+
__m256 b_fourth_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_unpackhi_epi64(b_high_u8x16, b_high_u8x16));
|
|
985
|
+
__m256 first_chain_f32x8 = _mm256_mul_ps(a_first_f32x8, b_first_f32x8);
|
|
986
|
+
__m256 second_chain_f32x8 = _mm256_mul_ps(a_second_f32x8, b_second_f32x8);
|
|
987
|
+
first_chain_f32x8 = _mm256_fmadd_ps(a_third_f32x8, b_third_f32x8, first_chain_f32x8);
|
|
988
|
+
second_chain_f32x8 = _mm256_fmadd_ps(a_fourth_f32x8, b_fourth_f32x8, second_chain_f32x8);
|
|
989
|
+
state->sum_f32x8 = _mm256_add_ps(state->sum_f32x8, _mm256_add_ps(first_chain_f32x8, second_chain_f32x8));
|
|
990
|
+
}
|
|
991
|
+
|
|
913
992
|
/**
|
|
914
993
|
* @brief Finalizes 4x low-precision dot-products placing them into 4x consecutive 32-bit slots.
|
|
915
994
|
* @sa nk_dot_f16x8_finalize_haswell, nk_dot_bf16x8_finalize_haswell
|
|
@@ -22,7 +22,7 @@
|
|
|
22
22
|
#if NK_TARGET_RVVBF16
|
|
23
23
|
|
|
24
24
|
#include "numkong/types.h"
|
|
25
|
-
#include "numkong/cast/rvv.h" // `nk_e4m3m1_to_bf16m2_rvv_`, `nk_e5m2m1_to_bf16m2_rvv_
|
|
25
|
+
#include "numkong/cast/rvv.h" // `nk_e4m3m1_to_bf16m2_rvv_`, `nk_e5m2m1_to_bf16m2_rvv_`
|
|
26
26
|
|
|
27
27
|
#if defined(__clang__)
|
|
28
28
|
#pragma clang attribute push(__attribute__((target("arch=+v,+zvfbfwma"))), apply_to = function)
|
|
@@ -23,7 +23,7 @@
|
|
|
23
23
|
#if NK_TARGET_RVVHALF
|
|
24
24
|
|
|
25
25
|
#include "numkong/types.h"
|
|
26
|
-
#include "numkong/cast/rvv.h" // `nk_e4m3m1_to_f16m2_rvv_`, `nk_e2m3m1_to_f16m2_rvv_
|
|
26
|
+
#include "numkong/cast/rvv.h" // `nk_e4m3m1_to_f16m2_rvv_`, `nk_e2m3m1_to_f16m2_rvv_`
|
|
27
27
|
|
|
28
28
|
#if defined(__clang__)
|
|
29
29
|
#pragma clang attribute push(__attribute__((target("arch=+v,+zvfh"))), apply_to = function)
|
|
@@ -139,6 +139,15 @@ extern "C" {
|
|
|
139
139
|
result->imag = sum_imag; \
|
|
140
140
|
}
|
|
141
141
|
|
|
142
|
+
/* Keep the serial instantiations below actually scalar, regardless of build type.
|
|
143
|
+
* See dots/serial.h for rationale. */
|
|
144
|
+
#if defined(__clang__)
|
|
145
|
+
#pragma clang attribute push(__attribute__((noinline)), apply_to = function)
|
|
146
|
+
#elif defined(__GNUC__)
|
|
147
|
+
#pragma GCC push_options
|
|
148
|
+
#pragma GCC optimize("no-tree-vectorize", "no-tree-slp-vectorize", "no-ipa-cp-clone", "no-inline")
|
|
149
|
+
#endif
|
|
150
|
+
|
|
142
151
|
#pragma region F32 and F64 Floats
|
|
143
152
|
|
|
144
153
|
nk_define_dot_(f32, f64, f64, nk_assign_from_to_) // nk_dot_f32_serial
|
|
@@ -867,6 +876,12 @@ NK_INTERNAL nk_i32_t nk_sum_i4x32_finalize_serial(nk_sum_i4x32_state_serial_t co
|
|
|
867
876
|
|
|
868
877
|
#pragma endregion Stateful Element Sum Helpers
|
|
869
878
|
|
|
879
|
+
#if defined(__clang__)
|
|
880
|
+
#pragma clang attribute pop
|
|
881
|
+
#elif defined(__GNUC__)
|
|
882
|
+
#pragma GCC pop_options
|
|
883
|
+
#endif
|
|
884
|
+
|
|
870
885
|
#if defined(__cplusplus)
|
|
871
886
|
} // extern "C"
|
|
872
887
|
#endif
|