numkong 7.5.0 → 7.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/binding.gyp +18 -0
  2. package/c/dispatch_e5m2.c +23 -3
  3. package/include/numkong/capabilities.h +1 -1
  4. package/include/numkong/cast/README.md +3 -0
  5. package/include/numkong/cast/haswell.h +28 -64
  6. package/include/numkong/cast/serial.h +17 -0
  7. package/include/numkong/cast/skylake.h +67 -52
  8. package/include/numkong/cast.h +1 -0
  9. package/include/numkong/dot/README.md +1 -0
  10. package/include/numkong/dot/haswell.h +92 -13
  11. package/include/numkong/dot/serial.h +15 -0
  12. package/include/numkong/dot/skylake.h +61 -14
  13. package/include/numkong/dots/README.md +2 -0
  14. package/include/numkong/dots/graniteamx.h +434 -0
  15. package/include/numkong/dots/haswell.h +28 -28
  16. package/include/numkong/dots/sapphireamx.h +1 -1
  17. package/include/numkong/dots/serial.h +23 -8
  18. package/include/numkong/dots/skylake.h +28 -23
  19. package/include/numkong/dots.h +12 -0
  20. package/include/numkong/each/serial.h +18 -1
  21. package/include/numkong/geospatial/serial.h +14 -3
  22. package/include/numkong/maxsim/serial.h +15 -0
  23. package/include/numkong/mesh/README.md +50 -44
  24. package/include/numkong/mesh/genoa.h +462 -0
  25. package/include/numkong/mesh/haswell.h +806 -933
  26. package/include/numkong/mesh/neon.h +871 -943
  27. package/include/numkong/mesh/neonbfdot.h +382 -522
  28. package/include/numkong/mesh/neonfhm.h +676 -0
  29. package/include/numkong/mesh/rvv.h +404 -319
  30. package/include/numkong/mesh/serial.h +204 -162
  31. package/include/numkong/mesh/skylake.h +1029 -1585
  32. package/include/numkong/mesh/v128relaxed.h +403 -377
  33. package/include/numkong/mesh.h +38 -0
  34. package/include/numkong/reduce/serial.h +15 -1
  35. package/include/numkong/sparse/serial.h +17 -2
  36. package/include/numkong/spatial/genoa.h +0 -68
  37. package/include/numkong/spatial/haswell.h +98 -56
  38. package/include/numkong/spatial/serial.h +15 -0
  39. package/include/numkong/spatial/skylake.h +114 -54
  40. package/include/numkong/spatial.h +0 -12
  41. package/include/numkong/spatials/graniteamx.h +128 -0
  42. package/include/numkong/spatials/serial.h +18 -1
  43. package/include/numkong/spatials/skylake.h +2 -2
  44. package/include/numkong/spatials.h +17 -0
  45. package/include/numkong/tensor.hpp +107 -23
  46. package/javascript/numkong.c +3 -2
  47. package/package.json +7 -7
  48. package/wasm/numkong.wasm +0 -0
package/binding.gyp CHANGED
@@ -70,6 +70,7 @@
70
70
  },
71
71
  "conditions": [
72
72
  # Pin TU baseline to each arch's ABI floor; SIMD kernels use per-function pragmas.
73
+ # Keep per-arch table in sync with cmake/nk_compiler_flags.cmake, build.rs, setup.py.
73
74
  [
74
75
  "OS!='win' and target_arch=='arm64'",
75
76
  {
@@ -94,6 +95,23 @@
94
95
  ]
95
96
  }
96
97
  ],
98
+ [
99
+ "OS!='win' and target_arch=='ppc64'",
100
+ {
101
+ "cflags": [
102
+ "-mcpu=power8"
103
+ ]
104
+ }
105
+ ],
106
+ [
107
+ "OS!='win' and target_arch=='loong64'",
108
+ {
109
+ "cflags": [
110
+ "-march=loongarch64",
111
+ "-mlasx"
112
+ ]
113
+ }
114
+ ],
97
115
  # Forbid auto-vectorization so serial fallbacks don't get silently
98
116
  # promoted to NEON/SSE2/VSX. SIMD kernels use explicit intrinsics
99
117
  # and per-function `target` pragmas; unaffected. MSVC has no
package/c/dispatch_e5m2.c CHANGED
@@ -113,6 +113,29 @@ void nk_dispatch_e5m2_find_(nk_capability_t v, nk_kernel_kind_t k, nk_kernel_pun
113
113
  default: break;
114
114
  }
115
115
  #endif
116
+ #if NK_TARGET_GRANITEAMX
117
+ if (v & nk_cap_graniteamx_k) switch (k) {
118
+ case nk_kernel_dots_packed_size_k:
119
+ *m = (m_t)&nk_dots_packed_size_e5m2_graniteamx, *c = nk_cap_graniteamx_k;
120
+ return;
121
+ case nk_kernel_dots_pack_k: *m = (m_t)&nk_dots_pack_e5m2_graniteamx, *c = nk_cap_graniteamx_k; return;
122
+ case nk_kernel_dots_packed_k: *m = (m_t)&nk_dots_packed_e5m2_graniteamx, *c = nk_cap_graniteamx_k; return;
123
+ case nk_kernel_angulars_packed_k:
124
+ *m = (m_t)&nk_angulars_packed_e5m2_graniteamx, *c = nk_cap_graniteamx_k;
125
+ return;
126
+ case nk_kernel_euclideans_packed_k:
127
+ *m = (m_t)&nk_euclideans_packed_e5m2_graniteamx, *c = nk_cap_graniteamx_k;
128
+ return;
129
+ case nk_kernel_dots_symmetric_k: *m = (m_t)&nk_dots_symmetric_e5m2_graniteamx, *c = nk_cap_graniteamx_k; return;
130
+ case nk_kernel_angulars_symmetric_k:
131
+ *m = (m_t)&nk_angulars_symmetric_e5m2_graniteamx, *c = nk_cap_graniteamx_k;
132
+ return;
133
+ case nk_kernel_euclideans_symmetric_k:
134
+ *m = (m_t)&nk_euclideans_symmetric_e5m2_graniteamx, *c = nk_cap_graniteamx_k;
135
+ return;
136
+ default: break;
137
+ }
138
+ #endif
116
139
  #if NK_TARGET_SAPPHIREAMX
