numkong 7.4.5 → 7.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. package/README.md +1 -0
  2. package/binding.gyp +99 -5
  3. package/c/dispatch_e5m2.c +23 -3
  4. package/c/dispatch_f16.c +23 -0
  5. package/c/numkong.c +0 -13
  6. package/include/numkong/attention/sme.h +34 -31
  7. package/include/numkong/capabilities.h +2 -15
  8. package/include/numkong/cast/README.md +3 -0
  9. package/include/numkong/cast/haswell.h +28 -64
  10. package/include/numkong/cast/neon.h +15 -0
  11. package/include/numkong/cast/serial.h +17 -0
  12. package/include/numkong/cast/skylake.h +67 -52
  13. package/include/numkong/cast.h +1 -0
  14. package/include/numkong/curved/smef64.h +82 -62
  15. package/include/numkong/dot/README.md +1 -0
  16. package/include/numkong/dot/haswell.h +92 -13
  17. package/include/numkong/dot/rvvbf16.h +1 -1
  18. package/include/numkong/dot/rvvhalf.h +1 -1
  19. package/include/numkong/dot/serial.h +15 -0
  20. package/include/numkong/dot/skylake.h +61 -14
  21. package/include/numkong/dot/sve.h +6 -5
  22. package/include/numkong/dot/svebfdot.h +2 -1
  23. package/include/numkong/dot/svehalf.h +6 -5
  24. package/include/numkong/dot/svesdot.h +3 -2
  25. package/include/numkong/dots/README.md +2 -0
  26. package/include/numkong/dots/graniteamx.h +1167 -0
  27. package/include/numkong/dots/haswell.h +28 -28
  28. package/include/numkong/dots/sapphireamx.h +1 -1
  29. package/include/numkong/dots/serial.h +33 -11
  30. package/include/numkong/dots/skylake.h +28 -23
  31. package/include/numkong/dots/sme.h +172 -140
  32. package/include/numkong/dots/smebi32.h +14 -11
  33. package/include/numkong/dots/smef64.h +31 -26
  34. package/include/numkong/dots.h +41 -3
  35. package/include/numkong/each/serial.h +39 -0
  36. package/include/numkong/geospatial/haswell.h +1 -1
  37. package/include/numkong/geospatial/neon.h +1 -1
  38. package/include/numkong/geospatial/serial.h +15 -4
  39. package/include/numkong/geospatial/skylake.h +1 -1
  40. package/include/numkong/maxsim/serial.h +15 -0
  41. package/include/numkong/maxsim/sme.h +34 -33
  42. package/include/numkong/mesh/README.md +50 -44
  43. package/include/numkong/mesh/genoa.h +462 -0
  44. package/include/numkong/mesh/haswell.h +806 -933
  45. package/include/numkong/mesh/neon.h +871 -943
  46. package/include/numkong/mesh/neonbfdot.h +382 -522
  47. package/include/numkong/mesh/neonfhm.h +676 -0
  48. package/include/numkong/mesh/rvv.h +404 -319
  49. package/include/numkong/mesh/serial.h +225 -161
  50. package/include/numkong/mesh/skylake.h +1029 -1585
  51. package/include/numkong/mesh/v128relaxed.h +403 -377
  52. package/include/numkong/mesh.h +38 -0
  53. package/include/numkong/reduce/neon.h +29 -0
  54. package/include/numkong/reduce/neonbfdot.h +2 -2
  55. package/include/numkong/reduce/neonfhm.h +4 -4
  56. package/include/numkong/reduce/serial.h +15 -1
  57. package/include/numkong/reduce/sve.h +52 -0
  58. package/include/numkong/reduce.h +4 -0
  59. package/include/numkong/set/sve.h +6 -5
  60. package/include/numkong/sets/smebi32.h +35 -30
  61. package/include/numkong/sparse/serial.h +17 -2
  62. package/include/numkong/sparse/sve2.h +3 -2
  63. package/include/numkong/spatial/genoa.h +0 -68
  64. package/include/numkong/spatial/haswell.h +98 -56
  65. package/include/numkong/spatial/serial.h +15 -0
  66. package/include/numkong/spatial/skylake.h +114 -54
  67. package/include/numkong/spatial/sve.h +7 -6
  68. package/include/numkong/spatial/svebfdot.h +7 -4
  69. package/include/numkong/spatial/svehalf.h +5 -4
  70. package/include/numkong/spatial/svesdot.h +9 -8
  71. package/include/numkong/spatial.h +0 -12
  72. package/include/numkong/spatials/graniteamx.h +301 -0
  73. package/include/numkong/spatials/serial.h +39 -0
  74. package/include/numkong/spatials/skylake.h +2 -2
  75. package/include/numkong/spatials/sme.h +391 -350
  76. package/include/numkong/spatials/smef64.h +79 -70
  77. package/include/numkong/spatials.h +54 -4
  78. package/include/numkong/tensor.hpp +107 -23
  79. package/include/numkong/types.h +59 -0
  80. package/javascript/dist/cjs/numkong.js +13 -0
  81. package/javascript/dist/esm/numkong.js +13 -0
  82. package/javascript/numkong.c +59 -14
  83. package/javascript/numkong.ts +13 -0
  84. package/package.json +7 -7
  85. package/probes/probe.js +2 -2
  86. package/wasm/numkong.wasm +0 -0
