numkong 7.4.5 → 7.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -0
- package/binding.gyp +99 -5
- package/c/dispatch_e5m2.c +23 -3
- package/c/dispatch_f16.c +23 -0
- package/c/numkong.c +0 -13
- package/include/numkong/attention/sme.h +34 -31
- package/include/numkong/capabilities.h +2 -15
- package/include/numkong/cast/README.md +3 -0
- package/include/numkong/cast/haswell.h +28 -64
- package/include/numkong/cast/neon.h +15 -0
- package/include/numkong/cast/serial.h +17 -0
- package/include/numkong/cast/skylake.h +67 -52
- package/include/numkong/cast.h +1 -0
- package/include/numkong/curved/smef64.h +82 -62
- package/include/numkong/dot/README.md +1 -0
- package/include/numkong/dot/haswell.h +92 -13
- package/include/numkong/dot/rvvbf16.h +1 -1
- package/include/numkong/dot/rvvhalf.h +1 -1
- package/include/numkong/dot/serial.h +15 -0
- package/include/numkong/dot/skylake.h +61 -14
- package/include/numkong/dot/sve.h +6 -5
- package/include/numkong/dot/svebfdot.h +2 -1
- package/include/numkong/dot/svehalf.h +6 -5
- package/include/numkong/dot/svesdot.h +3 -2
- package/include/numkong/dots/README.md +2 -0
- package/include/numkong/dots/graniteamx.h +1167 -0
- package/include/numkong/dots/haswell.h +28 -28
- package/include/numkong/dots/sapphireamx.h +1 -1
- package/include/numkong/dots/serial.h +33 -11
- package/include/numkong/dots/skylake.h +28 -23
- package/include/numkong/dots/sme.h +172 -140
- package/include/numkong/dots/smebi32.h +14 -11
- package/include/numkong/dots/smef64.h +31 -26
- package/include/numkong/dots.h +41 -3
- package/include/numkong/each/serial.h +39 -0
- package/include/numkong/geospatial/haswell.h +1 -1
- package/include/numkong/geospatial/neon.h +1 -1
- package/include/numkong/geospatial/serial.h +15 -4
- package/include/numkong/geospatial/skylake.h +1 -1
- package/include/numkong/maxsim/serial.h +15 -0
- package/include/numkong/maxsim/sme.h +34 -33
- package/include/numkong/mesh/README.md +50 -44
- package/include/numkong/mesh/genoa.h +462 -0
- package/include/numkong/mesh/haswell.h +806 -933
- package/include/numkong/mesh/neon.h +871 -943
- package/include/numkong/mesh/neonbfdot.h +382 -522
- package/include/numkong/mesh/neonfhm.h +676 -0
- package/include/numkong/mesh/rvv.h +404 -319
- package/include/numkong/mesh/serial.h +225 -161
- package/include/numkong/mesh/skylake.h +1029 -1585
- package/include/numkong/mesh/v128relaxed.h +403 -377
- package/include/numkong/mesh.h +38 -0
- package/include/numkong/reduce/neon.h +29 -0
- package/include/numkong/reduce/neonbfdot.h +2 -2
- package/include/numkong/reduce/neonfhm.h +4 -4
- package/include/numkong/reduce/serial.h +15 -1
- package/include/numkong/reduce/sve.h +52 -0
- package/include/numkong/reduce.h +4 -0
- package/include/numkong/set/sve.h +6 -5
- package/include/numkong/sets/smebi32.h +35 -30
- package/include/numkong/sparse/serial.h +17 -2
- package/include/numkong/sparse/sve2.h +3 -2
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +98 -56
- package/include/numkong/spatial/serial.h +15 -0
- package/include/numkong/spatial/skylake.h +114 -54
- package/include/numkong/spatial/sve.h +7 -6
- package/include/numkong/spatial/svebfdot.h +7 -4
- package/include/numkong/spatial/svehalf.h +5 -4
- package/include/numkong/spatial/svesdot.h +9 -8
- package/include/numkong/spatial.h +0 -12
- package/include/numkong/spatials/graniteamx.h +301 -0
- package/include/numkong/spatials/serial.h +39 -0
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +391 -350
- package/include/numkong/spatials/smef64.h +79 -70
- package/include/numkong/spatials.h +54 -4
- package/include/numkong/tensor.hpp +107 -23
- package/include/numkong/types.h +59 -0
- package/javascript/dist/cjs/numkong.js +13 -0
- package/javascript/dist/esm/numkong.js +13 -0
- package/javascript/numkong.c +59 -14
- package/javascript/numkong.ts +13 -0
- package/package.json +7 -7
- package/probes/probe.js +2 -2
- package/wasm/numkong.wasm +0 -0
|
@@ -115,45 +115,45 @@ nk_define_cross_packed_(dots, bf16, haswell, bf16, bf16, f32, nk_b256_vec_t, nk_
|
|
|
115
115
|
nk_partial_store_b32x4_haswell_,
|
|
116
116
|
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
117
117
|
|
|
118
|
-
/* E4M3 GEMM: depth_simd_dimensions=
|
|
119
|
-
nk_define_cross_pack_size_(dots, e4m3, haswell, e4m3, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/
|
|
118
|
+
/* E4M3 GEMM: depth_simd_dimensions=32 (byte-level batch; widen inside the update helper) */
|
|
119
|
+
nk_define_cross_pack_size_(dots, e4m3, haswell, e4m3, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/32,
|
|
120
120
|
/*dimensions_per_value=*/1)
|
|
121
|
-
nk_define_cross_pack_(dots, e4m3, haswell, e4m3, f32, nk_b256_vec_t,
|
|
122
|
-
|
|
123
|
-
/*simd_width=*/
|
|
124
|
-
/*depth_simd_dimensions=*/
|
|
121
|
+
nk_define_cross_pack_(dots, e4m3, haswell, e4m3, f32, nk_b256_vec_t, nk_load_b256_haswell_,
|
|
122
|
+
nk_partial_load_b8x32_serial_, nk_store_b256_haswell_, nk_partial_store_b8x32_serial_,
|
|
123
|
+
/*simd_width=*/32, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e4m3_,
|
|
124
|
+
/*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
|
|
125
125
|
nk_define_cross_symmetric_(dots, e4m3, haswell, e4m3, f32, nk_b256_vec_t, nk_dot_through_f32_state_haswell_t_,
|
|
126
|
-
nk_b128_vec_t, nk_dot_through_f32_init_haswell_,
|
|
127
|
-
|
|
126
|
+
nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_b256_haswell_,
|
|
127
|
+
nk_partial_load_b8x32_serial_, nk_dot_e4m3x32_update_haswell_,
|
|
128
128
|
nk_dot_through_f32_finalize_haswell_, nk_store_b128_haswell_,
|
|
129
129
|
nk_partial_store_b32x4_haswell_,
|
|
130
|
-
/*depth_simd_dimensions=*/
|
|
130
|
+
/*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
|
|
131
131
|
nk_define_cross_packed_(dots, e4m3, haswell, e4m3, f32, f32, nk_b256_vec_t, nk_dot_through_f32_state_haswell_t_,
|
|
132
|
-
nk_b128_vec_t, nk_dot_through_f32_init_haswell_,
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
/*depth_simd_dimensions=*/
|
|
132
|
+
nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_b256_haswell_,
|
|
133
|
+
nk_partial_load_b8x32_serial_, nk_load_b256_haswell_, nk_partial_load_b8x32_serial_,
|
|
134
|
+
nk_dot_e4m3x32_update_haswell_, nk_dot_through_f32_finalize_haswell_, nk_store_b128_haswell_,
|
|
135
|
+
nk_partial_store_b32x4_haswell_,
|
|
136
|
+
/*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
|
|
137
137
|
|
|
138
|
-
/* E5M2 GEMM: depth_simd_dimensions=
|
|
139
|
-
nk_define_cross_pack_size_(dots, e5m2, haswell, e5m2, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/
|
|
138
|
+
/* E5M2 GEMM: depth_simd_dimensions=32 (byte-level batch; widen inside the update helper) */
|
|
139
|
+
nk_define_cross_pack_size_(dots, e5m2, haswell, e5m2, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/32,
|
|
140
140
|
/*dimensions_per_value=*/1)
|
|
141
|
-
nk_define_cross_pack_(dots, e5m2, haswell, e5m2, f32, nk_b256_vec_t,
|
|
142
|
-
|
|
143
|
-
/*simd_width=*/
|
|
144
|
-
/*depth_simd_dimensions=*/
|
|
141
|
+
nk_define_cross_pack_(dots, e5m2, haswell, e5m2, f32, nk_b256_vec_t, nk_load_b256_haswell_,
|
|
142
|
+
nk_partial_load_b8x32_serial_, nk_store_b256_haswell_, nk_partial_store_b8x32_serial_,
|
|
143
|
+
/*simd_width=*/32, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e5m2_,
|
|
144
|
+
/*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
|
|
145
145
|
nk_define_cross_symmetric_(dots, e5m2, haswell, e5m2, f32, nk_b256_vec_t, nk_dot_through_f32_state_haswell_t_,
|
|
146
|
-
nk_b128_vec_t, nk_dot_through_f32_init_haswell_,
|
|
147
|
-
|
|
146
|
+
nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_b256_haswell_,
|
|
147
|
+
nk_partial_load_b8x32_serial_, nk_dot_e5m2x32_update_haswell_,
|
|
148
148
|
nk_dot_through_f32_finalize_haswell_, nk_store_b128_haswell_,
|
|
149
149
|
nk_partial_store_b32x4_haswell_,
|
|
150
|
-
/*depth_simd_dimensions=*/
|
|
150
|
+
/*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
|
|
151
151
|
nk_define_cross_packed_(dots, e5m2, haswell, e5m2, f32, f32, nk_b256_vec_t, nk_dot_through_f32_state_haswell_t_,
|
|
152
|
-
nk_b128_vec_t, nk_dot_through_f32_init_haswell_,
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
/*depth_simd_dimensions=*/
|
|
152
|
+
nk_b128_vec_t, nk_dot_through_f32_init_haswell_, nk_load_b256_haswell_,
|
|
153
|
+
nk_partial_load_b8x32_serial_, nk_load_b256_haswell_, nk_partial_load_b8x32_serial_,
|
|
154
|
+
nk_dot_e5m2x32_update_haswell_, nk_dot_through_f32_finalize_haswell_, nk_store_b128_haswell_,
|
|
155
|
+
nk_partial_store_b32x4_haswell_,
|
|
156
|
+
/*depth_simd_dimensions=*/32, /*dimensions_per_value=*/1)
|
|
157
157
|
|
|
158
158
|
/* E2M3 GEMM: integer LUT path, depth_simd_dimensions=32 (32 e2m3s = 32 bytes = AVX2 register width) */
|
|
159
159
|
nk_define_cross_pack_size_(dots, e2m3, haswell, e2m3, e2m3, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/32,
|
|
@@ -73,7 +73,7 @@
|
|
|
73
73
|
#if NK_TARGET_SAPPHIREAMX
|
|
74
74
|
|
|
75
75
|
#include "numkong/cast/icelake.h" // For FP8 ↔ BF16 conversions
|
|
76
|
-
#include "numkong/dots/serial.h" //
|
|
76
|
+
#include "numkong/dots/serial.h" // `nk_dots_reduce_sumsq_bf16_`
|
|
77
77
|
|
|
78
78
|
#if defined(__cplusplus)
|
|
79
79
|
extern "C" {
|
|
@@ -522,7 +522,7 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
|
|
|
522
522
|
load_a_vec_fn, partial_load_a_vec_fn, load_b_vec_fn, partial_load_b_vec_fn, \
|
|
523
523
|
inner_product_fn, reduce_accumulators_fn, store_fn, partial_store_fn, \
|
|
524
524
|
depth_simd_dimensions, dimensions_per_value) \
|
|
525
|
-
|
|
525
|
+
NK_INTERNAL void nk_##api_name##_packed_##input_type_name##_##isa_suffix##_aligned_( \
|
|
526
526
|
nk_##input_value_type##_t const *a_matrix, void const *b_packed_buffer, nk_##result_value_type##_t *c_matrix, \
|
|
527
527
|
nk_size_t row_count, nk_size_t column_count, nk_size_t depth, nk_size_t a_stride_in_bytes, \
|
|
528
528
|
nk_size_t c_stride_in_bytes) { \
|
|
@@ -698,7 +698,7 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
|
|
|
698
698
|
} \
|
|
699
699
|
} \
|
|
700
700
|
} \
|
|
701
|
-
|
|
701
|
+
NK_INTERNAL void nk_##api_name##_packed_##input_type_name##_##isa_suffix##_1x8_aligned_( \
|
|
702
702
|
nk_##input_value_type##_t const *a_matrix, void const *b_packed_buffer, nk_##result_value_type##_t *c_matrix, \
|
|
703
703
|
nk_size_t row_count, nk_size_t column_count, nk_size_t depth, nk_size_t a_stride_in_bytes, \
|
|
704
704
|
nk_size_t c_stride_in_bytes) { \
|
|
@@ -1090,7 +1090,7 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
|
|
|
1090
1090
|
norm_value_type, vec_type, state_type, result_vec_type, init_accumulator_fn, load_a_vec_fn, partial_load_a_vec_fn, \
|
|
1091
1091
|
load_b_vec_fn, partial_load_b_vec_fn, inner_product_fn, compensated_finalize_fn, store_fn, partial_store_fn, \
|
|
1092
1092
|
load_sum_fn, partial_load_sum_fn, compute_a_sum_fn, depth_simd_dimensions, dimensions_per_value) \
|
|
1093
|
-
|
|
1093
|
+
NK_INTERNAL void nk_##api_name##_packed_##input_type_name##_##isa_suffix##_aligned_( \
|
|
1094
1094
|
nk_##input_value_type##_t const *a_matrix, void const *b_packed_buffer, nk_##result_value_type##_t *c_matrix, \
|
|
1095
1095
|
nk_size_t row_count, nk_size_t column_count, nk_size_t depth, nk_size_t a_stride_in_bytes, \
|
|
1096
1096
|
nk_size_t c_stride_in_bytes) { \
|
|
@@ -1200,7 +1200,7 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
|
|
|
1200
1200
|
} \
|
|
1201
1201
|
} \
|
|
1202
1202
|
} \
|
|
1203
|
-
|
|
1203
|
+
NK_INTERNAL void nk_##api_name##_packed_##input_type_name##_##isa_suffix##_1x8_aligned_( \
|
|
1204
1204
|
nk_##input_value_type##_t const *a_matrix, void const *b_packed_buffer, nk_##result_value_type##_t *c_matrix, \
|
|
1205
1205
|
nk_size_t row_count, nk_size_t column_count, nk_size_t depth, nk_size_t a_stride_in_bytes, \
|
|
1206
1206
|
nk_size_t c_stride_in_bytes) { \
|
|
@@ -2431,13 +2431,25 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
|
|
|
2431
2431
|
} \
|
|
2432
2432
|
}
|
|
2433
2433
|
|
|
2434
|
-
/*
|
|
2435
|
-
*
|
|
2436
|
-
* wastes ~1
|
|
2437
|
-
*
|
|
2438
|
-
*/
|
|
2434
|
+
/* Keep the serial instantiations below actually scalar, regardless of build type.
|
|
2435
|
+
* Without this, -O3 + LTO can vectorize or clone the serial kernels under AVX-512
|
|
2436
|
+
* callers in dispatch_*.c, which wastes ~1 MB of binary and — more importantly —
|
|
2437
|
+
* breaks the nk_*_serial-as-scalar-oracle contract that tests and the numerical-
|
|
2438
|
+
* stability docs in this header rely on. */
|
|
2439
|
+
#if defined(__clang__)
|
|
2440
|
+
#pragma clang attribute push(__attribute__((noinline)), apply_to = function)
|
|
2441
|
+
#elif defined(__GNUC__)
|
|
2442
|
+
#pragma GCC push_options
|
|
2443
|
+
#pragma GCC optimize("no-tree-vectorize", "no-tree-slp-vectorize", "no-ipa-cp-clone", "no-inline")
|
|
2444
|
+
#endif
|
|
2445
|
+
|
|
2446
|
+
/* Size bias for release. Gated on NDEBUG so Debug builds keep -O0 for stepping. */
|
|
2439
2447
|
#if defined(NDEBUG)
|
|
2440
|
-
#if defined(
|
|
2448
|
+
#if defined(_MSC_VER)
|
|
2449
|
+
#pragma optimize("s", on)
|
|
2450
|
+
#elif defined(__clang__)
|
|
2451
|
+
#pragma clang attribute push(__attribute__((minsize)), apply_to = function)
|
|
2452
|
+
#elif defined(__GNUC__)
|
|
2441
2453
|
#pragma GCC push_options
|
|
2442
2454
|
#pragma GCC optimize("Os")
|
|
2443
2455
|
#endif
|
|
@@ -2677,11 +2689,21 @@ nk_define_cross_packed_(dots, u1, serial, u1x8, u1x8, u32, nk_b128_vec_t, nk_dot
|
|
|
2677
2689
|
/*depth_simd_dimensions=*/128, /*dimensions_per_value=*/8)
|
|
2678
2690
|
|
|
2679
2691
|
#if defined(NDEBUG)
|
|
2680
|
-
#if defined(
|
|
2692
|
+
#if defined(_MSC_VER)
|
|
2693
|
+
#pragma optimize("", on)
|
|
2694
|
+
#elif defined(__clang__)
|
|
2695
|
+
#pragma clang attribute pop
|
|
2696
|
+
#elif defined(__GNUC__)
|
|
2681
2697
|
#pragma GCC pop_options
|
|
2682
2698
|
#endif
|
|
2683
2699
|
#endif
|
|
2684
2700
|
|
|
2701
|
+
#if defined(__clang__)
|
|
2702
|
+
#pragma clang attribute pop
|
|
2703
|
+
#elif defined(__GNUC__)
|
|
2704
|
+
#pragma GCC pop_options
|
|
2705
|
+
#endif
|
|
2706
|
+
|
|
2685
2707
|
/* BF16 compact: truncate F32 → BF16 in-place.
|
|
2686
2708
|
* Reads F32 matrix with c_stride_in_bytes, writes BF16 tightly packed (stride_in_bytes = column_count × sizeof(bf16)).
|
|
2687
2709
|
*/
|
|
@@ -114,45 +114,50 @@ nk_define_cross_packed_(dots, f16, skylake, f16, f32, f32, nk_b512_vec_t, nk_dot
|
|
|
114
114
|
nk_dot_through_f32_finalize_skylake_, nk_store_b128_haswell_, nk_partial_store_b32x4_skylake_,
|
|
115
115
|
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
116
116
|
|
|
117
|
-
/* E4M3 GEMM:
|
|
118
|
-
|
|
117
|
+
/* E4M3 GEMM: F16-pack with asymmetric A/B representations at compute time. Pack converts
|
|
118
|
+
* E4M3 → F16 once (~10 ops/16 elements, 2 bytes/elt stored). A-stream uses the Giesen E4M3→F32
|
|
119
|
+
* cast (identical cost to F32-pack path). B-loader widens F16 → F32 inline (1 vcvtph2ps per 16
|
|
120
|
+
* lanes). Update takes both as F32 → plain fmadd. Saves 2 bytes/elt vs F32-pack; inner loop
|
|
121
|
+
* adds one cvtph2ps per B-read. Symmetric uses E4M3→F32 for both sides (no pack involved). */
|
|
122
|
+
nk_define_cross_pack_size_(dots, e4m3, skylake, e4m3, f16, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/16,
|
|
119
123
|
/*dimensions_per_value=*/1)
|
|
120
|
-
nk_define_cross_pack_(dots, e4m3, skylake, e4m3,
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
+
nk_define_cross_pack_(dots, e4m3, skylake, e4m3, f16, nk_b256_vec_t, nk_load_e4m3x16_to_f16x16_skylake_,
|
|
125
|
+
nk_partial_load_e4m3x16_to_f16x16_skylake_, nk_store_b256_haswell_,
|
|
126
|
+
nk_partial_store_b16x16_serial_,
|
|
127
|
+
/*simd_width=*/16, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e4m3_,
|
|
128
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
124
129
|
nk_define_cross_symmetric_(dots, e4m3, skylake, e4m3, f32, nk_b512_vec_t, nk_dot_through_f32_state_skylake_t_,
|
|
125
130
|
nk_b128_vec_t, nk_dot_through_f32_init_skylake_, nk_load_e4m3x16_to_f32x16_skylake_,
|
|
126
131
|
nk_partial_load_e4m3x16_to_f32x16_skylake_, nk_dot_through_f32_update_skylake_,
|
|
127
132
|
nk_dot_through_f32_finalize_skylake_, nk_store_b128_haswell_,
|
|
128
133
|
nk_partial_store_b32x4_skylake_,
|
|
129
134
|
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
130
|
-
nk_define_cross_packed_(dots, e4m3, skylake, e4m3,
|
|
135
|
+
nk_define_cross_packed_(dots, e4m3, skylake, e4m3, f16, f32, nk_b512_vec_t, nk_dot_through_f32_state_skylake_t_,
|
|
131
136
|
nk_b128_vec_t, nk_dot_through_f32_init_skylake_, nk_load_e4m3x16_to_f32x16_skylake_,
|
|
132
|
-
nk_partial_load_e4m3x16_to_f32x16_skylake_,
|
|
133
|
-
|
|
137
|
+
nk_partial_load_e4m3x16_to_f32x16_skylake_, nk_load_f16x16_to_f32x16_skylake_,
|
|
138
|
+
nk_partial_load_f16x16_to_f32x16_skylake_, nk_dot_through_f32_update_skylake_,
|
|
134
139
|
nk_dot_through_f32_finalize_skylake_, nk_store_b128_haswell_, nk_partial_store_b32x4_skylake_,
|
|
135
140
|
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
136
141
|
|
|
137
|
-
/* E5M2 GEMM: depth_simd_dimensions=
|
|
138
|
-
nk_define_cross_pack_size_(dots, e5m2, skylake, e5m2, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/
|
|
142
|
+
/* E5M2 GEMM: depth_simd_dimensions=64 (byte-level batch; widen inside the update helper) */
|
|
143
|
+
nk_define_cross_pack_size_(dots, e5m2, skylake, e5m2, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/64,
|
|
139
144
|
/*dimensions_per_value=*/1)
|
|
140
|
-
nk_define_cross_pack_(dots, e5m2, skylake, e5m2, f32, nk_b512_vec_t,
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
145
|
+
nk_define_cross_pack_(dots, e5m2, skylake, e5m2, f32, nk_b512_vec_t, nk_load_b512_skylake_,
|
|
146
|
+
nk_partial_load_b8x64_skylake_, nk_store_b512_skylake_, nk_partial_store_b8x64_skylake_,
|
|
147
|
+
/*simd_width=*/64, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e5m2_,
|
|
148
|
+
/*depth_simd_dimensions=*/64, /*dimensions_per_value=*/1)
|
|
144
149
|
nk_define_cross_symmetric_(dots, e5m2, skylake, e5m2, f32, nk_b512_vec_t, nk_dot_through_f32_state_skylake_t_,
|
|
145
|
-
nk_b128_vec_t, nk_dot_through_f32_init_skylake_,
|
|
146
|
-
|
|
150
|
+
nk_b128_vec_t, nk_dot_through_f32_init_skylake_, nk_load_b512_skylake_,
|
|
151
|
+
nk_partial_load_b8x64_skylake_, nk_dot_e5m2x64_update_skylake_,
|
|
147
152
|
nk_dot_through_f32_finalize_skylake_, nk_store_b128_haswell_,
|
|
148
153
|
nk_partial_store_b32x4_skylake_,
|
|
149
|
-
/*depth_simd_dimensions=*/
|
|
154
|
+
/*depth_simd_dimensions=*/64, /*dimensions_per_value=*/1)
|
|
150
155
|
nk_define_cross_packed_(dots, e5m2, skylake, e5m2, f32, f32, nk_b512_vec_t, nk_dot_through_f32_state_skylake_t_,
|
|
151
|
-
nk_b128_vec_t, nk_dot_through_f32_init_skylake_,
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
/*depth_simd_dimensions=*/
|
|
156
|
+
nk_b128_vec_t, nk_dot_through_f32_init_skylake_, nk_load_b512_skylake_,
|
|
157
|
+
nk_partial_load_b8x64_skylake_, nk_load_b512_skylake_, nk_partial_load_b8x64_skylake_,
|
|
158
|
+
nk_dot_e5m2x64_update_skylake_, nk_dot_through_f32_finalize_skylake_, nk_store_b128_haswell_,
|
|
159
|
+
nk_partial_store_b32x4_skylake_,
|
|
160
|
+
/*depth_simd_dimensions=*/64, /*dimensions_per_value=*/1)
|
|
156
161
|
|
|
157
162
|
/* E2M3 GEMM: integer LUT path, depth_simd_dimensions=64 (64 e2m3s = 64 bytes = AVX-512 register width) */
|
|
158
163
|
nk_define_cross_pack_size_(dots, e2m3, skylake, e2m3, e2m3, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/64,
|