117
140
  if (v & nk_cap_sapphireamx_k) switch (k) {
118
141
  case nk_kernel_dots_packed_size_k:
@@ -162,9 +185,6 @@ void nk_dispatch_e5m2_find_(nk_capability_t v, nk_kernel_kind_t k, nk_kernel_pun
162
185
  #if NK_TARGET_GENOA
163
186
  if (v & nk_cap_genoa_k) switch (k) {
164
187
  case nk_kernel_dot_k: *m = (m_t)&nk_dot_e5m2_genoa, *c = nk_cap_genoa_k; return;
165
- case nk_kernel_euclidean_k: *m = (m_t)&nk_euclidean_e5m2_genoa, *c = nk_cap_genoa_k; return;
166
- case nk_kernel_sqeuclidean_k: *m = (m_t)&nk_sqeuclidean_e5m2_genoa, *c = nk_cap_genoa_k; return;
167
- case nk_kernel_angular_k: *m = (m_t)&nk_angular_e5m2_genoa, *c = nk_cap_genoa_k; return;
168
188
  case nk_kernel_dots_packed_size_k: *m = (m_t)&nk_dots_packed_size_e5m2_genoa, *c = nk_cap_genoa_k; return;
169
189
  case nk_kernel_dots_pack_k: *m = (m_t)&nk_dots_pack_e5m2_genoa, *c = nk_cap_genoa_k; return;
170
190
  case nk_kernel_dots_packed_k: *m = (m_t)&nk_dots_packed_e5m2_genoa, *c = nk_cap_genoa_k; return;
@@ -95,7 +95,7 @@
95
95
  #include "numkong/types.h" // `nk_u64_t`, `NK_DEFINED_LINUX_`
96
96
 
97
97
  #define NK_VERSION_MAJOR 7
98
- #define NK_VERSION_MINOR 5
98
+ #define NK_VERSION_MINOR 6
99
99
  #define NK_VERSION_PATCH 0
100
100
 
101
101
  /**
@@ -93,6 +93,9 @@ NEON backend uses `vreinterpretq_u16_u8` + `vzip` for zero-extension; Haswell us
93
93
  `nk_f16_to_f32_haswell`, `nk_f32_to_f16_haswell` use the F16C extension instructions `VCVTPH2PS` / `VCVTPS2PH` — single-instruction conversion of 8 elements with correct denormal handling, NaN propagation, and RNE rounding.
94
94
  The serial fallback (`nk_f16_to_f32_serial`) must handle denormals via explicit exponent/mantissa extraction and conditional re-normalization — ~15 integer ops per element vs 1 instruction with F16C.
95
95
  AVX-512 (`nk_cast_skylake`) doubles throughput to 16 elements per instruction.
96
+ F16C also unlocks a cheaper FP8 → F32 path that bypasses i32-lane bit math: `nk_e5m2x16_to_f32x16_skylake_` and `nk_e5m2x8_to_f32x8_haswell_` widen u8 → u16 and left-shift by 8 (E5M2 shares F16's bias 15, so the result is a bit-exact F16 encoding of every input including subnormals and NaN), then feed `VCVTPH2PS` — three ops total.
97
+ E4M3 can't use a plain shift (bias 7 vs 15), but the Giesen-style fake-F16 `((byte & 0x7F) << 7) | ((byte & 0x80) << 8)` gives an F16 whose value differs from the E4M3 magnitude by exactly 2⁸; `nk_e4m3x16_to_f32x16_skylake_` and `nk_e4m3x8_to_f32x8_haswell_` widen through `VCVTPH2PS`, multiply by 256 in F32 to correct, and blend in F32 NaN for the lone `|byte|==0x7F` encoding.
98
+ For E4M3 GEMM specifically, `nk_e4m3x16_to_f16x16_skylake_` produces TRUE F16 (bias-corrected, with a small subnormal LUT and NaN blend) so the packed buffer stores 2 bytes/element instead of 4 — the inner loop reads F16 and widens to F32 once per B-load, trading ~10% compute for 50% pack memory.
96
99
 
97
100
  ## Performance
98
101
 
@@ -172,72 +172,36 @@ NK_INTERNAL __m128i nk_f32x8_to_u8x8_haswell_(__m256 f32x8) {
172
172
  return _mm_packus_epi16(u16x8, _mm_setzero_si128());
173
173
  }
174
174
 
175
- /** @brief Convert 8x e4m3 → 8x f32 via bit manipulation (AVX2).
176
- * E4M3 format: S EEEE MMM (bias=7). F32: sign<<31, (exp+120)<<23, mant<<20.
177
- * Subnormals (exp=0): looked up via vpermps from an 8-entry register LUT.
178
- * NaN detection uses a single comparison on the 7-bit magnitude (0x7F). */
175
+ /** @brief Convert 8x e4m3 → 8x f32 via Giesen-style fake-F16 cast (AVX2 + F16C).
176
+ * E4M3 `byte = S EEEE MMM` (bias 7). Shifting the magnitude into F16 positions
177
+ * `((byte & 0x7F) << 7) | ((byte & 0x80) << 8)` yields a fake F16 whose F16 value
178
+ * differs from the true E4M3 magnitude by exactly 2⁸ (bias delta 15 − 7). The
179
+ * fake F16 is widened via `vcvtph2ps` and corrected by ×256 in F32. Subnormal
180
+ * handling falls out for free via F16 subnormal semantics. NaN (|byte|==0x7F)
181
+ * is blended explicitly with F16 quiet-NaN bits. */
179
182
  NK_INTERNAL __m256 nk_e4m3x8_to_f32x8_haswell_(__m128i e4m3_i8x8) {
180
- __m256i e4m3_i32x8 = _mm256_cvtepu8_epi32(e4m3_i8x8);
181
-
182
- // Extract fields
183
- __m256i exp_i32x8 = _mm256_and_si256(_mm256_srli_epi32(e4m3_i32x8, 3), _mm256_set1_epi32(0x0F));
184
- __m256i mant_i32x8 = _mm256_and_si256(e4m3_i32x8, _mm256_set1_epi32(0x07));
185
-
186
- // Build F32 sign bit
187
- __m256i f32_sign_i32x8 = _mm256_slli_epi32(_mm256_srli_epi32(e4m3_i32x8, 7), 31);
188
-
189
- // Normal path: sign | ((exp+120)<<23) | (mant<<20)
190
- __m256i f32_exp_i32x8 = _mm256_slli_epi32(_mm256_add_epi32(exp_i32x8, _mm256_set1_epi32(120)), 23);
191
- __m256i f32_mant_i32x8 = _mm256_slli_epi32(mant_i32x8, 20);
192
- __m256i normal_bits_i32x8 = _mm256_or_si256(f32_sign_i32x8, _mm256_or_si256(f32_exp_i32x8, f32_mant_i32x8));
193
-
194
- // Subnormal path: vpermps from 8-entry register LUT (3 cy latency, no memory access)
195
- __m256 subnorm_lut_f32x8 = _mm256_setr_ps(0, 1.0f / 512, 2.0f / 512, 3.0f / 512, //
196
- 4.0f / 512, 5.0f / 512, 6.0f / 512, 7.0f / 512);
197
- __m256i subnorm_bits_i32x8 = _mm256_or_si256( //
198
- _mm256_castps_si256(_mm256_permutevar8x32_ps(subnorm_lut_f32x8, mant_i32x8)), f32_sign_i32x8);
199
-
200
- // Bitwise select: if exp==0, use subnormal; otherwise use normal
201
- __m256i exp_zero_mask = _mm256_cmpeq_epi32(exp_i32x8, _mm256_setzero_si256());
202
- __m256i result_i32x8 = _mm256_or_si256( //
203
- _mm256_and_si256(exp_zero_mask, subnorm_bits_i32x8), //
204
- _mm256_andnot_si256(exp_zero_mask, normal_bits_i32x8));
205
-
206
- // NaN: E4M3FN has NaN only at magnitude 0x7F (exp=15, mant=7)
207
- __m256i lower7_i32x8 = _mm256_and_si256(e4m3_i32x8, _mm256_set1_epi32(0x7F));
208
- __m256i is_nan_mask = _mm256_cmpeq_epi32(lower7_i32x8, _mm256_set1_epi32(0x7F));
209
- __m256i nan_i32x8 = _mm256_or_si256(f32_sign_i32x8, _mm256_set1_epi32(0x7FC00000));
210
- result_i32x8 = _mm256_or_si256( //
211
- _mm256_and_si256(is_nan_mask, nan_i32x8), //
212
- _mm256_andnot_si256(is_nan_mask, result_i32x8));
213
- return _mm256_castsi256_ps(result_i32x8);
214
- }
215
-
216
- /** @brief Convert 8x e5m2 → 8x f32 via bit manipulation (AVX2).
217
- * E5M2 format: S EEEEE MM (bias=15). F32: sign<<31, (exp+112)<<23, mant<<21.
218
- * Subnormals (exp=0): value = mantissa × 2⁽¹⁻¹⁵⁾ × 2⁻² = mantissa ÷ 65536. */
183
+ __m128i const magnitude_mask_u16x8 = _mm_set1_epi16(0x7F);
184
+ __m128i const sign_mask_u16x8 = _mm_set1_epi16((short)0x80);
185
+ __m128i const f16_nan_u16x8 = _mm_set1_epi16(0x7E00);
186
+ __m128i word_u16x8 = _mm_cvtepu8_epi16(e4m3_i8x8);
187
+ __m128i magnitude_u16x8 = _mm_and_si128(word_u16x8, magnitude_mask_u16x8);
188
+ __m128i is_nan_u16x8 = _mm_cmpeq_epi16(magnitude_u16x8, magnitude_mask_u16x8);
189
+ __m128i shifted_magnitude_u16x8 = _mm_slli_epi16(magnitude_u16x8, 7);
190
+ __m128i shifted_sign_u16x8 = _mm_slli_epi16(_mm_and_si128(word_u16x8, sign_mask_u16x8), 8);
191
+ __m128i f16_bits_u16x8 = _mm_or_si128(shifted_magnitude_u16x8, shifted_sign_u16x8);
192
+ f16_bits_u16x8 = _mm_blendv_epi8(f16_bits_u16x8, f16_nan_u16x8, is_nan_u16x8);
193
+ __m256 fake_f32x8 = _mm256_cvtph_ps(f16_bits_u16x8);
194
+ return _mm256_mul_ps(fake_f32x8, _mm256_set1_ps(256.0f));
195
+ }
196
+
197
+ /** @brief Convert 8x e5m2 8x f32 via free-shift widen (AVX2 + F16C).
198
+ * E5M2 shares F16's exponent bias (15): `(byte << 8)` is the matching F16 bit
199
+ * pattern for every E5M2 value (normals, subnormals, zero, ±Inf, NaN all
200
+ * bit-exact). Widen u8 → u16, shift, then VCVTPH2PS to F32. Three ops total. */
219
201
  NK_INTERNAL __m256 nk_e5m2x8_to_f32x8_haswell_(__m128i e5m2_i8x8) {
220
- __m256i e5m2_i32x8 = _mm256_cvtepu8_epi32(e5m2_i8x8);
221
-
222
- // Extract fields
223
- __m256i exp_i32x8 = _mm256_and_si256(_mm256_srli_epi32(e5m2_i32x8, 2), _mm256_set1_epi32(0x1F));
224
- __m256i mant_i32x8 = _mm256_and_si256(e5m2_i32x8, _mm256_set1_epi32(0x03));
225
-
226
- // Build F32 sign bit
227
- __m256i f32_sign_i32x8 = _mm256_slli_epi32(_mm256_srli_epi32(e5m2_i32x8, 7), 31);
228
-
229
- // Normal path: sign | ((exp+112)<<23) | (mant<<21)
230
- __m256i f32_exp_i32x8 = _mm256_slli_epi32(_mm256_add_epi32(exp_i32x8, _mm256_set1_epi32(112)), 23);
231
- __m256i f32_mant_i32x8 = _mm256_slli_epi32(mant_i32x8, 21);
232
- __m256i normal_bits_i32x8 = _mm256_or_si256(f32_sign_i32x8, _mm256_or_si256(f32_exp_i32x8, f32_mant_i32x8));
233
-
234
- // Subnormal path: value = mantissa / 65536.0f, then apply sign
235
- __m256 subnorm_abs_f32x8 = _mm256_mul_ps(_mm256_cvtepi32_ps(mant_i32x8), _mm256_set1_ps(1.0f / 65536.0f));
236
- __m256 subnorm_f32x8 = _mm256_or_ps(subnorm_abs_f32x8, _mm256_castsi256_ps(f32_sign_i32x8));
237
-
238
- // Blend: if exp==0, use subnormal result; otherwise use normal bits
239
- __m256i exp_zero_mask = _mm256_cmpeq_epi32(exp_i32x8, _mm256_setzero_si256());
240
- return _mm256_blendv_ps(_mm256_castsi256_ps(normal_bits_i32x8), subnorm_f32x8, _mm256_castsi256_ps(exp_zero_mask));
202
+ __m128i e5m2_u16x8 = _mm_cvtepu8_epi16(e5m2_i8x8);
203
+ __m128i f16_bits_u16x8 = _mm_slli_epi16(e5m2_u16x8, 8);
204
+ return _mm256_cvtph_ps(f16_bits_u16x8);
241
205
  }
242
206
 
243
207
  /** @brief Convert 8x f32 → 8x e4m3 via bit manipulation (AVX2).
@@ -13,6 +13,17 @@
13
13
  extern "C" {
14
14
  #endif
15
15
 
16
+ /* Keep the serial conversions below actually scalar, regardless of build type.
17
+ * Without this, -O3 + LTO can vectorize or clone the serial kernels under AVX-512
18
+ * callers in dispatch_*.c, which wastes binary and breaks the nk_*_serial-as-scalar-oracle
19
+ * contract. See dots/serial.h. */
20
+ #if defined(__clang__)
21
+ #pragma clang attribute push(__attribute__((noinline)), apply_to = function)
22
+ #elif defined(__GNUC__)
23
+ #pragma GCC push_options
24
+ #pragma GCC optimize("no-tree-vectorize", "no-tree-slp-vectorize", "no-ipa-cp-clone", "no-inline")
25
+ #endif
26
+
16
27
  #pragma region Type Punned Loads and Stores
17
28
 
18
29
  /** @brief Type-agnostic 32-bit full load (scalar). */
@@ -2329,6 +2340,12 @@ NK_PUBLIC void nk_e3m2_to_bf16(nk_e3m2_t const *src, nk_bf16_t *dest) {
2329
2340
 
2330
2341
  #pragma endregion Public API
2331
2342
 
2343
+ #if defined(__clang__)
2344
+ #pragma clang attribute pop
2345
+ #elif defined(__GNUC__)
2346
+ #pragma GCC pop_options
2347
+ #endif
2348
+
2332
2349
  #if defined(__cplusplus)
2333
2350
  } // extern "C"
2334
2351
  #endif
@@ -175,60 +175,63 @@ NK_INTERNAL __m256i nk_f32x16_to_bf16x16_skylake_(__m512 a) {
175
175
  return _mm512_cvtepi32_epi16(x);
176
176
  }
177
177
 
178
- /** @brief Convert 16x e4m3 → 16x f32 via bit manipulation (AVX-512).
179
- * E4M3 format: S EEEE MMM (bias=7). F32: sign<<31, (exp+120)<<23, mantissa<<20.
180
- * Subnormals (exp=0): value = mantissa × 2⁽¹⁻⁷⁾ × 2⁻³ = mantissa ÷ 512. */
178
+ /** @brief Convert 16x e4m3 → 16x f32 via Giesen-style fake-F16 cast (AVX-512 + F16C).
179
+ * E4M3 `byte = S EEEE MMM` (bias 7). Shifting the magnitude into F16 positions
180
+ * `((byte & 0x7F) << 7) | ((byte & 0x80) << 8)` yields a fake F16 whose F16 value
181
+ * differs from the true E4M3 magnitude by exactly 2⁸ (bias delta 15 − 7). The
182
+ * fake F16 is widened via `vcvtph2ps` and corrected by ×256 in F32. Subnormal
183
+ * handling falls out for free via F16 subnormal semantics. NaN (|byte|==0x7F) is
184
+ * the sole E4M3 special value that would misinterpret as finite; blended
185
+ * explicitly with F32 quiet NaN bits. */
181
186
  NK_INTERNAL __m512 nk_e4m3x16_to_f32x16_skylake_(__m128i e4m3_i8x16) {
182
- __m512i e4m3_i32x16 = _mm512_cvtepu8_epi32(e4m3_i8x16);
183
-
184
- // Extract fields
185
- __m512i exp_i32x16 = _mm512_and_si512(_mm512_srli_epi32(e4m3_i32x16, 3), _mm512_set1_epi32(0x0F));
186
- __m512i mantissa_i32x16 = _mm512_and_si512(e4m3_i32x16, _mm512_set1_epi32(0x07));
187
- __m512i sign_i32x16 = _mm512_slli_epi32(_mm512_srli_epi32(e4m3_i32x16, 7), 31);
188
-
189
- // Normal path: sign | ((exp+120)<<23) | (mantissa<<20)
190
- __m512i f32_exp_i32x16 = _mm512_slli_epi32(_mm512_add_epi32(exp_i32x16, _mm512_set1_epi32(120)), 23);
191
- __m512i f32_mantissa_i32x16 = _mm512_slli_epi32(mantissa_i32x16, 20);
192
- __m512 result_f32x16 = _mm512_castsi512_ps(
193
- _mm512_ternarylogic_epi32(sign_i32x16, f32_exp_i32x16, f32_mantissa_i32x16, 0xFE));
194
-
195
- // Subnormal fix: vpermps from 8-entry LUT (repeated to fill 16 lanes)
196
- __mmask16 is_subnormal = _mm512_testn_epi32_mask(e4m3_i32x16, _mm512_set1_epi32(0x78));
197
- __m512 subnorm_lut_f32x16 = _mm512_setr_ps( //
198
- 0, 1.0f / 512, 2.0f / 512, 3.0f / 512, 4.0f / 512, 5.0f / 512, 6.0f / 512, 7.0f / 512, //
199
- 0, 1.0f / 512, 2.0f / 512, 3.0f / 512, 4.0f / 512, 5.0f / 512, 6.0f / 512, 7.0f / 512);
200
- __m512 subnorm_abs_f32x16 = _mm512_permutexvar_ps(mantissa_i32x16, subnorm_lut_f32x16);
201
- result_f32x16 = _mm512_mask_or_ps(result_f32x16, is_subnormal, subnorm_abs_f32x16,
202
- _mm512_castsi512_ps(sign_i32x16));
203
-
204
- // NaN: E4M3FN has NaN only at magnitude 0x7F (single mask comparison)
205
- __m512i lower7_i32x16 = _mm512_and_si512(e4m3_i32x16, _mm512_set1_epi32(0x7F));
206
- __mmask16 is_nan = _mm512_cmpeq_epi32_mask(lower7_i32x16, _mm512_set1_epi32(0x7F));
207
- __m512i nan_i32x16 = _mm512_or_si512(sign_i32x16, _mm512_set1_epi32(0x7FC00000));
208
- return _mm512_mask_blend_ps(is_nan, result_f32x16, _mm512_castsi512_ps(nan_i32x16));
209
- }
210
-
211
- /** @brief Convert 16x e5m2 → 16x f32 via bit manipulation (AVX-512).
212
- * E5M2 format: S EEEEE MM (bias=15). F32: sign<<31, (exp+112)<<23, mantissa<<21.
213
- * Subnormals (exp=0): value = mantissa × 2⁽¹⁻¹⁵⁾ × 2⁻² = mantissa ÷ 65536. */
187
+ __m256i const magnitude_mask_u16x16 = _mm256_set1_epi16(0x7F);
188
+ __m256i const sign_mask_u16x16 = _mm256_set1_epi16((short)0x80);
189
+ __m256i const f16_nan_u16x16 = _mm256_set1_epi16(0x7E00);
190
+ __m256i word_u16x16 = _mm256_cvtepu8_epi16(e4m3_i8x16);
191
+ __m256i magnitude_u16x16 = _mm256_and_si256(word_u16x16, magnitude_mask_u16x16);
192
+ __mmask16 is_nan = _mm256_cmpeq_epi16_mask(magnitude_u16x16, magnitude_mask_u16x16);
193
+ __m256i shifted_magnitude_u16x16 = _mm256_slli_epi16(magnitude_u16x16, 7);
194
+ __m256i shifted_sign_u16x16 = _mm256_slli_epi16(_mm256_and_si256(word_u16x16, sign_mask_u16x16), 8);
195
+ __m256i f16_bits_u16x16 = _mm256_or_si256(shifted_magnitude_u16x16, shifted_sign_u16x16);
196
+ f16_bits_u16x16 = _mm256_mask_mov_epi16(f16_bits_u16x16, is_nan, f16_nan_u16x16);
197
+ __m512 fake_f32x16 = _mm512_cvtph_ps(f16_bits_u16x16);
198
+ return _mm512_mul_ps(fake_f32x16, _mm512_set1_ps(256.0f));
199
+ }
200
+
201
+ /** @brief Convert 16x e4m3 → 16x f16 via arithmetic + 8-entry subnormal LUT (AVX-512BW + AVX-512VL).
202
+ * E4M3: S EEEE MMM (bias=7). F16: S EEEEE MMMMMMMMMM (bias=15).
203
+ * Normal (exp != 0): F16 = ((lower7 << 7) + 0x2000) | (sign << 8) bias delta 8 added at the
204
+ * exp-position (8 << 10 = 0x2000) after placing magnitude bits at F16 positions 13..7.
205
+ * Subnormal (exp == 0): looked up from 8-entry F16 LUT — values 0, 1/512, 2/512, …, 7/512 encoded as
206
+ * F16 normals (the smallest E4M3 subnormal 1/512 = 2⁻⁹ is well within F16 normal range).
207
+ * NaN (|byte| == 0x7F): blended in as F16 quiet NaN with original sign. */
208
+ NK_INTERNAL __m256i nk_e4m3x16_to_f16x16_skylake_(__m128i e4m3_u8x16) {
209
+ __m256i e4m3_i16x16 = _mm256_cvtepu8_epi16(e4m3_u8x16);
210
+ __m256i sign_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16((short)0x80));
211
+ __m256i lower7_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x7F));
212
+ __m256i normal_abs_i16x16 = _mm256_add_epi16(_mm256_slli_epi16(lower7_i16x16, 7), _mm256_set1_epi16(0x2000));
213
+ __m256i subn_lut_i16x16 = _mm256_set_epi16( //
214
+ 0x2300, 0x2200, 0x2100, 0x2000, 0x1E00, 0x1C00, 0x1800, 0x0000, 0x2300, 0x2200, 0x2100, 0x2000, 0x1E00, 0x1C00,
215
+ 0x1800, 0x0000);
216
+ __m256i mant_idx_i16x16 = _mm256_and_si256(e4m3_i16x16, _mm256_set1_epi16(0x07));
217
+ __m256i subn_abs_i16x16 = _mm256_permutexvar_epi16(mant_idx_i16x16, subn_lut_i16x16);
218
+ __mmask16 is_subnormal = _mm256_testn_epi16_mask(e4m3_i16x16, _mm256_set1_epi16(0x78));
219
+ __m256i abs_i16x16 = _mm256_mask_blend_epi16(is_subnormal, normal_abs_i16x16, subn_abs_i16x16);
220
+ __m256i shifted_sign_i16x16 = _mm256_slli_epi16(sign_i16x16, 8);
221
+ __m256i result_i16x16 = _mm256_or_si256(abs_i16x16, shifted_sign_i16x16);
222
+ __mmask16 is_nan = _mm256_cmpeq_epi16_mask(lower7_i16x16, _mm256_set1_epi16(0x7F));
223
+ __m256i nan_i16x16 = _mm256_or_si256(shifted_sign_i16x16, _mm256_set1_epi16(0x7E00));
224
+ return _mm256_mask_blend_epi16(is_nan, result_i16x16, nan_i16x16);
225
+ }
226
+
227
+ /** @brief Convert 16x e5m2 → 16x f32 via free-shift widen (AVX-512 + F16C).
228
+ * E5M2 shares F16's exponent bias (15): `(byte << 8)` is the matching F16 bit
229
+ * pattern for every E5M2 value (normals, subnormals, zero, ±Inf, NaN — all
230
+ * bit-exact). Widen u8 → u16, shift, then VCVTPH2PS to F32. Three ops total. */
214
231
  NK_INTERNAL __m512 nk_e5m2x16_to_f32x16_skylake_(__m128i e5m2_i8x16) {
215
- __m512i e5m2_i32x16 = _mm512_cvtepu8_epi32(e5m2_i8x16);
216
-
217
- // Extract fields
218
- __m512i exp_i32x16 = _mm512_and_si512(_mm512_srli_epi32(e5m2_i32x16, 2), _mm512_set1_epi32(0x1F));
219
- __m512i mantissa_i32x16 = _mm512_and_si512(e5m2_i32x16, _mm512_set1_epi32(0x03));
220
- __m512i sign_i32x16 = _mm512_slli_epi32(_mm512_srli_epi32(e5m2_i32x16, 7), 31);
221
-
222
- // Normal path: sign | ((exp+112)<<23) | (mantissa<<21)
223
- __m512i f32_exp_i32x16 = _mm512_slli_epi32(_mm512_add_epi32(exp_i32x16, _mm512_set1_epi32(112)), 23);
224
- __m512i f32_mantissa_i32x16 = _mm512_slli_epi32(mantissa_i32x16, 21);
225
- __m512 result_f32x16 = _mm512_castsi512_ps(
226
- _mm512_ternarylogic_epi32(sign_i32x16, f32_exp_i32x16, f32_mantissa_i32x16, 0xFE));
227
-
228
- // Subnormal fix: for exp==0 lanes, replace with (mantissa / 65536) | sign using masked OR
229
- __mmask16 is_subnormal = _mm512_testn_epi32_mask(e5m2_i32x16, _mm512_set1_epi32(0x7C));
230
- __m512 subnorm_abs_f32x16 = _mm512_mul_ps(_mm512_cvtepi32_ps(mantissa_i32x16), _mm512_set1_ps(1.0f / 65536.0f));
231
- return _mm512_mask_or_ps(result_f32x16, is_subnormal, subnorm_abs_f32x16, _mm512_castsi512_ps(sign_i32x16));
232
+ __m256i e5m2_u16x16 = _mm256_cvtepu8_epi16(e5m2_i8x16);
233
+ __m256i f16_bits_u16x16 = _mm256_slli_epi16(e5m2_u16x16, 8);
234
+ return _mm512_cvtph_ps(f16_bits_u16x16);
232
235
  }
