whisper.rn 0.4.0-rc.3 → 0.4.0-rc.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +6 -6
- package/android/build.gradle +4 -0
- package/android/src/main/CMakeLists.txt +7 -0
- package/android/src/main/java/com/rnwhisper/AudioUtils.java +0 -80
- package/android/src/main/java/com/rnwhisper/RNWhisper.java +6 -1
- package/android/src/main/java/com/rnwhisper/WhisperContext.java +53 -135
- package/android/src/main/jni-utils.h +76 -0
- package/android/src/main/jni.cpp +188 -109
- package/cpp/README.md +1 -1
- package/cpp/coreml/whisper-encoder-impl.h +1 -1
- package/cpp/coreml/whisper-encoder.h +4 -0
- package/cpp/coreml/whisper-encoder.mm +4 -2
- package/cpp/ggml-alloc.c +451 -282
- package/cpp/ggml-alloc.h +74 -8
- package/cpp/ggml-backend-impl.h +112 -0
- package/cpp/ggml-backend.c +1357 -0
- package/cpp/ggml-backend.h +181 -0
- package/cpp/ggml-impl.h +243 -0
- package/cpp/{ggml-metal.metal → ggml-metal-whisper.metal} +1556 -329
- package/cpp/ggml-metal.h +28 -1
- package/cpp/ggml-metal.m +1128 -308
- package/cpp/ggml-quants.c +7382 -0
- package/cpp/ggml-quants.h +224 -0
- package/cpp/ggml.c +3848 -5245
- package/cpp/ggml.h +353 -155
- package/cpp/rn-audioutils.cpp +68 -0
- package/cpp/rn-audioutils.h +14 -0
- package/cpp/rn-whisper-log.h +11 -0
- package/cpp/rn-whisper.cpp +141 -59
- package/cpp/rn-whisper.h +47 -15
- package/cpp/whisper.cpp +1750 -964
- package/cpp/whisper.h +97 -15
- package/ios/RNWhisper.mm +15 -9
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/contents.xcworkspacedata +4 -0
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist +8 -0
- package/ios/RNWhisper.xcodeproj/project.xcworkspace/xcuserdata/jhen.xcuserdatad/UserInterfaceState.xcuserstate +0 -0
- package/ios/RNWhisper.xcodeproj/xcuserdata/jhen.xcuserdatad/xcschemes/xcschememanagement.plist +19 -0
- package/ios/RNWhisperAudioUtils.h +0 -2
- package/ios/RNWhisperAudioUtils.m +0 -56
- package/ios/RNWhisperContext.h +8 -12
- package/ios/RNWhisperContext.mm +132 -138
- package/jest/mock.js +1 -1
- package/lib/commonjs/NativeRNWhisper.js.map +1 -1
- package/lib/commonjs/index.js +28 -9
- package/lib/commonjs/index.js.map +1 -1
- package/lib/commonjs/version.json +1 -1
- package/lib/module/NativeRNWhisper.js.map +1 -1
- package/lib/module/index.js +28 -9
- package/lib/module/index.js.map +1 -1
- package/lib/module/version.json +1 -1
- package/lib/typescript/NativeRNWhisper.d.ts +7 -1
- package/lib/typescript/NativeRNWhisper.d.ts.map +1 -1
- package/lib/typescript/index.d.ts +7 -2
- package/lib/typescript/index.d.ts.map +1 -1
- package/package.json +6 -5
- package/src/NativeRNWhisper.ts +8 -1
- package/src/index.ts +29 -17
- package/src/version.json +1 -1
- package/whisper-rn.podspec +1 -2
package/cpp/ggml-metal.m
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
#import "ggml-metal.h"
|
|
2
2
|
|
|
3
|
+
#import "ggml-backend-impl.h"
|
|
3
4
|
#import "ggml.h"
|
|
4
5
|
|
|
5
6
|
#import <Foundation/Foundation.h>
|
|
@@ -11,16 +12,19 @@
|
|
|
11
12
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
|
12
13
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
|
13
14
|
|
|
14
|
-
// TODO: temporary - reuse llama.cpp logging
|
|
15
15
|
#ifdef WSP_GGML_METAL_NDEBUG
|
|
16
|
-
#define
|
|
16
|
+
#define WSP_GGML_METAL_LOG_INFO(...)
|
|
17
|
+
#define WSP_GGML_METAL_LOG_WARN(...)
|
|
18
|
+
#define WSP_GGML_METAL_LOG_ERROR(...)
|
|
17
19
|
#else
|
|
18
|
-
#define
|
|
20
|
+
#define WSP_GGML_METAL_LOG_INFO(...) wsp_ggml_metal_log(WSP_GGML_LOG_LEVEL_INFO, __VA_ARGS__)
|
|
21
|
+
#define WSP_GGML_METAL_LOG_WARN(...) wsp_ggml_metal_log(WSP_GGML_LOG_LEVEL_WARN, __VA_ARGS__)
|
|
22
|
+
#define WSP_GGML_METAL_LOG_ERROR(...) wsp_ggml_metal_log(WSP_GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
|
|
19
23
|
#endif
|
|
20
24
|
|
|
21
25
|
#define UNUSED(x) (void)(x)
|
|
22
26
|
|
|
23
|
-
#define WSP_GGML_MAX_CONCUR (2*
|
|
27
|
+
#define WSP_GGML_MAX_CONCUR (2*WSP_GGML_DEFAULT_GRAPH_SIZE)
|
|
24
28
|
|
|
25
29
|
struct wsp_ggml_metal_buffer {
|
|
26
30
|
const char * name;
|
|
@@ -58,7 +62,10 @@ struct wsp_ggml_metal_context {
|
|
|
58
62
|
WSP_GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
|
|
59
63
|
WSP_GGML_METAL_DECL_KERNEL(mul);
|
|
60
64
|
WSP_GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
|
65
|
+
WSP_GGML_METAL_DECL_KERNEL(div);
|
|
66
|
+
WSP_GGML_METAL_DECL_KERNEL(div_row);
|
|
61
67
|
WSP_GGML_METAL_DECL_KERNEL(scale);
|
|
68
|
+
WSP_GGML_METAL_DECL_KERNEL(scale_4);
|
|
62
69
|
WSP_GGML_METAL_DECL_KERNEL(silu);
|
|
63
70
|
WSP_GGML_METAL_DECL_KERNEL(relu);
|
|
64
71
|
WSP_GGML_METAL_DECL_KERNEL(gelu);
|
|
@@ -70,6 +77,8 @@ struct wsp_ggml_metal_context {
|
|
|
70
77
|
WSP_GGML_METAL_DECL_KERNEL(get_rows_f16);
|
|
71
78
|
WSP_GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
|
72
79
|
WSP_GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
|
80
|
+
WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_0);
|
|
81
|
+
WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_1);
|
|
73
82
|
WSP_GGML_METAL_DECL_KERNEL(get_rows_q8_0);
|
|
74
83
|
WSP_GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
|
75
84
|
WSP_GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
|
@@ -78,33 +87,62 @@ struct wsp_ggml_metal_context {
|
|
|
78
87
|
WSP_GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
|
79
88
|
WSP_GGML_METAL_DECL_KERNEL(rms_norm);
|
|
80
89
|
WSP_GGML_METAL_DECL_KERNEL(norm);
|
|
81
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
82
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
83
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
84
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
85
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
86
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
87
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
88
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
89
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
90
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
91
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
92
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
90
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
|
91
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
|
|
92
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
|
|
93
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
|
|
94
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
|
|
95
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
|
|
96
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
|
|
97
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
|
|
98
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
|
|
99
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
|
|
100
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
|
|
101
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
|
|
102
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
|
|
103
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
|
104
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
|
93
105
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
|
94
106
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
|
95
107
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
|
96
108
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
|
109
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
|
|
110
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
|
|
97
111
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
|
|
98
112
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
|
|
99
113
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
|
|
100
114
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
|
101
115
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
|
102
116
|
WSP_GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
|
103
|
-
WSP_GGML_METAL_DECL_KERNEL(
|
|
117
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
|
|
118
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
|
|
119
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
|
|
120
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
|
|
121
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
|
|
122
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
|
|
123
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
|
|
124
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
|
|
125
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
|
|
126
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
|
|
127
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
|
|
128
|
+
WSP_GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
|
|
129
|
+
WSP_GGML_METAL_DECL_KERNEL(rope_f32);
|
|
130
|
+
WSP_GGML_METAL_DECL_KERNEL(rope_f16);
|
|
104
131
|
WSP_GGML_METAL_DECL_KERNEL(alibi_f32);
|
|
132
|
+
WSP_GGML_METAL_DECL_KERNEL(im2col_f16);
|
|
133
|
+
WSP_GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
|
|
134
|
+
WSP_GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
|
|
105
135
|
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
|
106
136
|
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
|
137
|
+
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
|
|
138
|
+
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
|
|
139
|
+
WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
|
|
140
|
+
//WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
|
|
141
|
+
//WSP_GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
|
|
107
142
|
WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
|
143
|
+
WSP_GGML_METAL_DECL_KERNEL(concat);
|
|
144
|
+
WSP_GGML_METAL_DECL_KERNEL(sqr);
|
|
145
|
+
WSP_GGML_METAL_DECL_KERNEL(sum_rows);
|
|
108
146
|
|
|
109
147
|
#undef WSP_GGML_METAL_DECL_KERNEL
|
|
110
148
|
};
|
|
@@ -112,18 +150,48 @@ struct wsp_ggml_metal_context {
|
|
|
112
150
|
// MSL code
|
|
113
151
|
// TODO: move the contents here when ready
|
|
114
152
|
// for now it is easier to work in a separate file
|
|
115
|
-
static NSString * const msl_library_source = @"see metal.metal";
|
|
153
|
+
//static NSString * const msl_library_source = @"see metal.metal";
|
|
116
154
|
|
|
117
155
|
// Here to assist with NSBundle Path Hack
|
|
118
|
-
@interface
|
|
156
|
+
@interface WSPGGMLMetalClass : NSObject
|
|
119
157
|
@end
|
|
120
|
-
@implementation
|
|
158
|
+
@implementation WSPGGMLMetalClass
|
|
121
159
|
@end
|
|
122
160
|
|
|
161
|
+
wsp_ggml_log_callback wsp_ggml_metal_log_callback = NULL;
|
|
162
|
+
void * wsp_ggml_metal_log_user_data = NULL;
|
|
163
|
+
|
|
164
|
+
void wsp_ggml_metal_log_set_callback(wsp_ggml_log_callback log_callback, void * user_data) {
|
|
165
|
+
wsp_ggml_metal_log_callback = log_callback;
|
|
166
|
+
wsp_ggml_metal_log_user_data = user_data;
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
WSP_GGML_ATTRIBUTE_FORMAT(2, 3)
|
|
170
|
+
static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char * format, ...){
|
|
171
|
+
if (wsp_ggml_metal_log_callback != NULL) {
|
|
172
|
+
va_list args;
|
|
173
|
+
va_start(args, format);
|
|
174
|
+
char buffer[128];
|
|
175
|
+
int len = vsnprintf(buffer, 128, format, args);
|
|
176
|
+
if (len < 128) {
|
|
177
|
+
wsp_ggml_metal_log_callback(level, buffer, wsp_ggml_metal_log_user_data);
|
|
178
|
+
} else {
|
|
179
|
+
char* buffer2 = malloc(len+1);
|
|
180
|
+
va_end(args);
|
|
181
|
+
va_start(args, format);
|
|
182
|
+
vsnprintf(buffer2, len+1, format, args);
|
|
183
|
+
buffer2[len] = 0;
|
|
184
|
+
wsp_ggml_metal_log_callback(level, buffer2, wsp_ggml_metal_log_user_data);
|
|
185
|
+
free(buffer2);
|
|
186
|
+
}
|
|
187
|
+
va_end(args);
|
|
188
|
+
}
|
|
189
|
+
}
|
|
190
|
+
|
|
123
191
|
struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
124
|
-
|
|
192
|
+
WSP_GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
|
|
125
193
|
|
|
126
|
-
id
|
|
194
|
+
id<MTLDevice> device;
|
|
127
195
|
NSString * s;
|
|
128
196
|
|
|
129
197
|
#if TARGET_OS_OSX
|
|
@@ -131,17 +199,17 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
131
199
|
NSArray * devices = MTLCopyAllDevices();
|
|
132
200
|
for (device in devices) {
|
|
133
201
|
s = [device name];
|
|
134
|
-
|
|
202
|
+
WSP_GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]);
|
|
135
203
|
}
|
|
136
204
|
#endif
|
|
137
205
|
|
|
138
206
|
// Pick and show default Metal device
|
|
139
207
|
device = MTLCreateSystemDefaultDevice();
|
|
140
208
|
s = [device name];
|
|
141
|
-
|
|
209
|
+
WSP_GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]);
|
|
142
210
|
|
|
143
211
|
// Configure context
|
|
144
|
-
struct wsp_ggml_metal_context * ctx =
|
|
212
|
+
struct wsp_ggml_metal_context * ctx = calloc(1, sizeof(struct wsp_ggml_metal_context));
|
|
145
213
|
ctx->device = device;
|
|
146
214
|
ctx->n_cb = MIN(n_cb, WSP_GGML_METAL_MAX_BUFFERS);
|
|
147
215
|
ctx->queue = [ctx->device newCommandQueue];
|
|
@@ -150,68 +218,95 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
150
218
|
|
|
151
219
|
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
|
152
220
|
|
|
153
|
-
|
|
154
|
-
// load the default.metallib file
|
|
221
|
+
// load library
|
|
155
222
|
{
|
|
223
|
+
NSBundle * bundle = nil;
|
|
224
|
+
#ifdef SWIFT_PACKAGE
|
|
225
|
+
bundle = SWIFTPM_MODULE_BUNDLE;
|
|
226
|
+
#else
|
|
227
|
+
bundle = [NSBundle bundleForClass:[WSPGGMLMetalClass class]];
|
|
228
|
+
#endif
|
|
156
229
|
NSError * error = nil;
|
|
230
|
+
NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"];
|
|
231
|
+
if (libPath != nil) {
|
|
232
|
+
NSURL * libURL = [NSURL fileURLWithPath:libPath];
|
|
233
|
+
WSP_GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
|
|
234
|
+
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
|
235
|
+
} else {
|
|
236
|
+
WSP_GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
|
|
157
237
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
|
|
161
|
-
NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
|
|
162
|
-
NSURL * libURL = [NSURL fileURLWithPath:libPath];
|
|
163
|
-
|
|
164
|
-
// Load the metallib file into a Metal library
|
|
165
|
-
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
|
|
238
|
+
NSString * sourcePath;
|
|
239
|
+
NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"WSP_GGML_METAL_PATH_RESOURCES"];
|
|
166
240
|
|
|
167
|
-
|
|
168
|
-
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
|
169
|
-
return NULL;
|
|
170
|
-
}
|
|
171
|
-
}
|
|
172
|
-
#else
|
|
173
|
-
UNUSED(msl_library_source);
|
|
241
|
+
WSP_GGML_METAL_LOG_INFO("%s: WSP_GGML_METAL_PATH_RESOURCES = %s\n", __func__, ggmlMetalPathResources ? [ggmlMetalPathResources UTF8String] : "nil");
|
|
174
242
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
243
|
+
if (ggmlMetalPathResources) {
|
|
244
|
+
sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
|
|
245
|
+
} else {
|
|
246
|
+
sourcePath = [bundle pathForResource:@"ggml-metal-whisper" ofType:@"metal"];
|
|
247
|
+
}
|
|
248
|
+
if (sourcePath == nil) {
|
|
249
|
+
WSP_GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
|
|
250
|
+
sourcePath = @"ggml-metal.metal";
|
|
251
|
+
}
|
|
252
|
+
WSP_GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [sourcePath UTF8String]);
|
|
253
|
+
NSString * src = [NSString stringWithContentsOfFile:sourcePath encoding:NSUTF8StringEncoding error:&error];
|
|
254
|
+
if (error) {
|
|
255
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
|
256
|
+
return NULL;
|
|
257
|
+
}
|
|
178
258
|
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
259
|
+
MTLCompileOptions* options = nil;
|
|
260
|
+
#ifdef WSP_GGML_QKK_64
|
|
261
|
+
options = [MTLCompileOptions new];
|
|
262
|
+
options.preprocessorMacros = @{ @"QK_K" : @(64) };
|
|
263
|
+
#endif
|
|
264
|
+
ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
|
|
265
|
+
}
|
|
183
266
|
|
|
184
|
-
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
|
185
267
|
if (error) {
|
|
186
|
-
|
|
268
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
|
187
269
|
return NULL;
|
|
188
270
|
}
|
|
271
|
+
}
|
|
189
272
|
|
|
190
|
-
#
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
273
|
+
#if TARGET_OS_OSX
|
|
274
|
+
// print MTL GPU family:
|
|
275
|
+
WSP_GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
|
|
276
|
+
|
|
277
|
+
// determine max supported GPU family
|
|
278
|
+
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
|
279
|
+
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
|
280
|
+
for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
|
|
281
|
+
if ([ctx->device supportsFamily:i]) {
|
|
282
|
+
WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
|
|
283
|
+
break;
|
|
200
284
|
}
|
|
201
285
|
}
|
|
286
|
+
|
|
287
|
+
WSP_GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
|
288
|
+
WSP_GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
|
|
289
|
+
if (ctx->device.maxTransferRate != 0) {
|
|
290
|
+
WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
|
|
291
|
+
} else {
|
|
292
|
+
WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__);
|
|
293
|
+
}
|
|
202
294
|
#endif
|
|
203
295
|
|
|
204
296
|
// load kernels
|
|
205
297
|
{
|
|
206
298
|
NSError * error = nil;
|
|
299
|
+
|
|
300
|
+
/*
|
|
301
|
+
WSP_GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
|
|
302
|
+
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
|
|
303
|
+
(int) ctx->pipeline_##name.threadExecutionWidth); \
|
|
304
|
+
*/
|
|
207
305
|
#define WSP_GGML_METAL_ADD_KERNEL(name) \
|
|
208
306
|
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
|
209
307
|
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
|
210
|
-
metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (__bridge void *) ctx->pipeline_##name, \
|
|
211
|
-
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
|
|
212
|
-
(int) ctx->pipeline_##name.threadExecutionWidth); \
|
|
213
308
|
if (error) { \
|
|
214
|
-
|
|
309
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
|
215
310
|
return NULL; \
|
|
216
311
|
}
|
|
217
312
|
|
|
@@ -219,7 +314,10 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
219
314
|
WSP_GGML_METAL_ADD_KERNEL(add_row);
|
|
220
315
|
WSP_GGML_METAL_ADD_KERNEL(mul);
|
|
221
316
|
WSP_GGML_METAL_ADD_KERNEL(mul_row);
|
|
317
|
+
WSP_GGML_METAL_ADD_KERNEL(div);
|
|
318
|
+
WSP_GGML_METAL_ADD_KERNEL(div_row);
|
|
222
319
|
WSP_GGML_METAL_ADD_KERNEL(scale);
|
|
320
|
+
WSP_GGML_METAL_ADD_KERNEL(scale_4);
|
|
223
321
|
WSP_GGML_METAL_ADD_KERNEL(silu);
|
|
224
322
|
WSP_GGML_METAL_ADD_KERNEL(relu);
|
|
225
323
|
WSP_GGML_METAL_ADD_KERNEL(gelu);
|
|
@@ -231,6 +329,8 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
231
329
|
WSP_GGML_METAL_ADD_KERNEL(get_rows_f16);
|
|
232
330
|
WSP_GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
|
233
331
|
WSP_GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
|
332
|
+
WSP_GGML_METAL_ADD_KERNEL(get_rows_q5_0);
|
|
333
|
+
WSP_GGML_METAL_ADD_KERNEL(get_rows_q5_1);
|
|
234
334
|
WSP_GGML_METAL_ADD_KERNEL(get_rows_q8_0);
|
|
235
335
|
WSP_GGML_METAL_ADD_KERNEL(get_rows_q2_K);
|
|
236
336
|
WSP_GGML_METAL_ADD_KERNEL(get_rows_q3_K);
|
|
@@ -239,59 +339,83 @@ struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) {
|
|
|
239
339
|
WSP_GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
|
240
340
|
WSP_GGML_METAL_ADD_KERNEL(rms_norm);
|
|
241
341
|
WSP_GGML_METAL_ADD_KERNEL(norm);
|
|
242
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
243
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
244
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
245
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
246
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
247
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
248
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
249
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
250
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
251
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
252
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
253
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
254
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
255
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
256
|
-
WSP_GGML_METAL_ADD_KERNEL(
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
342
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
|
343
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
|
|
344
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
|
|
345
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
|
|
346
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
|
|
347
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
|
|
348
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
|
|
349
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
|
|
350
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
|
|
351
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
|
|
352
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
|
|
353
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
|
|
354
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
|
|
355
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
|
|
356
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
|
|
357
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
|
358
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
|
359
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
|
360
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
|
361
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
|
|
362
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
|
|
363
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
|
|
364
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
|
365
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
|
|
366
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
|
|
367
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
|
368
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
|
369
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
|
370
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
|
|
371
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
|
|
372
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
|
|
373
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
|
|
374
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
|
|
375
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
|
|
376
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
|
|
377
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
|
|
378
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
|
|
379
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
|
|
380
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
|
|
381
|
+
WSP_GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
|
|
382
|
+
}
|
|
383
|
+
WSP_GGML_METAL_ADD_KERNEL(rope_f32);
|
|
384
|
+
WSP_GGML_METAL_ADD_KERNEL(rope_f16);
|
|
265
385
|
WSP_GGML_METAL_ADD_KERNEL(alibi_f32);
|
|
386
|
+
WSP_GGML_METAL_ADD_KERNEL(im2col_f16);
|
|
387
|
+
WSP_GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
|
|
388
|
+
WSP_GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
|
|
266
389
|
WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
|
267
390
|
WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
|
391
|
+
WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
|
|
392
|
+
WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
|
|
393
|
+
WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
|
|
394
|
+
//WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
|
|
395
|
+
//WSP_GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
|
|
268
396
|
WSP_GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
|
397
|
+
WSP_GGML_METAL_ADD_KERNEL(concat);
|
|
398
|
+
WSP_GGML_METAL_ADD_KERNEL(sqr);
|
|
399
|
+
WSP_GGML_METAL_ADD_KERNEL(sum_rows);
|
|
269
400
|
|
|
270
401
|
#undef WSP_GGML_METAL_ADD_KERNEL
|
|
271
402
|
}
|
|
272
403
|
|
|
273
|
-
metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
|
274
|
-
#if TARGET_OS_OSX
|
|
275
|
-
metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
|
276
|
-
if (ctx->device.maxTransferRate != 0) {
|
|
277
|
-
metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
|
278
|
-
} else {
|
|
279
|
-
metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
|
|
280
|
-
}
|
|
281
|
-
#endif
|
|
282
|
-
|
|
283
404
|
return ctx;
|
|
284
405
|
}
|
|
285
406
|
|
|
286
407
|
void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
|
|
287
|
-
|
|
408
|
+
WSP_GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
|
|
288
409
|
#define WSP_GGML_METAL_DEL_KERNEL(name) \
|
|
289
410
|
|
|
290
411
|
WSP_GGML_METAL_DEL_KERNEL(add);
|
|
291
412
|
WSP_GGML_METAL_DEL_KERNEL(add_row);
|
|
292
413
|
WSP_GGML_METAL_DEL_KERNEL(mul);
|
|
293
414
|
WSP_GGML_METAL_DEL_KERNEL(mul_row);
|
|
415
|
+
WSP_GGML_METAL_DEL_KERNEL(div);
|
|
416
|
+
WSP_GGML_METAL_DEL_KERNEL(div_row);
|
|
294
417
|
WSP_GGML_METAL_DEL_KERNEL(scale);
|
|
418
|
+
WSP_GGML_METAL_DEL_KERNEL(scale_4);
|
|
295
419
|
WSP_GGML_METAL_DEL_KERNEL(silu);
|
|
296
420
|
WSP_GGML_METAL_DEL_KERNEL(relu);
|
|
297
421
|
WSP_GGML_METAL_DEL_KERNEL(gelu);
|
|
@@ -303,6 +427,8 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
|
|
|
303
427
|
WSP_GGML_METAL_DEL_KERNEL(get_rows_f16);
|
|
304
428
|
WSP_GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
|
305
429
|
WSP_GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
|
430
|
+
WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_0);
|
|
431
|
+
WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_1);
|
|
306
432
|
WSP_GGML_METAL_DEL_KERNEL(get_rows_q8_0);
|
|
307
433
|
WSP_GGML_METAL_DEL_KERNEL(get_rows_q2_K);
|
|
308
434
|
WSP_GGML_METAL_DEL_KERNEL(get_rows_q3_K);
|
|
@@ -311,33 +437,64 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) {
|
|
|
311
437
|
WSP_GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
|
312
438
|
WSP_GGML_METAL_DEL_KERNEL(rms_norm);
|
|
313
439
|
WSP_GGML_METAL_DEL_KERNEL(norm);
|
|
314
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
315
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
316
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
317
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
318
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
319
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
320
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
321
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
322
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
323
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
324
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
325
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
326
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
327
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
328
|
-
WSP_GGML_METAL_DEL_KERNEL(
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
440
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
|
441
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
|
|
442
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
|
|
443
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
|
|
444
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
|
|
445
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
|
|
446
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
|
|
447
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
|
|
448
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
|
|
449
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
|
|
450
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
|
|
451
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
|
|
452
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
|
|
453
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
|
454
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
|
455
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
|
456
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
|
457
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
|
458
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
|
459
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
|
|
460
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
|
|
461
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
|
|
462
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
|
463
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
|
|
464
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
|
|
465
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
|
466
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
|
467
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
|
468
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
|
|
469
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
|
|
470
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
|
|
471
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
|
|
472
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
|
|
473
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
|
|
474
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
|
|
475
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
|
|
476
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
|
|
477
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
|
|
478
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
|
|
479
|
+
WSP_GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
|
|
480
|
+
}
|
|
481
|
+
WSP_GGML_METAL_DEL_KERNEL(rope_f32);
|
|
482
|
+
WSP_GGML_METAL_DEL_KERNEL(rope_f16);
|
|
337
483
|
WSP_GGML_METAL_DEL_KERNEL(alibi_f32);
|
|
484
|
+
WSP_GGML_METAL_DEL_KERNEL(im2col_f16);
|
|
485
|
+
WSP_GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
|
|
486
|
+
WSP_GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
|
|
338
487
|
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
|
339
488
|
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
|
489
|
+
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
|
|
490
|
+
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
|
|
491
|
+
WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
|
|
492
|
+
//WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
|
|
493
|
+
//WSP_GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
|
|
340
494
|
WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
|
495
|
+
WSP_GGML_METAL_DEL_KERNEL(concat);
|
|
496
|
+
WSP_GGML_METAL_DEL_KERNEL(sqr);
|
|
497
|
+
WSP_GGML_METAL_DEL_KERNEL(sum_rows);
|
|
341
498
|
|
|
342
499
|
#undef WSP_GGML_METAL_DEL_KERNEL
|
|
343
500
|
|
|
@@ -348,7 +505,7 @@ void * wsp_ggml_metal_host_malloc(size_t n) {
|
|
|
348
505
|
void * data = NULL;
|
|
349
506
|
const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
|
|
350
507
|
if (result != 0) {
|
|
351
|
-
|
|
508
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
|
|
352
509
|
return NULL;
|
|
353
510
|
}
|
|
354
511
|
|
|
@@ -371,30 +528,50 @@ int * wsp_ggml_metal_get_concur_list(struct wsp_ggml_metal_context * ctx) {
|
|
|
371
528
|
return ctx->concur_list;
|
|
372
529
|
}
|
|
373
530
|
|
|
531
|
+
// temporarily defined here for compatibility between ggml-backend and the old API
|
|
532
|
+
struct wsp_ggml_backend_metal_buffer_context {
|
|
533
|
+
void * data;
|
|
534
|
+
|
|
535
|
+
id<MTLBuffer> metal;
|
|
536
|
+
};
|
|
537
|
+
|
|
374
538
|
// finds the Metal buffer that contains the tensor data on the GPU device
|
|
375
539
|
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
|
376
540
|
// Metal buffer based on the host memory pointer
|
|
377
541
|
//
|
|
378
542
|
static id<MTLBuffer> wsp_ggml_metal_get_buffer(struct wsp_ggml_metal_context * ctx, struct wsp_ggml_tensor * t, size_t * offs) {
|
|
379
|
-
//
|
|
543
|
+
//WSP_GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
|
|
380
544
|
|
|
381
545
|
const int64_t tsize = wsp_ggml_nbytes(t);
|
|
382
546
|
|
|
547
|
+
// compatibility with ggml-backend
|
|
548
|
+
if (t->buffer && t->buffer->buft == wsp_ggml_backend_metal_buffer_type()) {
|
|
549
|
+
struct wsp_ggml_backend_metal_buffer_context * buf_ctx = (struct wsp_ggml_backend_metal_buffer_context *) t->buffer->context;
|
|
550
|
+
|
|
551
|
+
const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->data;
|
|
552
|
+
|
|
553
|
+
WSP_GGML_ASSERT(ioffs >= 0 && ioffs + tsize <= (int64_t) t->buffer->size);
|
|
554
|
+
|
|
555
|
+
*offs = (size_t) ioffs;
|
|
556
|
+
|
|
557
|
+
return buf_ctx->metal;
|
|
558
|
+
}
|
|
559
|
+
|
|
383
560
|
// find the view that contains the tensor fully
|
|
384
561
|
for (int i = 0; i < ctx->n_buffers; ++i) {
|
|
385
562
|
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
|
386
563
|
|
|
387
|
-
//
|
|
564
|
+
//WSP_GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
|
|
388
565
|
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
|
389
566
|
*offs = (size_t) ioffs;
|
|
390
567
|
|
|
391
|
-
//
|
|
568
|
+
//WSP_GGML_METAL_LOG_INFO("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
|
|
392
569
|
|
|
393
570
|
return ctx->buffers[i].metal;
|
|
394
571
|
}
|
|
395
572
|
}
|
|
396
573
|
|
|
397
|
-
|
|
574
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: buffer is nil\n", __func__);
|
|
398
575
|
|
|
399
576
|
return nil;
|
|
400
577
|
}
|
|
@@ -406,7 +583,7 @@ bool wsp_ggml_metal_add_buffer(
|
|
|
406
583
|
size_t size,
|
|
407
584
|
size_t max_size) {
|
|
408
585
|
if (ctx->n_buffers >= WSP_GGML_METAL_MAX_BUFFERS) {
|
|
409
|
-
|
|
586
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__);
|
|
410
587
|
return false;
|
|
411
588
|
}
|
|
412
589
|
|
|
@@ -416,7 +593,7 @@ bool wsp_ggml_metal_add_buffer(
|
|
|
416
593
|
const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
|
|
417
594
|
|
|
418
595
|
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
|
|
419
|
-
|
|
596
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
|
|
420
597
|
return false;
|
|
421
598
|
}
|
|
422
599
|
}
|
|
@@ -437,11 +614,11 @@ bool wsp_ggml_metal_add_buffer(
|
|
|
437
614
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
|
438
615
|
|
|
439
616
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
|
440
|
-
|
|
617
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
|
441
618
|
return false;
|
|
442
619
|
}
|
|
443
620
|
|
|
444
|
-
|
|
621
|
+
WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
|
445
622
|
|
|
446
623
|
++ctx->n_buffers;
|
|
447
624
|
} else {
|
|
@@ -461,13 +638,13 @@ bool wsp_ggml_metal_add_buffer(
|
|
|
461
638
|
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
|
462
639
|
|
|
463
640
|
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
|
464
|
-
|
|
641
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MiB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
|
465
642
|
return false;
|
|
466
643
|
}
|
|
467
644
|
|
|
468
|
-
|
|
645
|
+
WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MiB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
|
469
646
|
if (i + size_step < size) {
|
|
470
|
-
|
|
647
|
+
WSP_GGML_METAL_LOG_INFO("\n");
|
|
471
648
|
}
|
|
472
649
|
|
|
473
650
|
++ctx->n_buffers;
|
|
@@ -475,17 +652,17 @@ bool wsp_ggml_metal_add_buffer(
|
|
|
475
652
|
}
|
|
476
653
|
|
|
477
654
|
#if TARGET_OS_OSX
|
|
478
|
-
|
|
655
|
+
WSP_GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
|
479
656
|
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
|
480
657
|
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
|
481
658
|
|
|
482
659
|
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
|
483
|
-
|
|
660
|
+
WSP_GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
|
|
484
661
|
} else {
|
|
485
|
-
|
|
662
|
+
WSP_GGML_METAL_LOG_INFO("\n");
|
|
486
663
|
}
|
|
487
664
|
#else
|
|
488
|
-
|
|
665
|
+
WSP_GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
|
|
489
666
|
#endif
|
|
490
667
|
}
|
|
491
668
|
|
|
@@ -598,10 +775,55 @@ void wsp_ggml_metal_graph_find_concurrency(
|
|
|
598
775
|
}
|
|
599
776
|
|
|
600
777
|
if (ctx->concur_list_len > WSP_GGML_MAX_CONCUR) {
|
|
601
|
-
|
|
778
|
+
WSP_GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__);
|
|
602
779
|
}
|
|
603
780
|
}
|
|
604
781
|
|
|
782
|
+
static bool wsp_ggml_metal_supports_op(const struct wsp_ggml_tensor * op) {
|
|
783
|
+
switch (op->op) {
|
|
784
|
+
case WSP_GGML_OP_UNARY:
|
|
785
|
+
switch (wsp_ggml_get_unary_op(op)) {
|
|
786
|
+
case WSP_GGML_UNARY_OP_SILU:
|
|
787
|
+
case WSP_GGML_UNARY_OP_RELU:
|
|
788
|
+
case WSP_GGML_UNARY_OP_GELU:
|
|
789
|
+
return true;
|
|
790
|
+
default:
|
|
791
|
+
return false;
|
|
792
|
+
}
|
|
793
|
+
case WSP_GGML_OP_NONE:
|
|
794
|
+
case WSP_GGML_OP_RESHAPE:
|
|
795
|
+
case WSP_GGML_OP_VIEW:
|
|
796
|
+
case WSP_GGML_OP_TRANSPOSE:
|
|
797
|
+
case WSP_GGML_OP_PERMUTE:
|
|
798
|
+
case WSP_GGML_OP_CONCAT:
|
|
799
|
+
case WSP_GGML_OP_ADD:
|
|
800
|
+
case WSP_GGML_OP_MUL:
|
|
801
|
+
case WSP_GGML_OP_DIV:
|
|
802
|
+
case WSP_GGML_OP_SCALE:
|
|
803
|
+
case WSP_GGML_OP_SQR:
|
|
804
|
+
case WSP_GGML_OP_SUM_ROWS:
|
|
805
|
+
case WSP_GGML_OP_SOFT_MAX:
|
|
806
|
+
case WSP_GGML_OP_RMS_NORM:
|
|
807
|
+
case WSP_GGML_OP_NORM:
|
|
808
|
+
case WSP_GGML_OP_ALIBI:
|
|
809
|
+
case WSP_GGML_OP_ROPE:
|
|
810
|
+
case WSP_GGML_OP_IM2COL:
|
|
811
|
+
case WSP_GGML_OP_ARGSORT:
|
|
812
|
+
case WSP_GGML_OP_DUP:
|
|
813
|
+
case WSP_GGML_OP_CPY:
|
|
814
|
+
case WSP_GGML_OP_CONT:
|
|
815
|
+
case WSP_GGML_OP_MUL_MAT:
|
|
816
|
+
case WSP_GGML_OP_MUL_MAT_ID:
|
|
817
|
+
return true;
|
|
818
|
+
case WSP_GGML_OP_DIAG_MASK_INF:
|
|
819
|
+
case WSP_GGML_OP_GET_ROWS:
|
|
820
|
+
{
|
|
821
|
+
return op->ne[0] % 4 == 0;
|
|
822
|
+
}
|
|
823
|
+
default:
|
|
824
|
+
return false;
|
|
825
|
+
}
|
|
826
|
+
}
|
|
605
827
|
void wsp_ggml_metal_graph_compute(
|
|
606
828
|
struct wsp_ggml_metal_context * ctx,
|
|
607
829
|
struct wsp_ggml_cgraph * gf) {
|
|
@@ -652,12 +874,28 @@ void wsp_ggml_metal_graph_compute(
|
|
|
652
874
|
continue;
|
|
653
875
|
}
|
|
654
876
|
|
|
655
|
-
//
|
|
877
|
+
//WSP_GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, wsp_ggml_op_name(gf->nodes[i]->op));
|
|
656
878
|
|
|
657
879
|
struct wsp_ggml_tensor * src0 = gf->nodes[i]->src[0];
|
|
658
880
|
struct wsp_ggml_tensor * src1 = gf->nodes[i]->src[1];
|
|
659
881
|
struct wsp_ggml_tensor * dst = gf->nodes[i];
|
|
660
882
|
|
|
883
|
+
switch (dst->op) {
|
|
884
|
+
case WSP_GGML_OP_NONE:
|
|
885
|
+
case WSP_GGML_OP_RESHAPE:
|
|
886
|
+
case WSP_GGML_OP_VIEW:
|
|
887
|
+
case WSP_GGML_OP_TRANSPOSE:
|
|
888
|
+
case WSP_GGML_OP_PERMUTE:
|
|
889
|
+
{
|
|
890
|
+
// noop -> next node
|
|
891
|
+
} continue;
|
|
892
|
+
default:
|
|
893
|
+
{
|
|
894
|
+
} break;
|
|
895
|
+
}
|
|
896
|
+
|
|
897
|
+
WSP_GGML_ASSERT(wsp_ggml_metal_supports_op(dst));
|
|
898
|
+
|
|
661
899
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
|
662
900
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
|
663
901
|
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
|
@@ -696,78 +934,129 @@ void wsp_ggml_metal_graph_compute(
|
|
|
696
934
|
id<MTLBuffer> id_src1 = src1 ? wsp_ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
|
|
697
935
|
id<MTLBuffer> id_dst = dst ? wsp_ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
|
|
698
936
|
|
|
699
|
-
//
|
|
937
|
+
//WSP_GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, wsp_ggml_op_name(dst->op));
|
|
700
938
|
//if (src0) {
|
|
701
|
-
//
|
|
939
|
+
// WSP_GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src0t), ne00, ne01, ne02,
|
|
702
940
|
// wsp_ggml_is_contiguous(src0), src0->name);
|
|
703
941
|
//}
|
|
704
942
|
//if (src1) {
|
|
705
|
-
//
|
|
943
|
+
// WSP_GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src1t), ne10, ne11, ne12,
|
|
706
944
|
// wsp_ggml_is_contiguous(src1), src1->name);
|
|
707
945
|
//}
|
|
708
946
|
//if (dst) {
|
|
709
|
-
//
|
|
947
|
+
// WSP_GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, wsp_ggml_type_name(dstt), ne0, ne1, ne2,
|
|
710
948
|
// dst->name);
|
|
711
949
|
//}
|
|
712
950
|
|
|
713
951
|
switch (dst->op) {
|
|
714
|
-
case
|
|
715
|
-
case WSP_GGML_OP_RESHAPE:
|
|
716
|
-
case WSP_GGML_OP_VIEW:
|
|
717
|
-
case WSP_GGML_OP_TRANSPOSE:
|
|
718
|
-
case WSP_GGML_OP_PERMUTE:
|
|
952
|
+
case WSP_GGML_OP_CONCAT:
|
|
719
953
|
{
|
|
720
|
-
|
|
721
|
-
} break;
|
|
722
|
-
case WSP_GGML_OP_ADD:
|
|
723
|
-
{
|
|
724
|
-
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
725
|
-
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
|
|
726
|
-
|
|
727
|
-
// utilize float4
|
|
728
|
-
WSP_GGML_ASSERT(ne00 % 4 == 0);
|
|
729
|
-
const int64_t nb = ne00/4;
|
|
954
|
+
const int64_t nb = ne00;
|
|
730
955
|
|
|
731
|
-
|
|
732
|
-
// src1 is a row
|
|
733
|
-
WSP_GGML_ASSERT(ne11 == 1);
|
|
734
|
-
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
|
735
|
-
} else {
|
|
736
|
-
[encoder setComputePipelineState:ctx->pipeline_add];
|
|
737
|
-
}
|
|
956
|
+
[encoder setComputePipelineState:ctx->pipeline_concat];
|
|
738
957
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
739
958
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
740
959
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
741
|
-
[encoder setBytes:&
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
[encoder
|
|
960
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
961
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
962
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
963
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
|
964
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
|
965
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
|
966
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
|
967
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
|
968
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
|
969
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
|
970
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
|
971
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
|
972
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
|
973
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
|
974
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
|
975
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
|
976
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
|
977
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
|
978
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
|
979
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
|
980
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
|
981
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
|
982
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
|
983
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
|
984
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
|
985
|
+
|
|
986
|
+
const int nth = MIN(1024, ne0);
|
|
987
|
+
|
|
988
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
746
989
|
} break;
|
|
990
|
+
case WSP_GGML_OP_ADD:
|
|
747
991
|
case WSP_GGML_OP_MUL:
|
|
992
|
+
case WSP_GGML_OP_DIV:
|
|
748
993
|
{
|
|
749
994
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
750
995
|
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1));
|
|
751
996
|
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
997
|
+
bool bcast_row = false;
|
|
998
|
+
|
|
999
|
+
int64_t nb = ne00;
|
|
755
1000
|
|
|
756
|
-
if (wsp_ggml_nelements(src1) == ne10) {
|
|
1001
|
+
if (wsp_ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
|
|
757
1002
|
// src1 is a row
|
|
758
1003
|
WSP_GGML_ASSERT(ne11 == 1);
|
|
759
|
-
|
|
1004
|
+
|
|
1005
|
+
nb = ne00 / 4;
|
|
1006
|
+
switch (dst->op) {
|
|
1007
|
+
case WSP_GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
|
|
1008
|
+
case WSP_GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
|
|
1009
|
+
case WSP_GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
|
|
1010
|
+
default: WSP_GGML_ASSERT(false);
|
|
1011
|
+
}
|
|
1012
|
+
|
|
1013
|
+
bcast_row = true;
|
|
760
1014
|
} else {
|
|
761
|
-
|
|
1015
|
+
switch (dst->op) {
|
|
1016
|
+
case WSP_GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
|
|
1017
|
+
case WSP_GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
|
|
1018
|
+
case WSP_GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
|
|
1019
|
+
default: WSP_GGML_ASSERT(false);
|
|
1020
|
+
}
|
|
762
1021
|
}
|
|
763
1022
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
764
1023
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
765
1024
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
766
|
-
[encoder setBytes:&
|
|
767
|
-
|
|
768
|
-
|
|
1025
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
1026
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
1027
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
1028
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
|
1029
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
|
1030
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
|
1031
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
|
1032
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
|
1033
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
|
1034
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
|
1035
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
|
1036
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
|
1037
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
|
1038
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
|
1039
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
|
1040
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
|
1041
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
|
1042
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
|
1043
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
|
1044
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
|
1045
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
|
1046
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
|
1047
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
|
1048
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
|
1049
|
+
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
|
1050
|
+
|
|
1051
|
+
if (bcast_row) {
|
|
1052
|
+
const int64_t n = wsp_ggml_nelements(dst)/4;
|
|
1053
|
+
|
|
1054
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1055
|
+
} else {
|
|
1056
|
+
const int nth = MIN(1024, ne0);
|
|
769
1057
|
|
|
770
|
-
|
|
1058
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1059
|
+
}
|
|
771
1060
|
} break;
|
|
772
1061
|
case WSP_GGML_OP_SCALE:
|
|
773
1062
|
{
|
|
@@ -775,13 +1064,19 @@ void wsp_ggml_metal_graph_compute(
|
|
|
775
1064
|
|
|
776
1065
|
const float scale = *(const float *) src1->data;
|
|
777
1066
|
|
|
778
|
-
|
|
1067
|
+
int64_t n = wsp_ggml_nelements(dst);
|
|
1068
|
+
|
|
1069
|
+
if (n % 4 == 0) {
|
|
1070
|
+
n /= 4;
|
|
1071
|
+
[encoder setComputePipelineState:ctx->pipeline_scale_4];
|
|
1072
|
+
} else {
|
|
1073
|
+
[encoder setComputePipelineState:ctx->pipeline_scale];
|
|
1074
|
+
}
|
|
1075
|
+
|
|
779
1076
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
780
1077
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
781
1078
|
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
|
782
1079
|
|
|
783
|
-
const int64_t n = wsp_ggml_nelements(dst)/4;
|
|
784
|
-
|
|
785
1080
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
786
1081
|
} break;
|
|
787
1082
|
case WSP_GGML_OP_UNARY:
|
|
@@ -792,9 +1087,10 @@ void wsp_ggml_metal_graph_compute(
|
|
|
792
1087
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
793
1088
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
794
1089
|
|
|
795
|
-
const int64_t n = wsp_ggml_nelements(dst)
|
|
1090
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
1091
|
+
WSP_GGML_ASSERT(n % 4 == 0);
|
|
796
1092
|
|
|
797
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1093
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
798
1094
|
} break;
|
|
799
1095
|
case WSP_GGML_UNARY_OP_RELU:
|
|
800
1096
|
{
|
|
@@ -812,32 +1108,92 @@ void wsp_ggml_metal_graph_compute(
|
|
|
812
1108
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
813
1109
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
814
1110
|
|
|
815
|
-
const int64_t n = wsp_ggml_nelements(dst)
|
|
1111
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
1112
|
+
WSP_GGML_ASSERT(n % 4 == 0);
|
|
816
1113
|
|
|
817
|
-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1114
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
818
1115
|
} break;
|
|
819
1116
|
default:
|
|
820
1117
|
{
|
|
821
|
-
|
|
1118
|
+
WSP_GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op));
|
|
822
1119
|
WSP_GGML_ASSERT(false);
|
|
823
1120
|
}
|
|
824
1121
|
} break;
|
|
1122
|
+
case WSP_GGML_OP_SQR:
|
|
1123
|
+
{
|
|
1124
|
+
WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0));
|
|
1125
|
+
|
|
1126
|
+
[encoder setComputePipelineState:ctx->pipeline_sqr];
|
|
1127
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1128
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1129
|
+
|
|
1130
|
+
const int64_t n = wsp_ggml_nelements(dst);
|
|
1131
|
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1132
|
+
} break;
|
|
1133
|
+
case WSP_GGML_OP_SUM_ROWS:
|
|
1134
|
+
{
|
|
1135
|
+
WSP_GGML_ASSERT(src0->nb[0] == wsp_ggml_type_size(src0->type));
|
|
1136
|
+
|
|
1137
|
+
[encoder setComputePipelineState:ctx->pipeline_sum_rows];
|
|
1138
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1139
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1140
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
1141
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
1142
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
1143
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
1144
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
|
1145
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
|
1146
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
|
1147
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
|
1148
|
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
|
1149
|
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
|
1150
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
|
1151
|
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
|
1152
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
|
1153
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
|
1154
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
|
1155
|
+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
|
1156
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
|
|
1157
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19];
|
|
1158
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20];
|
|
1159
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21];
|
|
1160
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22];
|
|
1161
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23];
|
|
1162
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24];
|
|
1163
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25];
|
|
1164
|
+
|
|
1165
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
1166
|
+
} break;
|
|
825
1167
|
case WSP_GGML_OP_SOFT_MAX:
|
|
826
1168
|
{
|
|
827
|
-
|
|
1169
|
+
int nth = 32; // SIMD width
|
|
828
1170
|
|
|
829
1171
|
if (ne00%4 == 0) {
|
|
1172
|
+
while (nth < ne00/4 && nth < 256) {
|
|
1173
|
+
nth *= 2;
|
|
1174
|
+
}
|
|
830
1175
|
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
|
|
831
1176
|
} else {
|
|
1177
|
+
while (nth < ne00 && nth < 1024) {
|
|
1178
|
+
nth *= 2;
|
|
1179
|
+
}
|
|
832
1180
|
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
|
833
1181
|
}
|
|
834
|
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
835
|
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
836
|
-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
837
|
-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
838
|
-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
839
1182
|
|
|
840
|
-
|
|
1183
|
+
const float scale = ((float *) dst->op_params)[0];
|
|
1184
|
+
|
|
1185
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1186
|
+
if (id_src1) {
|
|
1187
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1188
|
+
}
|
|
1189
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1190
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
|
1191
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
|
1192
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
|
1193
|
+
[encoder setBytes:&scale length:sizeof(scale) atIndex:6];
|
|
1194
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
1195
|
+
|
|
1196
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
841
1197
|
} break;
|
|
842
1198
|
case WSP_GGML_OP_DIAG_MASK_INF:
|
|
843
1199
|
{
|
|
@@ -863,26 +1219,57 @@ void wsp_ggml_metal_graph_compute(
|
|
|
863
1219
|
} break;
|
|
864
1220
|
case WSP_GGML_OP_MUL_MAT:
|
|
865
1221
|
{
|
|
866
|
-
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
|
|
867
|
-
|
|
868
1222
|
WSP_GGML_ASSERT(ne00 == ne10);
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
WSP_GGML_ASSERT(
|
|
1223
|
+
|
|
1224
|
+
// TODO: assert that dim2 and dim3 are contiguous
|
|
1225
|
+
WSP_GGML_ASSERT(ne12 % ne02 == 0);
|
|
1226
|
+
WSP_GGML_ASSERT(ne13 % ne03 == 0);
|
|
1227
|
+
|
|
1228
|
+
const uint r2 = ne12/ne02;
|
|
1229
|
+
const uint r3 = ne13/ne03;
|
|
1230
|
+
|
|
1231
|
+
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
|
1232
|
+
// to the matrix-vector kernel
|
|
1233
|
+
int ne11_mm_min = 1;
|
|
1234
|
+
|
|
1235
|
+
#if 0
|
|
1236
|
+
// the numbers below are measured on M2 Ultra for 7B and 13B models
|
|
1237
|
+
// these numbers do not translate to other devices or model sizes
|
|
1238
|
+
// TODO: need to find a better approach
|
|
1239
|
+
if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
|
|
1240
|
+
switch (src0t) {
|
|
1241
|
+
case WSP_GGML_TYPE_F16: ne11_mm_min = 2; break;
|
|
1242
|
+
case WSP_GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
|
|
1243
|
+
case WSP_GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
|
|
1244
|
+
case WSP_GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
|
|
1245
|
+
case WSP_GGML_TYPE_Q4_0:
|
|
1246
|
+
case WSP_GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
|
|
1247
|
+
case WSP_GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
|
|
1248
|
+
case WSP_GGML_TYPE_Q5_0: // not tested yet
|
|
1249
|
+
case WSP_GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
|
|
1250
|
+
case WSP_GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
|
|
1251
|
+
case WSP_GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
|
|
1252
|
+
default: ne11_mm_min = 1; break;
|
|
1253
|
+
}
|
|
1254
|
+
}
|
|
1255
|
+
#endif
|
|
872
1256
|
|
|
873
1257
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
|
874
1258
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
875
|
-
if (
|
|
1259
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
|
1260
|
+
!wsp_ggml_is_transposed(src0) &&
|
|
876
1261
|
!wsp_ggml_is_transposed(src1) &&
|
|
877
1262
|
src1t == WSP_GGML_TYPE_F32 &&
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
ne11
|
|
1263
|
+
ne00 % 32 == 0 && ne00 >= 64 &&
|
|
1264
|
+
(ne11 > ne11_mm_min || (wsp_ggml_is_quantized(src0t) && ne12 > 1))) {
|
|
1265
|
+
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
881
1266
|
switch (src0->type) {
|
|
882
1267
|
case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
|
883
1268
|
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
|
884
1269
|
case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
|
885
1270
|
case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
|
1271
|
+
case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
|
|
1272
|
+
case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
|
|
886
1273
|
case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
|
|
887
1274
|
case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
|
|
888
1275
|
case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
|
|
@@ -904,110 +1291,106 @@ void wsp_ggml_metal_graph_compute(
|
|
|
904
1291
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
|
905
1292
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
|
906
1293
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
|
907
|
-
[encoder setBytes:&
|
|
1294
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
|
|
1295
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
|
|
908
1296
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
|
909
|
-
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63)
|
|
1297
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
910
1298
|
} else {
|
|
911
1299
|
int nth0 = 32;
|
|
912
1300
|
int nth1 = 1;
|
|
913
1301
|
int nrows = 1;
|
|
1302
|
+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
|
914
1303
|
|
|
915
1304
|
// use custom matrix x vector kernel
|
|
916
1305
|
switch (src0t) {
|
|
917
1306
|
case WSP_GGML_TYPE_F32:
|
|
918
1307
|
{
|
|
919
|
-
|
|
1308
|
+
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
1309
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
|
|
920
1310
|
nrows = 4;
|
|
921
1311
|
} break;
|
|
922
1312
|
case WSP_GGML_TYPE_F16:
|
|
923
1313
|
{
|
|
924
1314
|
nth0 = 32;
|
|
925
1315
|
nth1 = 1;
|
|
926
|
-
if (
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
1316
|
+
if (src1t == WSP_GGML_TYPE_F32) {
|
|
1317
|
+
if (ne11 * ne12 < 4) {
|
|
1318
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
|
|
1319
|
+
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
|
1320
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
|
|
1321
|
+
nrows = ne11;
|
|
1322
|
+
} else {
|
|
1323
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
|
|
1324
|
+
nrows = 4;
|
|
1325
|
+
}
|
|
931
1326
|
} else {
|
|
932
|
-
[encoder setComputePipelineState:ctx->
|
|
1327
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
|
|
933
1328
|
nrows = 4;
|
|
934
1329
|
}
|
|
935
1330
|
} break;
|
|
936
1331
|
case WSP_GGML_TYPE_Q4_0:
|
|
937
1332
|
{
|
|
938
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
939
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
940
|
-
|
|
941
1333
|
nth0 = 8;
|
|
942
1334
|
nth1 = 8;
|
|
943
|
-
[encoder setComputePipelineState:ctx->
|
|
1335
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
|
|
944
1336
|
} break;
|
|
945
1337
|
case WSP_GGML_TYPE_Q4_1:
|
|
946
1338
|
{
|
|
947
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
948
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
949
|
-
|
|
950
1339
|
nth0 = 8;
|
|
951
1340
|
nth1 = 8;
|
|
952
|
-
[encoder setComputePipelineState:ctx->
|
|
1341
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
|
|
1342
|
+
} break;
|
|
1343
|
+
case WSP_GGML_TYPE_Q5_0:
|
|
1344
|
+
{
|
|
1345
|
+
nth0 = 8;
|
|
1346
|
+
nth1 = 8;
|
|
1347
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
|
|
1348
|
+
} break;
|
|
1349
|
+
case WSP_GGML_TYPE_Q5_1:
|
|
1350
|
+
{
|
|
1351
|
+
nth0 = 8;
|
|
1352
|
+
nth1 = 8;
|
|
1353
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
|
|
953
1354
|
} break;
|
|
954
1355
|
case WSP_GGML_TYPE_Q8_0:
|
|
955
1356
|
{
|
|
956
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
957
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
958
|
-
|
|
959
1357
|
nth0 = 8;
|
|
960
1358
|
nth1 = 8;
|
|
961
|
-
[encoder setComputePipelineState:ctx->
|
|
1359
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
|
|
962
1360
|
} break;
|
|
963
1361
|
case WSP_GGML_TYPE_Q2_K:
|
|
964
1362
|
{
|
|
965
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
966
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
967
|
-
|
|
968
1363
|
nth0 = 2;
|
|
969
1364
|
nth1 = 32;
|
|
970
|
-
[encoder setComputePipelineState:ctx->
|
|
1365
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
|
|
971
1366
|
} break;
|
|
972
1367
|
case WSP_GGML_TYPE_Q3_K:
|
|
973
1368
|
{
|
|
974
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
975
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
976
|
-
|
|
977
1369
|
nth0 = 2;
|
|
978
1370
|
nth1 = 32;
|
|
979
|
-
[encoder setComputePipelineState:ctx->
|
|
1371
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
|
|
980
1372
|
} break;
|
|
981
1373
|
case WSP_GGML_TYPE_Q4_K:
|
|
982
1374
|
{
|
|
983
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
984
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
985
|
-
|
|
986
1375
|
nth0 = 4; //1;
|
|
987
1376
|
nth1 = 8; //32;
|
|
988
|
-
[encoder setComputePipelineState:ctx->
|
|
1377
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
|
|
989
1378
|
} break;
|
|
990
1379
|
case WSP_GGML_TYPE_Q5_K:
|
|
991
1380
|
{
|
|
992
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
993
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
994
|
-
|
|
995
1381
|
nth0 = 2;
|
|
996
1382
|
nth1 = 32;
|
|
997
|
-
[encoder setComputePipelineState:ctx->
|
|
1383
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
|
|
998
1384
|
} break;
|
|
999
1385
|
case WSP_GGML_TYPE_Q6_K:
|
|
1000
1386
|
{
|
|
1001
|
-
WSP_GGML_ASSERT(ne02 == 1);
|
|
1002
|
-
WSP_GGML_ASSERT(ne12 == 1);
|
|
1003
|
-
|
|
1004
1387
|
nth0 = 2;
|
|
1005
1388
|
nth1 = 32;
|
|
1006
|
-
[encoder setComputePipelineState:ctx->
|
|
1389
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
|
|
1007
1390
|
} break;
|
|
1008
1391
|
default:
|
|
1009
1392
|
{
|
|
1010
|
-
|
|
1393
|
+
WSP_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
|
1011
1394
|
WSP_GGML_ASSERT(false && "not implemented");
|
|
1012
1395
|
}
|
|
1013
1396
|
};
|
|
@@ -1029,31 +1412,125 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1029
1412
|
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
|
1030
1413
|
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
|
1031
1414
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
|
1032
|
-
[encoder setBytes:&
|
|
1415
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
|
1416
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
|
1033
1417
|
|
|
1034
|
-
if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 ||
|
|
1035
|
-
src0t ==
|
|
1036
|
-
|
|
1418
|
+
if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 ||
|
|
1419
|
+
src0t == WSP_GGML_TYPE_Q5_0 || src0t == WSP_GGML_TYPE_Q5_1 || src0t == WSP_GGML_TYPE_Q8_0 ||
|
|
1420
|
+
src0t == WSP_GGML_TYPE_Q2_K) { // || src0t == WSP_GGML_TYPE_Q4_K) {
|
|
1421
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1037
1422
|
}
|
|
1038
1423
|
else if (src0t == WSP_GGML_TYPE_Q4_K) {
|
|
1039
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1424
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1040
1425
|
}
|
|
1041
1426
|
else if (src0t == WSP_GGML_TYPE_Q3_K) {
|
|
1042
1427
|
#ifdef WSP_GGML_QKK_64
|
|
1043
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1428
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1044
1429
|
#else
|
|
1045
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1430
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1046
1431
|
#endif
|
|
1047
1432
|
}
|
|
1048
1433
|
else if (src0t == WSP_GGML_TYPE_Q5_K) {
|
|
1049
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1434
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1050
1435
|
}
|
|
1051
1436
|
else if (src0t == WSP_GGML_TYPE_Q6_K) {
|
|
1052
|
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1437
|
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1053
1438
|
} else {
|
|
1054
1439
|
int64_t ny = (ne11 + nrows - 1)/nrows;
|
|
1055
|
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1440
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
1441
|
+
}
|
|
1442
|
+
}
|
|
1443
|
+
} break;
|
|
1444
|
+
case WSP_GGML_OP_MUL_MAT_ID:
|
|
1445
|
+
{
|
|
1446
|
+
//WSP_GGML_ASSERT(ne00 == ne10);
|
|
1447
|
+
//WSP_GGML_ASSERT(ne03 == ne13);
|
|
1448
|
+
|
|
1449
|
+
WSP_GGML_ASSERT(src0t == WSP_GGML_TYPE_I32);
|
|
1450
|
+
|
|
1451
|
+
const int n_as = ne00;
|
|
1452
|
+
|
|
1453
|
+
// TODO: make this more general
|
|
1454
|
+
WSP_GGML_ASSERT(n_as <= 8);
|
|
1455
|
+
|
|
1456
|
+
struct wsp_ggml_tensor * src2 = gf->nodes[i]->src[2];
|
|
1457
|
+
|
|
1458
|
+
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
|
1459
|
+
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
|
1460
|
+
const int64_t ne22 = src2 ? src2->ne[2] : 0;
|
|
1461
|
+
const int64_t ne23 = src2 ? src2->ne[3] : 0; WSP_GGML_UNUSED(ne23);
|
|
1462
|
+
|
|
1463
|
+
const uint64_t nb20 = src2 ? src2->nb[0] : 0; WSP_GGML_UNUSED(nb20);
|
|
1464
|
+
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
|
1465
|
+
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
|
1466
|
+
const uint64_t nb23 = src2 ? src2->nb[3] : 0; WSP_GGML_UNUSED(nb23);
|
|
1467
|
+
|
|
1468
|
+
const enum wsp_ggml_type src2t = src2 ? src2->type : WSP_GGML_TYPE_COUNT; WSP_GGML_UNUSED(src2t);
|
|
1469
|
+
|
|
1470
|
+
WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src2));
|
|
1471
|
+
WSP_GGML_ASSERT(!wsp_ggml_is_transposed(src1));
|
|
1472
|
+
|
|
1473
|
+
WSP_GGML_ASSERT(ne20 % 32 == 0);
|
|
1474
|
+
// !!!!!!!!! TODO: this assert is probably required but not sure!
|
|
1475
|
+
//WSP_GGML_ASSERT(ne20 >= 64);
|
|
1476
|
+
WSP_GGML_ASSERT(src1t == WSP_GGML_TYPE_F32);
|
|
1477
|
+
|
|
1478
|
+
const uint r2 = ne12/ne22;
|
|
1479
|
+
const uint r3 = ne13/ne23;
|
|
1480
|
+
|
|
1481
|
+
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
|
1482
|
+
// to the matrix-vector kernel
|
|
1483
|
+
int ne11_mm_min = 0;
|
|
1484
|
+
|
|
1485
|
+
const int idx = ((int32_t *) dst->op_params)[0];
|
|
1486
|
+
|
|
1487
|
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
|
1488
|
+
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
1489
|
+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
|
1490
|
+
ne11 > ne11_mm_min) {
|
|
1491
|
+
switch (src2->type) {
|
|
1492
|
+
case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
|
|
1493
|
+
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
|
|
1494
|
+
case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
|
|
1495
|
+
case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
|
|
1496
|
+
case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
|
|
1497
|
+
case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
|
|
1498
|
+
case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
|
|
1499
|
+
case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
|
|
1500
|
+
case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
|
|
1501
|
+
case WSP_GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
|
|
1502
|
+
case WSP_GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
|
|
1503
|
+
case WSP_GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
|
|
1504
|
+
default: WSP_GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
|
1505
|
+
}
|
|
1506
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1507
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1508
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1509
|
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
|
|
1510
|
+
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
|
|
1511
|
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
|
|
1512
|
+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
|
|
1513
|
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
|
1514
|
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
|
|
1515
|
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
|
|
1516
|
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
|
1517
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
|
1518
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
|
1519
|
+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
|
|
1520
|
+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
|
|
1521
|
+
[encoder setBytes:&idx length:sizeof(idx) atIndex:15];
|
|
1522
|
+
// TODO: how to make this an array? read Metal docs
|
|
1523
|
+
for (int j = 0; j < n_as; ++j) {
|
|
1524
|
+
struct wsp_ggml_tensor * src_cur = dst->src[2 + j];
|
|
1525
|
+
|
|
1526
|
+
size_t offs_src_cur = 0;
|
|
1527
|
+
id<MTLBuffer> id_src_cur = wsp_ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
|
1528
|
+
|
|
1529
|
+
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
|
|
1056
1530
|
}
|
|
1531
|
+
|
|
1532
|
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
|
1533
|
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
1057
1534
|
}
|
|
1058
1535
|
} break;
|
|
1059
1536
|
case WSP_GGML_OP_GET_ROWS:
|
|
@@ -1063,6 +1540,8 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1063
1540
|
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
|
1064
1541
|
case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
|
1065
1542
|
case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
|
1543
|
+
case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
|
|
1544
|
+
case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
|
|
1066
1545
|
case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
|
|
1067
1546
|
case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
|
|
1068
1547
|
case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
|
|
@@ -1085,18 +1564,24 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1085
1564
|
} break;
|
|
1086
1565
|
case WSP_GGML_OP_RMS_NORM:
|
|
1087
1566
|
{
|
|
1567
|
+
WSP_GGML_ASSERT(ne00 % 4 == 0);
|
|
1568
|
+
|
|
1088
1569
|
float eps;
|
|
1089
1570
|
memcpy(&eps, dst->op_params, sizeof(float));
|
|
1090
1571
|
|
|
1091
|
-
|
|
1572
|
+
int nth = 32; // SIMD width
|
|
1573
|
+
|
|
1574
|
+
while (nth < ne00/4 && nth < 1024) {
|
|
1575
|
+
nth *= 2;
|
|
1576
|
+
}
|
|
1092
1577
|
|
|
1093
1578
|
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
|
1094
|
-
[encoder setBuffer:id_src0 offset:offs_src0
|
|
1095
|
-
[encoder setBuffer:id_dst offset:offs_dst
|
|
1096
|
-
[encoder setBytes:&ne00
|
|
1097
|
-
[encoder setBytes:&nb01
|
|
1098
|
-
[encoder setBytes:&eps
|
|
1099
|
-
[encoder setThreadgroupMemoryLength:
|
|
1579
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1580
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1581
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1582
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
|
1583
|
+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
|
1584
|
+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
1100
1585
|
|
|
1101
1586
|
const int64_t nrows = wsp_ggml_nrows(src0);
|
|
1102
1587
|
|
|
@@ -1107,7 +1592,7 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1107
1592
|
float eps;
|
|
1108
1593
|
memcpy(&eps, dst->op_params, sizeof(float));
|
|
1109
1594
|
|
|
1110
|
-
const int nth = 256;
|
|
1595
|
+
const int nth = MIN(256, ne00);
|
|
1111
1596
|
|
|
1112
1597
|
[encoder setComputePipelineState:ctx->pipeline_norm];
|
|
1113
1598
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
@@ -1115,7 +1600,7 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1115
1600
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1116
1601
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
|
1117
1602
|
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
|
1118
|
-
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
|
1603
|
+
[encoder setThreadgroupMemoryLength:WSP_GGML_PAD(nth*sizeof(float), 16) atIndex:0];
|
|
1119
1604
|
|
|
1120
1605
|
const int64_t nrows = wsp_ggml_nrows(src0);
|
|
1121
1606
|
|
|
@@ -1125,17 +1610,16 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1125
1610
|
{
|
|
1126
1611
|
WSP_GGML_ASSERT((src0t == WSP_GGML_TYPE_F32));
|
|
1127
1612
|
|
|
1128
|
-
const int
|
|
1613
|
+
const int nth = MIN(1024, ne00);
|
|
1614
|
+
|
|
1615
|
+
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
1129
1616
|
const int n_head = ((int32_t *) dst->op_params)[1];
|
|
1130
1617
|
float max_bias;
|
|
1131
1618
|
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
|
1132
1619
|
|
|
1133
|
-
if (__builtin_popcount(n_head) != 1) {
|
|
1134
|
-
WSP_GGML_ASSERT(false && "only power-of-two n_head implemented");
|
|
1135
|
-
}
|
|
1136
|
-
|
|
1137
1620
|
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
|
|
1138
1621
|
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
|
1622
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
|
1139
1623
|
|
|
1140
1624
|
[encoder setComputePipelineState:ctx->pipeline_alibi_f32];
|
|
1141
1625
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
@@ -1156,62 +1640,164 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1156
1640
|
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
1157
1641
|
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
1158
1642
|
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
1159
|
-
[encoder setBytes:&m0
|
|
1160
|
-
|
|
1161
|
-
|
|
1643
|
+
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
|
1644
|
+
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
|
|
1645
|
+
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
|
|
1162
1646
|
|
|
1163
1647
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1164
1648
|
} break;
|
|
1165
1649
|
case WSP_GGML_OP_ROPE:
|
|
1166
1650
|
{
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
const int
|
|
1651
|
+
WSP_GGML_ASSERT(ne10 == ne02);
|
|
1652
|
+
|
|
1653
|
+
const int nth = MIN(1024, ne00);
|
|
1654
|
+
|
|
1655
|
+
const int n_past = ((int32_t *) dst->op_params)[0];
|
|
1656
|
+
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
1657
|
+
const int mode = ((int32_t *) dst->op_params)[2];
|
|
1658
|
+
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
|
1659
|
+
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
|
|
1660
|
+
|
|
1661
|
+
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
1662
|
+
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
1663
|
+
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
1664
|
+
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
1665
|
+
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
|
1666
|
+
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
|
1667
|
+
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
1668
|
+
|
|
1669
|
+
switch (src0->type) {
|
|
1670
|
+
case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
|
|
1671
|
+
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
|
|
1672
|
+
default: WSP_GGML_ASSERT(false);
|
|
1673
|
+
};
|
|
1674
|
+
|
|
1675
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1676
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
1677
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
1678
|
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
|
1679
|
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4];
|
|
1680
|
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5];
|
|
1681
|
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6];
|
|
1682
|
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7];
|
|
1683
|
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
|
|
1684
|
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
|
|
1685
|
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
|
|
1686
|
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11];
|
|
1687
|
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12];
|
|
1688
|
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13];
|
|
1689
|
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14];
|
|
1690
|
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15];
|
|
1691
|
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16];
|
|
1692
|
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17];
|
|
1693
|
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18];
|
|
1694
|
+
[encoder setBytes:&n_past length:sizeof( int) atIndex:19];
|
|
1695
|
+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:20];
|
|
1696
|
+
[encoder setBytes:&mode length:sizeof( int) atIndex:21];
|
|
1697
|
+
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
|
|
1698
|
+
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
|
|
1699
|
+
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
|
|
1700
|
+
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
|
|
1701
|
+
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
|
|
1702
|
+
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
|
|
1703
|
+
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
|
|
1704
|
+
|
|
1705
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
1706
|
+
} break;
|
|
1707
|
+
case WSP_GGML_OP_IM2COL:
|
|
1708
|
+
{
|
|
1709
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16);
|
|
1710
|
+
WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32);
|
|
1711
|
+
WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16);
|
|
1712
|
+
|
|
1713
|
+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
|
1714
|
+
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
|
1715
|
+
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
|
1716
|
+
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
|
1717
|
+
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
|
1718
|
+
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
|
1719
|
+
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
|
1720
|
+
|
|
1721
|
+
const int32_t N = src1->ne[is_2D ? 3 : 2];
|
|
1722
|
+
const int32_t IC = src1->ne[is_2D ? 2 : 1];
|
|
1723
|
+
const int32_t IH = is_2D ? src1->ne[1] : 1;
|
|
1724
|
+
const int32_t IW = src1->ne[0];
|
|
1725
|
+
|
|
1726
|
+
const int32_t KH = is_2D ? src0->ne[1] : 1;
|
|
1727
|
+
const int32_t KW = src0->ne[0];
|
|
1728
|
+
|
|
1729
|
+
const int32_t OH = is_2D ? dst->ne[2] : 1;
|
|
1730
|
+
const int32_t OW = dst->ne[1];
|
|
1731
|
+
|
|
1732
|
+
const int32_t CHW = IC * KH * KW;
|
|
1733
|
+
|
|
1734
|
+
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
|
1735
|
+
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
|
1736
|
+
|
|
1737
|
+
switch (src0->type) {
|
|
1738
|
+
case WSP_GGML_TYPE_F32: WSP_GGML_ASSERT(false && "not implemented"); break;
|
|
1739
|
+
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
|
|
1740
|
+
default: WSP_GGML_ASSERT(false);
|
|
1741
|
+
};
|
|
1742
|
+
|
|
1743
|
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
|
1744
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1745
|
+
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
|
|
1746
|
+
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
|
|
1747
|
+
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
|
|
1748
|
+
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
|
|
1749
|
+
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
|
|
1750
|
+
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
|
|
1751
|
+
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
|
|
1752
|
+
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
|
|
1753
|
+
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
|
|
1754
|
+
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
|
|
1755
|
+
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
|
|
1756
|
+
|
|
1757
|
+
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
|
1758
|
+
} break;
|
|
1759
|
+
case WSP_GGML_OP_ARGSORT:
|
|
1760
|
+
{
|
|
1761
|
+
WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F32);
|
|
1762
|
+
WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_I32);
|
|
1170
1763
|
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1764
|
+
const int nrows = wsp_ggml_nrows(src0);
|
|
1765
|
+
|
|
1766
|
+
enum wsp_ggml_sort_order order = (enum wsp_ggml_sort_order) dst->op_params[0];
|
|
1767
|
+
|
|
1768
|
+
switch (order) {
|
|
1769
|
+
case WSP_GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break;
|
|
1770
|
+
case WSP_GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
|
|
1771
|
+
default: WSP_GGML_ASSERT(false);
|
|
1772
|
+
};
|
|
1175
1773
|
|
|
1176
|
-
[encoder setComputePipelineState:ctx->pipeline_rope];
|
|
1177
1774
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
1178
1775
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
1179
1776
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
1180
|
-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
|
1181
|
-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
|
1182
|
-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
|
1183
|
-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
|
1184
|
-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
|
1185
|
-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
|
1186
|
-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
|
1187
|
-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
|
1188
|
-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
|
1189
|
-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
|
1190
|
-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
|
1191
|
-
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
|
1192
|
-
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
|
1193
|
-
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
|
1194
|
-
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
|
1195
|
-
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
|
|
1196
|
-
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
|
|
1197
|
-
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
|
|
1198
|
-
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
|
|
1199
|
-
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
|
|
1200
1777
|
|
|
1201
|
-
[encoder dispatchThreadgroups:MTLSizeMake(
|
|
1778
|
+
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
|
|
1202
1779
|
} break;
|
|
1203
1780
|
case WSP_GGML_OP_DUP:
|
|
1204
1781
|
case WSP_GGML_OP_CPY:
|
|
1205
1782
|
case WSP_GGML_OP_CONT:
|
|
1206
1783
|
{
|
|
1207
|
-
|
|
1784
|
+
WSP_GGML_ASSERT(ne00 % wsp_ggml_blck_size(src0->type) == 0);
|
|
1785
|
+
|
|
1786
|
+
int nth = MIN(1024, ne00/wsp_ggml_blck_size(src0->type));
|
|
1208
1787
|
|
|
1209
1788
|
switch (src0t) {
|
|
1210
1789
|
case WSP_GGML_TYPE_F32:
|
|
1211
1790
|
{
|
|
1791
|
+
WSP_GGML_ASSERT(ne0 % wsp_ggml_blck_size(dst->type) == 0);
|
|
1792
|
+
|
|
1212
1793
|
switch (dstt) {
|
|
1213
|
-
case WSP_GGML_TYPE_F16:
|
|
1214
|
-
case WSP_GGML_TYPE_F32:
|
|
1794
|
+
case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
|
|
1795
|
+
case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
|
|
1796
|
+
case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
|
|
1797
|
+
case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
|
|
1798
|
+
case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
|
|
1799
|
+
//case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
|
|
1800
|
+
//case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
|
|
1215
1801
|
default: WSP_GGML_ASSERT(false && "not implemented");
|
|
1216
1802
|
};
|
|
1217
1803
|
} break;
|
|
@@ -1249,7 +1835,7 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1249
1835
|
} break;
|
|
1250
1836
|
default:
|
|
1251
1837
|
{
|
|
1252
|
-
|
|
1838
|
+
WSP_GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op));
|
|
1253
1839
|
WSP_GGML_ASSERT(false);
|
|
1254
1840
|
}
|
|
1255
1841
|
}
|
|
@@ -1274,10 +1860,244 @@ void wsp_ggml_metal_graph_compute(
|
|
|
1274
1860
|
|
|
1275
1861
|
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
|
|
1276
1862
|
if (status != MTLCommandBufferStatusCompleted) {
|
|
1277
|
-
|
|
1863
|
+
WSP_GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
|
1278
1864
|
WSP_GGML_ASSERT(false);
|
|
1279
1865
|
}
|
|
1280
1866
|
}
|
|
1281
1867
|
|
|
1282
1868
|
}
|
|
1283
1869
|
}
|
|
1870
|
+
|
|
1871
|
+
////////////////////////////////////////////////////////////////////////////////
|
|
1872
|
+
|
|
1873
|
+
// backend interface
|
|
1874
|
+
|
|
1875
|
+
static id<MTLDevice> g_backend_device = nil;
|
|
1876
|
+
static int g_backend_device_ref_count = 0;
|
|
1877
|
+
|
|
1878
|
+
static id<MTLDevice> wsp_ggml_backend_metal_get_device(void) {
|
|
1879
|
+
if (g_backend_device == nil) {
|
|
1880
|
+
g_backend_device = MTLCreateSystemDefaultDevice();
|
|
1881
|
+
}
|
|
1882
|
+
|
|
1883
|
+
g_backend_device_ref_count++;
|
|
1884
|
+
|
|
1885
|
+
return g_backend_device;
|
|
1886
|
+
}
|
|
1887
|
+
|
|
1888
|
+
static void wsp_ggml_backend_metal_free_device(void) {
|
|
1889
|
+
assert(g_backend_device_ref_count > 0);
|
|
1890
|
+
|
|
1891
|
+
g_backend_device_ref_count--;
|
|
1892
|
+
|
|
1893
|
+
if (g_backend_device_ref_count == 0) {
|
|
1894
|
+
g_backend_device = nil;
|
|
1895
|
+
}
|
|
1896
|
+
}
|
|
1897
|
+
|
|
1898
|
+
static void * wsp_ggml_backend_metal_buffer_get_base(wsp_ggml_backend_buffer_t buffer) {
|
|
1899
|
+
struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context;
|
|
1900
|
+
|
|
1901
|
+
return ctx->data;
|
|
1902
|
+
}
|
|
1903
|
+
|
|
1904
|
+
static void wsp_ggml_backend_metal_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) {
|
|
1905
|
+
struct wsp_ggml_backend_metal_buffer_context * ctx = (struct wsp_ggml_backend_metal_buffer_context *)buffer->context;
|
|
1906
|
+
|
|
1907
|
+
wsp_ggml_backend_metal_free_device();
|
|
1908
|
+
|
|
1909
|
+
free(ctx->data);
|
|
1910
|
+
free(ctx);
|
|
1911
|
+
|
|
1912
|
+
UNUSED(buffer);
|
|
1913
|
+
}
|
|
1914
|
+
|
|
1915
|
+
static void wsp_ggml_backend_metal_buffer_set_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
|
1916
|
+
WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds");
|
|
1917
|
+
WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
|
1918
|
+
|
|
1919
|
+
memcpy((char *)tensor->data + offset, data, size);
|
|
1920
|
+
|
|
1921
|
+
UNUSED(buffer);
|
|
1922
|
+
}
|
|
1923
|
+
|
|
1924
|
+
static void wsp_ggml_backend_metal_buffer_get_tensor(wsp_ggml_backend_buffer_t buffer, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
|
1925
|
+
WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds");
|
|
1926
|
+
WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
|
|
1927
|
+
|
|
1928
|
+
memcpy(data, (const char *)tensor->data + offset, size);
|
|
1929
|
+
|
|
1930
|
+
UNUSED(buffer);
|
|
1931
|
+
}
|
|
1932
|
+
|
|
1933
|
+
static void wsp_ggml_backend_metal_buffer_cpy_tensor_from(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
|
|
1934
|
+
wsp_ggml_backend_tensor_get(src, dst->data, 0, wsp_ggml_nbytes(src));
|
|
1935
|
+
|
|
1936
|
+
UNUSED(buffer);
|
|
1937
|
+
}
|
|
1938
|
+
|
|
1939
|
+
static void wsp_ggml_backend_metal_buffer_cpy_tensor_to(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) {
|
|
1940
|
+
wsp_ggml_backend_tensor_set(dst, src->data, 0, wsp_ggml_nbytes(src));
|
|
1941
|
+
|
|
1942
|
+
UNUSED(buffer);
|
|
1943
|
+
}
|
|
1944
|
+
|
|
1945
|
+
static struct wsp_ggml_backend_buffer_i metal_backend_buffer_i = {
|
|
1946
|
+
/* .free_buffer = */ wsp_ggml_backend_metal_buffer_free_buffer,
|
|
1947
|
+
/* .get_base = */ wsp_ggml_backend_metal_buffer_get_base,
|
|
1948
|
+
/* .init_tensor = */ NULL,
|
|
1949
|
+
/* .set_tensor = */ wsp_ggml_backend_metal_buffer_set_tensor,
|
|
1950
|
+
/* .get_tensor = */ wsp_ggml_backend_metal_buffer_get_tensor,
|
|
1951
|
+
/* .cpy_tensor_from = */ wsp_ggml_backend_metal_buffer_cpy_tensor_from,
|
|
1952
|
+
/* .cpy_tensor_to = */ wsp_ggml_backend_metal_buffer_cpy_tensor_to,
|
|
1953
|
+
};
|
|
1954
|
+
|
|
1955
|
+
static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_buffer_type_alloc_buffer(wsp_ggml_backend_buffer_type_t buft, size_t size) {
|
|
1956
|
+
struct wsp_ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct wsp_ggml_backend_metal_buffer_context));
|
|
1957
|
+
|
|
1958
|
+
const size_t size_page = sysconf(_SC_PAGESIZE);
|
|
1959
|
+
|
|
1960
|
+
size_t size_aligned = size;
|
|
1961
|
+
if ((size_aligned % size_page) != 0) {
|
|
1962
|
+
size_aligned += (size_page - (size_aligned % size_page));
|
|
1963
|
+
}
|
|
1964
|
+
|
|
1965
|
+
ctx->data = wsp_ggml_metal_host_malloc(size);
|
|
1966
|
+
ctx->metal = [wsp_ggml_backend_metal_get_device() newBufferWithBytesNoCopy:ctx->data
|
|
1967
|
+
length:size_aligned
|
|
1968
|
+
options:MTLResourceStorageModeShared
|
|
1969
|
+
deallocator:nil];
|
|
1970
|
+
|
|
1971
|
+
return wsp_ggml_backend_buffer_init(buft, metal_backend_buffer_i, ctx, size);
|
|
1972
|
+
}
|
|
1973
|
+
|
|
1974
|
+
static size_t wsp_ggml_backend_metal_buffer_type_get_alignment(wsp_ggml_backend_buffer_type_t buft) {
|
|
1975
|
+
return 32;
|
|
1976
|
+
UNUSED(buft);
|
|
1977
|
+
}
|
|
1978
|
+
|
|
1979
|
+
static bool wsp_ggml_backend_metal_buffer_type_supports_backend(wsp_ggml_backend_buffer_type_t buft, wsp_ggml_backend_t backend) {
|
|
1980
|
+
return wsp_ggml_backend_is_metal(backend) || wsp_ggml_backend_is_cpu(backend);
|
|
1981
|
+
|
|
1982
|
+
WSP_GGML_UNUSED(buft);
|
|
1983
|
+
}
|
|
1984
|
+
|
|
1985
|
+
wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_buffer_type(void) {
|
|
1986
|
+
static struct wsp_ggml_backend_buffer_type wsp_ggml_backend_buffer_type_metal = {
|
|
1987
|
+
/* .iface = */ {
|
|
1988
|
+
/* .alloc_buffer = */ wsp_ggml_backend_metal_buffer_type_alloc_buffer,
|
|
1989
|
+
/* .get_alignment = */ wsp_ggml_backend_metal_buffer_type_get_alignment,
|
|
1990
|
+
/* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes
|
|
1991
|
+
/* .supports_backend = */ wsp_ggml_backend_metal_buffer_type_supports_backend,
|
|
1992
|
+
},
|
|
1993
|
+
/* .context = */ NULL,
|
|
1994
|
+
};
|
|
1995
|
+
|
|
1996
|
+
return &wsp_ggml_backend_buffer_type_metal;
|
|
1997
|
+
}
|
|
1998
|
+
|
|
1999
|
+
static const char * wsp_ggml_backend_metal_name(wsp_ggml_backend_t backend) {
|
|
2000
|
+
return "Metal";
|
|
2001
|
+
|
|
2002
|
+
UNUSED(backend);
|
|
2003
|
+
}
|
|
2004
|
+
|
|
2005
|
+
static void wsp_ggml_backend_metal_free(wsp_ggml_backend_t backend) {
|
|
2006
|
+
struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
|
|
2007
|
+
wsp_ggml_metal_free(ctx);
|
|
2008
|
+
free(backend);
|
|
2009
|
+
}
|
|
2010
|
+
|
|
2011
|
+
static void wsp_ggml_backend_metal_synchronize(wsp_ggml_backend_t backend) {
|
|
2012
|
+
UNUSED(backend);
|
|
2013
|
+
}
|
|
2014
|
+
|
|
2015
|
+
static wsp_ggml_backend_buffer_type_t wsp_ggml_backend_metal_get_default_buffer_type(wsp_ggml_backend_t backend) {
|
|
2016
|
+
return wsp_ggml_backend_metal_buffer_type();
|
|
2017
|
+
|
|
2018
|
+
UNUSED(backend);
|
|
2019
|
+
}
|
|
2020
|
+
|
|
2021
|
+
static void wsp_ggml_backend_metal_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) {
|
|
2022
|
+
struct wsp_ggml_metal_context * metal_ctx = (struct wsp_ggml_metal_context *)backend->context;
|
|
2023
|
+
|
|
2024
|
+
wsp_ggml_metal_graph_compute(metal_ctx, cgraph);
|
|
2025
|
+
}
|
|
2026
|
+
|
|
2027
|
+
static bool wsp_ggml_backend_metal_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) {
|
|
2028
|
+
return wsp_ggml_metal_supports_op(op);
|
|
2029
|
+
|
|
2030
|
+
UNUSED(backend);
|
|
2031
|
+
}
|
|
2032
|
+
|
|
2033
|
+
static struct wsp_ggml_backend_i metal_backend_i = {
|
|
2034
|
+
/* .get_name = */ wsp_ggml_backend_metal_name,
|
|
2035
|
+
/* .free = */ wsp_ggml_backend_metal_free,
|
|
2036
|
+
/* .get_default_buffer_type = */ wsp_ggml_backend_metal_get_default_buffer_type,
|
|
2037
|
+
/* .set_tensor_async = */ NULL,
|
|
2038
|
+
/* .get_tensor_async = */ NULL,
|
|
2039
|
+
/* .cpy_tensor_from_async = */ NULL,
|
|
2040
|
+
/* .cpy_tensor_to_async = */ NULL,
|
|
2041
|
+
/* .synchronize = */ wsp_ggml_backend_metal_synchronize,
|
|
2042
|
+
/* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm
|
|
2043
|
+
/* .graph_plan_free = */ NULL,
|
|
2044
|
+
/* .graph_plan_compute = */ NULL,
|
|
2045
|
+
/* .graph_compute = */ wsp_ggml_backend_metal_graph_compute,
|
|
2046
|
+
/* .supports_op = */ wsp_ggml_backend_metal_supports_op,
|
|
2047
|
+
};
|
|
2048
|
+
|
|
2049
|
+
// TODO: make a common log callback for all backends in ggml-backend
|
|
2050
|
+
static void wsp_ggml_backend_log_callback(enum wsp_ggml_log_level level, const char * msg, void * user_data) {
|
|
2051
|
+
fprintf(stderr, "%s", msg);
|
|
2052
|
+
|
|
2053
|
+
UNUSED(level);
|
|
2054
|
+
UNUSED(user_data);
|
|
2055
|
+
}
|
|
2056
|
+
|
|
2057
|
+
wsp_ggml_backend_t wsp_ggml_backend_metal_init(void) {
|
|
2058
|
+
wsp_ggml_metal_log_set_callback(wsp_ggml_backend_log_callback, NULL);
|
|
2059
|
+
|
|
2060
|
+
struct wsp_ggml_metal_context * ctx = wsp_ggml_metal_init(WSP_GGML_DEFAULT_N_THREADS);
|
|
2061
|
+
|
|
2062
|
+
if (ctx == NULL) {
|
|
2063
|
+
return NULL;
|
|
2064
|
+
}
|
|
2065
|
+
|
|
2066
|
+
wsp_ggml_backend_t metal_backend = malloc(sizeof(struct wsp_ggml_backend));
|
|
2067
|
+
|
|
2068
|
+
*metal_backend = (struct wsp_ggml_backend) {
|
|
2069
|
+
/* .interface = */ metal_backend_i,
|
|
2070
|
+
/* .context = */ ctx,
|
|
2071
|
+
};
|
|
2072
|
+
|
|
2073
|
+
return metal_backend;
|
|
2074
|
+
}
|
|
2075
|
+
|
|
2076
|
+
bool wsp_ggml_backend_is_metal(wsp_ggml_backend_t backend) {
|
|
2077
|
+
return backend->iface.get_name == wsp_ggml_backend_metal_name;
|
|
2078
|
+
}
|
|
2079
|
+
|
|
2080
|
+
void wsp_ggml_backend_metal_set_n_cb(wsp_ggml_backend_t backend, int n_cb) {
|
|
2081
|
+
WSP_GGML_ASSERT(wsp_ggml_backend_is_metal(backend));
|
|
2082
|
+
|
|
2083
|
+
struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
|
|
2084
|
+
|
|
2085
|
+
wsp_ggml_metal_set_n_cb(ctx, n_cb);
|
|
2086
|
+
}
|
|
2087
|
+
|
|
2088
|
+
bool wsp_ggml_backend_metal_supports_family(wsp_ggml_backend_t backend, int family) {
|
|
2089
|
+
WSP_GGML_ASSERT(wsp_ggml_backend_is_metal(backend));
|
|
2090
|
+
|
|
2091
|
+
struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context;
|
|
2092
|
+
|
|
2093
|
+
return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
|
2094
|
+
}
|
|
2095
|
+
|
|
2096
|
+
wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
|
|
2097
|
+
|
|
2098
|
+
wsp_ggml_backend_t wsp_ggml_backend_reg_metal_init(const char * params, void * user_data) {
|
|
2099
|
+
return wsp_ggml_backend_metal_init();
|
|
2100
|
+
|
|
2101
|
+
WSP_GGML_UNUSED(params);
|
|
2102
|
+
WSP_GGML_UNUSED(user_data);
|
|
2103
|
+
}
|