whisper.rn 0.5.0-rc.9 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/cpp/ggml-alloc.c +1 -15
- package/cpp/ggml-backend-reg.cpp +17 -8
- package/cpp/ggml-backend.cpp +15 -22
- package/cpp/ggml-common.h +17 -0
- package/cpp/ggml-cpu/arch/arm/quants.c +132 -596
- package/cpp/ggml-cpu/arch/arm/repack.cpp +14 -286
- package/cpp/ggml-cpu/arch/x86/quants.c +184 -675
- package/cpp/ggml-cpu/arch/x86/repack.cpp +4679 -1657
- package/cpp/ggml-cpu/arch-fallback.h +34 -0
- package/cpp/ggml-cpu/ggml-cpu.c +22 -1
- package/cpp/ggml-cpu/ggml-cpu.cpp +21 -24
- package/cpp/ggml-cpu/ops.cpp +870 -211
- package/cpp/ggml-cpu/ops.h +3 -8
- package/cpp/ggml-cpu/quants.c +35 -0
- package/cpp/ggml-cpu/quants.h +8 -0
- package/cpp/ggml-cpu/repack.cpp +458 -47
- package/cpp/ggml-cpu/repack.h +22 -0
- package/cpp/ggml-cpu/simd-mappings.h +1 -1
- package/cpp/ggml-cpu/traits.cpp +2 -2
- package/cpp/ggml-cpu/traits.h +1 -1
- package/cpp/ggml-cpu/vec.cpp +12 -9
- package/cpp/ggml-cpu/vec.h +107 -13
- package/cpp/ggml-impl.h +77 -0
- package/cpp/ggml-metal-impl.h +51 -12
- package/cpp/ggml-metal.m +610 -115
- package/cpp/ggml-opt.cpp +97 -41
- package/cpp/ggml-opt.h +25 -6
- package/cpp/ggml-quants.c +110 -16
- package/cpp/ggml-quants.h +6 -0
- package/cpp/ggml-whisper-sim.metallib +0 -0
- package/cpp/ggml-whisper.metallib +0 -0
- package/cpp/ggml.c +314 -88
- package/cpp/ggml.h +137 -11
- package/cpp/gguf.cpp +8 -1
- package/cpp/jsi/RNWhisperJSI.cpp +23 -6
- package/cpp/whisper.cpp +15 -6
- package/ios/RNWhisper.mm +6 -6
- package/ios/RNWhisperContext.mm +2 -0
- package/ios/RNWhisperVadContext.mm +2 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/ios-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/ggml-whisper.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64/rnwhisper.framework/rnwhisper +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-common.h +17 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-impl.h +77 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-metal-impl.h +51 -12
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-opt.h +25 -6
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml-quants.h +6 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/Headers/ggml.h +137 -11
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/ggml-whisper-sim.metallib +0 -0
- package/ios/rnwhisper.xcframework/tvos-arm64_x86_64-simulator/rnwhisper.framework/rnwhisper +0 -0
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js +13 -0
- package/lib/commonjs/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/module/realtime-transcription/RealtimeTranscriber.js +13 -0
- package/lib/module/realtime-transcription/RealtimeTranscriber.js.map +1 -1
- package/lib/typescript/realtime-transcription/RealtimeTranscriber.d.ts.map +1 -1
- package/lib/typescript/realtime-transcription/types.d.ts +6 -0
- package/lib/typescript/realtime-transcription/types.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/realtime-transcription/RealtimeTranscriber.ts +17 -0
- package/src/realtime-transcription/types.ts +6 -0
package/cpp/ggml-cpu/repack.h
CHANGED
|
@@ -44,7 +44,14 @@ struct block_q4_Kx8 {
|
|
|
44
44
|
};
|
|
45
45
|
|
|
46
46
|
static_assert(sizeof(block_q4_Kx8) == sizeof(wsp_ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
|
|
47
|
+
struct block_q2_Kx8 {
|
|
48
|
+
wsp_ggml_half d[8]; // super-block scale for quantized scales
|
|
49
|
+
wsp_ggml_half dmin[8]; // super-block scale for quantized mins
|
|
50
|
+
uint8_t scales[128]; // scales and mins, quantized with 4 bits
|
|
51
|
+
uint8_t qs[512]; // 2--bit quants
|
|
52
|
+
};
|
|
47
53
|
|
|
54
|
+
static_assert(sizeof(block_q2_Kx8) == sizeof(wsp_ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding");
|
|
48
55
|
struct block_q8_Kx4 {
|
|
49
56
|
float d[4]; // delta
|
|
50
57
|
int8_t qs[QK_K * 4]; // quants
|
|
@@ -60,6 +67,13 @@ struct block_iq4_nlx4 {
|
|
|
60
67
|
|
|
61
68
|
static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(wsp_ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
|
|
62
69
|
|
|
70
|
+
struct block_iq4_nlx8 {
|
|
71
|
+
wsp_ggml_half d[8]; // deltas for 8 iq4_nl blocks
|
|
72
|
+
uint8_t qs[QK4_NL * 4]; // nibbles / quants for 8 iq4_nl blocks
|
|
73
|
+
};
|
|
74
|
+
|
|
75
|
+
static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(wsp_ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding");
|
|
76
|
+
|
|
63
77
|
#if defined(__cplusplus)
|
|
64
78
|
extern "C" {
|
|
65
79
|
#endif
|
|
@@ -71,12 +85,16 @@ void wsp_ggml_gemv_q4_0_4x4_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs,
|
|
|
71
85
|
void wsp_ggml_gemv_q4_0_4x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
72
86
|
void wsp_ggml_gemv_q4_0_8x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
73
87
|
void wsp_ggml_gemv_q4_K_8x8_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
88
|
+
void wsp_ggml_gemv_q2_K_8x8_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
74
89
|
void wsp_ggml_gemv_iq4_nl_4x4_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
90
|
+
void wsp_ggml_gemv_iq4_nl_8x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
75
91
|
void wsp_ggml_gemm_q4_0_4x4_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
76
92
|
void wsp_ggml_gemm_q4_0_4x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
77
93
|
void wsp_ggml_gemm_q4_0_8x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
78
94
|
void wsp_ggml_gemm_q4_K_8x8_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
95
|
+
void wsp_ggml_gemm_q2_K_8x8_q8_K(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
79
96
|
void wsp_ggml_gemm_iq4_nl_4x4_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
97
|
+
void wsp_ggml_gemm_iq4_nl_8x8_q8_0(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
80
98
|
|
|
81
99
|
// Native implementations
|
|
82
100
|
void wsp_ggml_wsp_quantize_mat_q8_0_4x4_generic(const float * WSP_GGML_RESTRICT x, void * WSP_GGML_RESTRICT vy, int64_t k);
|
|
@@ -86,12 +104,16 @@ void wsp_ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, siz
|
|
|
86
104
|
void wsp_ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
87
105
|
void wsp_ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
88
106
|
void wsp_ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
107
|
+
void wsp_ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
89
108
|
void wsp_ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
109
|
+
void wsp_ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
90
110
|
void wsp_ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
91
111
|
void wsp_ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
92
112
|
void wsp_ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
93
113
|
void wsp_ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
114
|
+
void wsp_ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
94
115
|
void wsp_ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
116
|
+
void wsp_ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * WSP_GGML_RESTRICT s, size_t bs, const void * WSP_GGML_RESTRICT vx, const void * WSP_GGML_RESTRICT vy, int nr, int nc);
|
|
95
117
|
|
|
96
118
|
#if defined(__cplusplus)
|
|
97
119
|
} // extern "C"
|
|
@@ -189,7 +189,7 @@ inline static float wsp_ggml_lookup_fp16_to_fp32(wsp_ggml_fp16_t f) {
|
|
|
189
189
|
#define WSP_GGML_F32xt_LOAD(...) WSP_GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__)
|
|
190
190
|
#define WSP_GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b)
|
|
191
191
|
#define WSP_GGML_F32xt_STORE(...) WSP_GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__)
|
|
192
|
-
#define WSP_GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg,
|
|
192
|
+
#define WSP_GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a)
|
|
193
193
|
#define WSP_GGML_F32xt_FMA(...) WSP_GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__)
|
|
194
194
|
#define WSP_GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b)
|
|
195
195
|
#define WSP_GGML_F32xt_ADD(...) WSP_GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__)
|
package/cpp/ggml-cpu/traits.cpp
CHANGED
|
@@ -10,7 +10,7 @@ extra_buffer_type::~extra_buffer_type() {}
|
|
|
10
10
|
} // namespace ggml::cpu
|
|
11
11
|
|
|
12
12
|
bool wsp_ggml_cpu_extra_compute_forward(struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * op) {
|
|
13
|
-
for (auto extra :
|
|
13
|
+
for (auto extra : wsp_ggml_backend_cpu_get_extra_buffer_types()) {
|
|
14
14
|
if (extra && extra->context) {
|
|
15
15
|
auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context;
|
|
16
16
|
auto tensor_traits = buf_extra->get_tensor_traits(op);
|
|
@@ -23,7 +23,7 @@ bool wsp_ggml_cpu_extra_compute_forward(struct wsp_ggml_compute_params * params,
|
|
|
23
23
|
}
|
|
24
24
|
|
|
25
25
|
bool wsp_ggml_cpu_extra_work_size(int n_threads, const struct wsp_ggml_tensor * op, size_t * size) {
|
|
26
|
-
for (auto extra :
|
|
26
|
+
for (auto extra : wsp_ggml_backend_cpu_get_extra_buffer_types()) {
|
|
27
27
|
if (extra && extra->context) {
|
|
28
28
|
auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context;
|
|
29
29
|
auto tensor_traits = buf_extra->get_tensor_traits(op);
|
package/cpp/ggml-cpu/traits.h
CHANGED
|
@@ -33,6 +33,6 @@ class extra_buffer_type {
|
|
|
33
33
|
} // namespace ggml::cpu
|
|
34
34
|
|
|
35
35
|
// implemented in ggml-cpu.cpp.
|
|
36
|
-
std::vector<wsp_ggml_backend_buffer_type_t> &
|
|
36
|
+
std::vector<wsp_ggml_backend_buffer_type_t> & wsp_ggml_backend_cpu_get_extra_buffer_types();
|
|
37
37
|
|
|
38
38
|
#endif
|
package/cpp/ggml-cpu/vec.cpp
CHANGED
|
@@ -37,35 +37,35 @@ void wsp_ggml_vec_dot_f32(int n, float * WSP_GGML_RESTRICT s, size_t bs, const f
|
|
|
37
37
|
for (int i = 0; i < np; i += wsp_ggml_f32_step) {
|
|
38
38
|
ax1 = WSP_GGML_F32_VEC_LOAD(x + i);
|
|
39
39
|
ay1 = WSP_GGML_F32_VEC_LOAD(y + i);
|
|
40
|
-
sum1 = WSP_GGML_F32_VEC_FMA(ax1, ay1
|
|
40
|
+
sum1 = WSP_GGML_F32_VEC_FMA(sum1, ax1, ay1);
|
|
41
41
|
|
|
42
42
|
ax2 = WSP_GGML_F32_VEC_LOAD(x + i + 1*wsp_ggml_f32_epr);
|
|
43
43
|
ay2 = WSP_GGML_F32_VEC_LOAD(y + i + 1*wsp_ggml_f32_epr);
|
|
44
|
-
sum2 = WSP_GGML_F32_VEC_FMA(ax2, ay2
|
|
44
|
+
sum2 = WSP_GGML_F32_VEC_FMA(sum2, ax2, ay2);
|
|
45
45
|
|
|
46
46
|
ax3 = WSP_GGML_F32_VEC_LOAD(x + i + 2*wsp_ggml_f32_epr);
|
|
47
47
|
ay3 = WSP_GGML_F32_VEC_LOAD(y + i + 2*wsp_ggml_f32_epr);
|
|
48
|
-
sum3 = WSP_GGML_F32_VEC_FMA(ax3, ay3
|
|
48
|
+
sum3 = WSP_GGML_F32_VEC_FMA(sum3, ax3, ay3);
|
|
49
49
|
|
|
50
50
|
ax4 = WSP_GGML_F32_VEC_LOAD(x + i + 3*wsp_ggml_f32_epr);
|
|
51
51
|
ay4 = WSP_GGML_F32_VEC_LOAD(y + i + 3*wsp_ggml_f32_epr);
|
|
52
|
-
sum4 = WSP_GGML_F32_VEC_FMA(ax4, ay4
|
|
52
|
+
sum4 = WSP_GGML_F32_VEC_FMA(sum4, ax4, ay4);
|
|
53
53
|
|
|
54
54
|
ax5 = WSP_GGML_F32_VEC_LOAD(x + i + 4*wsp_ggml_f32_epr);
|
|
55
55
|
ay5 = WSP_GGML_F32_VEC_LOAD(y + i + 4*wsp_ggml_f32_epr);
|
|
56
|
-
sum5 = WSP_GGML_F32_VEC_FMA(ax5, ay5
|
|
56
|
+
sum5 = WSP_GGML_F32_VEC_FMA(sum5, ax5, ay5);
|
|
57
57
|
|
|
58
58
|
ax6 = WSP_GGML_F32_VEC_LOAD(x + i + 5*wsp_ggml_f32_epr);
|
|
59
59
|
ay6 = WSP_GGML_F32_VEC_LOAD(y + i + 5*wsp_ggml_f32_epr);
|
|
60
|
-
sum6 = WSP_GGML_F32_VEC_FMA(ax6, ay6
|
|
60
|
+
sum6 = WSP_GGML_F32_VEC_FMA(sum6, ax6, ay6);
|
|
61
61
|
|
|
62
62
|
ax7 = WSP_GGML_F32_VEC_LOAD(x + i + 6*wsp_ggml_f32_epr);
|
|
63
63
|
ay7 = WSP_GGML_F32_VEC_LOAD(y + i + 6*wsp_ggml_f32_epr);
|
|
64
|
-
sum7 = WSP_GGML_F32_VEC_FMA(ax7, ay7
|
|
64
|
+
sum7 = WSP_GGML_F32_VEC_FMA(sum7, ax7, ay7);
|
|
65
65
|
|
|
66
66
|
ax8 = WSP_GGML_F32_VEC_LOAD(x + i + 7*wsp_ggml_f32_epr);
|
|
67
67
|
ay8 = WSP_GGML_F32_VEC_LOAD(y + i + 7*wsp_ggml_f32_epr);
|
|
68
|
-
sum8 = WSP_GGML_F32_VEC_FMA(ax8, ay8
|
|
68
|
+
sum8 = WSP_GGML_F32_VEC_FMA(sum8, ax8, ay8);
|
|
69
69
|
}
|
|
70
70
|
// leftovers
|
|
71
71
|
// Since 8 unrolls are done in above loop, leftovers lie in range [0, wsp_ggml_f32_step] which is handled in below loop
|
|
@@ -73,7 +73,7 @@ void wsp_ggml_vec_dot_f32(int n, float * WSP_GGML_RESTRICT s, size_t bs, const f
|
|
|
73
73
|
for (int i = np; i < np2; i += wsp_ggml_f32_epr) {
|
|
74
74
|
ax1 = WSP_GGML_F32_VEC_LOAD(x + i);
|
|
75
75
|
ay1 = WSP_GGML_F32_VEC_LOAD(y + i);
|
|
76
|
-
sum1 = WSP_GGML_F32_VEC_FMA(ax1, ay1
|
|
76
|
+
sum1 = WSP_GGML_F32_VEC_FMA(sum1, ax1, ay1);
|
|
77
77
|
}
|
|
78
78
|
// maximum number of leftover elements will be less that wsp_ggml_f32_epr. Apply predicated svmad on available elements only
|
|
79
79
|
if (np2 < n) {
|
|
@@ -221,6 +221,9 @@ void wsp_ggml_vec_dot_f16(int n, float * WSP_GGML_RESTRICT s, size_t bs, wsp_ggm
|
|
|
221
221
|
for (int i = np; i < n; ++i) {
|
|
222
222
|
sumf += (wsp_ggml_float)(WSP_GGML_CPU_FP16_TO_FP32(x[i])*WSP_GGML_CPU_FP16_TO_FP32(y[i]));
|
|
223
223
|
}
|
|
224
|
+
|
|
225
|
+
// if you hit this, you are likely running outside the FP range
|
|
226
|
+
assert(!isnan(sumf) && !isinf(sumf));
|
|
224
227
|
#else
|
|
225
228
|
for (int i = 0; i < n; ++i) {
|
|
226
229
|
sumf += (wsp_ggml_float)(WSP_GGML_CPU_FP16_TO_FP32(x[i])*WSP_GGML_CPU_FP16_TO_FP32(y[i]));
|
package/cpp/ggml-cpu/vec.h
CHANGED
|
@@ -55,7 +55,22 @@ inline static void wsp_ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t
|
|
|
55
55
|
|
|
56
56
|
inline static void wsp_ggml_vec_set_f16(const int n, wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
|
57
57
|
inline static void wsp_ggml_vec_set_bf16(const int n, wsp_ggml_bf16_t * x, const wsp_ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
|
|
58
|
-
|
|
58
|
+
|
|
59
|
+
inline static void wsp_ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) {
|
|
60
|
+
int i = 0;
|
|
61
|
+
#if defined(__AVX2__)
|
|
62
|
+
for (; i + 7 < n; i += 8) {
|
|
63
|
+
__m256 vx = _mm256_loadu_ps(x + i);
|
|
64
|
+
__m256 vy = _mm256_loadu_ps(y + i);
|
|
65
|
+
__m256 vz = _mm256_add_ps(vx, vy);
|
|
66
|
+
_mm256_storeu_ps(z + i, vz);
|
|
67
|
+
}
|
|
68
|
+
#endif
|
|
69
|
+
for (; i < n; ++i) {
|
|
70
|
+
z[i] = x[i] + y[i];
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
|
|
59
74
|
inline static void wsp_ggml_vec_add_f16 (const int n, wsp_ggml_fp16_t * z, const wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t * y) {
|
|
60
75
|
for (int i = 0; i < n; ++i) {
|
|
61
76
|
z[i] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_CPU_FP16_TO_FP32(x[i]) + WSP_GGML_CPU_FP16_TO_FP32(y[i]));
|
|
@@ -163,49 +178,49 @@ inline static void wsp_ggml_vec_mad_f32(const int n, float * WSP_GGML_RESTRICT y
|
|
|
163
178
|
|
|
164
179
|
ax1 = WSP_GGML_F32_VEC_LOAD(x + i);
|
|
165
180
|
ay1 = WSP_GGML_F32_VEC_LOAD(y + i);
|
|
166
|
-
ay1 = WSP_GGML_F32_VEC_FMA(ax1, vx
|
|
181
|
+
ay1 = WSP_GGML_F32_VEC_FMA(ay1, ax1, vx);
|
|
167
182
|
|
|
168
183
|
WSP_GGML_F32_VEC_STORE(y + i, ay1);
|
|
169
184
|
|
|
170
185
|
ax2 = WSP_GGML_F32_VEC_LOAD(x + i + 1*wsp_ggml_f32_epr);
|
|
171
186
|
ay2 = WSP_GGML_F32_VEC_LOAD(y + i + 1*wsp_ggml_f32_epr);
|
|
172
|
-
ay2 = WSP_GGML_F32_VEC_FMA(ax2, vx
|
|
187
|
+
ay2 = WSP_GGML_F32_VEC_FMA(ay2, ax2, vx);
|
|
173
188
|
|
|
174
189
|
WSP_GGML_F32_VEC_STORE(y + i + 1*wsp_ggml_f32_epr, ay2);
|
|
175
190
|
|
|
176
191
|
ax3 = WSP_GGML_F32_VEC_LOAD(x + i + 2*wsp_ggml_f32_epr);
|
|
177
192
|
ay3 = WSP_GGML_F32_VEC_LOAD(y + i + 2*wsp_ggml_f32_epr);
|
|
178
|
-
ay3 = WSP_GGML_F32_VEC_FMA(ax3, vx
|
|
193
|
+
ay3 = WSP_GGML_F32_VEC_FMA(ay3, ax3, vx);
|
|
179
194
|
|
|
180
195
|
WSP_GGML_F32_VEC_STORE(y + i + 2*wsp_ggml_f32_epr, ay3);
|
|
181
196
|
|
|
182
197
|
ax4 = WSP_GGML_F32_VEC_LOAD(x + i + 3*wsp_ggml_f32_epr);
|
|
183
198
|
ay4 = WSP_GGML_F32_VEC_LOAD(y + i + 3*wsp_ggml_f32_epr);
|
|
184
|
-
ay4 = WSP_GGML_F32_VEC_FMA(ax4, vx
|
|
199
|
+
ay4 = WSP_GGML_F32_VEC_FMA(ay4, ax4, vx);
|
|
185
200
|
|
|
186
201
|
WSP_GGML_F32_VEC_STORE(y + i + 3*wsp_ggml_f32_epr, ay4);
|
|
187
202
|
|
|
188
203
|
ax5 = WSP_GGML_F32_VEC_LOAD(x + i + 4*wsp_ggml_f32_epr);
|
|
189
204
|
ay5 = WSP_GGML_F32_VEC_LOAD(y + i + 4*wsp_ggml_f32_epr);
|
|
190
|
-
ay5 = WSP_GGML_F32_VEC_FMA(ax5, vx
|
|
205
|
+
ay5 = WSP_GGML_F32_VEC_FMA(ay5, ax5, vx);
|
|
191
206
|
|
|
192
207
|
WSP_GGML_F32_VEC_STORE(y + i + 4*wsp_ggml_f32_epr, ay5);
|
|
193
208
|
|
|
194
209
|
ax6 = WSP_GGML_F32_VEC_LOAD(x + i + 5*wsp_ggml_f32_epr);
|
|
195
210
|
ay6 = WSP_GGML_F32_VEC_LOAD(y + i + 5*wsp_ggml_f32_epr);
|
|
196
|
-
ay6 = WSP_GGML_F32_VEC_FMA(ax6, vx
|
|
211
|
+
ay6 = WSP_GGML_F32_VEC_FMA(ay6, ax6, vx);
|
|
197
212
|
|
|
198
213
|
WSP_GGML_F32_VEC_STORE(y + i + 5*wsp_ggml_f32_epr, ay6);
|
|
199
214
|
|
|
200
215
|
ax7 = WSP_GGML_F32_VEC_LOAD(x + i + 6*wsp_ggml_f32_epr);
|
|
201
216
|
ay7 = WSP_GGML_F32_VEC_LOAD(y + i + 6*wsp_ggml_f32_epr);
|
|
202
|
-
ay7 = WSP_GGML_F32_VEC_FMA(ax7, vx
|
|
217
|
+
ay7 = WSP_GGML_F32_VEC_FMA(ay7, ax7, vx);
|
|
203
218
|
|
|
204
219
|
WSP_GGML_F32_VEC_STORE(y + i + 6*wsp_ggml_f32_epr, ay7);
|
|
205
220
|
|
|
206
221
|
ax8 = WSP_GGML_F32_VEC_LOAD(x + i + 7*wsp_ggml_f32_epr);
|
|
207
222
|
ay8 = WSP_GGML_F32_VEC_LOAD(y + i + 7*wsp_ggml_f32_epr);
|
|
208
|
-
ay8 = WSP_GGML_F32_VEC_FMA(ax8, vx
|
|
223
|
+
ay8 = WSP_GGML_F32_VEC_FMA(ay8, ax8, vx);
|
|
209
224
|
|
|
210
225
|
WSP_GGML_F32_VEC_STORE(y + i + 7*wsp_ggml_f32_epr, ay8);
|
|
211
226
|
}
|
|
@@ -215,7 +230,7 @@ inline static void wsp_ggml_vec_mad_f32(const int n, float * WSP_GGML_RESTRICT y
|
|
|
215
230
|
for (int i = np; i < np2; i += wsp_ggml_f32_epr) {
|
|
216
231
|
ax1 = WSP_GGML_F32_VEC_LOAD(x + i);
|
|
217
232
|
ay1 = WSP_GGML_F32_VEC_LOAD(y + i);
|
|
218
|
-
ay1 = WSP_GGML_F32_VEC_FMA(ax1, vx
|
|
233
|
+
ay1 = WSP_GGML_F32_VEC_FMA(ay1, ax1, vx);
|
|
219
234
|
|
|
220
235
|
WSP_GGML_F32_VEC_STORE(y + i, ay1);
|
|
221
236
|
}
|
|
@@ -351,6 +366,45 @@ inline static void wsp_ggml_vec_mad_f32_unroll(const int n, const int xs, const
|
|
|
351
366
|
#endif
|
|
352
367
|
}
|
|
353
368
|
|
|
369
|
+
inline static void wsp_ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) {
|
|
370
|
+
#if defined(WSP_GGML_USE_ACCELERATE)
|
|
371
|
+
vDSP_vsmsa(x, 1, &s, &b, y, 1, n);
|
|
372
|
+
#elif defined(WSP_GGML_SIMD)
|
|
373
|
+
#if defined(__ARM_FEATURE_SVE)
|
|
374
|
+
// scalar ; TODO: Write SVE code
|
|
375
|
+
for (int i = 0; i < n; ++i) {
|
|
376
|
+
y[i] = x[i]*s + b;
|
|
377
|
+
}
|
|
378
|
+
#else
|
|
379
|
+
const int np = (n & ~(WSP_GGML_F32_STEP - 1));
|
|
380
|
+
|
|
381
|
+
WSP_GGML_F32_VEC vs = WSP_GGML_F32_VEC_SET1(s);
|
|
382
|
+
WSP_GGML_F32_VEC vb = WSP_GGML_F32_VEC_SET1(b);
|
|
383
|
+
|
|
384
|
+
WSP_GGML_F32_VEC ay[WSP_GGML_F32_ARR];
|
|
385
|
+
|
|
386
|
+
for (int i = 0; i < np; i += WSP_GGML_F32_STEP) {
|
|
387
|
+
for (int j = 0; j < WSP_GGML_F32_ARR; j++) {
|
|
388
|
+
ay[j] = WSP_GGML_F32_VEC_LOAD(x + i + j*WSP_GGML_F32_EPR);
|
|
389
|
+
ay[j] = WSP_GGML_F32_VEC_FMA(ay[j], vs, vb);
|
|
390
|
+
|
|
391
|
+
WSP_GGML_F32_VEC_STORE(y + i + j*WSP_GGML_F32_EPR, ay[j]);
|
|
392
|
+
}
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
// leftovers
|
|
396
|
+
for (int i = np; i < n; ++i) {
|
|
397
|
+
y[i] = x[i]*s + b;
|
|
398
|
+
}
|
|
399
|
+
#endif
|
|
400
|
+
#else
|
|
401
|
+
// scalar
|
|
402
|
+
for (int i = 0; i < n; ++i) {
|
|
403
|
+
y[i] = x[i]*s + b;
|
|
404
|
+
}
|
|
405
|
+
#endif
|
|
406
|
+
}
|
|
407
|
+
|
|
354
408
|
//inline static void wsp_ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
|
|
355
409
|
inline static void wsp_ggml_vec_scale_f32(const int n, float * y, const float v) {
|
|
356
410
|
#if defined(WSP_GGML_USE_ACCELERATE)
|
|
@@ -953,9 +1007,49 @@ void wsp_ggml_vec_swiglu_f32(const int n, float * y, const float * x, const floa
|
|
|
953
1007
|
|
|
954
1008
|
inline static void wsp_ggml_vec_swiglu_f16(const int n, wsp_ggml_fp16_t * y, const wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t * g) {
|
|
955
1009
|
for (int i = 0; i < n; ++i) {
|
|
956
|
-
float
|
|
957
|
-
float
|
|
958
|
-
y[i] = WSP_GGML_CPU_FP32_TO_FP16((
|
|
1010
|
+
float xi = WSP_GGML_CPU_FP16_TO_FP32(x[i]);
|
|
1011
|
+
float gi = WSP_GGML_CPU_FP16_TO_FP32(g[i]);
|
|
1012
|
+
y[i] = WSP_GGML_CPU_FP32_TO_FP16((xi/(1.0f + expf(-xi))) * gi);
|
|
1013
|
+
}
|
|
1014
|
+
}
|
|
1015
|
+
|
|
1016
|
+
inline static void wsp_ggml_vec_geglu_erf_f32(const int n, float * y, const float * x, const float * g) {
|
|
1017
|
+
for (int i = 0; i < n; ++i) {
|
|
1018
|
+
float xi = x[i];
|
|
1019
|
+
y[i] = 0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * g[i];
|
|
1020
|
+
}
|
|
1021
|
+
}
|
|
1022
|
+
|
|
1023
|
+
inline static void wsp_ggml_vec_geglu_erf_f16(const int n, wsp_ggml_fp16_t * y, const wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t * g) {
|
|
1024
|
+
for (int i = 0; i < n; ++i) {
|
|
1025
|
+
float xi = WSP_GGML_CPU_FP16_TO_FP32(x[i]);
|
|
1026
|
+
float gi = WSP_GGML_CPU_FP16_TO_FP32(g[i]);
|
|
1027
|
+
y[i] = WSP_GGML_CPU_FP32_TO_FP16(0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * gi);
|
|
1028
|
+
}
|
|
1029
|
+
}
|
|
1030
|
+
|
|
1031
|
+
#ifdef WSP_GGML_GELU_QUICK_FP16
|
|
1032
|
+
inline static void wsp_ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
|
|
1033
|
+
uint16_t t;
|
|
1034
|
+
for (int i = 0; i < n; ++i) {
|
|
1035
|
+
wsp_ggml_fp16_t fp16 = WSP_GGML_CPU_FP32_TO_FP16(x[i]);
|
|
1036
|
+
memcpy(&t, &fp16, sizeof(uint16_t));
|
|
1037
|
+
y[i] = WSP_GGML_CPU_FP16_TO_FP32(wsp_ggml_table_gelu_quick_f16[t]) * g[i];
|
|
1038
|
+
}
|
|
1039
|
+
}
|
|
1040
|
+
#else
|
|
1041
|
+
inline static void wsp_ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
|
|
1042
|
+
for (int i = 0; i < n; ++i) {
|
|
1043
|
+
y[i] = wsp_ggml_gelu_quick_f32(x[i]) * g[i];
|
|
1044
|
+
}
|
|
1045
|
+
}
|
|
1046
|
+
#endif
|
|
1047
|
+
|
|
1048
|
+
inline static void wsp_ggml_vec_geglu_quick_f16(const int n, wsp_ggml_fp16_t * y, const wsp_ggml_fp16_t * x, const wsp_ggml_fp16_t * g) {
|
|
1049
|
+
const uint16_t * i16 = (const uint16_t *) x;
|
|
1050
|
+
for (int i = 0; i < n; ++i) {
|
|
1051
|
+
float v = WSP_GGML_CPU_FP16_TO_FP32(g[i]);
|
|
1052
|
+
y[i] = WSP_GGML_CPU_FP32_TO_FP16(WSP_GGML_CPU_FP16_TO_FP32(wsp_ggml_table_gelu_quick_f16[i16[i]]) * v);
|
|
959
1053
|
}
|
|
960
1054
|
}
|
|
961
1055
|
|
package/cpp/ggml-impl.h
CHANGED
|
@@ -73,6 +73,22 @@ static inline int wsp_ggml_up(int n, int m) {
|
|
|
73
73
|
return (n + m - 1) & ~(m - 1);
|
|
74
74
|
}
|
|
75
75
|
|
|
76
|
+
// TODO: move to ggml.h?
|
|
77
|
+
static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b) {
|
|
78
|
+
if (a->type != b->type) {
|
|
79
|
+
return false;
|
|
80
|
+
}
|
|
81
|
+
for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) {
|
|
82
|
+
if (a->ne[i] != b->ne[i]) {
|
|
83
|
+
return false;
|
|
84
|
+
}
|
|
85
|
+
if (a->nb[i] != b->nb[i]) {
|
|
86
|
+
return false;
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
return true;
|
|
90
|
+
}
|
|
91
|
+
|
|
76
92
|
//
|
|
77
93
|
// logging
|
|
78
94
|
//
|
|
@@ -394,6 +410,67 @@ static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) {
|
|
|
394
410
|
#define WSP_GGML_FP16_TO_FP32(x) WSP_GGML_COMPUTE_FP16_TO_FP32(x)
|
|
395
411
|
#define WSP_GGML_FP32_TO_FP16(x) WSP_GGML_COMPUTE_FP32_TO_FP16(x)
|
|
396
412
|
|
|
413
|
+
static inline float wsp_ggml_e8m0_to_fp32(uint8_t x) {
|
|
414
|
+
uint32_t bits; // Stores the raw bit representation of the float
|
|
415
|
+
|
|
416
|
+
// Handle special case for minimum exponent (denormalized float)
|
|
417
|
+
if (x == 0) {
|
|
418
|
+
// Bit pattern for 2^(-127):
|
|
419
|
+
// - Sign bit: 0 (positive)
|
|
420
|
+
// - Exponent: 0 (denormalized number)
|
|
421
|
+
// - Mantissa: 0x400000 (0.5 in fractional form)
|
|
422
|
+
// Value = 0.5 * 2^(-126) = 2^(-127)
|
|
423
|
+
bits = 0x00400000;
|
|
424
|
+
}
|
|
425
|
+
// note: disabled as we don't need to handle NaNs
|
|
426
|
+
//// Handle special case for NaN (all bits set)
|
|
427
|
+
//else if (x == 0xFF) {
|
|
428
|
+
// // Standard quiet NaN pattern:
|
|
429
|
+
// // - Sign bit: 0
|
|
430
|
+
// // - Exponent: all 1s (0xFF)
|
|
431
|
+
// // - Mantissa: 0x400000 (quiet NaN flag)
|
|
432
|
+
// bits = 0x7FC00000;
|
|
433
|
+
//}
|
|
434
|
+
// Normalized values (most common case)
|
|
435
|
+
else {
|
|
436
|
+
// Construct normalized float by shifting exponent into position:
|
|
437
|
+
// - Exponent field: 8 bits (positions 30-23)
|
|
438
|
+
// - Mantissa: 0 (implicit leading 1)
|
|
439
|
+
// Value = 2^(x - 127)
|
|
440
|
+
bits = (uint32_t) x << 23;
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
float result; // Final float value
|
|
444
|
+
// Safely reinterpret bit pattern as float without type-punning issues
|
|
445
|
+
memcpy(&result, &bits, sizeof(float));
|
|
446
|
+
return result;
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
// Equal to wsp_ggml_e8m0_to_fp32/2
|
|
450
|
+
// Useful with MXFP4 quantization since the E0M2 values are doubled
|
|
451
|
+
static inline float wsp_ggml_e8m0_to_fp32_half(uint8_t x) {
|
|
452
|
+
uint32_t bits;
|
|
453
|
+
|
|
454
|
+
// For x < 2: use precomputed denormal patterns
|
|
455
|
+
if (x < 2) {
|
|
456
|
+
// 0x00200000 = 2^(-128), 0x00400000 = 2^(-127)
|
|
457
|
+
bits = 0x00200000 << x;
|
|
458
|
+
}
|
|
459
|
+
// For x >= 2: normalized exponent adjustment
|
|
460
|
+
else {
|
|
461
|
+
// 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1)
|
|
462
|
+
bits = (uint32_t)(x - 1) << 23;
|
|
463
|
+
}
|
|
464
|
+
// Note: NaNs are not handled here
|
|
465
|
+
|
|
466
|
+
float result;
|
|
467
|
+
memcpy(&result, &bits, sizeof(float));
|
|
468
|
+
return result;
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
#define WSP_GGML_E8M0_TO_FP32(x) wsp_ggml_e8m0_to_fp32(x)
|
|
472
|
+
#define WSP_GGML_E8M0_TO_FP32_HALF(x) wsp_ggml_e8m0_to_fp32_half(x)
|
|
473
|
+
|
|
397
474
|
/**
|
|
398
475
|
* Converts brain16 to float32.
|
|
399
476
|
*
|
package/cpp/ggml-metal-impl.h
CHANGED
|
@@ -23,6 +23,9 @@
|
|
|
23
23
|
#define N_R0_Q8_0 4
|
|
24
24
|
#define N_SG_Q8_0 2
|
|
25
25
|
|
|
26
|
+
#define N_R0_MXFP4 2
|
|
27
|
+
#define N_SG_MXFP4 2
|
|
28
|
+
|
|
26
29
|
#define N_R0_Q2_K 4
|
|
27
30
|
#define N_SG_Q2_K 2
|
|
28
31
|
|
|
@@ -126,8 +129,18 @@ typedef struct {
|
|
|
126
129
|
uint64_t nb2;
|
|
127
130
|
uint64_t nb3;
|
|
128
131
|
uint64_t offs;
|
|
132
|
+
uint64_t o1[8];
|
|
129
133
|
} wsp_ggml_metal_kargs_bin;
|
|
130
134
|
|
|
135
|
+
typedef struct {
|
|
136
|
+
int64_t ne0;
|
|
137
|
+
int64_t ne1;
|
|
138
|
+
size_t nb01;
|
|
139
|
+
size_t nb02;
|
|
140
|
+
size_t nb11;
|
|
141
|
+
size_t nb21;
|
|
142
|
+
} wsp_ggml_metal_kargs_add_id;
|
|
143
|
+
|
|
131
144
|
typedef struct {
|
|
132
145
|
int32_t ne00;
|
|
133
146
|
int32_t ne01;
|
|
@@ -229,14 +242,18 @@ typedef struct {
|
|
|
229
242
|
uint64_t nb21;
|
|
230
243
|
uint64_t nb22;
|
|
231
244
|
uint64_t nb23;
|
|
245
|
+
int32_t ne32;
|
|
246
|
+
int32_t ne33;
|
|
232
247
|
uint64_t nb31;
|
|
248
|
+
uint64_t nb32;
|
|
249
|
+
uint64_t nb33;
|
|
233
250
|
int32_t ne1;
|
|
234
251
|
int32_t ne2;
|
|
235
252
|
float scale;
|
|
236
253
|
float max_bias;
|
|
237
254
|
float m0;
|
|
238
255
|
float m1;
|
|
239
|
-
|
|
256
|
+
int32_t n_head_log2;
|
|
240
257
|
float logit_softcap;
|
|
241
258
|
} wsp_ggml_metal_kargs_flash_attn_ext;
|
|
242
259
|
|
|
@@ -373,8 +390,16 @@ typedef struct {
|
|
|
373
390
|
typedef struct {
|
|
374
391
|
int32_t ne00;
|
|
375
392
|
int32_t ne00_4;
|
|
376
|
-
uint64_t
|
|
393
|
+
uint64_t nb1;
|
|
394
|
+
uint64_t nb2;
|
|
395
|
+
uint64_t nb3;
|
|
377
396
|
float eps;
|
|
397
|
+
int32_t nef1[3];
|
|
398
|
+
int32_t nef2[3];
|
|
399
|
+
int32_t nef3[3];
|
|
400
|
+
uint64_t nbf1[3];
|
|
401
|
+
uint64_t nbf2[3];
|
|
402
|
+
uint64_t nbf3[3];
|
|
378
403
|
} wsp_ggml_metal_kargs_rms_norm;
|
|
379
404
|
|
|
380
405
|
typedef struct {
|
|
@@ -431,6 +456,8 @@ typedef struct{
|
|
|
431
456
|
uint64_t nb1;
|
|
432
457
|
int32_t i00;
|
|
433
458
|
int32_t i10;
|
|
459
|
+
float alpha;
|
|
460
|
+
float limit;
|
|
434
461
|
} wsp_ggml_metal_kargs_glu;
|
|
435
462
|
|
|
436
463
|
typedef struct {
|
|
@@ -461,14 +488,26 @@ typedef struct {
|
|
|
461
488
|
} wsp_ggml_metal_kargs_sum_rows;
|
|
462
489
|
|
|
463
490
|
typedef struct {
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
491
|
+
int32_t ne00;
|
|
492
|
+
int32_t ne01;
|
|
493
|
+
int32_t ne02;
|
|
494
|
+
uint64_t nb01;
|
|
495
|
+
uint64_t nb02;
|
|
496
|
+
uint64_t nb03;
|
|
497
|
+
int32_t ne11;
|
|
498
|
+
int32_t ne12;
|
|
499
|
+
int32_t ne13;
|
|
500
|
+
uint64_t nb11;
|
|
501
|
+
uint64_t nb12;
|
|
502
|
+
uint64_t nb13;
|
|
503
|
+
uint64_t nb1;
|
|
504
|
+
uint64_t nb2;
|
|
505
|
+
uint64_t nb3;
|
|
467
506
|
float scale;
|
|
468
507
|
float max_bias;
|
|
469
508
|
float m0;
|
|
470
509
|
float m1;
|
|
471
|
-
|
|
510
|
+
int32_t n_head_log2;
|
|
472
511
|
} wsp_ggml_metal_kargs_soft_max;
|
|
473
512
|
|
|
474
513
|
typedef struct {
|
|
@@ -499,26 +538,26 @@ typedef struct {
|
|
|
499
538
|
typedef struct {
|
|
500
539
|
int64_t d_state;
|
|
501
540
|
int64_t d_inner;
|
|
541
|
+
int64_t n_head;
|
|
542
|
+
int64_t n_group;
|
|
502
543
|
int64_t n_seq_tokens;
|
|
503
544
|
int64_t n_seqs;
|
|
504
|
-
|
|
545
|
+
int64_t s_off;
|
|
505
546
|
uint64_t nb01;
|
|
506
547
|
uint64_t nb02;
|
|
507
|
-
uint64_t
|
|
548
|
+
uint64_t nb03;
|
|
508
549
|
uint64_t nb11;
|
|
509
550
|
uint64_t nb12;
|
|
510
551
|
uint64_t nb13;
|
|
511
|
-
uint64_t nb20;
|
|
512
552
|
uint64_t nb21;
|
|
513
553
|
uint64_t nb22;
|
|
514
|
-
uint64_t nb30;
|
|
515
554
|
uint64_t nb31;
|
|
516
|
-
uint64_t nb40;
|
|
517
555
|
uint64_t nb41;
|
|
518
556
|
uint64_t nb42;
|
|
519
|
-
uint64_t
|
|
557
|
+
uint64_t nb43;
|
|
520
558
|
uint64_t nb51;
|
|
521
559
|
uint64_t nb52;
|
|
560
|
+
uint64_t nb53;
|
|
522
561
|
} wsp_ggml_metal_kargs_ssm_scan;
|
|
523
562
|
|
|
524
563
|
typedef struct {
|