233
236
 
234
237
  /** @brief Convert 16x e2m3 → 16x f32 via bit manipulation (AVX-512).
@@ -660,6 +663,18 @@ NK_INTERNAL void nk_partial_load_e4m3x16_to_f32x16_skylake_(void const *src, nk_
660
663
  dst->zmm_ps = nk_e4m3x16_to_f32x16_skylake_(e4m3_partial.xmm);
661
664
  }
662
665
 
666
+ /** @brief Load 16 e4m3 values and convert to 16 f16 (Skylake AVX-512BW). */
667
+ NK_INTERNAL void nk_load_e4m3x16_to_f16x16_skylake_(void const *src, nk_b256_vec_t *dst) {
668
+ dst->ymm = nk_e4m3x16_to_f16x16_skylake_(_mm_loadu_si128((__m128i const *)src));
669
+ }
670
+
671
+ /** @brief Partial load of up to 16 e4m3 values with conversion to f16 (Skylake AVX-512BW). */
672
+ NK_INTERNAL void nk_partial_load_e4m3x16_to_f16x16_skylake_(void const *src, nk_b256_vec_t *dst, nk_size_t n) {
673
+ nk_b128_vec_t e4m3_partial;
674
+ nk_partial_load_b8x16_skylake_(src, &e4m3_partial, n);
675
+ dst->ymm = nk_e4m3x16_to_f16x16_skylake_(e4m3_partial.xmm);
676
+ }
677
+
663
678
  /** @brief Load 16 e5m2 values and convert to 16 f32 (Skylake AVX-512). */