@@ -115,45 +115,45 @@ nk_define_cross_packed_(dots, bf16, haswell, bf16, bf16, f32, nk_b256_vec_t, nk_
115
115
  nk_partial_store_b32x4_haswell_,
116
116
  /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
117
117
 
118
- /* E4M3 GEMM: depth_simd_dimensions=8 (8 e4m3s = 8 bytes) upcasted to 8×f32 (256-bit) */
119
- nk_define_cross_pack_size_(dots, e4m3, haswell, e4m3, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/8,
118
+ /* E4M3 GEMM: depth_simd_dimensions=32 (byte-level batch; widen inside the update helper) */
119
+ nk_define_cross_pack_size_(dots, e4m3, haswell, e4m3, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/32,
120
120
  /*dimensions_per_value=*/1)
121
- nk_define_cross_pack_(dots, e4m3, haswell, e4m3, f32, nk_b256_vec_t, nk_load_e4m3x8_to_f32x8_haswell_,
122
- nk_partial_load_e4m3x8_to_f32x8_haswell_, nk_store_b256_haswell_, nk_partial_store_b32x8_serial_,
123
- /*simd_width=*/8, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e4m3_,
124
- /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
121
+ nk_define_cross_pack_(dots, e4m3, haswell, e4m3, f32, nk_b256_vec_t, nk_load_b256_haswell_,
122
+ nk_partial_load_b8x32_serial_, nk_store_b256_haswell_, nk_partial_store_b8x32_serial_,
123
+ /*simd_width=*/32, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e4m3_,
124
+ /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
125
125
  nk_define_cross_symmetric_(dots, e4m3, haswell, e4m3, f32, nk_b256_vec_t, nk_dot_through_f32_state_haswell_t_,
126
- nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_e4m3x8_to_f32x8_haswell_,
127
- nk_partial_load_e4m3x8_to_f32x8_haswell_, nk_dot_through_f32_update_haswell_,
126
+ nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_b256_haswell_,
127
+ nk_partial_load_b8x32_serial_, nk_dot_e4m3x32_update_haswell_,
128
128
  nk_dot_through_f32_finalize_haswell_, nk_store_b128_haswell_,
129
129
  nk_partial_store_b32x4_haswell_,
130
- /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
130
+ /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
131
131
  nk_define_cross_packed_(dots, e4m3, haswell, e4m3, f32, f32, nk_b256_vec_t, nk_dot_through_f32_state_haswell_t_,
132
- nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_e4m3x8_to_f32x8_haswell_,
133
- nk_partial_load_e4m3x8_to_f32x8_haswell_, nk_load_b256_haswell_, nk_partial_load_b32x8_serial_,
134
- nk_dot_through_f32_update_haswell_, nk_dot_through_f32_finalize_haswell_,
135
- nk_store_b128_haswell_, nk_partial_store_b32x4_haswell_,
136
- /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
132
+ nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_b256_haswell_,
133
+ nk_partial_load_b8x32_serial_, nk_load_b256_haswell_, nk_partial_load_b8x32_serial_,
134
+ nk_dot_e4m3x32_update_haswell_, nk_dot_through_f32_finalize_haswell_, nk_store_b128_haswell_,
135
+ nk_partial_store_b32x4_haswell_,
136
+ /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
137
137
 
138
- /* E5M2 GEMM: depth_simd_dimensions=8 (8 e5m2s = 8 bytes) upcasted to 8×f32 (256-bit) */
139
- nk_define_cross_pack_size_(dots, e5m2, haswell, e5m2, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/8,
138
+ /* E5M2 GEMM: depth_simd_dimensions=32 (byte-level batch; widen inside the update helper) */
139
+ nk_define_cross_pack_size_(dots, e5m2, haswell, e5m2, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/32,
140
140
  /*dimensions_per_value=*/1)
141
- nk_define_cross_pack_(dots, e5m2, haswell, e5m2, f32, nk_b256_vec_t, nk_load_e5m2x8_to_f32x8_haswell_,
142
- nk_partial_load_e5m2x8_to_f32x8_haswell_, nk_store_b256_haswell_, nk_partial_store_b32x8_serial_,
143
- /*simd_width=*/8, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e5m2_,
144
- /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
141
+ nk_define_cross_pack_(dots, e5m2, haswell, e5m2, f32, nk_b256_vec_t, nk_load_b256_haswell_,
142
+ nk_partial_load_b8x32_serial_, nk_store_b256_haswell_, nk_partial_store_b8x32_serial_,
143
+ /*simd_width=*/32, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e5m2_,
144
+ /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
145
145
  nk_define_cross_symmetric_(dots, e5m2, haswell, e5m2, f32, nk_b256_vec_t, nk_dot_through_f32_state_haswell_t_,
146
- nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_e5m2x8_to_f32x8_haswell_,
147
- nk_partial_load_e5m2x8_to_f32x8_haswell_, nk_dot_through_f32_update_haswell_,
146
+ nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_b256_haswell_,
147
+ nk_partial_load_b8x32_serial_, nk_dot_e5m2x32_update_haswell_,
148
148
  nk_dot_through_f32_finalize_haswell_, nk_store_b128_haswell_,
149
149
  nk_partial_store_b32x4_haswell_,
150
- /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
150
+ /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
151
151
  nk_define_cross_packed_(dots, e5m2, haswell, e5m2, f32, f32, nk_b256_vec_t, nk_dot_through_f32_state_haswell_t_,
152
- nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_e5m2x8_to_f32x8_haswell_,
153
- nk_partial_load_e5m2x8_to_f32x8_haswell_, nk_load_b256_haswell_, nk_partial_load_b32x8_serial_,
154
- nk_dot_through_f32_update_haswell_, nk_dot_through_f32_finalize_haswell_,
155
- nk_store_b128_haswell_, nk_partial_store_b32x4_haswell_,
156
- /*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
152
+ nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_b256_haswell_,
153
+ nk_partial_load_b8x32_serial_, nk_load_b256_haswell_, nk_partial_load_b8x32_serial_,
154
+ nk_dot_e5m2x32_update_haswell_, nk_dot_through_f32_finalize_haswell_, nk_store_b128_haswell_,
155
+ nk_partial_store_b32x4_haswell_,
156
+ /*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
157
157
 
158
158
  /* E2M3 GEMM: integer LUT path, depth_simd_dimensions=32 (32 e2m3s = 32 bytes = AVX2 register width) */
159
159
  nk_define_cross_pack_size_(dots, e2m3, haswell, e2m3, e2m3, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/32,
@@ -73,7 +73,7 @@
73
73
  #if NK_TARGET_SAPPHIREAMX
74
74
 
75
75
  #include "numkong/cast/icelake.h" // For FP8 ↔ BF16 conversions
76
- #include "numkong/dots/serial.h" // For nk_dots_reduce_sumsq_bf16_
76
+ #include "numkong/dots/serial.h" // `nk_dots_reduce_sumsq_bf16_`
77
77
 
78
78
  #if defined(__cplusplus)
79
79
  extern "C" {
@@ -522,7 +522,7 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
522
522
  load_a_vec_fn, partial_load_a_vec_fn, load_b_vec_fn, partial_load_b_vec_fn, \
523
523
  inner_product_fn, reduce_accumulators_fn, store_fn, partial_store_fn, \
524
524
  depth_simd_dimensions, dimensions_per_value) \
525
- NK_PUBLIC void nk_##api_name##_packed_##input_type_name##_##isa_suffix##_aligned_( \
525
+ NK_INTERNAL void nk_##api_name##_packed_##input_type_name##_##isa_suffix##_aligned_( \
526
526
  nk_##input_value_type##_t const *a_matrix, void const *b_packed_buffer, nk_##result_value_type##_t *c_matrix, \
527
527
  nk_size_t row_count, nk_size_t column_count, nk_size_t depth, nk_size_t a_stride_in_bytes, \
528
528
  nk_size_t c_stride_in_bytes) { \
@@ -698,7 +698,7 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
698
698
  } \
699
699
  } \
700
700
  } \
701
- NK_PUBLIC void nk_##api_name##_packed_##input_type_name##_##isa_suffix##_1x8_aligned_( \
701
+ NK_INTERNAL void nk_##api_name##_packed_##input_type_name##_##isa_suffix##_1x8_aligned_( \
702
702
  nk_##input_value_type##_t const *a_matrix, void const *b_packed_buffer, nk_##result_value_type##_t *c_matrix, \
703
703
  nk_size_t row_count, nk_size_t column_count, nk_size_t depth, nk_size_t a_stride_in_bytes, \
704
704
  nk_size_t c_stride_in_bytes) { \
@@ -1090,7 +1090,7 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
1090
1090
  norm_value_type, vec_type, state_type, result_vec_type, init_accumulator_fn, load_a_vec_fn, partial_load_a_vec_fn, \
1091
1091
  load_b_vec_fn, partial_load_b_vec_fn, inner_product_fn, compensated_finalize_fn, store_fn, partial_store_fn, \
1092
1092
  load_sum_fn, partial_load_sum_fn, compute_a_sum_fn, depth_simd_dimensions, dimensions_per_value) \
1093
- NK_PUBLIC void nk_##api_name##_packed_##input_type_name##_##isa_suffix##_aligned_( \
1093
+ NK_INTERNAL void nk_##api_name##_packed_##input_type_name##_##isa_suffix##_aligned_( \
1094
1094
  nk_##input_value_type##_t const *a_matrix, void const *b_packed_buffer, nk_##result_value_type##_t *c_matrix, \
1095
1095
  nk_size_t row_count, nk_size_t column_count, nk_size_t depth, nk_size_t a_stride_in_bytes, \
1096
1096
  nk_size_t c_stride_in_bytes) { \
@@ -1200,7 +1200,7 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
1200
1200
  } \
1201
1201
  } \
1202
1202
  } \
1203
- NK_PUBLIC void nk_##api_name##_packed_##input_type_name##_##isa_suffix##_1x8_aligned_( \
1203
+ NK_INTERNAL void nk_##api_name##_packed_##input_type_name##_##isa_suffix##_1x8_aligned_( \
1204
1204
  nk_##input_value_type##_t const *a_matrix, void const *b_packed_buffer, nk_##result_value_type##_t *c_matrix, \
1205
1205
  nk_size_t row_count, nk_size_t column_count, nk_size_t depth, nk_size_t a_stride_in_bytes, \
1206
1206
  nk_size_t c_stride_in_bytes) { \
@@ -2431,13 +2431,25 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
2431
2431
  } \
2432
2432
  }
2433
2433
 
2434
- /* Optimize serial GEMM instantiations for size rather than speed.
2435
- * These fallback kernels are only used when no SIMD backend is available, so aggressive inlining/unrolling from -O3
2436
- * wastes ~1.3 MB of binary space with negligible performance benefit on the serial path. Sadly, a scoped application
2437
- * of `__attribute__((optimize("Os"))` isn't supported on Clang, so this flag only applies to GCC builds.
2438
- */
2434
+ /* Keep the serial instantiations below actually scalar, regardless of build type.
2435
+ * Without this, -O3 + LTO can vectorize or clone the serial kernels under AVX-512
2436
+ * callers in dispatch_*.c, which wastes ~1 MB of binary and more importantly
2437
+ * breaks the nk_*_serial-as-scalar-oracle contract that tests and the numerical-
2438
+ * stability docs in this header rely on. */
2439
+ #if defined(__clang__)
2440
+ #pragma clang attribute push(__attribute__((noinline)), apply_to = function)
2441
+ #elif defined(__GNUC__)
2442
+ #pragma GCC push_options
2443
+ #pragma GCC optimize("no-tree-vectorize", "no-tree-slp-vectorize", "no-ipa-cp-clone", "no-inline")
2444
+ #endif
2445
+
2446
+ /* Size bias for release. Gated on NDEBUG so Debug builds keep -O0 for stepping. */
2439
2447
  #if defined(NDEBUG)
2440
- #if defined(__GNUC__) && !defined(__clang__)
2448
+ #if defined(_MSC_VER)
2449
+ #pragma optimize("s", on)
2450
+ #elif defined(__clang__)
2451
+ #pragma clang attribute push(__attribute__((minsize)), apply_to = function)
2452
+ #elif defined(__GNUC__)
2441
2453
  #pragma GCC push_options
2442
2454
  #pragma GCC optimize("Os")
2443
2455
  #endif
@@ -2677,11 +2689,21 @@ nk_define_cross_packed_(dots, u1, serial, u1x8, u1x8, u32, nk_b128_vec_t, nk_dot
2677
2689
  /*depth_simd_dimensions=*/128, /*dimensions_per_value=*/8)
2678
2690
 
2679
2691
  #if defined(NDEBUG)
2680
- #if defined(__GNUC__) && !defined(__clang__)
2692
+ #if defined(_MSC_VER)
2693
+ #pragma optimize("", on)
2694
+ #elif defined(__clang__)
2695
+ #pragma clang attribute pop
2696
+ #elif defined(__GNUC__)
2681
2697
  #pragma GCC pop_options
2682
2698
  #endif
2683
2699
  #endif
2684
2700
 
2701
+ #if defined(__clang__)
2702
+ #pragma clang attribute pop
2703
+ #elif defined(__GNUC__)
2704
+ #pragma GCC pop_options
2705
+ #endif
2706
+
2685
2707
  /* BF16 compact: truncate F32 → BF16 in-place.
2686
2708
  * Reads F32 matrix with c_stride_in_bytes, writes BF16 tightly packed (stride_in_bytes = column_count × sizeof(bf16)).
2687
2709
  */
@@ -114,45 +114,50 @@ nk_define_cross_packed_(dots, f16, skylake, f16, f32, f32, nk_b512_vec_t, nk_dot
114
114
  nk_dot_through_f32_finalize_skylake_, nk_store_b128_haswell_, nk_partial_store_b32x4_skylake_,
115
115
  /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
116
116
 
117
- /* E4M3 GEMM: depth_simd_dimensions=16 (16 e4m3s = 16 bytes = quarter cache line), F32 accumulator */
118
- nk_define_cross_pack_size_(dots, e4m3, skylake, e4m3, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/16,
117
+ /* E4M3 GEMM: F16-pack with asymmetric A/B representations at compute time. Pack converts
118
+ * E4M3 → F16 once (~10 ops/16 elements, 2 bytes/elt stored). A-stream uses the Giesen E4M3→F32
119
+ * cast (identical cost to F32-pack path). B-loader widens F16 → F32 inline (1 vcvtph2ps per 16
120
+ * lanes). Update takes both as F32 → plain fmadd. Saves 2 bytes/elt vs F32-pack; inner loop
121
+ * adds one cvtph2ps per B-read. Symmetric uses E4M3→F32 for both sides (no pack involved). */
122
+ nk_define_cross_pack_size_(dots, e4m3, skylake, e4m3, f16, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/16,
119
123
  /*dimensions_per_value=*/1)
120
- nk_define_cross_pack_(dots, e4m3, skylake, e4m3, f32, nk_b512_vec_t, nk_load_e4m3x16_to_f32x16_skylake_,
121
- nk_partial_load_e4m3x16_to_f32x16_skylake_, nk_store_b512_skylake_,
122
- nk_partial_store_b32x16_skylake_, /*simd_width=*/16, /*norm_value_type=*/f32,
123
- nk_dots_reduce_sumsq_e4m3_, /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
124
+ nk_define_cross_pack_(dots, e4m3, skylake, e4m3, f16, nk_b256_vec_t, nk_load_e4m3x16_to_f16x16_skylake_,
125
+ nk_partial_load_e4m3x16_to_f16x16_skylake_, nk_store_b256_haswell_,
126
+ nk_partial_store_b16x16_serial_,
127
+ /*simd_width=*/16, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e4m3_,
128
+ /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
124
129
  nk_define_cross_symmetric_(dots, e4m3, skylake, e4m3, f32, nk_b512_vec_t, nk_dot_through_f32_state_skylake_t_,
125
130
  nk_b128_vec_t, nk_dot_through_f32_init_skylake_, nk_load_e4m3x16_to_f32x16_skylake_,
126
131
  nk_partial_load_e4m3x16_to_f32x16_skylake_, nk_dot_through_f32_update_skylake_,
127
132
  nk_dot_through_f32_finalize_skylake_, nk_store_b128_haswell_,
128
133
  nk_partial_store_b32x4_skylake_,
129
134
  /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
130
- nk_define_cross_packed_(dots, e4m3, skylake, e4m3, f32, f32, nk_b512_vec_t, nk_dot_through_f32_state_skylake_t_,
135
+ nk_define_cross_packed_(dots, e4m3, skylake, e4m3, f16, f32, nk_b512_vec_t, nk_dot_through_f32_state_skylake_t_,
131
136
  nk_b128_vec_t, nk_dot_through_f32_init_skylake_, nk_load_e4m3x16_to_f32x16_skylake_,
132
- nk_partial_load_e4m3x16_to_f32x16_skylake_, nk_load_b512_skylake_,
133
- nk_partial_load_b32x16_skylake_, nk_dot_through_f32_update_skylake_,
137
+ nk_partial_load_e4m3x16_to_f32x16_skylake_, nk_load_f16x16_to_f32x16_skylake_,
138
+ nk_partial_load_f16x16_to_f32x16_skylake_, nk_dot_through_f32_update_skylake_,
134
139
  nk_dot_through_f32_finalize_skylake_, nk_store_b128_haswell_, nk_partial_store_b32x4_skylake_,
135
140
  /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
136
141
 
137
- /* E5M2 GEMM: depth_simd_dimensions=16 (16 e5m2s = 16 bytes = quarter cache line), F32 accumulator */
138
- nk_define_cross_pack_size_(dots, e5m2, skylake, e5m2, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/16,
142
+ /* E5M2 GEMM: depth_simd_dimensions=64 (byte-level batch; widen inside the update helper) */
143
+ nk_define_cross_pack_size_(dots, e5m2, skylake, e5m2, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/64,
139
144
  /*dimensions_per_value=*/1)
140
- nk_define_cross_pack_(dots, e5m2, skylake, e5m2, f32, nk_b512_vec_t, nk_load_e5m2x16_to_f32x16_skylake_,
141
- nk_partial_load_e5m2x16_to_f32x16_skylake_, nk_store_b512_skylake_,
142
- nk_partial_store_b32x16_skylake_, /*simd_width=*/16, /*norm_value_type=*/f32,
143
- nk_dots_reduce_sumsq_e5m2_, /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
145
+ nk_define_cross_pack_(dots, e5m2, skylake, e5m2, f32, nk_b512_vec_t, nk_load_b512_skylake_,
146
+ nk_partial_load_b8x64_skylake_, nk_store_b512_skylake_, nk_partial_store_b8x64_skylake_,
147
+ /*simd_width=*/64, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e5m2_,
148
+ /*depth_simd_dimensions=*/64, /*dimensions_per_value=*/1)
144
149
  nk_define_cross_symmetric_(dots, e5m2, skylake, e5m2, f32, nk_b512_vec_t, nk_dot_through_f32_state_skylake_t_,
145
- nk_b128_vec_t, nk_dot_through_f32_init_skylake_, nk_load_e5m2x16_to_f32x16_skylake_,
146
- nk_partial_load_e5m2x16_to_f32x16_skylake_, nk_dot_through_f32_update_skylake_,
150
+ nk_b128_vec_t, nk_dot_through_f32_init_skylake_, nk_load_b512_skylake_,
151
+ nk_partial_load_b8x64_skylake_, nk_dot_e5m2x64_update_skylake_,
147
152
  nk_dot_through_f32_finalize_skylake_, nk_store_b128_haswell_,
148
153
  nk_partial_store_b32x4_skylake_,
149
- /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
154
+ /*depth_simd_dimensions=*/64, /*dimensions_per_value=*/1)
150
155
  nk_define_cross_packed_(dots, e5m2, skylake, e5m2, f32, f32, nk_b512_vec_t, nk_dot_through_f32_state_skylake_t_,
151
- nk_b128_vec_t, nk_dot_through_f32_init_skylake_, nk_load_e5m2x16_to_f32x16_skylake_,
152
- nk_partial_load_e5m2x16_to_f32x16_skylake_, nk_load_b512_skylake_,
153
- nk_partial_load_b32x16_skylake_, nk_dot_through_f32_update_skylake_,
154
- nk_dot_through_f32_finalize_skylake_, nk_store_b128_haswell_, nk_partial_store_b32x4_skylake_,
155
- /*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
156
+ nk_b128_vec_t, nk_dot_through_f32_init_skylake_, nk_load_b512_skylake_,
157
+ nk_partial_load_b8x64_skylake_, nk_load_b512_skylake_, nk_partial_load_b8x64_skylake_,
158
+ nk_dot_e5m2x64_update_skylake_, nk_dot_through_f32_finalize_skylake_, nk_store_b128_haswell_,
159
+ nk_partial_store_b32x4_skylake_,
160
+ /*depth_simd_dimensions=*/64, /*dimensions_per_value=*/1)
156
161
 
157
162
  /* E2M3 GEMM: integer LUT path, depth_simd_dimensions=64 (64 e2m3s = 64 bytes = AVX-512 register width) */
158
163
  nk_define_cross_pack_size_(dots, e2m3, skylake, e2m3, e2m3, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/64,