whisper.rn 0.4.0-rc.3 → 0.4.0-rc.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +2 -0
- package/android/src/main/java/com/rnwhisper/RNWhisper.java +6 -1
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +3 -3
- package/android/src/main/jni.cpp +6 -2
- package/cpp/ggml-alloc.c +413 -280
- package/cpp/ggml-alloc.h +67 -8
- package/cpp/ggml-backend-impl.h +87 -0
- package/cpp/ggml-backend.c +950 -0
- package/cpp/ggml-backend.h +136 -0
- package/cpp/ggml-impl.h +243 -0
- package/cpp/{ggml-metal.metal → ggml-metal-whisper.metal} +591 -121
- package/cpp/ggml-metal.h +21 -0
- package/cpp/ggml-metal.m +623 -234
- package/cpp/ggml-quants.c +7377 -0
- package/cpp/ggml-quants.h +224 -0
- package/cpp/ggml.c +3773 -4455
- package/cpp/ggml.h +279 -146
- package/cpp/whisper.cpp +182 -103
- package/cpp/whisper.h +48 -11
- package/ios/RNWhisper.mm +8 -2
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +4 -0
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
- package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +19 -0
- package/ios/RNWhisperContext.h +5 -1
- package/ios/RNWhisperContext.mm +76 -10
- package/jest/mock.js +1 -1
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/index.js +28 -9
- package/lib/commonjs/index.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/index.js +28 -9
- package/lib/module/index.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +7 -1
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +7 -2
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +1 -1
- package/src/NativeRNWhisper.ts +8 -1
- package/src/index.ts +29 -17
- package/src/version.json +1 -1
- package/whisper-rn.podspec +1 -2
package/cpp/ggml-metal.m
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
#import "ggml-metal.h"
|
|
2
2
|
|
|
3
|
+
#import "ggml-backend-impl.h"
|
|
3
4
|
#import "ggml.h"
|
|
4
5
|
|
|
5
6
|
#import <Foundation/Foundation.h>
|
|
@@ -11,16 +12,19 @@
|
|
|
11
12
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
|
12
13
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
|
13
14
|
|
|
14
|
-
// TODO: temporary - reuse llama.cpp logging
|
|
15
15
|
#ifdef WSP_GGML_METAL_NDEBUG
|
|
16
|
-
#define
|
|
16
|
+
#define WSP_GGML_METAL_LOG_INFO(...)
|
|
17
|
+
#define WSP_GGML_METAL_LOG_WARN(...)
|
|
18
|
+
#define WSP_GGML_METAL_LOG_ERROR(...)
|
|
17
19
|
#else
|
|
18
|
-
#define
|
|
20
|
+
#define WSP_GGML_METAL_LOG_INFO(...) wsp_ggml_metal_log(WSP_GGML_LOG_LEVEL_INFO, __VA_ARGS__)
|
|
21
|
+
#define WSP_GGML_METAL_LOG_WARN(...) wsp_ggml_metal_log(WSP_GGML_LOG_LEVEL_WARN, __VA_ARGS__)
|
|
22
|
+
#define WSP_GGML_METAL_LOG_ERROR(...) wsp_ggml_metal_log(WSP_GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
|
19
23
|
#endif
|
|
20
24
|
|
|
21
25
|
#define UNUSED(x) (void)(x)
|
|
22
26
|
|
|
23
|
-
#define WSP_GGML_MAX_CONCUR (2*
|
|
27
|
+
#define WSP_GGML_MAX_CONCUR (2*WSP_GGML_DEFAULT_GRAPH_SIZE)
|
|
24
28
|
|
|
25
29
|
struct wsp_ggml_metal_buffer {
|
|
26
30
|
const char * name;
|
|
@@ -59,6 +63,7 @@ struct wsp_ggml_metal_context {
|
|
|
59
63
|
WSP_GGML_METAL_DECL_KERNEL(mul);
|
|
60
64
|
WSP_GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
|
61
65
|
WSP_GGML_METAL_DECL_KERNEL(scale);
|
|
66
|
+
WSP_GGML_METAL_DECL_KERNEL(scale_4);
|
|
62
67
|
WSP_GGML_METAL_DECL_KERNEL(silu);
|
|
63
68
|
WSP_GGML_METAL_DECL_KERNEL(relu);
|
|
64
69
|
WSP_GGML_METAL_DECL_KERNEL(gelu);
|
|
@@ -70,6 +75,8 @@ struct wsp_ggml_metal_context {
|
|
|
70
75
|
WSP_GGML_METAL_DECL_KERNEL(get_rows_f16);
|
|
71
76
|
WSP_GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
|
72
77
|
WSP_GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
|
78
|
+
WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_0);
|
|
79
|
+
WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_1);
|
|
73
80
|
WSP_GGML_METAL_DECL_KERNEL(get_rows_q8_0);
|
|
74
81
|
WSP_GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
|
75
82
|
WSP_GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
|
@@ -78,33 +85,40 @@ struct wsp_ggml_metal_context {
|
|
|
78
85
|
WSP_GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
|
79
86
|
WSP_GGML_METAL_DECL_KERNEL(rms_norm);
|
|
80
87
|
WSP_GGML_METAL_DECL_KERNEL(norm);
|
|
81
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
82
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
83
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
84
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
85
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
86
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
87
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
88
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
89
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
90
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
91
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
92
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
88
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
|
89
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
|
|
90
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
|
|
91
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
|
|
92
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
|
|
93
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
|
|
94
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
|
|
95
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
|
|
96
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
|
|
97
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
|
|
98
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
|
|
99
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
|
|
100
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
|
101
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
|
93
102
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
|
94
103
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
|
95
104
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
|
96
105
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
|
106
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
|
|
107
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
|
|
97
108
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
|
|
98
109
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
|
|
99
110
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
|
|
100
111
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
|
101
112
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
|
102
113
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
|
103
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
114
|
+
WSP_GGML_METAL_DECL_KERNEL(rope_f32);
|
|
115
|
+
WSP_GGML_METAL_DECL_KERNEL(rope_f16);
|
|
104
116
|
WSP_GGML_METAL_DECL_KERNEL(alibi_f32);
|
|
105
117
|
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
|
106
118
|
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
|
107
119
|
WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
|
120
|
+
WSP_GGML_METAL_DECL_KERNEL(concat);
|
|
121
|
+
WSP_GGML_METAL_DECL_KERNEL(sqr);
|
|
108
122
|
|
|
109
123
|
#undef WSP_GGML_METAL_DECL_KERNEL
|
|
110
124
|
};
|
|
@@ -115,13 +129,42 @@ struct wsp_ggml_metal_context {
|
|
|
115
129
|
static NSString * const msl_library_source = @"see metal.metal";
|
|
116
130
|
|
|
117
131
|
// Here to assist with NSBundle Path Hack
|
|
118
|
-
@interface
|
|
132
|
+
@interface WSPGGMLMetalClass : NSObject
|
|
119
133
|
@end
|
|
120
|
-
@implementation
|
|
134
|
+
@implementation WSPGGMLMetalClass
|
|
121
135
|
@end
|
|
122
136
|
|
|
137
|
+
wsp_ggml_log_callback wsp_ggml_metal_log_callback = NULL;
|
|
138
|
+
void * wsp_ggml_metal_log_user_data = NULL;
|
|
139
|
+
|
|
140
|
+
void wsp_ggml_metal_log_set_callback(wsp_ggml_log_callback log_callback, void * user_data) {
|
|
141
|
+
wsp_ggml_metal_log_callback = log_callback;
|
|
142
|
+
wsp_ggml_metal_log_user_data = user_data;
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char* format, ...){
|
|
146
|
+
if (wsp_ggml_metal_log_callback != NULL) {
|
|
147
|
+
va_list args;
|
|
148
|
+
va_start(args, format);
|
|
149
|
+
char buffer[128];
|
|
150
|
+
int len = vsnprintf(buffer, 128, format, args);
|
|
151
|
+
if (len < 128) {
|
|
152
|
+
wsp_ggml_metal_log_callback(level, buffer, wsp_ggml_metal_log_user_data);
|
|
153
|
+
} else {
|
|
154
|
+
char* buffer2 = malloc(len+1);
|
|
155
|
+
vsnprintf(buffer2, len+1, format, args);
|
|
156
|
+
buffer2[len] = 0;
|
|
157
|
+
wsp_ggml_metal_log_callback(level, buffer2, wsp_ggml_metal_log_user_data);
|
|
158
|
+
free(buffer2);
|
|
159
|
+
}
|
|
160
|
+
va_end(args);
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
|
|
123
166
|
struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
124
|
-
|
|
167
|
+
WSP_GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
|
|
125
168
|
|
|
126
169
|
id <MTLDevice> device;
|
|
127
170
|
NSString * s;
|
|
@@ -131,14 +174,14 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
131
174
|
NSArray * devices = MTLCopyAllDevices();
|
|
132
175
|
for (device in devices) {
|
|
133
176
|
s = [device name];
|
|
134
|
-
|
|
177
|
+
WSP_GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]);
|
|
135
178
|
}
|
|
136
179
|
#endif
|
|
137
180
|
|
|
138
181
|
// Pick and show default Metal device
|
|
139
182
|
device = MTLCreateSystemDefaultDevice();
|
|
140
183
|
s = [device name];
|
|
141
|
-
|
|
184
|
+
WSP_GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
|
|
142
185
|
|
|
143
186
|
// Configure context
|
|
144
187
|
struct wsp_ggml_metal_context * ctx = malloc(sizeof(struct wsp_ggml_metal_context));
|
|
@@ -150,68 +193,69 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
150
193
|
|
|
151
194
|
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
|
152
195
|
|
|
153
|
-
|
|
154
|
-
// load the default.metallib file
|
|
196
|
+
// load library
|
|
155
197
|
{
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
|
|
160
|
-
NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
|
|
161
|
-
NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
|
|
162
|
-
NSURL * libURL = [NSURL fileURLWithPath:libPath];
|
|
163
|
-
|
|
164
|
-
// Load the metallib file into a Metal library
|
|
165
|
-
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
|
166
|
-
|
|
167
|
-
if (error) {
|
|
168
|
-
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
|
169
|
-
return NULL;
|
|
170
|
-
}
|
|
171
|
-
}
|
|
198
|
+
NSBundle * bundle = nil;
|
|
199
|
+
#ifdef SWIFT_PACKAGE
|
|
200
|
+
bundle = SWIFTPM_MODULE_BUNDLE;
|
|
172
201
|
#else
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
// read the source from "ggml-metal.metal" into a string and use newLibraryWithSource
|
|
176
|
-
{
|
|
202
|
+
bundle = [NSBundle bundleForClass:[WSPGGMLMetalClass class]];
|
|
203
|
+
#endif
|
|
177
204
|
NSError * error = nil;
|
|
205
|
+
NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"];
|
|
206
|
+
if (libPath != nil) {
|
|
207
|
+
NSURL * libURL = [NSURL fileURLWithPath:libPath];
|
|
208
|
+
WSP_GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
|
|
209
|
+
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
|
210
|
+
} else {
|
|
211
|
+
WSP_GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
|
|
212
|
+
|
|
213
|
+
NSString * sourcePath;
|
|
214
|
+
NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"WSP_GGML_METAL_PATH_RESOURCES"];
|
|
215
|
+
if (ggmlMetalPathResources) {
|
|
216
|
+
sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
|
|
217
|
+
} else {
|
|
218
|
+
sourcePath = [bundle pathForResource:@"ggml-metal-whisper" ofType:@"metal"];
|
|
219
|
+
}
|
|
220
|
+
if (sourcePath == nil) {
|
|
221
|
+
WSP_GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
|
|
222
|
+
sourcePath = @"ggml-metal.metal";
|
|
223
|
+
}
|
|
224
|
+
WSP_GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [sourcePath UTF8String]);
|
|
225
|
+
NSString * src = [NSString stringWithContentsOfFile:sourcePath encoding:NSUTF8StringEncoding error:&error];
|
|
226
|
+
if (error) {
|
|
227
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
|
228
|
+
return NULL;
|
|
229
|
+
}
|
|
178
230
|
|
|
179
|
-
|
|
180
|
-
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
|
181
|
-
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
|
182
|
-
metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
|
|
183
|
-
|
|
184
|
-
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
|
185
|
-
if (error) {
|
|
186
|
-
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
|
187
|
-
return NULL;
|
|
188
|
-
}
|
|
189
|
-
|
|
231
|
+
MTLCompileOptions* options = nil;
|
|
190
232
|
#ifdef WSP_GGML_QKK_64
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
|
|
194
|
-
#else
|
|
195
|
-
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
|
|
233
|
+
options = [MTLCompileOptions new];
|
|
234
|
+
options.preprocessorMacros = @{ @"QK_K" : @(64) };
|
|
196
235
|
#endif
|
|
236
|
+
ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
|
|
237
|
+
}
|
|
238
|
+
|
|
197
239
|
if (error) {
|
|
198
|
-
|
|
240
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
|
199
241
|
return NULL;
|
|
200
242
|
}
|
|
201
243
|
}
|
|
202
|
-
#endif
|
|
203
244
|
|
|
204
245
|
// load kernels
|
|
205
246
|
{
|
|
206
247
|
NSError * error = nil;
|
|
248
|
+
|
|
249
|
+
/*
|
|
250
|
+
WSP_GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
|
|
251
|
+
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
|
|
252
|
+
(int) ctx->pipeline_##name.threadExecutionWidth); \
|
|
253
|
+
*/
|
|
207
254
|
#define WSP_GGML_METAL_ADD_KERNEL(name) \
|
|
208
255
|
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
|
209
256
|
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
|
210
|
-
metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (__bridge void *) ctx->pipeline_##name, \
|
|
211
|
-
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
|
|
212
|
-
(int) ctx->pipeline_##name.threadExecutionWidth); \
|
|
213
257
|
if (error) { \
|
|
214
|
-
|
|
258
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
|
215
259
|
return NULL; \
|
|
216
260
|
}
|
|
217
261
|
|
|
@@ -220,6 +264,7 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
220
264
|
WSP_GGML_METAL_ADD_KERNEL(mul);
|
|
221
265
|
WSP_GGML_METAL_ADD_KERNEL(mul_row);
|
|
222
266
|
WSP_GGML_METAL_ADD_KERNEL(scale);
|
|
267
|
+
WSP_GGML_METAL_ADD_KERNEL(scale_4);
|
|
223
268
|
WSP_GGML_METAL_ADD_KERNEL(silu);
|
|
224
269
|
WSP_GGML_METAL_ADD_KERNEL(relu);
|
|
225
270
|
WSP_GGML_METAL_ADD_KERNEL(gelu);
|
|
@@ -231,6 +276,8 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
231
276
|
WSP_GGML_METAL_ADD_KERNEL(get_rows_f16);
|
|
232
277
|
WSP_GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
|
233
278
|
WSP_GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
|
279
|
+
WSP_GGML_METAL_ADD_KERNEL(get_rows_q5_0);
|
|
280
|
+
WSP_GGML_METAL_ADD_KERNEL(get_rows_q5_1);
|
|
234
281
|
WSP_GGML_METAL_ADD_KERNEL(get_rows_q8_0);
|
|
235
282
|
WSP_GGML_METAL_ADD_KERNEL(get_rows_q2_K);
|
|
236
283
|
WSP_GGML_METAL_ADD_KERNEL(get_rows_q3_K);
|
|
@@ -239,44 +286,66 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
239
286
|
WSP_GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
|
240
287
|
WSP_GGML_METAL_ADD_KERNEL(rms_norm);
|
|
241
288
|
WSP_GGML_METAL_ADD_KERNEL(norm);
|
|
242
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
243
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
244
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
245
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
246
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
247
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
248
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
249
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
250
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
251
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
252
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
253
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
254
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
255
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
289
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
|
290
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
|
|
291
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
|
|
292
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
|
|
293
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
|
|
294
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
|
|
295
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
|
|
296
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
|
|
297
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
|
|
298
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
|
|
299
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
|
|
300
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
|
|
301
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
|
|
302
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
|
|
303
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
|
304
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
|
305
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
|
306
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
|
307
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
|
|
308
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
|
|
309
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
|
|
310
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
|
311
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
|
|
312
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
|
|
313
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
|
314
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
|
315
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
|
316
|
+
}
|
|
317
|
+
WSP_GGML_METAL_ADD_KERNEL(rope_f32);
|
|
318
|
+
WSP_GGML_METAL_ADD_KERNEL(rope_f16);
|
|
265
319
|
WSP_GGML_METAL_ADD_KERNEL(alibi_f32);
|
|
266
320
|
WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
|
267
321
|
WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
|
268
322
|
WSP_GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
|
323
|
+
WSP_GGML_METAL_ADD_KERNEL(concat);
|
|
324
|
+
WSP_GGML_METAL_ADD_KERNEL(sqr);
|
|
269
325
|
|
|
270
326
|
#undef WSP_GGML_METAL_ADD_KERNEL
|
|
271
327
|
}
|
|
272
328
|
|
|
273
|
-
metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
|
274
329
|
#if TARGET_OS_OSX
|
|
275
|
-
|
|
330
|
+
// print MTL GPU family:
|
|
331
|
+
WSP_GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
|
|
332
|
+
|
|
333
|
+
// determine max supported GPU family
|
|
334
|
+
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
|
335
|
+
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
|
336
|
+
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
|
337
|
+
if ([ctx->device supportsFamily:i]) {
|
|
338
|
+
WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
|
|
339
|
+
break;
|
|
340
|
+
}
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
WSP_GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
|
344
|
+
WSP_GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
|
276
345
|
if (ctx->device.maxTransferRate != 0) {
|
|
277
|
-
|
|
346
|
+
WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
|
278
347
|
} else {
|
|
279
|
-
|
|
348
|
+
WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
|
|
280
349
|
}
|
|
281
350
|
#endif
|
|
282
351
|
|
|
@@ -284,7 +353,7 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
284
353
|
}
|
|
285
354
|
|
|
286
355
|
void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
|
|
287
|
-
|
|
356
|
+
WSP_GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
|
|
288
357
|
#define WSP_GGML_METAL_DEL_KERNEL(name) \
|
|
289
358
|
|
|
290
359
|
WSP_GGML_METAL_DEL_KERNEL(add);
|
|
@@ -292,6 +361,7 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
|
|
|
292
361
|
WSP_GGML_METAL_DEL_KERNEL(mul);
|
|
293
362
|
WSP_GGML_METAL_DEL_KERNEL(mul_row);
|
|
294
363
|
WSP_GGML_METAL_DEL_KERNEL(scale);
|
|
364
|
+
WSP_GGML_METAL_DEL_KERNEL(scale_4);
|
|
295
365
|
WSP_GGML_METAL_DEL_KERNEL(silu);
|
|
296
366
|
WSP_GGML_METAL_DEL_KERNEL(relu);
|
|
297
367
|
WSP_GGML_METAL_DEL_KERNEL(gelu);
|
|
@@ -303,6 +373,8 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
|
|
|
303
373
|
WSP_GGML_METAL_DEL_KERNEL(get_rows_f16);
|
|
304
374
|
WSP_GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
|
305
375
|
WSP_GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
|
376
|
+
WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_0);
|
|
377
|
+
WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_1);
|
|
306
378
|
WSP_GGML_METAL_DEL_KERNEL(get_rows_q8_0);
|
|
307
379
|
WSP_GGML_METAL_DEL_KERNEL(get_rows_q2_K);
|
|
308
380
|
WSP_GGML_METAL_DEL_KERNEL(get_rows_q3_K);
|
|
@@ -311,33 +383,42 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
|
|
|
311
383
|
WSP_GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
|
312
384
|
WSP_GGML_METAL_DEL_KERNEL(rms_norm);
|
|
313
385
|
WSP_GGML_METAL_DEL_KERNEL(norm);
|
|
314
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
315
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
316
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
317
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
318
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
319
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
320
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
321
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
322
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
323
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
324
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
325
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
326
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
327
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
386
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
|
387
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
|
|
388
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
|
|
389
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
|
|
390
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
|
|
391
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
|
|
392
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
|
|
393
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
|
|
394
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
|
|
395
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
|
|
396
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
|
|
397
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
|
|
398
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
|
399
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
|
400
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
|
401
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
|
402
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
|
403
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
|
404
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
|
|
405
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
|
|
406
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
|
|
407
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
|
408
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
|
|
409
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
|
|
410
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
|
411
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
|
412
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
|
413
|
+
}
|
|
414
|
+
WSP_GGML_METAL_DEL_KERNEL(rope_f32);
|
|
415
|
+
WSP_GGML_METAL_DEL_KERNEL(rope_f16);
|
|
337
416
|
WSP_GGML_METAL_DEL_KERNEL(alibi_f32);
|
|
338
417
|
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
|
339
418
|
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
|
340
419
|
WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
|
420
|
+
WSP_GGML_METAL_DEL_KERNEL(concat);
|
|
421
|
+
WSP_GGML_METAL_DEL_KERNEL(sqr);
|
|
341
422
|
|
|
342
423
|
#undef WSP_GGML_METAL_DEL_KERNEL
|
|
343
424
|
|
|
@@ -348,7 +429,7 @@ void * wsp_ggml_metal_host_malloc(size_t n) {
|
|
|
348
429
|
void * data = NULL;
|
|
349
430
|
const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
|
|
350
431
|
if (result != 0) {
|
|
351
|
-
|
|
432
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
|
|
352
433
|
return NULL;
|
|
353
434
|
}
|
|
354
435
|
|
|
@@ -376,7 +457,7 @@ int * wsp_ggml_metal_get_concur_list(struct wsp_ggml_metal_context * ctx) {
|
|
|
376
457
|
// Metal buffer based on the host memory pointer
|
|
377
458
|
//
|
|
378
459
|
static id<MTLBuffer> wsp_ggml_metal_get_buffer(struct wsp_ggml_metal_context * ctx, struct wsp_ggml_tensor * t, size_t * offs) {
|
|
379
|
-
//
|
|
460
|
+
//WSP_GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
|
|
380
461
|
|
|
381
462
|
const int64_t tsize = wsp_ggml_nbytes(t);
|
|
382
463
|
|
|
@@ -384,17 +465,17 @@ static id<MTLBuffer> wsp_ggml_metal_get_buffer(struct wsp_ggml_metal_context * c
|
|
|
384
465
|
for (int i = 0; i < ctx->n_buffers; ++i) {
|
|
385
466
|
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
|
386
467
|
|
|
387
|
-
//
|
|
468
|
+
//WSP_GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
|
|
388
469
|
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
|
389
470
|
*offs = (size_t) ioffs;
|
|
390
471
|
|
|
391
|
-
//
|
|
472
|
+
//WSP_GGML_METAL_LOG_INFO("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
|
|
392
473
|
|
|
393
474
|
return ctx->buffers[i].metal;
|
|
394
475
|
}
|
|
395
476
|
}
|
|
396
477
|
|
|
397
|
-
|
|
478
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: buffer is nil\n", __func__);
|
|
398
479
|
|
|
399
480
|
return nil;
|
|
400
481
|
}
|
|
@@ -406,7 +487,7 @@ bool wsp_ggml_metal_add_buffer(
|
|
|
406
487
|
size_t size,
|
|
407
488
|
size_t max_size) {
|
|
408
489
|
if (ctx->n_buffers >= WSP_GGML_METAL_MAX_BUFFERS) {
|
|
409
|
-
|
|
490
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__);
|
|
410
491
|
return false;
|
|
411
492
|
}
|
|
412
493
|
|
|
@@ -416,7 +497,7 @@ bool wsp_ggml_metal_add_buffer(
|
|
|
416
497
|
const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
|
|
417
498
|
|
|
418
499
|
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
|
|
419
|
-
|
|
500
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
|
|
420
501
|
return false;
|
|
421
502
|
}
|
|
422
503
|
}
|
|
@@ -437,11 +518,11 @@ bool wsp_ggml_metal_add_buffer(
|
|
|
437
518
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
|
438
519
|
|
|
439
520
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
|
440
|
-
|
|
521
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
|
441
522
|
return false;
|
|
442
523
|
}
|
|
443
524
|
|
|
444
|
-
|
|
525
|
+
WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
|
445
526
|
|
|
446
527
|
++ctx->n_buffers;
|
|
447
528
|
} else {
|
|
@@ -461,13 +542,13 @@ bool wsp_ggml_metal_add_buffer(
|
|
|
461
542
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
|
462
543
|
|
|
463
544
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
|
464
|
-
|
|
545
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
|
465
546
|
return false;
|
|
466
547
|
}
|
|
467
548
|
|
|
468
|
-
|
|
549
|
+
WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
|
469
550
|
if (i + size_step < size) {
|
|
470
|
-
|
|
551
|
+
WSP_GGML_METAL_LOG_INFO("\n");
|
|
471
552
|
}
|
|
472
553
|
|
|
473
554
|
++ctx->n_buffers;
|
|
@@ -475,17 +556,17 @@ bool wsp_ggml_metal_add_buffer(
|
|
|
475
556
|
}
|
|
476
557
|
|
|
477
558
|
#if TARGET_OS_OSX
|
|
478
|
-
|
|
559
|
+
WSP_GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
|
479
560
|
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
|
480
561
|
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
|
481
562
|
|
|
482
563
|
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
|
483
|
-
|
|
564
|
+
WSP_GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
|
484
565
|
} else {
|
|
485
|
-
|
|
566
|
+
WSP_GGML_METAL_LOG_INFO("\n");
|
|
486
567
|
}
|
|
487
568
|
#else
|
|
488
|
-
|
|
569
|
+
WSP_GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
|
|
489
570
|
#endif
|
|
490
571
|
}
|
|
491
572
|
|
|
@@ -598,7 +679,7 @@ void wsp_ggml_metal_graph_find_concurrency(
|
|
|
598
679
|
}
|
|
599
680
|
|
|
600
681
|
if (ctx->concur_list_len > WSP_GGML_MAX_CONCUR) {
|
|
601
|
-
|
|
682
|
+
WSP_GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__);
|
|
602
683
|
}
|
|
603
684
|
}
|
|
604
685
|
|
|
@@ -652,12 +733,26 @@ void wsp_ggml_metal_graph_compute(
|
|
|
652
733
|
continue;
|
|
653
734
|
}
|
|
654
735
|
|
|
655
|
-
//
|
|
736
|
+
//WSP_GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, wsp_ggml_op_name(gf->nodes[i]->op));
|
|
656
737
|
|
|
657
738
|
struct wsp_ggml_tensor * src0 = gf->nodes[i]->src[0];
|
|
658
739
|
struct wsp_ggml_tensor * src1 = gf->nodes[i]->src[1];
|
|
659
740
|
struct wsp_ggml_tensor * dst = gf->nodes[i];
|
|
660
741
|
|
|
742
|
+
switch (dst->op) {
|
|
743
|
+
case WSP_GGML_OP_NONE:
|
|
744
|
+
case WSP_GGML_OP_RESHAPE:
|
|
745
|
+
case WSP_GGML_OP_VIEW:
|
|
746
|
+
case WSP_GGML_OP_TRANSPOSE:
|
|
747
|
+
case WSP_GGML_OP_PERMUTE:
|
|
748
|
+
{
|
|
749
|
+
// noop -> next node
|
|
750
|
+
} continue;
|
|
751
|
+
default:
|
|
752
|
+
{
|
|
753
|
+
} break;
|
|
754
|
+
}
|
|
755
|
+
|
|
661
756
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
|
662
757
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
|
663
758
|
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
|
@@ -696,53 +791,117 @@ void wsp_ggml_metal_graph_compute(
|
|
|
696
791
|
id<MTLBuffer> id_src1 = src1 ? wsp_ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
|
|
697
792
|
id<MTLBuffer> id_dst = dst ? wsp_ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
|
|
698
793
|
|
|
699
|
-
//
|
|
794
|
+
//WSP_GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, wsp_ggml_op_name(dst->op));
|
|
700
795
|
//if (src0) {
|
|
701
|
-
//
|
|
796
|
+
// WSP_GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src0t), ne00, ne01, ne02,
|
|
702
797
|
// wsp_ggml_is_contiguous(src0), src0->name);
|
|
703
798
|
//}
|
|
704
799
|
//if (src1) {
|
|
705
|
-
//
|
|
800
|
+
// WSP_GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src1t), ne10, ne11, ne12,
|
|
706
801
|
// wsp_ggml_is_contiguous(src1), src1->name);
|
|
707
802
|
//}
|
|
708
803
|
//if (dst) {
|
|
709
|
-
//
|
|
804
|
+
// WSP_GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, wsp_ggml_type_name(dstt), ne0, ne1, ne2,
|
|
710
805
|
// dst->name);
|
|
711
806
|
//}
|
|
712
807
|
|
|
713
808
|
switch (dst->op) {
|
|
714
|
-
case
|
|
715
|
-
case WSP_GGML_OP_RESHAPE:
|
|
716
|
-
case WSP_GGML_OP_VIEW:
|
|
717
|
-
case WSP_GGML_OP_TRANSPOSE:
|
|
718
|
-
case WSP_GGML_OP_PERMUTE:
|
|
809
|
+
case WSP_GGML_OP_CONCAT:
|
|
719
810
|
{
|
|
720
|
-
|
|
811
|
+
const int64_t nb = ne00;
|
|
812
|
+
|
|
813
|
+
[encoder setComputePipelineState:ctx->pipeline_concat];
|
|
814
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
815
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
816
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
817
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
818
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
819
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
820
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
|
821
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
|
822
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
|
823
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
|
824
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
|
825
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
|
826
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
|
827
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
|
828
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
|
829
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
|
830
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
|
831
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
|
832
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
|
833
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
|
834
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
|
835
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
|
836
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
|
837
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
|
838
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
|
839
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
|
840
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
|
841
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
|
842
|
+
|
|
843
|
+
const int nth = MIN(1024, ne0);
|
|
844
|
+
|
|
845
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
721
846
|
} break;
|
|
722
847
|
case WSP_GGML_OP_ADD:
|
|
723
848
|
{
|
|
724
849
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
725
850
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
|
|
726
851
|
|
|
727
|
-
|
|
728
|
-
WSP_GGML_ASSERT(ne00 % 4 == 0);
|
|
729
|
-
const int64_t nb = ne00/4;
|
|
852
|
+
bool bcast_row = false;
|
|
730
853
|
|
|
731
|
-
|
|
854
|
+
int64_t nb = ne00;
|
|
855
|
+
|
|
856
|
+
if (wsp_ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
|
|
732
857
|
// src1 is a row
|
|
733
858
|
WSP_GGML_ASSERT(ne11 == 1);
|
|
859
|
+
|
|
860
|
+
nb = ne00 / 4;
|
|
734
861
|
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
|
862
|
+
|
|
863
|
+
bcast_row = true;
|
|
735
864
|
} else {
|
|
736
865
|
[encoder setComputePipelineState:ctx->pipeline_add];
|
|
737
866
|
}
|
|
738
867
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
739
868
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
740
869
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
741
|
-
[encoder setBytes:&
|
|
742
|
-
|
|
743
|
-
|
|
870
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
871
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
872
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
873
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
|
874
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
|
875
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
|
876
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
|
877
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
|
878
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
|
879
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
|
880
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
|
881
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
|
882
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
|
883
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
|
884
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
|
885
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
|
886
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
|
887
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
|
888
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
|
889
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
|
890
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
|
891
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
|
892
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
|
893
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
|
894
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
|
895
|
+
|
|
896
|
+
if (bcast_row) {
|
|
897
|
+
const int64_t n = wsp_ggml_nelements(dst)/4;
|
|
898
|
+
|
|
899
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
900
|
+
} else {
|
|
901
|
+
const int nth = MIN(1024, ne0);
|
|
744
902
|
|
|
745
|
-
|
|
903
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
904
|
+
}
|
|
746
905
|
} break;
|
|
747
906
|
case WSP_GGML_OP_MUL:
|
|
748
907
|
{
|
|
@@ -775,13 +934,19 @@ void wsp_ggml_metal_graph_compute(
|
|
|
775
934
|
|
|
776
935
|
const float scale = *(const float *) src1->data;
|
|
777
936
|
|
|
778
|
-
|
|
937
|
+
int64_t n = wsp_ggml_nelements(dst);
|
|
938
|
+
|
|
939
|
+
if (n % 4 == 0) {
|
|
940
|
+
n /= 4;
|
|
941
|
+
[encoder setComputePipelineState:ctx->pipeline_scale_4];
|
|
942
|
+
} else {
|
|
943
|
+
[encoder setComputePipelineState:ctx->pipeline_scale];
|
|
944
|
+
}
|
|
945
|
+
|
|
779
946
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
780
947
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
781
948
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
|
782
949
|
|
|
783
|
-
const int64_t n = wsp_ggml_nelements(dst)/4;
|
|
784
|
-
|
|
785
950
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
786
951
|
} break;
|
|
787
952
|
case WSP_GGML_OP_UNARY:
|
|
@@ -792,9 +957,10 @@ void wsp_ggml_metal_graph_compute(
|
|
|
792
957
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
793
958
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
794
959
|
|
|
795
|
-
const int64_t n = wsp_ggml_nelements(dst)
|
|
960
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
961
|
+
WSP_GGML_ASSERT(n % 4 == 0);
|
|
796
962
|
|
|
797
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
963
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
798
964
|
} break;
|
|
799
965
|
case WSP_GGML_UNARY_OP_RELU:
|
|
800
966
|
{
|
|
@@ -812,23 +978,39 @@ void wsp_ggml_metal_graph_compute(
|
|
|
812
978
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
813
979
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
814
980
|
|
|
815
|
-
const int64_t n = wsp_ggml_nelements(dst)
|
|
981
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
982
|
+
WSP_GGML_ASSERT(n % 4 == 0);
|
|
816
983
|
|
|
817
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
984
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
818
985
|
} break;
|
|
819
986
|
default:
|
|
820
987
|
{
|
|
821
|
-
|
|
988
|
+
WSP_GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op));
|
|
822
989
|
WSP_GGML_ASSERT(false);
|
|
823
990
|
}
|
|
824
991
|
} break;
|
|
992
|
+
case WSP_GGML_OP_SQR:
|
|
993
|
+
{
|
|
994
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
995
|
+
|
|
996
|
+
[encoder setComputePipelineState:ctx->pipeline_sqr];
|
|
997
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
998
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
999
|
+
|
|
1000
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
1001
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1002
|
+
} break;
|
|
825
1003
|
case WSP_GGML_OP_SOFT_MAX:
|
|
826
1004
|
{
|
|
827
|
-
|
|
1005
|
+
int nth = 32; // SIMD width
|
|
828
1006
|
|
|
829
1007
|
if (ne00%4 == 0) {
|
|
830
1008
|
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
|
831
1009
|
} else {
|
|
1010
|
+
do {
|
|
1011
|
+
nth *= 2;
|
|
1012
|
+
} while (nth <= ne00 && nth <= 1024);
|
|
1013
|
+
nth /= 2;
|
|
832
1014
|
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
|
833
1015
|
}
|
|
834
1016
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
@@ -836,8 +1018,9 @@ void wsp_ggml_metal_graph_compute(
|
|
|
836
1018
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
837
1019
|
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
838
1020
|
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
1021
|
+
[encoder setThreadgroupMemoryLength:WSP_GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
|
|
839
1022
|
|
|
840
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01,
|
|
1023
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
841
1024
|
} break;
|
|
842
1025
|
case WSP_GGML_OP_DIAG_MASK_INF:
|
|
843
1026
|
{
|
|
@@ -863,26 +1046,53 @@ void wsp_ggml_metal_graph_compute(
|
|
|
863
1046
|
} break;
|
|
864
1047
|
case WSP_GGML_OP_MUL_MAT:
|
|
865
1048
|
{
|
|
866
|
-
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
|
|
867
|
-
|
|
868
1049
|
WSP_GGML_ASSERT(ne00 == ne10);
|
|
869
|
-
// WSP_GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
|
|
870
|
-
uint gqa = ne12/ne02;
|
|
871
1050
|
WSP_GGML_ASSERT(ne03 == ne13);
|
|
872
1051
|
|
|
1052
|
+
const uint gqa = ne12/ne02;
|
|
1053
|
+
|
|
1054
|
+
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
|
1055
|
+
// to the matrix-vector kernel
|
|
1056
|
+
int ne11_mm_min = 1;
|
|
1057
|
+
|
|
1058
|
+
#if 0
|
|
1059
|
+
// the numbers below are measured on M2 Ultra for 7B and 13B models
|
|
1060
|
+
// these numbers do not translate to other devices or model sizes
|
|
1061
|
+
// TODO: need to find a better approach
|
|
1062
|
+
if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
|
|
1063
|
+
switch (src0t) {
|
|
1064
|
+
case WSP_GGML_TYPE_F16: ne11_mm_min = 2; break;
|
|
1065
|
+
case WSP_GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
|
|
1066
|
+
case WSP_GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
|
|
1067
|
+
case WSP_GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
|
|
1068
|
+
case WSP_GGML_TYPE_Q4_0:
|
|
1069
|
+
case WSP_GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
|
|
1070
|
+
case WSP_GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
|
|
1071
|
+
case WSP_GGML_TYPE_Q5_0: // not tested yet
|
|
1072
|
+
case WSP_GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
|
|
1073
|
+
case WSP_GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
|
|
1074
|
+
case WSP_GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
|
|
1075
|
+
default: ne11_mm_min = 1; break;
|
|
1076
|
+
}
|
|
1077
|
+
}
|
|
1078
|
+
#endif
|
|
1079
|
+
|
|
873
1080
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
|
874
1081
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
875
|
-
if (
|
|
1082
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
|
1083
|
+
!wsp_ggml_is_transposed(src0) &&
|
|
876
1084
|
!wsp_ggml_is_transposed(src1) &&
|
|
877
1085
|
src1t == WSP_GGML_TYPE_F32 &&
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
ne11
|
|
1086
|
+
ne00 % 32 == 0 && ne00 >= 64 &&
|
|
1087
|
+
ne11 > ne11_mm_min) {
|
|
1088
|
+
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
881
1089
|
switch (src0->type) {
|
|
882
1090
|
case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
|
883
1091
|
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
|
884
1092
|
case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
|
885
1093
|
case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
|
1094
|
+
case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
|
|
1095
|
+
case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
|
|
886
1096
|
case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
|
|
887
1097
|
case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
|
|
888
1098
|
case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
|
|
@@ -906,17 +1116,18 @@ void wsp_ggml_metal_graph_compute(
|
|
|
906
1116
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
|
907
1117
|
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
|
|
908
1118
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
|
909
|
-
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63)
|
|
1119
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
910
1120
|
} else {
|
|
911
1121
|
int nth0 = 32;
|
|
912
1122
|
int nth1 = 1;
|
|
913
1123
|
int nrows = 1;
|
|
1124
|
+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
914
1125
|
|
|
915
1126
|
// use custom matrix x vector kernel
|
|
916
1127
|
switch (src0t) {
|
|
917
1128
|
case WSP_GGML_TYPE_F32:
|
|
918
1129
|
{
|
|
919
|
-
[encoder setComputePipelineState:ctx->
|
|
1130
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
|
|
920
1131
|
nrows = 4;
|
|
921
1132
|
} break;
|
|
922
1133
|
case WSP_GGML_TYPE_F16:
|
|
@@ -924,12 +1135,12 @@ void wsp_ggml_metal_graph_compute(
|
|
|
924
1135
|
nth0 = 32;
|
|
925
1136
|
nth1 = 1;
|
|
926
1137
|
if (ne11 * ne12 < 4) {
|
|
927
|
-
[encoder setComputePipelineState:ctx->
|
|
1138
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
|
|
928
1139
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
|
929
|
-
[encoder setComputePipelineState:ctx->
|
|
1140
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
|
|
930
1141
|
nrows = ne11;
|
|
931
1142
|
} else {
|
|
932
|
-
[encoder setComputePipelineState:ctx->
|
|
1143
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
|
|
933
1144
|
nrows = 4;
|
|
934
1145
|
}
|
|
935
1146
|
} break;
|
|
@@ -940,7 +1151,7 @@ void wsp_ggml_metal_graph_compute(
|
|
|
940
1151
|
|
|
941
1152
|
nth0 = 8;
|
|
942
1153
|
nth1 = 8;
|
|
943
|
-
[encoder setComputePipelineState:ctx->
|
|
1154
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
|
|
944
1155
|
} break;
|
|
945
1156
|
case WSP_GGML_TYPE_Q4_1:
|
|
946
1157
|
{
|
|
@@ -949,7 +1160,25 @@ void wsp_ggml_metal_graph_compute(
|
|
|
949
1160
|
|
|
950
1161
|
nth0 = 8;
|
|
951
1162
|
nth1 = 8;
|
|
952
|
-
[encoder setComputePipelineState:ctx->
|
|
1163
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
|
|
1164
|
+
} break;
|
|
1165
|
+
case WSP_GGML_TYPE_Q5_0:
|
|
1166
|
+
{
|
|
1167
|
+
WSP_GGML_ASSERT(ne02 == 1);
|
|
1168
|
+
WSP_GGML_ASSERT(ne12 == 1);
|
|
1169
|
+
|
|
1170
|
+
nth0 = 8;
|
|
1171
|
+
nth1 = 8;
|
|
1172
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
|
|
1173
|
+
} break;
|
|
1174
|
+
case WSP_GGML_TYPE_Q5_1:
|
|
1175
|
+
{
|
|
1176
|
+
WSP_GGML_ASSERT(ne02 == 1);
|
|
1177
|
+
WSP_GGML_ASSERT(ne12 == 1);
|
|
1178
|
+
|
|
1179
|
+
nth0 = 8;
|
|
1180
|
+
nth1 = 8;
|
|
1181
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
|
|
953
1182
|
} break;
|
|
954
1183
|
case WSP_GGML_TYPE_Q8_0:
|
|
955
1184
|
{
|
|
@@ -958,7 +1187,7 @@ void wsp_ggml_metal_graph_compute(
|
|
|
958
1187
|
|
|
959
1188
|
nth0 = 8;
|
|
960
1189
|
nth1 = 8;
|
|
961
|
-
[encoder setComputePipelineState:ctx->
|
|
1190
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
|
|
962
1191
|
} break;
|
|
963
1192
|
case WSP_GGML_TYPE_Q2_K:
|
|
964
1193
|
{
|
|
@@ -967,7 +1196,7 @@ void wsp_ggml_metal_graph_compute(
|
|
|
967
1196
|
|
|
968
1197
|
nth0 = 2;
|
|
969
1198
|
nth1 = 32;
|
|
970
|
-
[encoder setComputePipelineState:ctx->
|
|
1199
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
|
|
971
1200
|
} break;
|
|
972
1201
|
case WSP_GGML_TYPE_Q3_K:
|
|
973
1202
|
{
|
|
@@ -976,7 +1205,7 @@ void wsp_ggml_metal_graph_compute(
|
|
|
976
1205
|
|
|
977
1206
|
nth0 = 2;
|
|
978
1207
|
nth1 = 32;
|
|
979
|
-
[encoder setComputePipelineState:ctx->
|
|
1208
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
|
|
980
1209
|
} break;
|
|
981
1210
|
case WSP_GGML_TYPE_Q4_K:
|
|
982
1211
|
{
|
|
@@ -985,7 +1214,7 @@ void wsp_ggml_metal_graph_compute(
|
|
|
985
1214
|
|
|
986
1215
|
nth0 = 4; //1;
|
|
987
1216
|
nth1 = 8; //32;
|
|
988
|
-
[encoder setComputePipelineState:ctx->
|
|
1217
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
|
|
989
1218
|
} break;
|
|
990
1219
|
case WSP_GGML_TYPE_Q5_K:
|
|
991
1220
|
{
|
|
@@ -994,7 +1223,7 @@ void wsp_ggml_metal_graph_compute(
|
|
|
994
1223
|
|
|
995
1224
|
nth0 = 2;
|
|
996
1225
|
nth1 = 32;
|
|
997
|
-
[encoder setComputePipelineState:ctx->
|
|
1226
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
|
|
998
1227
|
} break;
|
|
999
1228
|
case WSP_GGML_TYPE_Q6_K:
|
|
1000
1229
|
{
|
|
@@ -1003,11 +1232,11 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1003
1232
|
|
|
1004
1233
|
nth0 = 2;
|
|
1005
1234
|
nth1 = 32;
|
|
1006
|
-
[encoder setComputePipelineState:ctx->
|
|
1235
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
|
|
1007
1236
|
} break;
|
|
1008
1237
|
default:
|
|
1009
1238
|
{
|
|
1010
|
-
|
|
1239
|
+
WSP_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
|
1011
1240
|
WSP_GGML_ASSERT(false && "not implemented");
|
|
1012
1241
|
}
|
|
1013
1242
|
};
|
|
@@ -1031,8 +1260,9 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1031
1260
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
|
1032
1261
|
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
|
1033
1262
|
|
|
1034
|
-
if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 ||
|
|
1035
|
-
src0t ==
|
|
1263
|
+
if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 ||
|
|
1264
|
+
src0t == WSP_GGML_TYPE_Q5_0 || src0t == WSP_GGML_TYPE_Q5_1 || src0t == WSP_GGML_TYPE_Q8_0 ||
|
|
1265
|
+
src0t == WSP_GGML_TYPE_Q2_K) { // || src0t == WSP_GGML_TYPE_Q4_K) {
|
|
1036
1266
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1037
1267
|
}
|
|
1038
1268
|
else if (src0t == WSP_GGML_TYPE_Q4_K) {
|
|
@@ -1063,6 +1293,8 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1063
1293
|
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
|
1064
1294
|
case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
|
1065
1295
|
case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
|
1296
|
+
case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
|
|
1297
|
+
case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
|
|
1066
1298
|
case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
|
|
1067
1299
|
case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
|
|
1068
1300
|
case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
|
|
@@ -1085,10 +1317,12 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1085
1317
|
} break;
|
|
1086
1318
|
case WSP_GGML_OP_RMS_NORM:
|
|
1087
1319
|
{
|
|
1320
|
+
WSP_GGML_ASSERT(ne00 % 4 == 0);
|
|
1321
|
+
|
|
1088
1322
|
float eps;
|
|
1089
1323
|
memcpy(&eps, dst->op_params, sizeof(float));
|
|
1090
1324
|
|
|
1091
|
-
const int nth = 512;
|
|
1325
|
+
const int nth = MIN(512, ne00);
|
|
1092
1326
|
|
|
1093
1327
|
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
|
1094
1328
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
@@ -1096,7 +1330,7 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1096
1330
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1097
1331
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
|
1098
1332
|
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
|
1099
|
-
[encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
|
|
1333
|
+
[encoder setThreadgroupMemoryLength:WSP_GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
|
|
1100
1334
|
|
|
1101
1335
|
const int64_t nrows = wsp_ggml_nrows(src0);
|
|
1102
1336
|
|
|
@@ -1107,7 +1341,7 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1107
1341
|
float eps;
|
|
1108
1342
|
memcpy(&eps, dst->op_params, sizeof(float));
|
|
1109
1343
|
|
|
1110
|
-
const int nth = 256;
|
|
1344
|
+
const int nth = MIN(256, ne00);
|
|
1111
1345
|
|
|
1112
1346
|
[encoder setComputePipelineState:ctx->pipeline_norm];
|
|
1113
1347
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
@@ -1115,7 +1349,7 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1115
1349
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1116
1350
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
|
1117
1351
|
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
|
1118
|
-
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
|
1352
|
+
[encoder setThreadgroupMemoryLength:WSP_GGML_PAD(nth*sizeof(float), 16) atIndex:0];
|
|
1119
1353
|
|
|
1120
1354
|
const int64_t nrows = wsp_ggml_nrows(src0);
|
|
1121
1355
|
|
|
@@ -1125,17 +1359,16 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1125
1359
|
{
|
|
1126
1360
|
WSP_GGML_ASSERT((src0t == WSP_GGML_TYPE_F32));
|
|
1127
1361
|
|
|
1128
|
-
const int
|
|
1362
|
+
const int nth = MIN(1024, ne00);
|
|
1363
|
+
|
|
1364
|
+
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
1129
1365
|
const int n_head = ((int32_t *) dst->op_params)[1];
|
|
1130
1366
|
float max_bias;
|
|
1131
1367
|
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
|
1132
1368
|
|
|
1133
|
-
if (__builtin_popcount(n_head) != 1) {
|
|
1134
|
-
WSP_GGML_ASSERT(false && "only power-of-two n_head implemented");
|
|
1135
|
-
}
|
|
1136
|
-
|
|
1137
1369
|
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
|
1138
1370
|
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
|
1371
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
|
1139
1372
|
|
|
1140
1373
|
[encoder setComputePipelineState:ctx->pipeline_alibi_f32];
|
|
1141
1374
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
@@ -1156,55 +1389,74 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1156
1389
|
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
1157
1390
|
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
1158
1391
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
1159
|
-
[encoder setBytes:&m0
|
|
1160
|
-
|
|
1161
|
-
|
|
1392
|
+
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
|
1393
|
+
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
|
|
1394
|
+
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
|
|
1162
1395
|
|
|
1163
1396
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1164
1397
|
} break;
|
|
1165
1398
|
case WSP_GGML_OP_ROPE:
|
|
1166
1399
|
{
|
|
1167
|
-
|
|
1168
|
-
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
1169
|
-
const int mode = ((int32_t *) dst->op_params)[2];
|
|
1400
|
+
WSP_GGML_ASSERT(ne10 == ne02);
|
|
1170
1401
|
|
|
1171
|
-
|
|
1172
|
-
float freq_scale;
|
|
1173
|
-
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
|
1174
|
-
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
1402
|
+
const int nth = MIN(1024, ne00);
|
|
1175
1403
|
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
[
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
|
1189
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
|
1190
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
|
1191
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
|
1192
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
1193
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
1194
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
1195
|
-
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
|
|
1196
|
-
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
|
|
1197
|
-
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
|
|
1198
|
-
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
|
|
1199
|
-
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
|
|
1404
|
+
const int n_past = ((int32_t *) dst->op_params)[0];
|
|
1405
|
+
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
1406
|
+
const int mode = ((int32_t *) dst->op_params)[2];
|
|
1407
|
+
const int n_orig_ctx = ((int32_t *) dst->op_params)[3];
|
|
1408
|
+
|
|
1409
|
+
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
1410
|
+
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
1411
|
+
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
1412
|
+
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
1413
|
+
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
|
1414
|
+
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
|
1415
|
+
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
1200
1416
|
|
|
1201
|
-
|
|
1417
|
+
switch (src0->type) {
|
|
1418
|
+
case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
|
|
1419
|
+
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
|
|
1420
|
+
default: WSP_GGML_ASSERT(false);
|
|
1421
|
+
};
|
|
1422
|
+
|
|
1423
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1424
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1425
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1426
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
|
1427
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
|
|
1428
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
|
|
1429
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
|
|
1430
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
|
|
1431
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
|
1432
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
|
1433
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
|
1434
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
|
|
1435
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
|
|
1436
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
|
|
1437
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
|
|
1438
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
|
|
1439
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
|
|
1440
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
|
|
1441
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
|
|
1442
|
+
[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
|
|
1443
|
+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
|
|
1444
|
+
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
|
|
1445
|
+
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
|
|
1446
|
+
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
|
1447
|
+
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
|
1448
|
+
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
|
1449
|
+
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
|
|
1450
|
+
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
|
|
1451
|
+
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
|
|
1452
|
+
|
|
1453
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1202
1454
|
} break;
|
|
1203
1455
|
case WSP_GGML_OP_DUP:
|
|
1204
1456
|
case WSP_GGML_OP_CPY:
|
|
1205
1457
|
case WSP_GGML_OP_CONT:
|
|
1206
1458
|
{
|
|
1207
|
-
const int nth =
|
|
1459
|
+
const int nth = MIN(1024, ne00);
|
|
1208
1460
|
|
|
1209
1461
|
switch (src0t) {
|
|
1210
1462
|
case WSP_GGML_TYPE_F32:
|
|
@@ -1249,7 +1501,7 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1249
1501
|
} break;
|
|
1250
1502
|
default:
|
|
1251
1503
|
{
|
|
1252
|
-
|
|
1504
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op));
|
|
1253
1505
|
WSP_GGML_ASSERT(false);
|
|
1254
1506
|
}
|
|
1255
1507
|
}
|
|
@@ -1274,10 +1526,147 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1274
1526
|
|
|
1275
1527
|
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
|
|
1276
1528
|
if (status != MTLCommandBufferStatusCompleted) {
|
|
1277
|
-
|
|
1529
|
+
WSP_GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
|
1278
1530
|
WSP_GGML_ASSERT(false);
|
|
1279
1531
|
}
|
|
1280
1532
|
}
|
|
1281
1533
|
|
|
1282
1534
|
}
|
|
1283
1535
|
}
|
|
1536
|
+
|
|
1537
|
+
////////////////////////////////////////////////////////////////////////////////
|
|
1538
|
+
|
|
1539
|
+
// backend interface
|
|
1540
|
+
|
|
1541
|
+
static const char * wsp_ggml_backend_metal_name(wsp_ggml_backend_t backend) {
|
|
1542
|
+
return "Metal";
|
|
1543
|
+
|
|
1544
|
+
UNUSED(backend);
|
|
1545
|
+
}
|
|
1546
|
+
|
|
1547
|
+
static void wsp_ggml_backend_metal_free(wsp_ggml_backend_t backend) {
|
|
1548
|
+
struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
|
|
1549
|
+
wsp_ggml_metal_free(ctx);
|
|
1550
|
+
free(backend);
|
|
1551
|
+
}
|
|
1552
|
+
|
|
1553
|
+
static void * wsp_ggml_backend_metal_buffer_get_base(wsp_ggml_backend_buffer_t buffer) {
|
|
1554
|
+
return (void *)buffer->context;
|
|
1555
|
+
}
|
|
1556
|
+
|
|
1557
|
+
static void wsp_ggml_backend_metal_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) {
|
|
1558
|
+
free(buffer->context);
|
|
1559
|
+
UNUSED(buffer);
|
|
1560
|
+
}
|
|
1561
|
+
|
|
1562
|
+
static struct wsp_ggml_backend_buffer_i metal_backend_buffer_i = {
|
|
1563
|
+
/* .free_buffer = */ wsp_ggml_backend_metal_buffer_free_buffer,
|
|
1564
|
+
/* .get_base = */ wsp_ggml_backend_metal_buffer_get_base,
|
|
1565
|
+
/* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes
|
|
1566
|
+
/* .init_tensor = */ NULL, // no initialization required
|
|
1567
|
+
/* .free_tensor = */ NULL, // no cleanup required
|
|
1568
|
+
};
|
|
1569
|
+
|
|
1570
|
+
static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_alloc_buffer(wsp_ggml_backend_t backend, size_t size) {
|
|
1571
|
+
struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
|
|
1572
|
+
|
|
1573
|
+
void * data = wsp_ggml_metal_host_malloc(size);
|
|
1574
|
+
|
|
1575
|
+
// TODO: set proper name of the buffers
|
|
1576
|
+
wsp_ggml_metal_add_buffer(ctx, "backend", data, size, 0);
|
|
1577
|
+
|
|
1578
|
+
return wsp_ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size);
|
|
1579
|
+
}
|
|
1580
|
+
|
|
1581
|
+
static size_t wsp_ggml_backend_metal_get_alignment(wsp_ggml_backend_t backend) {
|
|
1582
|
+
return 32;
|
|
1583
|
+
UNUSED(backend);
|
|
1584
|
+
}
|
|
1585
|
+
|
|
1586
|
+
static void wsp_ggml_backend_metal_set_tensor_async(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
|
1587
|
+
WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
|
|
1588
|
+
WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
|
1589
|
+
|
|
1590
|
+
memcpy((char *)tensor->data + offset, data, size);
|
|
1591
|
+
|
|
1592
|
+
UNUSED(backend);
|
|
1593
|
+
}
|
|
1594
|
+
|
|
1595
|
+
static void wsp_ggml_backend_metal_get_tensor_async(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
|
1596
|
+
WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
|
|
1597
|
+
WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
|
1598
|
+
|
|
1599
|
+
memcpy(data, (const char *)tensor->data + offset, size);
|
|
1600
|
+
|
|
1601
|
+
UNUSED(backend);
|
|
1602
|
+
}
|
|
1603
|
+
|
|
1604
|
+
static void wsp_ggml_backend_metal_synchronize(wsp_ggml_backend_t backend) {
|
|
1605
|
+
UNUSED(backend);
|
|
1606
|
+
}
|
|
1607
|
+
|
|
1608
|
+
static void wsp_ggml_backend_metal_cpy_tensor_from(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
|
|
1609
|
+
wsp_ggml_backend_tensor_get(src, dst->data, 0, wsp_ggml_nbytes(src));
|
|
1610
|
+
|
|
1611
|
+
UNUSED(backend);
|
|
1612
|
+
}
|
|
1613
|
+
|
|
1614
|
+
static void wsp_ggml_backend_metal_cpy_tensor_to(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
|
|
1615
|
+
wsp_ggml_backend_tensor_set_async(dst, src->data, 0, wsp_ggml_nbytes(src));
|
|
1616
|
+
|
|
1617
|
+
UNUSED(backend);
|
|
1618
|
+
}
|
|
1619
|
+
|
|
1620
|
+
static void wsp_ggml_backend_metal_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) {
|
|
1621
|
+
struct wsp_ggml_metal_context * metal_ctx = (struct wsp_ggml_metal_context *)backend->context;
|
|
1622
|
+
|
|
1623
|
+
wsp_ggml_metal_graph_compute(metal_ctx, cgraph);
|
|
1624
|
+
}
|
|
1625
|
+
|
|
1626
|
+
static bool wsp_ggml_backend_metal_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) {
|
|
1627
|
+
return true;
|
|
1628
|
+
UNUSED(backend);
|
|
1629
|
+
UNUSED(op);
|
|
1630
|
+
}
|
|
1631
|
+
|
|
1632
|
+
static struct wsp_ggml_backend_i metal_backend_i = {
|
|
1633
|
+
/* .get_name = */ wsp_ggml_backend_metal_name,
|
|
1634
|
+
/* .free = */ wsp_ggml_backend_metal_free,
|
|
1635
|
+
/* .alloc_buffer = */ wsp_ggml_backend_metal_alloc_buffer,
|
|
1636
|
+
/* .get_alignment = */ wsp_ggml_backend_metal_get_alignment,
|
|
1637
|
+
/* .set_tensor_async = */ wsp_ggml_backend_metal_set_tensor_async,
|
|
1638
|
+
/* .get_tensor_async = */ wsp_ggml_backend_metal_get_tensor_async,
|
|
1639
|
+
/* .synchronize = */ wsp_ggml_backend_metal_synchronize,
|
|
1640
|
+
/* .cpy_tensor_from = */ wsp_ggml_backend_metal_cpy_tensor_from,
|
|
1641
|
+
/* .cpy_tensor_to = */ wsp_ggml_backend_metal_cpy_tensor_to,
|
|
1642
|
+
/* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
|
|
1643
|
+
/* .graph_plan_free = */ NULL,
|
|
1644
|
+
/* .graph_plan_compute = */ NULL,
|
|
1645
|
+
/* .graph_compute = */ wsp_ggml_backend_metal_graph_compute,
|
|
1646
|
+
/* .supports_op = */ wsp_ggml_backend_metal_supports_op,
|
|
1647
|
+
};
|
|
1648
|
+
|
|
1649
|
+
wsp_ggml_backend_t wsp_ggml_backend_metal_init(void) {
|
|
1650
|
+
struct wsp_ggml_metal_context * ctx = malloc(sizeof(struct wsp_ggml_metal_context));
|
|
1651
|
+
|
|
1652
|
+
ctx = wsp_ggml_metal_init(WSP_GGML_DEFAULT_N_THREADS);
|
|
1653
|
+
|
|
1654
|
+
wsp_ggml_backend_t metal_backend = malloc(sizeof(struct wsp_ggml_backend));
|
|
1655
|
+
|
|
1656
|
+
*metal_backend = (struct wsp_ggml_backend) {
|
|
1657
|
+
/* .interface = */ metal_backend_i,
|
|
1658
|
+
/* .context = */ ctx,
|
|
1659
|
+
};
|
|
1660
|
+
|
|
1661
|
+
return metal_backend;
|
|
1662
|
+
}
|
|
1663
|
+
|
|
1664
|
+
bool wsp_ggml_backend_is_metal(wsp_ggml_backend_t backend) {
|
|
1665
|
+
return backend->iface.get_name == wsp_ggml_backend_metal_name;
|
|
1666
|
+
}
|
|
1667
|
+
|
|
1668
|
+
void wsp_ggml_backend_metal_set_n_cb(wsp_ggml_backend_t backend, int n_cb) {
|
|
1669
|
+
struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
|
|
1670
|
+
|
|
1671
|
+
wsp_ggml_metal_set_n_cb(ctx, n_cb);
|
|
1672
|
+
}
|