664
679
  NK_INTERNAL void nk_load_e5m2x16_to_f32x16_skylake_(void const *src, nk_b512_vec_t *dst) {
665
680
  dst->zmm_ps = nk_e5m2x16_to_f32x16_skylake_(_mm_loadu_si128((__m128i const *)src));
@@ -175,6 +175,7 @@ NK_PUBLIC void nk_cast_v128relaxed(void const *from, nk_dtype_t from_type, nk_si
175
175
  #include "numkong/cast/icelake.h"
176
176
  #include "numkong/cast/sapphire.h"
177
177
  #include "numkong/cast/rvv.h"
178
+ #include "numkong/cast/v128relaxed.h"
178
179
  #include "numkong/cast/powervsx.h"
179
180
  #include "numkong/cast/loongsonasx.h"
180
181
 
@@ -111,6 +111,7 @@ This processes 64 E4M3 bytes per iteration in u8, doubling the element density o
111
111
 
112
112
  `nk_dot_e5m2_genoa` converts FP8 values to BF16, then accumulates via `VDPBF16PS`, reusing Genoa's BF16 dot-product instruction for FP8 types.
113
113
  Each `VDPBF16PS` fuses two BF16 multiply-adds per 32-bit lane at 6-cycle throughput.
114
+ On Skylake-X–class CPUs without BF16 dot-product hardware, `nk_dot_e4m3_skylake` / `nk_dot_e5m2_skylake` (and their Haswell twins `nk_dot_e4m3_haswell` / `nk_dot_e5m2_haswell`) instead route through the Giesen-style FP8 → F16 fake-bit-pattern cast, widen via `VCVTPH2PS`, and accumulate in F32 with two independent FMA chains reducing into a single register — avoiding the 3-chain scheduler-stall of the BF16 algebraic form on kernels without native BF16 FMA.
114
115
  `nk_dot_bf16c_genoa` uses the same instruction for complex BF16, preparing operands with `VPSHUFB` for lane swapping and `VPXORD` with `0x80000000` for sign flips before feeding into `VDPBF16PS`.
115
116
 
116
117
  ### Deferred Sign-Flip in Complex Dot Products
@@ -698,25 +698,39 @@ nk_dot_e4m3_haswell_cycle:
698
698
 
699
699
  NK_PUBLIC void nk_dot_e5m2_haswell(nk_e5m2_t const *a_scalars, nk_e5m2_t const *b_scalars, nk_size_t count_scalars,
700
700
  nk_f32_t *result) {
701
- __m256 a_f32x8, b_f32x8;
702
- __m256 sum_f32x8 = _mm256_setzero_ps();
701
+ // E5M2 shares F16 bias; inline the free-shift unpack for the two 8-lane halves.
702
+ __m256 first_chain_f32x8 = _mm256_setzero_ps();
703
+ __m256 second_chain_f32x8 = _mm256_setzero_ps();
704
+ __m128i const zero_u8x16 = _mm_setzero_si128();
705
+ __m128i a_u8x16, b_u8x16;
706
+
703
707
  nk_dot_e5m2_haswell_cycle:
704
- if (count_scalars < 8) {
705
- nk_b256_vec_t a_vec, b_vec;
706
- nk_partial_load_e5m2x8_to_f32x8_haswell_(a_scalars, &a_vec, count_scalars);
707
- nk_partial_load_e5m2x8_to_f32x8_haswell_(b_scalars, &b_vec, count_scalars);
708
- a_f32x8 = a_vec.ymm_ps;
709
- b_f32x8 = b_vec.ymm_ps;
708
+ if (count_scalars < 16) {
709
+ nk_b128_vec_t a_vec, b_vec;
710
+ nk_partial_load_b8x16_serial_(a_scalars, &a_vec, count_scalars);
711
+ nk_partial_load_b8x16_serial_(b_scalars, &b_vec, count_scalars);
712
+ a_u8x16 = a_vec.xmm;
713
+ b_u8x16 = b_vec.xmm;
710
714
  count_scalars = 0;
711
715
  }
712
716
  else {
713
- a_f32x8 = nk_e5m2x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)a_scalars));
714
- b_f32x8 = nk_e5m2x8_to_f32x8_haswell_(_mm_loadl_epi64((__m128i const *)b_scalars));
715
- a_scalars += 8, b_scalars += 8, count_scalars -= 8;
717
+ a_u8x16 = _mm_loadu_si128((__m128i const *)a_scalars);
718
+ b_u8x16 = _mm_loadu_si128((__m128i const *)b_scalars);
719
+ a_scalars += 16, b_scalars += 16, count_scalars -= 16;
716
720
  }
717
- sum_f32x8 = _mm256_fmadd_ps(a_f32x8, b_f32x8, sum_f32x8);
721
+ __m128i a_even_f16x8 = _mm_unpacklo_epi8(zero_u8x16, a_u8x16);
722
+ __m128i a_odd_f16x8 = _mm_unpackhi_epi8(zero_u8x16, a_u8x16);
723
+ __m128i b_even_f16x8 = _mm_unpacklo_epi8(zero_u8x16, b_u8x16);
724
+ __m128i b_odd_f16x8 = _mm_unpackhi_epi8(zero_u8x16, b_u8x16);
725
+ __m256 a_first_f32x8 = _mm256_cvtph_ps(a_even_f16x8);
726
+ __m256 a_second_f32x8 = _mm256_cvtph_ps(a_odd_f16x8);
727
+ __m256 b_first_f32x8 = _mm256_cvtph_ps(b_even_f16x8);
728
+ __m256 b_second_f32x8 = _mm256_cvtph_ps(b_odd_f16x8);
729
+ first_chain_f32x8 = _mm256_fmadd_ps(a_first_f32x8, b_first_f32x8, first_chain_f32x8);
730
+ second_chain_f32x8 = _mm256_fmadd_ps(a_second_f32x8, b_second_f32x8, second_chain_f32x8);
718
731
  if (count_scalars) goto nk_dot_e5m2_haswell_cycle;
719
- *result = (nk_f32_t)nk_reduce_add_f32x8_haswell_(sum_f32x8);
732
+
733
+ *result = (nk_f32_t)nk_reduce_add_f32x8_haswell_(_mm256_add_ps(first_chain_f32x8, second_chain_f32x8));
720
734
  }
721
735
 
722
736
  NK_PUBLIC void nk_dot_e2m3_haswell(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars, nk_size_t count_scalars,
@@ -910,6 +924,71 @@ NK_INTERNAL void nk_dot_through_f32_update_haswell_(nk_dot_through_f32_state_has
910
924
  state->sum_f32x8 = _mm256_fmadd_ps(a.ymm_ps, b.ymm_ps, state->sum_f32x8);
911
925
  }
912
926
 
927
+ /**
928
+ * @brief E5M2 byte-batched update: consumes 32 raw E5M2 bytes per call and widens inline.
929
+ * Two independent FMA chains (each 2-deep) merge into the single __m256 state accumulator.
930
+ */
931
+ NK_INTERNAL void nk_dot_e5m2x32_update_haswell_(nk_dot_through_f32_state_haswell_t_ *state, nk_b256_vec_t a_bytes,
932
+ nk_b256_vec_t b_bytes, nk_size_t depth_offset,
933
+ nk_size_t active_dimensions) {
934
+ nk_unused_(depth_offset);
935
+ nk_unused_(active_dimensions);
936
+ __m128i const zero_u8x16 = _mm_setzero_si128();
937
+ __m128i a_low_u8x16 = _mm256_castsi256_si128(a_bytes.ymm);
938
+ __m128i a_high_u8x16 = _mm256_extracti128_si256(a_bytes.ymm, 1);
939
+ __m128i b_low_u8x16 = _mm256_castsi256_si128(b_bytes.ymm);
940
+ __m128i b_high_u8x16 = _mm256_extracti128_si256(b_bytes.ymm, 1);
941
+ __m128i a_first_f16x8 = _mm_unpacklo_epi8(zero_u8x16, a_low_u8x16);
942
+ __m128i a_second_f16x8 = _mm_unpackhi_epi8(zero_u8x16, a_low_u8x16);
943
+ __m128i a_third_f16x8 = _mm_unpacklo_epi8(zero_u8x16, a_high_u8x16);
944
+ __m128i a_fourth_f16x8 = _mm_unpackhi_epi8(zero_u8x16, a_high_u8x16);
945
+ __m128i b_first_f16x8 = _mm_unpacklo_epi8(zero_u8x16, b_low_u8x16);
946
+ __m128i b_second_f16x8 = _mm_unpackhi_epi8(zero_u8x16, b_low_u8x16);
947
+ __m128i b_third_f16x8 = _mm_unpacklo_epi8(zero_u8x16, b_high_u8x16);
948
+ __m128i b_fourth_f16x8 = _mm_unpackhi_epi8(zero_u8x16, b_high_u8x16);
949
+ __m256 a_first_f32x8 = _mm256_cvtph_ps(a_first_f16x8);
950
+ __m256 a_second_f32x8 = _mm256_cvtph_ps(a_second_f16x8);
951
+ __m256 a_third_f32x8 = _mm256_cvtph_ps(a_third_f16x8);
952
+ __m256 a_fourth_f32x8 = _mm256_cvtph_ps(a_fourth_f16x8);
953
+ __m256 b_first_f32x8 = _mm256_cvtph_ps(b_first_f16x8);
954
+ __m256 b_second_f32x8 = _mm256_cvtph_ps(b_second_f16x8);
955
+ __m256 b_third_f32x8 = _mm256_cvtph_ps(b_third_f16x8);
956
+ __m256 b_fourth_f32x8 = _mm256_cvtph_ps(b_fourth_f16x8);
957
+ __m256 first_chain_f32x8 = _mm256_mul_ps(a_first_f32x8, b_first_f32x8);
958
+ __m256 second_chain_f32x8 = _mm256_mul_ps(a_second_f32x8, b_second_f32x8);
959
+ first_chain_f32x8 = _mm256_fmadd_ps(a_third_f32x8, b_third_f32x8, first_chain_f32x8);
960
+ second_chain_f32x8 = _mm256_fmadd_ps(a_fourth_f32x8, b_fourth_f32x8, second_chain_f32x8);
961
+ state->sum_f32x8 = _mm256_add_ps(state->sum_f32x8, _mm256_add_ps(first_chain_f32x8, second_chain_f32x8));
962
+ }
963
+
964
+ /**
965
+ * @brief E4M3 byte-batched update: consumes 32 raw E4M3 bytes per call. Widens through the
966
+ * Giesen cast helper. Two independent FMA chains merge into the single state accumulator.
967
+ */
968
+ NK_INTERNAL void nk_dot_e4m3x32_update_haswell_(nk_dot_through_f32_state_haswell_t_ *state, nk_b256_vec_t a_bytes,
969
+ nk_b256_vec_t b_bytes, nk_size_t depth_offset,
970
+ nk_size_t active_dimensions) {
971
+ nk_unused_(depth_offset);
972
+ nk_unused_(active_dimensions);
973
+ __m128i a_low_u8x16 = _mm256_castsi256_si128(a_bytes.ymm);
974
+ __m128i a_high_u8x16 = _mm256_extracti128_si256(a_bytes.ymm, 1);
975
+ __m128i b_low_u8x16 = _mm256_castsi256_si128(b_bytes.ymm);
976
+ __m128i b_high_u8x16 = _mm256_extracti128_si256(b_bytes.ymm, 1);
977
+ __m256 a_first_f32x8 = nk_e4m3x8_to_f32x8_haswell_(a_low_u8x16);
978
+ __m256 a_second_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_unpackhi_epi64(a_low_u8x16, a_low_u8x16));
979
+ __m256 a_third_f32x8 = nk_e4m3x8_to_f32x8_haswell_(a_high_u8x16);
980
+ __m256 a_fourth_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_unpackhi_epi64(a_high_u8x16, a_high_u8x16));
981
+ __m256 b_first_f32x8 = nk_e4m3x8_to_f32x8_haswell_(b_low_u8x16);
982
+ __m256 b_second_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_unpackhi_epi64(b_low_u8x16, b_low_u8x16));
983
+ __m256 b_third_f32x8 = nk_e4m3x8_to_f32x8_haswell_(b_high_u8x16);
984
+ __m256 b_fourth_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_unpackhi_epi64(b_high_u8x16, b_high_u8x16));
985
+ __m256 first_chain_f32x8 = _mm256_mul_ps(a_first_f32x8, b_first_f32x8);
986
+ __m256 second_chain_f32x8 = _mm256_mul_ps(a_second_f32x8, b_second_f32x8);
987
+ first_chain_f32x8 = _mm256_fmadd_ps(a_third_f32x8, b_third_f32x8, first_chain_f32x8);
988
+ second_chain_f32x8 = _mm256_fmadd_ps(a_fourth_f32x8, b_fourth_f32x8, second_chain_f32x8);
989
+ state->sum_f32x8 = _mm256_add_ps(state->sum_f32x8, _mm256_add_ps(first_chain_f32x8, second_chain_f32x8));
990
+ }
991
+
913
992
  /**
914
993
  * @brief Finalizes 4x low-precision dot-products placing them into 4x consecutive 32-bit slots.
915
994
  * @sa nk_dot_f16x8_finalize_haswell, nk_dot_bf16x8_finalize_haswell
@@ -139,6 +139,15 @@ extern "C" {
139
139
  result->imag = sum_imag; \
140
140
  }
141
141
 
142
+ /* Keep the serial instantiations below actually scalar, regardless of build type.
143
+ * See dots/serial.h for rationale. */
144
+ #if defined(__clang__)
145
+ #pragma clang attribute push(__attribute__((noinline)), apply_to = function)
146
+ #elif defined(__GNUC__)
147
+ #pragma GCC push_options
148
+ #pragma GCC optimize("no-tree-vectorize", "no-tree-slp-vectorize", "no-ipa-cp-clone", "no-inline")
149
+ #endif
150
+
142
151
  #pragma region F32 and F64 Floats
143
152
 
144
153
  nk_define_dot_(f32, f64, f64, nk_assign_from_to_) // nk_dot_f32_serial
@@ -867,6 +876,12 @@ NK_INTERNAL nk_i32_t nk_sum_i4x32_finalize_serial(nk_sum_i4x32_state_serial_t co
867
876
 
868
877
  #pragma endregion Stateful Element Sum Helpers
869
878
 
879
+ #if defined(__clang__)
880
+ #pragma clang attribute pop
881
+ #elif defined(__GNUC__)
882
+ #pragma GCC pop_options
883
+ #endif
884
+
870
885
  #if defined(__cplusplus)
871
886
  } // extern "C"
872
887
  #endif